chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline:
_task_state: TaskState
_application_generate_entity: AppGenerateEntity
def __init__(self, application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool) -> None:
def __init__(
self,
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
@@ -61,18 +64,18 @@ class BasedGenerateTaskPipeline:
e = event.error
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError('Incorrect API key provided')
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
err = e
else:
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
if message:
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
if refetch_message:
err_desc = self._error_to_desc(err)
refetch_message.status = 'error'
refetch_message.status = "error"
refetch_message.error = err_desc
db.session.commit()
@@ -86,12 +89,14 @@ class BasedGenerateTaskPipeline:
:return:
"""
if isinstance(e, QuotaExceededError):
return ("Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.")
return (
"Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials."
)
message = getattr(e, 'description', str(e))
message = getattr(e, "description", str(e))
if not message:
message = 'Internal Server Error, please contact support.'
message = "Internal Server Error, please contact support."
return message
@@ -101,10 +106,7 @@ class BasedGenerateTaskPipeline:
:param e: exception
:return:
"""
return ErrorStreamResponse(
task_id=self._application_generate_entity.task_id,
err=e
)
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
def _ping_stream_response(self) -> PingStreamResponse:
"""
@@ -125,11 +127,8 @@ class BasedGenerateTaskPipeline:
return OutputModeration(
tenant_id=app_config.tenant_id,
app_id=app_config.app_id,
rule=ModerationRule(
type=sensitive_word_avoidance.type,
config=sensitive_word_avoidance.config
),
queue_manager=self._queue_manager
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
queue_manager=self._queue_manager,
)
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
@@ -143,8 +142,7 @@ class BasedGenerateTaskPipeline:
self._output_moderation_handler.stop_thread()
completion = self._output_moderation_handler.moderation_completion(
completion=completion,
public_event=False
completion=completion, public_event=False
)
self._output_moderation_handler = None

View File

@@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
"""
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: EasyUITaskState
_application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity
]
def __init__(self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool) -> None:
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
def __init__(
self,
application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
@@ -101,18 +99,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
model=self._model_config.model,
prompt_messages=[],
message=AssistantPromptMessage(content=""),
usage=LLMUsage.empty_usage()
usage=LLMUsage.empty_usage(),
)
)
self._conversation_name_generate_thread = None
def process(
self,
self,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]:
"""
Process generate task pipeline.
@@ -125,22 +123,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation,
self._application_generate_entity.query
self._conversation, self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse
]:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
"""
Process blocking response.
:return:
@@ -149,11 +143,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {
'usage': jsonable_encoder(self._task_state.llm_result.usage)
}
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
extras["metadata"] = self._task_state.metadata
if self._conversation.mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse(
@@ -164,8 +156,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_id=self._message.id,
answer=self._task_state.llm_result.message.content,
created_at=int(self._message.created_at.timestamp()),
**extras
)
**extras,
),
)
else:
response = ChatbotAppBlockingResponse(
@@ -177,18 +169,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_id=self._message.id,
answer=self._task_state.llm_result.message.content,
created_at=int(self._message.created_at.timestamp()),
**extras
)
**extras,
),
)
return response
else:
continue
raise Exception('Queue listening stopped unexpectedly.')
raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
-> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
"""
To stream response.
:return:
@@ -198,14 +191,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
yield CompletionAppStreamResponse(
message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response
stream_response=stream_response,
)
else:
yield ChatbotAppStreamResponse(
conversation_id=self._conversation.id,
message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response
stream_response=stream_response,
)
def _listenAudioMsg(self, publisher, task_id: str):
@@ -217,15 +210,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
publisher = None
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
if (
text_to_speech_dict
and text_to_speech_dict.get("autoPlay") == "enabled"
and text_to_speech_dict.get("enabled")
):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id)
@@ -250,14 +247,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio.audio,
task_id=task_id)
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@@ -333,9 +327,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(
self, trace_manager: Optional[TraceQueueManager] = None
) -> None:
def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None:
"""
Save message.
:return:
@@ -347,31 +339,32 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode,
self._task_state.llm_result.prompt_messages
self._model_config.mode, self._task_state.llm_result.prompt_messages
)
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
if llm_result.message.content else ''
self._message.answer = (
PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
if llm_result.message.content
else ""
)
self._message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.total_price = usage.total_price
self._message.currency = usage.currency
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit()
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE,
conversation_id=self._conversation.id,
message_id=self._message.id
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id
)
)
@@ -379,11 +372,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._message,
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in [
AppMode.AGENT_CHAT,
AppMode.CHAT
] and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras
is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
)
def _handle_stop(self, event: QueueStopEvent) -> None:
@@ -395,22 +386,17 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
model = model_config.model
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
# calculate num tokens
prompt_tokens = 0
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
prompt_tokens = model_instance.get_llm_num_tokens(
self._task_state.llm_result.prompt_messages
)
prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages)
completion_tokens = 0
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
completion_tokens = model_instance.get_llm_num_tokens(
[self._task_state.llm_result.message]
)
completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message])
credentials = model_config.credentials
@@ -418,10 +404,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
model,
credentials,
prompt_tokens,
completion_tokens
model, credentials, prompt_tokens, completion_tokens
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@@ -429,16 +412,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
Message end to stream response.
:return:
"""
self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
extras["metadata"] = self._task_state.metadata
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message.id,
**extras
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
)
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
@@ -449,9 +430,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
:return:
"""
return AgentMessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer
task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
)
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
@@ -461,9 +440,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
:return:
"""
agent_thought: MessageAgentThought = (
db.session.query(MessageAgentThought)
.filter(MessageAgentThought.id == event.agent_thought_id)
.first()
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
)
db.session.refresh(agent_thought)
db.session.close()
@@ -478,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
tool=agent_thought.tool,
tool_labels=agent_thought.tool_labels,
tool_input=agent_thought.tool_input,
message_files=agent_thought.files
message_files=agent_thought.files,
)
return None
@@ -500,15 +477,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
prompt_messages=self._task_state.llm_result.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
)
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content),
),
)
), PublishFrom.TASK_PIPELINE
),
PublishFrom.TASK_PIPELINE,
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True
else:

View File

@@ -30,10 +30,7 @@ from services.annotation_service import AppAnnotationService
class MessageCycleManage:
_application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
@@ -49,15 +46,18 @@ class MessageCycleManage:
is_first_message = self._application_generate_entity.conversation_id is None
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
if auto_generate_conversation_name and is_first_message:
# start generate thread
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore
'conversation_id': conversation.id,
'query': query
})
thread = Thread(
target=self._generate_conversation_name_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"conversation_id": conversation.id,
"query": query,
},
)
thread.start()
@@ -65,17 +65,10 @@ class MessageCycleManage:
return None
def _generate_conversation_name_worker(self,
flask_app: Flask,
conversation_id: str,
query: str):
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation:
return
@@ -105,12 +98,9 @@ class MessageCycleManage:
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
self._task_state.metadata['annotation_reply'] = {
'id': annotation.id,
'account': {
'id': annotation.account_id,
'name': account.name if account else 'Dify user'
}
self._task_state.metadata["annotation_reply"] = {
"id": annotation.id,
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
}
return annotation
@@ -124,7 +114,7 @@ class MessageCycleManage:
:return:
"""
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata['retriever_resources'] = event.retriever_resources
self._task_state.metadata["retriever_resources"] = event.retriever_resources
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
"""
@@ -132,27 +122,23 @@ class MessageCycleManage:
:param event: event
:return:
"""
message_file = (
db.session.query(MessageFile)
.filter(MessageFile.id == event.message_file_id)
.first()
)
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
if message_file:
# get tool file id
tool_file_id = message_file.url.split('/')[-1]
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split('.')[0]
tool_file_id = tool_file_id.split(".")[0]
# get extension
if '.' in message_file.url:
if "." in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}'
if len(extension) > 10:
extension = '.bin'
extension = ".bin"
else:
extension = '.bin'
extension = ".bin"
# add sign url to local file
if message_file.url.startswith('http'):
if message_file.url.startswith("http"):
url = message_file.url
else:
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
@@ -161,8 +147,8 @@ class MessageCycleManage:
task_id=self._application_generate_entity.task_id,
id=message_file.id,
type=message_file.type,
belongs_to=message_file.belongs_to or 'user',
url=url
belongs_to=message_file.belongs_to or "user",
url=url,
)
return None
@@ -174,11 +160,7 @@ class MessageCycleManage:
:param message_id: message id
:return:
"""
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer
)
return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer)
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
"""
@@ -186,7 +168,4 @@ class MessageCycleManage:
:param answer: answer
:return:
"""
return MessageReplaceStreamResponse(
task_id=self._application_generate_entity.task_id,
answer=answer
)
return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer)

View File

@@ -70,14 +70,14 @@ class WorkflowCycleManage:
inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items():
if key.value == 'conversation':
if key.value == "conversation":
continue
inputs[f'sys.{key.value}'] = value
inputs[f"sys.{key.value}"] = value
inputs = WorkflowEntry.handle_special_values(inputs)
triggered_from= (
triggered_from = (
WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN
@@ -185,20 +185,26 @@ class WorkflowCycleManage:
db.session.commit()
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
).all()
running_workflow_node_executions = (
db.session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
.all()
)
for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds()
db.session.commit()
db.session.refresh(workflow_run)
@@ -216,7 +222,9 @@ class WorkflowCycleManage:
return workflow_run
def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
def _handle_node_execution_start(
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution:
# init workflow node execution
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
@@ -333,16 +341,16 @@ class WorkflowCycleManage:
created_by_account = workflow_run.created_by_account
if created_by_account:
created_by = {
'id': created_by_account.id,
'name': created_by_account.name,
'email': created_by_account.email,
"id": created_by_account.id,
"name": created_by_account.name,
"email": created_by_account.email,
}
else:
created_by_end_user = workflow_run.created_by_end_user
if created_by_end_user:
created_by = {
'id': created_by_end_user.id,
'user': created_by_end_user.session_id,
"id": created_by_end_user.id,
"user": created_by_end_user.session_id,
}
return WorkflowFinishStreamResponse(
@@ -401,7 +409,7 @@ class WorkflowCycleManage:
# extras logic
if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data)
response.data.extras['icon'] = ToolManager.get_tool_icon(
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
@@ -410,10 +418,10 @@ class WorkflowCycleManage:
return response
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
"""
Workflow node finish to stream response.
@@ -424,7 +432,7 @@ class WorkflowCycleManage:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
return None
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
@@ -452,13 +460,10 @@ class WorkflowCycleManage:
iteration_id=event.in_iteration_id,
),
)
def _workflow_parallel_branch_start_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
"""
Workflow parallel branch start to stream response
:param task_id: task id
@@ -476,15 +481,15 @@ class WorkflowCycleManage:
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
created_at=int(time.time()),
)
),
)
def _workflow_parallel_branch_finished_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
) -> ParallelBranchFinishedStreamResponse:
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
"""
Workflow parallel branch finished to stream response
:param task_id: task id
@@ -501,18 +506,15 @@ class WorkflowCycleManage:
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
)
),
)
def _workflow_iteration_start_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
"""
Workflow iteration start to stream response
:param task_id: task id
@@ -534,10 +536,12 @@ class WorkflowCycleManage:
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
),
)
def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse:
def _workflow_iteration_next_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse:
"""
Workflow iteration next to stream response
:param task_id: task id
@@ -559,10 +563,12 @@ class WorkflowCycleManage:
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
),
)
def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse:
def _workflow_iteration_completed_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse:
"""
Workflow iteration completed to stream response
:param task_id: task id
@@ -585,13 +591,13 @@ class WorkflowCycleManage:
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
),
)
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
@@ -643,7 +649,7 @@ class WorkflowCycleManage:
return None
if isinstance(value, dict):
if '__variant' in value and value['__variant'] == FileVar.__name__:
if "__variant" in value and value["__variant"] == FileVar.__name__:
return value
elif isinstance(value, FileVar):
return value.to_dict()
@@ -656,11 +662,10 @@ class WorkflowCycleManage:
:param workflow_run_id: workflow run id
:return:
"""
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == workflow_run_id).first()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
if not workflow_run:
raise Exception(f'Workflow run not found: {workflow_run_id}')
raise Exception(f"Workflow run not found: {workflow_run_id}")
return workflow_run
@@ -683,6 +688,6 @@ class WorkflowCycleManage:
)
if not workflow_node_execution:
raise Exception(f'Workflow node execution not found: {node_execution_id}')
raise Exception(f"Workflow node execution not found: {node_execution_id}")
return workflow_node_execution
return workflow_node_execution