mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
refactor: optimize database usage (#12071)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
@@ -17,9 +20,7 @@ from core.app.entities.task_entities import (
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser, Message
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,7 +37,6 @@ class BasedGenerateTaskPipeline:
|
||||
self,
|
||||
application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -48,18 +48,11 @@ class BasedGenerateTaskPipeline:
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._user = user
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
self._stream = stream
|
||||
|
||||
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None):
|
||||
"""
|
||||
Handle error event.
|
||||
:param event: event
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
err: Exception
|
||||
@@ -71,16 +64,17 @@ class BasedGenerateTaskPipeline:
|
||||
else:
|
||||
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
|
||||
|
||||
if message:
|
||||
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
if not message_id or not session:
|
||||
return err
|
||||
|
||||
if refetch_message:
|
||||
err_desc = self._error_to_desc(err)
|
||||
refetch_message.status = "error"
|
||||
refetch_message.error = err_desc
|
||||
|
||||
db.session.commit()
|
||||
stmt = select(Message).where(Message.id == message_id)
|
||||
message = session.scalar(stmt)
|
||||
if not message:
|
||||
return err
|
||||
|
||||
err_desc = self._error_to_desc(err)
|
||||
message.status = "error"
|
||||
message.error = err_desc
|
||||
return err
|
||||
|
||||
def _error_to_desc(self, e: Exception) -> str:
|
||||
|
||||
@@ -5,6 +5,9 @@ from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@@ -55,8 +58,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -77,23 +79,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:param user: user
|
||||
:param stream: stream
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
super().__init__(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream,
|
||||
)
|
||||
self._model_config = application_generate_entity.model_conf
|
||||
self._app_config = application_generate_entity.app_config
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
|
||||
self._message_id = message.id
|
||||
self._message_created_at = int(message.created_at.timestamp())
|
||||
|
||||
self._task_state = EasyUITaskState(
|
||||
llm_result=LLMResult(
|
||||
@@ -113,18 +113,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
|
||||
]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
db.session.refresh(self._conversation)
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation, self._application_generate_entity.query or ""
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
@@ -148,15 +140,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
if self._conversation.mode == AppMode.COMPLETION.value:
|
||||
if self._conversation_mode == AppMode.COMPLETION.value:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
id=self._message.id,
|
||||
mode=self._conversation.mode,
|
||||
message_id=self._message.id,
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
message_id=self._message_id,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
@@ -164,12 +156,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
response = ChatbotAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id=self._message.id,
|
||||
mode=self._conversation.mode,
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
@@ -190,15 +182,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
for stream_response in generator:
|
||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
yield CompletionAppStreamResponse(
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
message_id=self._message_id,
|
||||
created_at=self._message_created_at,
|
||||
stream_response=stream_response,
|
||||
)
|
||||
else:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
created_at=self._message_created_at,
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
@@ -265,7 +257,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event, self._message)
|
||||
with Session(db.engine) as session:
|
||||
err = self._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
session.commit()
|
||||
yield self._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
@@ -283,10 +277,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
# Save message
|
||||
self._save_message(trace_manager)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
with Session(db.engine) as session:
|
||||
# Save message
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
session.commit()
|
||||
message_end_resp = self._message_end_to_stream_response()
|
||||
yield message_end_resp
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
@@ -320,9 +316,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
yield self._message_to_stream_response(cast(str, delta_text), self._message.id)
|
||||
yield self._message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id)
|
||||
yield self._agent_message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
@@ -334,7 +336,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None:
|
||||
def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
@@ -342,53 +344,46 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
llm_result = self._task_state.llm_result
|
||||
usage = llm_result.usage
|
||||
|
||||
message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
message_stmt = select(Message).where(Message.id == self._message_id)
|
||||
message = session.scalar(message_stmt)
|
||||
if not message:
|
||||
raise Exception(f"Message {self._message.id} not found")
|
||||
self._message = message
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
raise ValueError(f"message {self._message_id} not found")
|
||||
conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id)
|
||||
conversation = session.scalar(conversation_stmt)
|
||||
if not conversation:
|
||||
raise Exception(f"Conversation {self._conversation.id} not found")
|
||||
self._conversation = conversation
|
||||
raise ValueError(f"Conversation {self._conversation_id} not found")
|
||||
|
||||
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode, self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer = (
|
||||
message.message_tokens = usage.prompt_tokens
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
message.answer = (
|
||||
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
|
||||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
self._message.answer_tokens = usage.completion_tokens
|
||||
self._message.answer_unit_price = usage.completion_unit_price
|
||||
self._message.answer_price_unit = usage.completion_price_unit
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.total_price = usage.total_price
|
||||
self._message.currency = usage.currency
|
||||
self._message.message_metadata = (
|
||||
message.answer_tokens = usage.completion_tokens
|
||||
message.answer_unit_price = usage.completion_unit_price
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
|
||||
)
|
||||
)
|
||||
|
||||
message_was_created.send(
|
||||
self._message,
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
|
||||
and hasattr(self._application_generate_entity, "conversation_id")
|
||||
and self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras,
|
||||
)
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent) -> None:
|
||||
@@ -434,7 +429,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
id=self._message_id,
|
||||
metadata=extras.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class MessageCycleManage:
|
||||
]
|
||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
||||
|
||||
def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
|
||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
Generate conversation name.
|
||||
:param conversation: conversation
|
||||
@@ -56,7 +56,7 @@ class MessageCycleManage:
|
||||
target=self._generate_conversation_name_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"conversation_id": conversation.id,
|
||||
"conversation_id": conversation_id,
|
||||
"query": query,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
@@ -63,27 +64,34 @@ from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
|
||||
|
||||
class WorkflowCycleManage:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def _handle_workflow_run_start(self) -> WorkflowRun:
|
||||
max_sequence = (
|
||||
db.session.query(db.func.max(WorkflowRun.sequence_number))
|
||||
.filter(WorkflowRun.tenant_id == self._workflow.tenant_id)
|
||||
.filter(WorkflowRun.app_id == self._workflow.app_id)
|
||||
.scalar()
|
||||
or 0
|
||||
def _handle_workflow_run_start(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
created_by_role: CreatedByRole,
|
||||
) -> WorkflowRun:
|
||||
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
|
||||
workflow = session.scalar(workflow_stmt)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||
|
||||
max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where(
|
||||
WorkflowRun.tenant_id == workflow.tenant_id,
|
||||
WorkflowRun.app_id == workflow.app_id,
|
||||
)
|
||||
max_sequence = session.scalar(max_sequence_stmt) or 0
|
||||
new_sequence_number = max_sequence + 1
|
||||
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for key, value in (self._workflow_system_variables or {}).items():
|
||||
if key.value == "conversation":
|
||||
continue
|
||||
|
||||
inputs[f"sys.{key.value}"] = value
|
||||
|
||||
triggered_from = (
|
||||
@@ -96,33 +104,32 @@ class WorkflowCycleManage:
|
||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
# init workflow run
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = WorkflowRun()
|
||||
system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
|
||||
workflow_run.id = system_id or str(uuid4())
|
||||
workflow_run.tenant_id = self._workflow.tenant_id
|
||||
workflow_run.app_id = self._workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = self._workflow.id
|
||||
workflow_run.type = self._workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = self._workflow.version
|
||||
workflow_run.graph = self._workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING
|
||||
workflow_run.created_by_role = (
|
||||
CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER
|
||||
)
|
||||
workflow_run.created_by = self._user.id
|
||||
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
|
||||
|
||||
session.add(workflow_run)
|
||||
session.commit()
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.id = workflow_run_id
|
||||
workflow_run.tenant_id = workflow.tenant_id
|
||||
workflow_run.app_id = workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = workflow.id
|
||||
workflow_run.type = workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = workflow.version
|
||||
workflow_run.graph = workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING
|
||||
workflow_run.created_by_role = created_by_role
|
||||
workflow_run.created_by = user_id
|
||||
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_run)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_success(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
@@ -141,7 +148,7 @@ class WorkflowCycleManage:
|
||||
:param conversation_id: conversation id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
|
||||
@@ -152,9 +159,6 @@ class WorkflowCycleManage:
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
@@ -165,12 +169,12 @@ class WorkflowCycleManage:
|
||||
)
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_partial_success(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
@@ -190,7 +194,7 @@ class WorkflowCycleManage:
|
||||
:param conversation_id: conversation id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
|
||||
|
||||
@@ -201,8 +205,6 @@ class WorkflowCycleManage:
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@@ -214,12 +216,12 @@ class WorkflowCycleManage:
|
||||
)
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_failed(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
@@ -240,7 +242,7 @@ class WorkflowCycleManage:
|
||||
:param error: error message
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
|
||||
|
||||
workflow_run.status = status.value
|
||||
workflow_run.error = error
|
||||
@@ -249,21 +251,18 @@ class WorkflowCycleManage:
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
db.session.commit()
|
||||
|
||||
running_workflow_node_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
)
|
||||
.all()
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
)
|
||||
|
||||
running_workflow_node_executions = session.scalars(stmt).all()
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
@@ -271,13 +270,6 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.elapsed_time = (
|
||||
workflow_node_execution.finished_at - workflow_node_execution.created_at
|
||||
).total_seconds()
|
||||
db.session.commit()
|
||||
|
||||
db.session.close()
|
||||
|
||||
# with Session(db.engine, expire_on_commit=False) as session:
|
||||
# session.add(workflow_run)
|
||||
# session.refresh(workflow_run)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@@ -485,14 +477,14 @@ class WorkflowCycleManage:
|
||||
#################################################
|
||||
|
||||
def _workflow_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
) -> WorkflowStartStreamResponse:
|
||||
"""
|
||||
Workflow start to stream response.
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:return:
|
||||
"""
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return WorkflowStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
@@ -506,36 +498,32 @@ class WorkflowCycleManage:
|
||||
)
|
||||
|
||||
def _workflow_finish_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
"""
|
||||
Workflow finish to stream response.
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:return:
|
||||
"""
|
||||
# Attach WorkflowRun to an active session so "created_by_role" can be accessed.
|
||||
workflow_run = db.session.merge(workflow_run)
|
||||
|
||||
# Refresh to ensure any expired attributes are fully loaded
|
||||
db.session.refresh(workflow_run)
|
||||
|
||||
created_by = None
|
||||
if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value:
|
||||
created_by_account = workflow_run.created_by_account
|
||||
if created_by_account:
|
||||
if workflow_run.created_by_role == CreatedByRole.ACCOUNT:
|
||||
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
||||
account = session.scalar(stmt)
|
||||
if account:
|
||||
created_by = {
|
||||
"id": created_by_account.id,
|
||||
"name": created_by_account.name,
|
||||
"email": created_by_account.email,
|
||||
"id": account.id,
|
||||
"name": account.name,
|
||||
"email": account.email,
|
||||
}
|
||||
elif workflow_run.created_by_role == CreatedByRole.END_USER:
|
||||
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
|
||||
end_user = session.scalar(stmt)
|
||||
if end_user:
|
||||
created_by = {
|
||||
"id": end_user.id,
|
||||
"user": end_user.session_id,
|
||||
}
|
||||
else:
|
||||
created_by_end_user = workflow_run.created_by_end_user
|
||||
if created_by_end_user:
|
||||
created_by = {
|
||||
"id": created_by_end_user.id,
|
||||
"user": created_by_end_user.session_id,
|
||||
}
|
||||
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
@@ -895,14 +883,14 @@ class WorkflowCycleManage:
|
||||
|
||||
return None
|
||||
|
||||
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
def _refetch_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
|
||||
|
||||
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
if not workflow_run:
|
||||
raise WorkflowRunNotFoundError(workflow_run_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user