feat: support binding context var (#1227)

Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
Garfield Dai
2023-09-27 14:53:22 +08:00
committed by GitHub
parent 59236b789f
commit 18c710c906
44 changed files with 711 additions and 77 deletions

View File

@@ -647,6 +647,76 @@ def update_app_model_configs(batch_size):
pbar.update(len(data_batch))
@click.command('migrate_default_input_to_dataset_query_variable')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def migrate_default_input_to_dataset_query_variable(batch_size):
click.secho("Starting...", fg='green')
total_records = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.filter(AppModelConfig.dataset_query_variable == None) \
.count()
if total_records == 0:
click.secho("No data to migrate.", fg='green')
return
num_batches = (total_records + batch_size - 1) // batch_size
with tqdm(total=total_records, desc="Migrating Data") as pbar:
for i in range(num_batches):
offset = i * batch_size
limit = min(batch_size, total_records - offset)
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.filter(AppModelConfig.dataset_query_variable == None) \
.order_by(App.created_at) \
.offset(offset).limit(limit).all()
if not data_batch:
click.secho("No more data to migrate.", fg='green')
break
try:
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
for data in data_batch:
config = AppModelConfig.to_dict(data)
tools = config["agent_mode"]["tools"]
dataset_exists = "dataset" in str(tools)
if not dataset_exists:
continue
user_input_form = config.get("user_input_form", [])
for form in user_input_form:
paragraph = form.get('paragraph')
if paragraph \
and paragraph.get('variable') == 'query':
data.dataset_query_variable = 'query'
break
if paragraph \
and paragraph.get('variable') == 'default_input':
data.dataset_query_variable = 'default_input'
break
db.session.commit()
except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch))
def register_commands(app):
app.cli.add_command(reset_password)
@@ -660,3 +730,4 @@ def register_commands(app):
app.cli.add_command(update_qdrant_indexes)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable)

View File

@@ -34,6 +34,7 @@ model_config_fields = {
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'dataset_query_variable': fields.String,
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
}
@@ -162,7 +163,8 @@ class AppListApi(Resource):
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=model_config_dict
config=model_config_dict,
mode=args['mode']
)
app = App(

View File

@@ -31,7 +31,8 @@ class ModelConfigResource(Resource):
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=request.json
config=request.json,
mode=app_model.mode
)
new_app_model_config = AppModelConfig(

View File

@@ -108,12 +108,14 @@ class Completion:
retriever_from=retriever_from
)
query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
# run agent executor
agent_execute_result = None
if agent_executor:
should_use_agent = agent_executor.should_use_agent(query)
if query_for_agent and agent_executor:
should_use_agent = agent_executor.should_use_agent(query_for_agent)
if should_use_agent:
agent_execute_result = agent_executor.run(query)
agent_execute_result = agent_executor.run(query_for_agent)
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
@@ -142,6 +144,13 @@ class Completion:
logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return
@classmethod
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
if app.mode != 'completion':
return query
return inputs.get(app_model_config.dataset_query_variable, "")
@classmethod
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,

View File

@@ -0,0 +1,31 @@
"""add dataset query variable at app model configs.
Revision ID: ab23c11305d4
Revises: 6e2cfb077b04
Create Date: 2023-09-26 12:22:59.044088
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ab23c11305d4'
down_revision = '6e2cfb077b04'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('dataset_query_variable', sa.String(length=255), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('dataset_query_variable')
# ### end Alembic commands ###

View File

@@ -88,6 +88,7 @@ class AppModelConfig(db.Model):
more_like_this = db.Column(db.Text)
model = db.Column(db.Text)
user_input_form = db.Column(db.Text)
dataset_query_variable = db.Column(db.String(255))
pre_prompt = db.Column(db.Text)
agent_mode = db.Column(db.Text)
sensitive_word_avoidance = db.Column(db.Text)
@@ -152,6 +153,7 @@ class AppModelConfig(db.Model):
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
"model": self.model_dict,
"user_input_form": self.user_input_form_list,
"dataset_query_variable": self.dataset_query_variable,
"pre_prompt": self.pre_prompt,
"agent_mode": self.agent_mode_dict
}
@@ -170,6 +172,7 @@ class AppModelConfig(db.Model):
if model_config.get('sensitive_word_avoidance') else None
self.model = json.dumps(model_config['model'])
self.user_input_form = json.dumps(model_config['user_input_form'])
self.dataset_query_variable = model_config['dataset_query_variable']
self.pre_prompt = model_config['pre_prompt']
self.agent_mode = json.dumps(model_config['agent_mode'])
self.retriever_resource = json.dumps(model_config['retriever_resource']) \
@@ -191,6 +194,7 @@ class AppModelConfig(db.Model):
sensitive_word_avoidance=self.sensitive_word_avoidance,
model=self.model,
user_input_form=self.user_input_form,
dataset_query_variable=self.dataset_query_variable,
pre_prompt=self.pre_prompt,
agent_mode=self.agent_mode
)

View File

@@ -81,7 +81,7 @@ class AppModelConfigService:
return filtered_cp
@staticmethod
def validate_configuration(tenant_id: str, account: Account, config: dict) -> dict:
def validate_configuration(tenant_id: str, account: Account, config: dict, mode: str) -> dict:
# opening_statement
if 'opening_statement' not in config or not config["opening_statement"]:
config["opening_statement"] = ""
@@ -335,6 +335,9 @@ class AppModelConfigService:
if not AppModelConfigService.is_dataset_exists(account, tool_item["id"]):
raise ValueError("Dataset ID does not exist, please check your permission.")
# dataset_query_variable
AppModelConfigService.is_dataset_query_variable_valid(config, mode)
# Filter out extra parameters
filtered_config = {
@@ -351,8 +354,25 @@ class AppModelConfigService:
"completion_params": config["model"]["completion_params"]
},
"user_input_form": config["user_input_form"],
"dataset_query_variable": config["dataset_query_variable"],
"pre_prompt": config["pre_prompt"],
"agent_mode": config["agent_mode"]
}
return filtered_config
@staticmethod
def is_dataset_query_variable_valid(config: dict, mode: str) -> None:
# Only check when mode is completion
if mode != 'completion':
return
agent_mode = config.get("agent_mode", {})
tools = agent_mode.get("tools", [])
dataset_exists = "dataset" in str(tools)
dataset_query_variable = config.get("dataset_query_variable")
if dataset_exists and not dataset_query_variable:
raise ValueError("Dataset query variable is required when dataset is exist")

View File

@@ -117,7 +117,8 @@ class CompletionService:
model_config = AppModelConfigService.validate_configuration(
tenant_id=app_model.tenant_id,
account=user,
config=args['model_config']
config=args['model_config'],
mode=app_model.mode
)
app_model_config = AppModelConfig(