feat: Parallel Execution of Nodes in Workflows (#8192)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
takatost
2024-09-10 15:23:16 +08:00
committed by GitHub
parent 5da0182800
commit dabfd74622
156 changed files with 11158 additions and 5605 deletions

View File

@@ -4,12 +4,10 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Literal, Union, overload
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@@ -20,20 +18,15 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, Workflow
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -60,13 +53,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
):
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@@ -154,7 +148,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True):
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@@ -171,16 +166,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# get conversation
conversation = None
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
@@ -191,14 +176,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=conversation.id if conversation else None,
conversation_id=None,
inputs={},
query='',
files=[],
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras=extras,
extras={
"auto_generate_conversation_name": False
},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
@@ -211,17 +198,28 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
conversation=conversation,
conversation=None,
stream=stream
)
def _generate(self, *,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation | None = None,
stream: bool = True):
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
:param workflow: Workflow
:param user: account or end user
:param invoke_from: invoke from source
:param application_generate_entity: application generate entity
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = False
if not conversation:
is_first_conversation = True
@@ -236,7 +234,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
# db.session.refresh(conversation)
db.session.refresh(conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@@ -248,67 +246,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id
)
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
conversation.dialogue_count += 1
conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'flask_app': current_app._get_current_object(), # type: ignore
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'context': contextvars.copy_context(),
})
@@ -334,6 +277,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context) -> None:
"""
@@ -349,28 +293,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
var.set(val)
with flask_app.app_context():
try:
runner = AdvancedChatAppRunner()
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
# get message
message = self._get_message(message_id)
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = AdvancedChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message=message
)
# chatbot app
runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
runner.run()
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:

View File

@@ -1,49 +1,67 @@
import logging
import os
import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueStopEvent,
QueueTextChunkEvent,
)
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models import App, Message, Workflow
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__)
class AdvancedChatAppRunner(AppRunner):
class AdvancedChatAppRunner(WorkflowBasedAppRunner):
"""
AdvancedChat Application Runner
"""
def run(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
message: Message,
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message
) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
"""
super().__init__(queue_manager)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
def run(self) -> None:
"""
Run application
:return:
"""
app_config = application_generate_entity.app_config
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
@@ -54,101 +72,133 @@ class AdvancedChatAppRunner(AppRunner):
if not workflow:
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
query = application_generate_entity.query
user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id,
):
return
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
):
return
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs
)
else:
inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query
files = self.application_generate_entity.files
# moderation
if self.handle_input_moderation(
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity
):
return
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
self.conversation.dialogue_count += 1
conversation_dialogue_count = self.conversation.dialogue_count
db.session.commit()
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
)
generator = workflow_entry.run(
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
)
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
for event in generator:
self._handle_event(workflow_entry, event)
def handle_input_moderation(
self,
queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
self,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
) -> bool:
"""
Handle input moderation
:param queue_manager: application queue manager
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
@@ -167,30 +217,23 @@ class AdvancedChatAppRunner(AppRunner):
message_id=message_id,
)
except ModerationException as e:
self._stream_output(
queue_manager=queue_manager,
self._complete_with_stream_output(
text=str(e),
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
)
return True
return False
def handle_annotation_reply(
self,
app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity,
) -> bool:
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
"""
Handle annotation reply
:param app_record: app record
:param message: message
:param query: query
:param queue_manager: application queue manager
:param app_generate_entity: application generate entity
"""
# annotation reply
@@ -203,37 +246,32 @@ class AdvancedChatAppRunner(AppRunner):
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
self._publish_event(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
)
self._stream_output(
queue_manager=queue_manager,
self._complete_with_stream_output(
text=annotation_reply.content,
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
)
return True
return False
def _stream_output(
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
) -> None:
def _complete_with_stream_output(self,
text: str,
stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param queue_manager: application queue manager
:param text: text
:param stream: stream
:return:
"""
if stream:
index = 0
for token in text:
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
self._publish_event(
QueueTextChunkEvent(
text=text
)
)
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)
self._publish_event(
QueueStopEvent(stopped_by=stopped_by)
)

View File

@@ -2,9 +2,8 @@ import json
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union, cast
from typing import Any, Optional, Union
import contexts
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -22,6 +21,9 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
@@ -31,34 +33,28 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ChatflowStreamGenerateRoute,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from events.message_event import message_was_created
from extensions.ext_database import db
from models.account import Account
from models.model import Conversation, EndUser, Message
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRunStatus,
)
@@ -69,16 +65,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: AdvancedChatTaskState
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(
self, application_generate_entity: AdvancedChatAppGenerateEntity,
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
@@ -106,7 +101,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._workflow = workflow
self._conversation = conversation
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
@@ -114,12 +108,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
SystemVariableKey.USER_ID: user_id,
}
self._task_state = AdvancedChatTaskState(
usage=LLMUsage.empty_usage()
)
self._task_state = WorkflowTaskState()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None
def process(self):
@@ -140,6 +130,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
return self._to_stream_response(generator)
else:
@@ -199,17 +190,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@@ -220,9 +212,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
if not tts_publisher:
break
audio_trunk = publisher.checkAndGetAudio()
audio_trunk = tts_publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@@ -240,34 +232,34 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if (message.event
and getattr(message.event, 'metadata', None)
and message.event.metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
elif (hasattr(message.event, 'execution_metadata')
and message.event.execution_metadata
and message.event.execution_metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
event = message.event
# init fake graph runtime state
graph_runtime_state = None
workflow_run = None
if isinstance(event, QueueErrorEvent):
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event, self._message)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
workflow_run = self._handle_workflow_start()
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
# init workflow run
workflow_run = self._handle_workflow_run_start()
self._refetch_message()
self._message.workflow_run_id = workflow_run.id
db.session.commit()
@@ -279,133 +271,242 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._handle_node_start(event)
if not workflow_run:
raise Exception('Workflow run not initialized.')
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
# reset current route position to 0
self._task_state.current_stream_generate_state.current_route_position = 0
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
event=event
)
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
# stream outputs when node finished
generator = self._generate_stream_outputs_when_node_finished()
if generator:
yield from generator
if response:
yield response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
yield self._workflow_node_finish_to_stream_response(
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, conversation_id=self._conversation.id, trace_manager=trace_manager
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if workflow_run:
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if event.outputs else None,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
elif isinstance(event, QueueStopEvent):
if workflow_run and graph_runtime_state:
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error=event.get_stop_reason(),
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
if workflow_run.status == WorkflowRunStatus.FAILED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
if isinstance(event, QueueStopEvent):
# Save message
self._save_message()
yield self._message_end_to_stream_response()
break
else:
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message()
self._save_message(graph_runtime_state=graph_runtime_state)
yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
continue
if not self._is_stream_out_support(
event=event
):
continue
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
if should_direct_answer:
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text
yield self._message_to_stream_response(delta_text, self._message.id)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message(graph_runtime_state=graph_runtime_state)
yield self._message_end_to_stream_response()
else:
continue
if publisher:
publisher.publish(None)
# publish None when task finished
if tts_publisher:
tts_publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self) -> None:
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
"""
Save message.
:return:
"""
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._refetch_message()
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage'])
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
@@ -432,7 +533,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
extras['metadata'] = self._task_state.metadata.copy()
if 'annotation_reply' in extras['metadata']:
del extras['metadata']['annotation_reply']
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
@@ -440,323 +544,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
**extras
)
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
"""
Get stream generate routes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
answer_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.ANSWER.value
]
# parse stream output node value selectors of answer nodes
stream_generate_routes = {}
for node_config in answer_node_configs:
# get generate route for stream output
answer_node_id = node_config['id']
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
return stream_generate_routes
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get answer start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
# check if it's the first node in the iteration
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
if not target_node:
return []
node_iteration_id = target_node.get('data', {}).get('iteration_id')
# get iteration start node id
for node in nodes:
if node.get('id') == node_iteration_id:
if node.get('data', {}).get('start_node_id') == target_node_id:
return [target_node_id]
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value,
NodeType.ITERATION.value,
NodeType.LOOP.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:
]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
if should_direct_answer:
continue
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
break
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
"""
Generate stream outputs.
:return:
"""
if not self._task_state.current_stream_generate_state:
return
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
value = None
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
self._task_state.current_stream_generate_state.current_route_position += 1
continue
route_chunk_node_id = value_selector[0]
if route_chunk_node_id == 'sys':
# system variable
value = contexts.workflow_variable_pool.get().get(value_selector)
if value:
value = value.text
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
continue
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
iterator = iteration_state.inputs
if not iterator:
continue
iterator_selector = iterator.get('iterator_selector', [])
if value_selector[1] == 'index':
value = iteration_state.current_index
elif value_selector[1] == 'item':
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
iterator_selector
) else None
else:
# check chunk node id is before current node id or equal to current node id
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
break
latest_node_execution_info = self._task_state.latest_node_execution_info
# get route chunk node execution info
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
if (route_chunk_node_execution_info.node_type == NodeType.LLM
and latest_node_execution_info.node_type == NodeType.LLM):
# only LLM support chunk stream output
self._task_state.current_stream_generate_state.current_route_position += 1
continue
# get route chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
).first()
outputs = route_chunk_node_execution.outputs_dict
# get value from outputs
value = None
for key in value_selector[1:]:
if not value:
value = outputs.get(key) if outputs else None
else:
value = value.get(key)
if value is not None:
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, dict):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
if file_vars:
file_var = file_vars[0]
try:
file_var_obj = FileVar(**file_var)
# convert file to markdown
text = file_var_obj.to_markdown()
except Exception as e:
logger.error(f'Error creating file var: {e}')
if not text:
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, list):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
for file_var in file_vars:
try:
file_var_obj = FileVar(**file_var)
except Exception as e:
logger.error(f'Error creating file var: {e}')
continue
# convert file to markdown
text = file_var_obj.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
if text:
self._task_state.answer += text
yield self._message_to_stream_response(text, self._message.id)
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return True
if 'node_id' not in event.metadata:
return True
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
route_chunk = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position]
if route_chunk.type != 'var':
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
return False
return True
def _handle_output_moderation_chunk(self, text: str) -> bool:
"""
Handle output moderation chunk.
@@ -782,3 +569,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._output_moderation_handler.append_new_token(text)
return False
def _refetch_message(self) -> None:
"""
Refetch message.
:return:
"""
message = db.session.query(Message).filter(Message.id == self._message.id).first()
if message:
self._message = message

View File

@@ -1,203 +0,0 @@
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
outputs=outputs,
process_data=process_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager._publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager._publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
self._queue_manager.publish(
event,
PublishFrom.APPLICATION_MANAGER
)

View File

@@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
def convert(cls, response: Union[
AppBlockingResponse,
Generator[AppStreamResponse, Any, None]
], invoke_from: InvokeFrom):
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)

View File

@@ -1,6 +1,6 @@
import time
from collections.abc import Generator
from typing import TYPE_CHECKING, Optional, Union
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -347,7 +347,7 @@ class AppRunner:
self, app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: dict,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> tuple[bool, dict, str]:

View File

@@ -4,7 +4,7 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Literal, Union, overload
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -40,6 +40,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
args: dict,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
) -> Generator[str, None, None]: ...
@overload
@@ -50,16 +52,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
) -> dict: ...
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
):
"""
Generate App response.
@@ -71,6 +77,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
:param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id
"""
inputs = args['inputs']
@@ -118,16 +125,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
stream=stream,
workflow_thread_pool_id=workflow_thread_pool_id
)
def _generate(
self, app_model: App,
self, *,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
workflow_thread_pool_id: Optional[str] = None
) -> dict[str, Any] | Generator[str, None, None]:
"""
Generate App response.
@@ -137,6 +147,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param stream: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
@@ -148,10 +159,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'flask_app': current_app._get_current_object(), # type: ignore
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'context': contextvars.copy_context()
'context': contextvars.copy_context(),
'workflow_thread_pool_id': workflow_thread_pool_id
})
worker_thread.start()
@@ -175,7 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True):
stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@@ -192,10 +204,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
@@ -211,7 +219,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras=extras,
extras={
"auto_generate_conversation_name": False
},
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
@@ -231,12 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context) -> None:
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
for var, val in context.items():
@@ -244,22 +256,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
with flask_app.app_context():
try:
# workflow app
runner = WorkflowAppRunner()
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager
)
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id
)
runner.run()
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
@@ -271,14 +274,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,

View File

@@ -4,46 +4,61 @@ from typing import Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App, EndUser
from models.workflow import Workflow
from models.workflow import WorkflowType
logger = logging.getLogger(__name__)
class WorkflowAppRunner:
class WorkflowAppRunner(WorkflowBasedAppRunner):
"""
Workflow Application Runner
"""
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
def run(self) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:return:
"""
app_config = application_generate_entity.app_config
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
user_id = self.application_generate_entity.user_id
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
@@ -53,80 +68,64 @@ class WorkflowAppRunner:
if not workflow:
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
files = application_generate_entity.files
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id
)
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
if not app_record.workflow_id:
raise ValueError('Workflow not initialized')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
generator = workflow_entry.run(
callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
for event in generator:
self._handle_event(workflow_entry, event)

View File

@@ -1,3 +1,4 @@
import json
import logging
import time
from collections.abc import Generator
@@ -15,10 +16,12 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
@@ -32,19 +35,16 @@ from core.app.entities.task_entities import (
MessageAudioStreamResponse,
StreamResponse,
TextChunkStreamResponse,
TextReplaceStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStreamGenerateNodes,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@@ -52,8 +52,8 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
@@ -68,7 +68,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
@@ -96,11 +95,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
SystemVariableKey.USER_ID: user_id
}
self._task_state = WorkflowTaskState(
iteration_nested_node_ids=[]
)
self._stream_generate_nodes = self._get_stream_generate_nodes()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._task_state = WorkflowTaskState()
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@@ -129,23 +124,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowFinishStreamResponse):
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == self._task_state.workflow_run_id).first()
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
status=workflow_run.status,
outputs=workflow_run.outputs_dict,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp())
id=stream_response.data.id,
workflow_id=stream_response.data.workflow_id,
status=stream_response.data.status,
outputs=stream_response.data.outputs,
error=stream_response.data.error,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at)
)
)
@@ -161,9 +153,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
To stream response.
:return:
"""
workflow_run_id = None
for stream_response in generator:
if isinstance(stream_response, WorkflowStartStreamResponse):
workflow_run_id = stream_response.workflow_run_id
yield WorkflowAppStreamResponse(
workflow_run_id=self._task_state.workflow_run_id,
workflow_run_id=workflow_run_id,
stream_response=stream_response
)
@@ -178,17 +174,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@@ -198,9 +195,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
start_listener_time = time.time()
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
if not tts_publisher:
break
audio_trunk = publisher.checkAndGetAudio()
audio_trunk = tts_publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@@ -218,69 +215,159 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event
graph_runtime_state = None
workflow_run = None
if isinstance(event, QueueErrorEvent):
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
workflow_run = self._handle_workflow_start()
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
# init workflow run
workflow_run = self._handle_workflow_run_start()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._handle_node_start(event)
if not workflow_run:
raise Exception('Workflow run not initialized.')
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
event=event
)
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
yield self._workflow_node_finish_to_stream_response(
if response:
yield response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, trace_manager=trace_manager
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
@@ -295,22 +382,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if delta_text is None:
continue
if not self._is_stream_out_support(
event=event
):
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(delta_text)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._text_replace_to_stream_response(event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
else:
continue
if publisher:
publisher.publish(None)
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
@@ -329,15 +411,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# not save log for debugging
return
workflow_app_log = WorkflowAppLog(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.id,
created_from=created_from.value,
created_by_role=('account' if isinstance(self._user, Account) else 'end_user'),
created_by=self._user.id,
)
workflow_app_log = WorkflowAppLog()
workflow_app_log.tenant_id = workflow_run.tenant_id
workflow_app_log.app_id = workflow_run.app_id
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user'
workflow_app_log.created_by = self._user.id
db.session.add(workflow_app_log)
db.session.commit()
db.session.close()
@@ -354,180 +436,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
)
return response
def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse:
"""
Text replace to stream response.
:param text: text
:return:
"""
return TextReplaceStreamResponse(
task_id=self._application_generate_entity.task_id,
text=TextReplaceStreamResponse.Data(text=text)
)
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
"""
Get stream generate nodes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
end_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.END.value
]
# parse stream output node value selectors of end nodes
stream_generate_routes = {}
for node_config in end_node_configs:
# get generate route for stream output
end_node_id = node_config['id']
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
end_node_id=end_node_id,
stream_node_ids=generate_nodes
)
return stream_generate_routes
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get end start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
if node_id not in stream_node_ids:
continue
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
# get chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
if not route_chunk_node_execution:
continue
outputs = route_chunk_node_execution.outputs_dict
if not outputs:
continue
# get value from outputs
text = outputs.get('text')
if text:
self._task_state.answer += text
yield self._text_chunk_to_stream_response(text)
db.session.close()
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return False
if 'node_id' not in event.metadata:
return False
node_id = event.metadata.get('node_id')
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
return True
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}

View File

@@ -1,200 +0,0 @@
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
outputs=outputs,
process_data=process_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager.publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager.publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
pass

View File

@@ -0,0 +1,379 @@
from collections.abc import Mapping
from typing import Any, Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow
class WorkflowBasedAppRunner(AppRunner):
def __init__(self, queue_manager: AppQueueManager):
self.queue_manager = queue_manager
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
Init graph
"""
if 'nodes' not in graph_config or 'edges' not in graph_config:
raise ValueError('nodes or edges not found in workflow graph')
if not isinstance(graph_config.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph_config.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# init graph
graph = Graph.init(
graph_config=graph_config
)
if not graph:
raise ValueError('graph not found in workflow')
return graph
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError('workflow graph not found')
graph_config = cast(dict[str, Any], graph_config)
if 'nodes' not in graph_config or 'edges' not in graph_config:
raise ValueError('nodes or edges not found in workflow graph')
if not isinstance(graph_config.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph_config.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# filter nodes only in iteration
node_configs = [
node for node in graph_config.get('nodes', [])
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
]
graph_config['nodes'] = node_configs
node_ids = [node.get('id') for node in node_configs]
# filter edges only in iteration
edge_configs = [
edge for edge in graph_config.get('edges', [])
if (edge.get('source') is None or edge.get('source') in node_ids)
and (edge.get('target') is None or edge.get('target') in node_ids)
]
graph_config['edges'] = edge_configs
# init graph
graph = Graph.init(
graph_config=graph_config,
root_node_id=node_id
)
if not graph:
raise ValueError('graph not found in workflow')
# fetch node config from node id
iteration_node_config = None
for node in node_configs:
if node.get('id') == node_id:
iteration_node_config = node
break
if not iteration_node_config:
raise ValueError('iteration node id not found in workflow graph')
# Get node class
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
node_cls = cast(type[BaseNode], node_cls)
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict,
config=iteration_node_config
)
except NotImplementedError:
variable_mapping = {}
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
)
return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
"""
Handle event
:param workflow_entry: workflow entry
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(
QueueWorkflowStartedEvent(
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
)
)
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(
QueueWorkflowSucceededEvent(outputs=event.outputs)
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(
QueueWorkflowFailedEvent(error=event.error)
)
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, NodeRunSucceededEvent):
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result else {},
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result
and event.route_node_state.node_run_result.error
else "Unknown error",
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
text=event.chunk_content,
from_variable_selector=event.from_variable_selector,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event(
QueueRetrieverResourcesEvent(
retriever_resources=event.retriever_resources,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
self._publish_event(
QueueParallelBranchRunSucceededEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunFailedEvent):
self._publish_event(
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
error=event.error
)
)
elif isinstance(event, IterationRunStartedEvent):
self._publish_event(
QueueIterationStartEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata
)
)
elif isinstance(event, IterationRunNextEvent):
self._publish_event(
QueueIterationNextEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
self._publish_event(
QueueIterationCompletedEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None
)
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(
event,
PublishFrom.APPLICATION_MANAGER
)

View File

@@ -1,10 +1,24 @@
from typing import Optional
from core.app.entities.queue_entities import AppQueueEvent
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
@@ -20,127 +34,203 @@ class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
self.current_node_id = None
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self.print_text("\n[on_workflow_run_started]", color='pink')
def on_event(
self,
event: GraphEngineEvent
) -> None:
if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color='pink')
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color='green')
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(
event=event
)
elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded(
event=event
)
elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed(
event=event
)
elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk(
event=event
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started(
event=event
)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed(
event=event
)
elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started(
event=event
)
elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next(
event=event
)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed(
event=event
)
else:
self.print_text(f"\n[{event.__class__.__name__}]", color='blue')
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self.print_text("\n[on_workflow_run_succeeded]", color='green')
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self.print_text("\n[on_workflow_run_failed]", color='red')
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
def on_workflow_node_execute_started(
self,
event: NodeRunStartedEvent
) -> None:
"""
Workflow node execute started
"""
self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
self.print_text(f"Node ID: {node_id}", color='yellow')
self.print_text(f"Type: {node_type.value}", color='yellow')
self.print_text(f"Index: {node_run_index}", color='yellow')
if predecessor_node_id:
self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow')
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
self.print_text(f"Node ID: {event.node_id}", color='yellow')
self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
self.print_text(f"Type: {event.node_type.value}", color='yellow')
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
def on_workflow_node_execute_succeeded(
self,
event: NodeRunSucceededEvent
) -> None:
"""
Workflow node execute succeeded
"""
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
self.print_text(f"Node ID: {node_id}", color='green')
self.print_text(f"Type: {node_type.value}", color='green')
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green')
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green')
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green')
self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}",
color='green')
route_node_state = event.route_node_state
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
self.print_text("\n[NodeRunSucceededEvent]", color='green')
self.print_text(f"Node ID: {event.node_id}", color='green')
self.print_text(f"Node Title: {event.node_data.title}", color='green')
self.print_text(f"Type: {event.node_type.value}", color='green')
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='green')
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color='green')
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color='green')
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
color='green')
def on_workflow_node_execute_failed(
self,
event: NodeRunFailedEvent
) -> None:
"""
Workflow node execute failed
"""
self.print_text("\n[on_workflow_node_execute_failed]", color='red')
self.print_text(f"Node ID: {node_id}", color='red')
self.print_text(f"Type: {node_type.value}", color='red')
self.print_text(f"Error: {error}", color='red')
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red')
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red')
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red')
route_node_state = event.route_node_state
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
self.print_text("\n[NodeRunFailedEvent]", color='red')
self.print_text(f"Node ID: {event.node_id}", color='red')
self.print_text(f"Node Title: {event.node_data.title}", color='red')
self.print_text(f"Type: {event.node_type.value}", color='red')
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color='red')
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='red')
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color='red')
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color='red')
def on_node_text_chunk(
self,
event: NodeRunStreamChunkEvent
) -> None:
"""
Publish text chunk
"""
if not self.current_node_id or self.current_node_id != node_id:
self.current_node_id = node_id
self.print_text('\n[on_node_text_chunk]')
self.print_text(f"Node ID: {node_id}")
self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}")
route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id
self.print_text('\n[NodeRunStreamChunkEvent]')
self.print_text(f"Node ID: {route_node_state.node_id}")
self.print_text(text, color="pink", end="")
node_run_result = route_node_state.node_run_result
if node_run_result:
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}")
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(
self,
event: ParallelBranchRunStartedEvent
) -> None:
"""
Publish parallel started
"""
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
def on_workflow_parallel_completed(
self,
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None:
"""
Publish parallel completed
"""
if isinstance(event, ParallelBranchRunSucceededEvent):
color = 'blue'
elif isinstance(event, ParallelBranchRunFailedEvent):
color = 'red'
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started(
self,
event: IterationRunStartedEvent
) -> None:
"""
Publish iteration started
"""
self.print_text("\n[on_workflow_iteration_started]", color='blue')
self.print_text(f"Node ID: {node_id}", color='blue')
self.print_text("\n[IterationRunStartedEvent]", color='blue')
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[dict]) -> None:
def on_workflow_iteration_next(
self,
event: IterationRunNextEvent
) -> None:
"""
Publish iteration next
"""
self.print_text("\n[on_workflow_iteration_next]", color='blue')
self.print_text("\n[IterationRunNextEvent]", color='blue')
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
self.print_text(f"Iteration Index: {event.index}", color='blue')
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
def on_workflow_iteration_completed(
self,
event: IterationRunSucceededEvent | IterationRunFailedEvent
) -> None:
"""
Publish iteration completed
"""
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
self.print_text("\n[on_workflow_event]", color='blue')
self.print_text(f"Event: {jsonable_encoder(event)}", color='blue')
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
def print_text(
self, text: str, color: Optional[str] = None, end: str = "\n"