feat: optimize db connection when llm invoking (#2774)

This commit is contained in:
takatost
2024-03-10 15:48:31 +08:00
committed by GitHub
parent 8b1e35d7dc
commit f073dca22a
5 changed files with 37 additions and 5 deletions

View File

@@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner):
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id,
).count()
db.session.close()
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
@@ -341,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner):
created_by=self.user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
result.append((
message_file,
message.save_as
))
db.session.commit()
db.session.close()
return result
def create_agent_thought(self, message_id: str, message: str,
@@ -384,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner):
db.session.add(thought)
db.session.commit()
db.session.refresh(thought)
db.session.close()
self.agent_thought_count += 1
@@ -401,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()
if thought is not None:
agent_thought.thought = thought
@@ -451,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner):
agent_thought.tool_labels_str = json.dumps(labels)
db.session.commit()
db.session.close()
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
"""
@@ -523,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
convert tool variables to db variables
"""
db_variables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
).first()
db_variables.updated_at = datetime.utcnow()
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
db.session.close()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
@@ -581,4 +597,6 @@ class BaseAssistantApplicationRunner(AppRunner):
if message.answer:
result.append(AssistantPromptMessage(content=message.answer))
db.session.close()
return result