mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-15 22:06:52 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -29,6 +29,7 @@ from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -145,7 +146,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@@ -313,6 +314,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner(
|
||||
|
||||
@@ -5,6 +5,7 @@ import queue
|
||||
import re
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
@@ -15,6 +16,7 @@ from core.app.entities.queue_entities import (
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
@@ -71,8 +73,9 @@ class AppGeneratorTTSPublisher:
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get("value")
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self._last_audio_event: Optional[AudioTrunk] = None
|
||||
# FIXME better way to handle this threading.start
|
||||
threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
|
||||
@@ -92,10 +95,21 @@ class AppGeneratorTTSPublisher:
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
self.msg_text += message.event.chunk.delta.message.content
|
||||
message_content = message.event.chunk.delta.message.content
|
||||
if not message_content:
|
||||
continue
|
||||
if isinstance(message_content, str):
|
||||
self.msg_text += message_content
|
||||
elif isinstance(message_content, list):
|
||||
for content in message_content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
self.msg_text += content.data
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||
if message.event.outputs is None:
|
||||
continue
|
||||
self.msg_text += message.event.outputs.get("output", "")
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
@@ -121,11 +135,10 @@ class AppGeneratorTTSPublisher:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False)
|
||||
return self.last_message
|
||||
return self._last_audio_event
|
||||
audio = self._audio_queue.get_nowait()
|
||||
if audio and audio.status == "finish":
|
||||
self.executor.shutdown(wait=False)
|
||||
self._runtime_thread = None
|
||||
if audio:
|
||||
self._last_audio_event = audio
|
||||
return audio
|
||||
|
||||
@@ -109,18 +109,18 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
ConversationVariable.conversation_id == self.conversation.id,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
db_conversation_variables = session.scalars(stmt).all()
|
||||
if not db_conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
db_conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
session.add_all(db_conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
conversation_variables = [item.to_variable() for item in db_conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from threading import Thread
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@@ -64,6 +65,7 @@ from models.enums import CreatedByRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
@@ -81,6 +83,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_user: Union[Account, EndUser]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
_conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -131,7 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._conversation_name_generate_thread = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
|
||||
def process(self):
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@@ -262,8 +265,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:return:
|
||||
"""
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
workflow_run: Optional[WorkflowRun] = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
@@ -315,14 +318,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
response_start = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
if response_start:
|
||||
yield response_start
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
@@ -330,18 +333,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
response_finish = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
if response_finish:
|
||||
yield response_finish
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
response_finish = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
@@ -609,7 +612,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
files=self._recorded_files,
|
||||
metadata=extras.get("metadata", {}),
|
||||
)
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user