mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 19:36:53 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline:
|
||||
_task_state: TaskState
|
||||
_application_generate_entity: AppGenerateEntity
|
||||
|
||||
def __init__(self, application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -61,18 +64,18 @@ class BasedGenerateTaskPipeline:
|
||||
e = event.error
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError('Incorrect API key provided')
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||
err = e
|
||||
else:
|
||||
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
|
||||
|
||||
if message:
|
||||
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
|
||||
if refetch_message:
|
||||
err_desc = self._error_to_desc(err)
|
||||
refetch_message.status = 'error'
|
||||
refetch_message.status = "error"
|
||||
refetch_message.error = err_desc
|
||||
|
||||
db.session.commit()
|
||||
@@ -86,12 +89,14 @@ class BasedGenerateTaskPipeline:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(e, QuotaExceededError):
|
||||
return ("Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.")
|
||||
return (
|
||||
"Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
)
|
||||
|
||||
message = getattr(e, 'description', str(e))
|
||||
message = getattr(e, "description", str(e))
|
||||
if not message:
|
||||
message = 'Internal Server Error, please contact support.'
|
||||
message = "Internal Server Error, please contact support."
|
||||
|
||||
return message
|
||||
|
||||
@@ -101,10 +106,7 @@ class BasedGenerateTaskPipeline:
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
return ErrorStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
err=e
|
||||
)
|
||||
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
|
||||
|
||||
def _ping_stream_response(self) -> PingStreamResponse:
|
||||
"""
|
||||
@@ -125,11 +127,8 @@ class BasedGenerateTaskPipeline:
|
||||
return OutputModeration(
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_id=app_config.app_id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance.type,
|
||||
config=sensitive_word_avoidance.config
|
||||
),
|
||||
queue_manager=self._queue_manager
|
||||
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
|
||||
queue_manager=self._queue_manager,
|
||||
)
|
||||
|
||||
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
|
||||
@@ -143,8 +142,7 @@ class BasedGenerateTaskPipeline:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
completion = self._output_moderation_handler.moderation_completion(
|
||||
completion=completion,
|
||||
public_event=False
|
||||
completion=completion, public_event=False
|
||||
)
|
||||
|
||||
self._output_moderation_handler = None
|
||||
|
||||
@@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
"""
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity
|
||||
]
|
||||
|
||||
def __init__(self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -101,18 +99,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
model=self._model_config.model,
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=LLMUsage.empty_usage()
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
|
||||
]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
@@ -125,22 +123,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation,
|
||||
self._application_generate_entity.query
|
||||
self._conversation, self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse
|
||||
]:
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
@@ -149,11 +143,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {
|
||||
'usage': jsonable_encoder(self._task_state.llm_result.usage)
|
||||
}
|
||||
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
if self._conversation.mode == AppMode.COMPLETION.value:
|
||||
response = CompletionAppBlockingResponse(
|
||||
@@ -164,8 +156,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras
|
||||
)
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
else:
|
||||
response = ChatbotAppBlockingResponse(
|
||||
@@ -177,18 +169,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras
|
||||
)
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
@@ -198,14 +191,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
yield CompletionAppStreamResponse(
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
stream_response=stream_response
|
||||
stream_response=stream_response,
|
||||
)
|
||||
else:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
stream_response=stream_response
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
@@ -217,15 +210,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
publisher = None
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
|
||||
if (
|
||||
text_to_speech_dict
|
||||
and text_to_speech_dict.get("autoPlay") == "enabled"
|
||||
and text_to_speech_dict.get("enabled")
|
||||
):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id)
|
||||
@@ -250,14 +247,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio.audio,
|
||||
task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
@@ -333,9 +327,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> None:
|
||||
def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
@@ -347,31 +339,32 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
|
||||
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode,
|
||||
self._task_state.llm_result.prompt_messages
|
||||
self._model_config.mode, self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
|
||||
if llm_result.message.content else ''
|
||||
self._message.answer = (
|
||||
PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
|
||||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
self._message.answer_tokens = usage.completion_tokens
|
||||
self._message.answer_unit_price = usage.completion_unit_price
|
||||
self._message.answer_price_unit = usage.completion_price_unit
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.total_price = usage.total_price
|
||||
self._message.currency = usage.currency
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE,
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id
|
||||
)
|
||||
)
|
||||
|
||||
@@ -379,11 +372,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.app_config.app_mode in [
|
||||
AppMode.AGENT_CHAT,
|
||||
AppMode.CHAT
|
||||
] and self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
|
||||
and self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras,
|
||||
)
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent) -> None:
|
||||
@@ -395,22 +386,17 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
model = model_config.model
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = 0
|
||||
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages)
|
||||
|
||||
completion_tokens = 0
|
||||
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||
completion_tokens = model_instance.get_llm_num_tokens(
|
||||
[self._task_state.llm_result.message]
|
||||
)
|
||||
completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message])
|
||||
|
||||
credentials = model_config.credentials
|
||||
|
||||
@@ -418,10 +404,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
model,
|
||||
credentials,
|
||||
prompt_tokens,
|
||||
completion_tokens
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
@@ -429,16 +412,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
**extras
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
||||
)
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
@@ -449,9 +430,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
:return:
|
||||
"""
|
||||
return AgentMessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer
|
||||
task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
|
||||
)
|
||||
|
||||
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
|
||||
@@ -461,9 +440,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
:return:
|
||||
"""
|
||||
agent_thought: MessageAgentThought = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||
.first()
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
db.session.close()
|
||||
@@ -478,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
tool=agent_thought.tool,
|
||||
tool_labels=agent_thought.tool_labels,
|
||||
tool_input=agent_thought.tool_input,
|
||||
message_files=agent_thought.files
|
||||
message_files=agent_thought.files,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -500,15 +477,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
||||
)
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content),
|
||||
),
|
||||
)
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
),
|
||||
PublishFrom.TASK_PIPELINE,
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
|
||||
@@ -30,10 +30,7 @@ from services.annotation_service import AppAnnotationService
|
||||
|
||||
class MessageCycleManage:
|
||||
_application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
|
||||
]
|
||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
||||
|
||||
@@ -49,15 +46,18 @@ class MessageCycleManage:
|
||||
|
||||
is_first_message = self._application_generate_entity.conversation_id is None
|
||||
extras = self._application_generate_entity.extras
|
||||
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
|
||||
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
|
||||
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'conversation_id': conversation.id,
|
||||
'query': query
|
||||
})
|
||||
thread = Thread(
|
||||
target=self._generate_conversation_name_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"conversation_id": conversation.id,
|
||||
"query": query,
|
||||
},
|
||||
)
|
||||
|
||||
thread.start()
|
||||
|
||||
@@ -65,17 +65,10 @@ class MessageCycleManage:
|
||||
|
||||
return None
|
||||
|
||||
def _generate_conversation_name_worker(self,
|
||||
flask_app: Flask,
|
||||
conversation_id: str,
|
||||
query: str):
|
||||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
# get conversation and message
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
@@ -105,12 +98,9 @@ class MessageCycleManage:
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata['annotation_reply'] = {
|
||||
'id': annotation.id,
|
||||
'account': {
|
||||
'id': annotation.account_id,
|
||||
'name': account.name if account else 'Dify user'
|
||||
}
|
||||
self._task_state.metadata["annotation_reply"] = {
|
||||
"id": annotation.id,
|
||||
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
|
||||
}
|
||||
|
||||
return annotation
|
||||
@@ -124,7 +114,7 @@ class MessageCycleManage:
|
||||
:return:
|
||||
"""
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
self._task_state.metadata["retriever_resources"] = event.retriever_resources
|
||||
|
||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
"""
|
||||
@@ -132,27 +122,23 @@ class MessageCycleManage:
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
message_file = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
)
|
||||
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
|
||||
|
||||
if message_file:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split('/')[-1]
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split('.')[0]
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
if "." in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
if len(extension) > 10:
|
||||
extension = '.bin'
|
||||
extension = ".bin"
|
||||
else:
|
||||
extension = '.bin'
|
||||
extension = ".bin"
|
||||
# add sign url to local file
|
||||
if message_file.url.startswith('http'):
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
else:
|
||||
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
|
||||
@@ -161,8 +147,8 @@ class MessageCycleManage:
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_file.id,
|
||||
type=message_file.type,
|
||||
belongs_to=message_file.belongs_to or 'user',
|
||||
url=url
|
||||
belongs_to=message_file.belongs_to or "user",
|
||||
url=url,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -174,11 +160,7 @@ class MessageCycleManage:
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer
|
||||
)
|
||||
return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer)
|
||||
|
||||
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
@@ -186,7 +168,4 @@ class MessageCycleManage:
|
||||
:param answer: answer
|
||||
:return:
|
||||
"""
|
||||
return MessageReplaceStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
answer=answer
|
||||
)
|
||||
return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer)
|
||||
|
||||
@@ -70,14 +70,14 @@ class WorkflowCycleManage:
|
||||
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for key, value in (self._workflow_system_variables or {}).items():
|
||||
if key.value == 'conversation':
|
||||
if key.value == "conversation":
|
||||
continue
|
||||
|
||||
inputs[f'sys.{key.value}'] = value
|
||||
inputs[f"sys.{key.value}"] = value
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
|
||||
triggered_from= (
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN
|
||||
@@ -185,20 +185,26 @@ class WorkflowCycleManage:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
|
||||
).all()
|
||||
running_workflow_node_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
|
||||
workflow_node_execution.elapsed_time = (
|
||||
workflow_node_execution.finished_at - workflow_node_execution.created_at
|
||||
).total_seconds()
|
||||
db.session.commit()
|
||||
|
||||
db.session.refresh(workflow_run)
|
||||
@@ -216,7 +222,9 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
|
||||
def _handle_node_execution_start(
|
||||
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
# init workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
@@ -333,16 +341,16 @@ class WorkflowCycleManage:
|
||||
created_by_account = workflow_run.created_by_account
|
||||
if created_by_account:
|
||||
created_by = {
|
||||
'id': created_by_account.id,
|
||||
'name': created_by_account.name,
|
||||
'email': created_by_account.email,
|
||||
"id": created_by_account.id,
|
||||
"name": created_by_account.name,
|
||||
"email": created_by_account.email,
|
||||
}
|
||||
else:
|
||||
created_by_end_user = workflow_run.created_by_end_user
|
||||
if created_by_end_user:
|
||||
created_by = {
|
||||
'id': created_by_end_user.id,
|
||||
'user': created_by_end_user.session_id,
|
||||
"id": created_by_end_user.id,
|
||||
"user": created_by_end_user.session_id,
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
@@ -401,7 +409,7 @@ class WorkflowCycleManage:
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
node_data = cast(ToolNodeData, event.node_data)
|
||||
response.data.extras['icon'] = ToolManager.get_tool_icon(
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id,
|
||||
@@ -410,10 +418,10 @@ class WorkflowCycleManage:
|
||||
return response
|
||||
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution
|
||||
self,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
@@ -424,7 +432,7 @@ class WorkflowCycleManage:
|
||||
"""
|
||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
||||
return None
|
||||
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
@@ -452,13 +460,10 @@ class WorkflowCycleManage:
|
||||
iteration_id=event.in_iteration_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch start to stream response
|
||||
:param task_id: task id
|
||||
@@ -476,15 +481,15 @@ class WorkflowCycleManage:
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch finished to stream response
|
||||
:param task_id: task id
|
||||
@@ -501,18 +506,15 @@ class WorkflowCycleManage:
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
|
||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_iteration_start_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueIterationStartEvent
|
||||
) -> IterationNodeStartStreamResponse:
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
|
||||
) -> IterationNodeStartStreamResponse:
|
||||
"""
|
||||
Workflow iteration start to stream response
|
||||
:param task_id: task id
|
||||
@@ -534,10 +536,12 @@ class WorkflowCycleManage:
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse:
|
||||
def _workflow_iteration_next_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
|
||||
) -> IterationNodeNextStreamResponse:
|
||||
"""
|
||||
Workflow iteration next to stream response
|
||||
:param task_id: task id
|
||||
@@ -559,10 +563,12 @@ class WorkflowCycleManage:
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse:
|
||||
def _workflow_iteration_completed_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
|
||||
) -> IterationNodeCompletedStreamResponse:
|
||||
"""
|
||||
Workflow iteration completed to stream response
|
||||
:param task_id: task id
|
||||
@@ -585,13 +591,13 @@ class WorkflowCycleManage:
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
|
||||
@@ -643,7 +649,7 @@ class WorkflowCycleManage:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if '__variant' in value and value['__variant'] == FileVar.__name__:
|
||||
if "__variant" in value and value["__variant"] == FileVar.__name__:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
return value.to_dict()
|
||||
@@ -656,11 +662,10 @@ class WorkflowCycleManage:
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == workflow_run_id).first()
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
|
||||
|
||||
if not workflow_run:
|
||||
raise Exception(f'Workflow run not found: {workflow_run_id}')
|
||||
raise Exception(f"Workflow run not found: {workflow_run_id}")
|
||||
|
||||
return workflow_run
|
||||
|
||||
@@ -683,6 +688,6 @@ class WorkflowCycleManage:
|
||||
)
|
||||
|
||||
if not workflow_node_execution:
|
||||
raise Exception(f'Workflow node execution not found: {node_execution_id}')
|
||||
raise Exception(f"Workflow node execution not found: {node_execution_id}")
|
||||
|
||||
return workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
Reference in New Issue
Block a user