refactor: optimize database usage (#12071)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2024-12-25 16:24:52 +08:00
committed by GitHub
parent b281a80150
commit 83ea931e3c
12 changed files with 574 additions and 561 deletions

View File

@@ -3,6 +3,8 @@ import time
from collections.abc import Generator
from typing import Any, Optional, Union
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager
@@ -50,6 +52,7 @@ from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
from models.model import EndUser
from models.workflow import (
Workflow,
@@ -68,8 +71,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
@@ -83,25 +84,27 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: user
:param stream: is streamed
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
super().__init__(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
)
if isinstance(self._user, EndUser):
user_id = self._user.session_id
if isinstance(user, EndUser):
self._user_id = user.session_id
self._created_by_role = CreatedByRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
self._created_by_role = CreatedByRole.ACCOUNT
else:
user_id = self._user.id
raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._workflow = workflow
self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.USER_ID: self._user_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
@@ -115,10 +118,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
Process generate task pipeline.
:return:
"""
db.session.refresh(self._workflow)
db.session.refresh(self._user)
db.session.close()
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
return self._to_stream_response(generator)
@@ -185,7 +184,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
features_dict = self._workflow_features_dict
if (
features_dict.get("text_to_speech")
@@ -242,18 +241,26 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event)
err = self._handle_error(event=event)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
# init workflow run
workflow_run = self._handle_workflow_run_start()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
with Session(db.engine) as session:
# init workflow run
workflow_run = self._handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
start_resp = self._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield start_resp
elif isinstance(
event,
QueueNodeRetryEvent,
@@ -350,22 +357,28 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
)
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_success(
session=session,
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise ValueError("workflow run not initialized.")
@@ -373,49 +386,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_partial_success(
session=session,
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
session=session,
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
@@ -435,7 +457,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None:
"""
Save workflow app log.
:return:
@@ -457,12 +479,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
workflow_app_log.created_by = self._user.id
workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user_id
db.session.add(workflow_app_log)
db.session.commit()
db.session.close()
session.add(workflow_app_log)
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None