refactor: optimize database usage (#12071)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2024-12-25 16:24:52 +08:00
committed by GitHub
parent b281a80150
commit 83ea931e3c
12 changed files with 574 additions and 561 deletions

View File

@@ -9,6 +9,8 @@ from typing import Any, Optional, Union
from uuid import UUID, uuid4
from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
@@ -329,15 +331,15 @@ class TraceTask:
):
self.trace_type = trace_type
self.message_id = message_id
self.workflow_run = workflow_run
self.workflow_run_id = workflow_run.id if workflow_run else None
self.conversation_id = conversation_id
self.user_id = user_id
self.timer = timer
self.kwargs = kwargs
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.app_id = None
self.kwargs = kwargs
def execute(self):
return self.preprocess()
@@ -345,19 +347,23 @@ class TraceTask:
preprocess_map = {
TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
self.workflow_run, self.conversation_id, self.user_id
workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs),
TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
self.message_id, self.timer, **self.kwargs
message_id=self.message_id, timer=self.timer, **self.kwargs
),
TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
self.message_id, self.timer, **self.kwargs
message_id=self.message_id, timer=self.timer, **self.kwargs
),
TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
),
TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs),
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
self.conversation_id, self.timer, **self.kwargs
conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
),
}
@@ -367,86 +373,100 @@ class TraceTask:
def conversation_trace(self, **kwargs):
return kwargs
def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id):
if not workflow_run:
raise ValueError("Workflow run not found")
def workflow_trace(
self,
*,
workflow_run_id: str | None,
conversation_id: str | None,
user_id: str | None,
):
if not workflow_run_id:
return {}
db.session.merge(workflow_run)
db.session.refresh(workflow_run)
with Session(db.engine) as session:
workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalars(workflow_run_stmt).first()
if not workflow_run:
raise ValueError("Workflow run not found")
workflow_id = workflow_run.workflow_id
tenant_id = workflow_run.tenant_id
workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status
workflow_run_inputs = workflow_run.inputs_dict
workflow_run_outputs = workflow_run.outputs_dict
workflow_run_version = workflow_run.version
error = workflow_run.error or ""
workflow_id = workflow_run.workflow_id
tenant_id = workflow_run.tenant_id
workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status
workflow_run_inputs = workflow_run.inputs_dict
workflow_run_outputs = workflow_run.outputs_dict
workflow_run_version = workflow_run.version
error = workflow_run.error or ""
total_tokens = workflow_run.total_tokens
total_tokens = workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
# get workflow_app_log_id
workflow_app_log_data = (
db.session.query(WorkflowAppLog)
.filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id)
.first()
)
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
# get message_id
message_data = (
db.session.query(Message.id)
.filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id)
.first()
)
message_id = str(message_data.id) if message_data else None
# get workflow_app_log_id
workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
WorkflowAppLog.tenant_id == tenant_id,
WorkflowAppLog.app_id == workflow_run.app_id,
WorkflowAppLog.workflow_run_id == workflow_run.id,
)
workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
# get message_id
message_id = None
if conversation_id:
message_data_stmt = select(Message.id).where(
Message.conversation_id == conversation_id,
Message.workflow_run_id == workflow_run_id,
)
message_id = session.scalar(message_data_stmt)
metadata = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
"tenant_id": tenant_id,
"elapsed_time": workflow_run_elapsed_time,
"status": workflow_run_status,
"version": workflow_run_version,
"total_tokens": total_tokens,
"file_list": file_list,
"triggered_form": workflow_run.triggered_from,
"user_id": user_id,
}
workflow_trace_info = WorkflowTraceInfo(
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,
workflow_id=workflow_id,
tenant_id=tenant_id,
workflow_run_id=workflow_run_id,
workflow_run_elapsed_time=workflow_run_elapsed_time,
workflow_run_status=workflow_run_status,
workflow_run_inputs=workflow_run_inputs,
workflow_run_outputs=workflow_run_outputs,
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
file_list=file_list,
query=query,
metadata=metadata,
workflow_app_log_id=workflow_app_log_id,
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
)
metadata = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
"tenant_id": tenant_id,
"elapsed_time": workflow_run_elapsed_time,
"status": workflow_run_status,
"version": workflow_run_version,
"total_tokens": total_tokens,
"file_list": file_list,
"triggered_form": workflow_run.triggered_from,
"user_id": user_id,
}
workflow_trace_info = WorkflowTraceInfo(
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,
workflow_id=workflow_id,
tenant_id=tenant_id,
workflow_run_id=workflow_run_id,
workflow_run_elapsed_time=workflow_run_elapsed_time,
workflow_run_status=workflow_run_status,
workflow_run_inputs=workflow_run_inputs,
workflow_run_outputs=workflow_run_outputs,
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
file_list=file_list,
query=query,
metadata=metadata,
workflow_app_log_id=workflow_app_log_id,
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
)
return workflow_trace_info
def message_trace(self, message_id):
def message_trace(self, message_id: str | None):
if not message_id:
return {}
message_data = get_message_data(message_id)
if not message_data:
return {}
conversation_mode = db.session.query(Conversation.mode).filter_by(id=message_data.conversation_id).first()
conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
conversation_mode = db.session.scalars(conversation_mode_stmt).all()
if not conversation_mode or len(conversation_mode) == 0:
return {}
conversation_mode = conversation_mode[0]
created_at = message_data.created_at
inputs = message_data.message

View File

@@ -18,7 +18,7 @@ def filter_none_values(data: dict):
return new_data
def get_message_data(message_id):
def get_message_data(message_id: str):
return db.session.query(Message).filter(Message.id == message_id).first()