mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-23 09:46:53 +08:00
refactor: streamline initialization of application_generate_entity and task_state in task pipeline classes (#12326)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -58,7 +58,6 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
@@ -66,16 +65,11 @@ from models.workflow import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
|
||||
class WorkflowAppGenerateTaskPipeline:
|
||||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
@@ -84,7 +78,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream,
|
||||
@@ -101,19 +95,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}")
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
self._workflow_run_id = ""
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
@@ -122,7 +118,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
:return:
|
||||
"""
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
if self._base_task_pipeline._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
@@ -237,29 +233,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
"""
|
||||
graph_runtime_state = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event=event)
|
||||
yield self._error_to_stream_response(err)
|
||||
err = self._base_task_pipeline._handle_error(event=event)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start(
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
user_id=self._user_id,
|
||||
created_by_role=self._created_by_role,
|
||||
)
|
||||
self._workflow_run_id = workflow_run.id
|
||||
start_resp = self._workflow_start_to_stream_response(
|
||||
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
@@ -271,12 +267,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -290,12 +288,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
node_start_response = self._workflow_node_start_to_stream_response(
|
||||
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -306,9 +306,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if node_start_response:
|
||||
yield node_start_response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
||||
node_success_response = self._workflow_node_finish_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||
session=session, event=event
|
||||
)
|
||||
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -319,12 +321,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||
session=session,
|
||||
event=event,
|
||||
)
|
||||
node_failed_response = self._workflow_node_finish_to_stream_response(
|
||||
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -339,13 +341,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_start_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_start_resp
|
||||
@@ -354,13 +360,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_finish_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_finish_resp
|
||||
@@ -369,9 +379,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@@ -384,9 +396,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@@ -399,9 +413,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@@ -416,8 +432,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -431,7 +447,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@@ -445,8 +461,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -461,7 +477,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
@@ -473,8 +489,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -492,7 +508,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
Reference in New Issue
Block a user