mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 11:26:52 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -29,6 +29,7 @@ from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -145,7 +146,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@@ -313,6 +314,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner(
|
||||
|
||||
@@ -5,6 +5,7 @@ import queue
|
||||
import re
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
@@ -15,6 +16,7 @@ from core.app.entities.queue_entities import (
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
@@ -71,8 +73,9 @@ class AppGeneratorTTSPublisher:
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get("value")
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self._last_audio_event: Optional[AudioTrunk] = None
|
||||
# FIXME better way to handle this threading.start
|
||||
threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
|
||||
@@ -92,10 +95,21 @@ class AppGeneratorTTSPublisher:
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
self.msg_text += message.event.chunk.delta.message.content
|
||||
message_content = message.event.chunk.delta.message.content
|
||||
if not message_content:
|
||||
continue
|
||||
if isinstance(message_content, str):
|
||||
self.msg_text += message_content
|
||||
elif isinstance(message_content, list):
|
||||
for content in message_content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
self.msg_text += content.data
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||
if message.event.outputs is None:
|
||||
continue
|
||||
self.msg_text += message.event.outputs.get("output", "")
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
@@ -121,11 +135,10 @@ class AppGeneratorTTSPublisher:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False)
|
||||
return self.last_message
|
||||
return self._last_audio_event
|
||||
audio = self._audio_queue.get_nowait()
|
||||
if audio and audio.status == "finish":
|
||||
self.executor.shutdown(wait=False)
|
||||
self._runtime_thread = None
|
||||
if audio:
|
||||
self._last_audio_event = audio
|
||||
return audio
|
||||
|
||||
@@ -109,18 +109,18 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
ConversationVariable.conversation_id == self.conversation.id,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
db_conversation_variables = session.scalars(stmt).all()
|
||||
if not db_conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
db_conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
session.add_all(db_conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
conversation_variables = [item.to_variable() for item in db_conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from threading import Thread
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@@ -64,6 +65,7 @@ from models.enums import CreatedByRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
@@ -81,6 +83,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_user: Union[Account, EndUser]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
_conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -131,7 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._conversation_name_generate_thread = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
|
||||
def process(self):
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@@ -262,8 +265,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:return:
|
||||
"""
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
workflow_run: Optional[WorkflowRun] = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
@@ -315,14 +318,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
response_start = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
if response_start:
|
||||
yield response_start
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
@@ -330,18 +333,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
response_finish = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
if response_finish:
|
||||
yield response_finish
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
response_finish = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
@@ -609,7 +612,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
files=self._recorded_files,
|
||||
metadata=extras.get("metadata", {}),
|
||||
)
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
|
||||
@@ -61,7 +61,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
config_dict = override_config_dict
|
||||
config_dict = override_config_dict or {}
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = AgentChatAppConfig(
|
||||
|
||||
@@ -23,6 +23,7 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -97,7 +98,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get("conversation_id"):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
@@ -153,7 +154,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@@ -180,7 +181,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
@@ -199,8 +200,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
# FIXME: Type hinting issue here, ignore it for now, will fix it later
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
@@ -224,6 +225,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AgentChatAppRunner()
|
||||
|
||||
@@ -173,6 +173,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
return
|
||||
|
||||
agent_entity = app_config.agent
|
||||
if not agent_entity:
|
||||
raise ValueError("Agent entity not found")
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(
|
||||
@@ -200,14 +202,21 @@ class AgentChatAppRunner(AppRunner):
|
||||
# change function call strategy based on LLM model
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if not model_schema or not model_schema.features:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
|
||||
message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
message_result = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
@@ -225,12 +234,12 @@ class AgentChatAppRunner(AppRunner):
|
||||
runner = runner_cls(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
conversation=conversation_result,
|
||||
app_config=app_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
message=message_result,
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
@@ -257,7 +266,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables = (
|
||||
tool_variables: ToolConversationVariables | None = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
|
||||
@@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -51,8 +51,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
def convert_stream_full_response( # type: ignore[override]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -82,8 +83,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
def convert_stream_simple_response( # type: ignore[override]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -50,7 +50,7 @@ class AppQueueManager:
|
||||
# wait for APP_MAX_EXECUTION_TIME seconds to stop listen
|
||||
listen_timeout = dify_config.APP_MAX_EXECUTION_TIME
|
||||
start_time = time.time()
|
||||
last_ping_time = 0
|
||||
last_ping_time: int | float = 0
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
@@ -36,8 +36,8 @@ class AppRunner:
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["File"],
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -64,7 +64,7 @@ class AppRunner:
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
@@ -85,7 +85,7 @@ class AppRunner:
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise InvokeBadRequestError(
|
||||
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
@@ -111,7 +111,7 @@ class AppRunner:
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
@@ -136,8 +136,8 @@ class AppRunner:
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["File"],
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
@@ -156,6 +156,7 @@ class AppRunner:
|
||||
"""
|
||||
# get prompt without memory and context
|
||||
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform]
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_messages, stop = prompt_transform.get_prompt(
|
||||
app_mode=AppMode.value_of(app_record.mode),
|
||||
@@ -171,8 +172,11 @@ class AppRunner:
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
|
||||
if not advanced_completion_prompt_template:
|
||||
raise InvokeBadRequestError("Advanced completion prompt template is required.")
|
||||
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
|
||||
|
||||
if advanced_completion_prompt_template.role_prefix:
|
||||
@@ -181,6 +185,8 @@ class AppRunner:
|
||||
assistant=advanced_completion_prompt_template.role_prefix.assistant,
|
||||
)
|
||||
else:
|
||||
if not prompt_template_entity.advanced_chat_prompt_template:
|
||||
raise InvokeBadRequestError("Advanced chat prompt template is required.")
|
||||
prompt_template = []
|
||||
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
|
||||
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
|
||||
@@ -246,7 +252,7 @@ class AppRunner:
|
||||
|
||||
def _handle_invoke_result(
|
||||
self,
|
||||
invoke_result: Union[LLMResult, Generator],
|
||||
invoke_result: Union[LLMResult, Generator[Any, None, None]],
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False,
|
||||
@@ -259,10 +265,12 @@ class AppRunner:
|
||||
:param agent: agent
|
||||
:return:
|
||||
"""
|
||||
if not stream:
|
||||
if not stream and isinstance(invoke_result, LLMResult):
|
||||
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
else:
|
||||
elif stream and isinstance(invoke_result, Generator):
|
||||
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
else:
|
||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||
|
||||
def _handle_invoke_result_direct(
|
||||
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
|
||||
@@ -291,8 +299,8 @@ class AppRunner:
|
||||
:param agent: agent
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
model: str = ""
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
text = ""
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
@@ -328,13 +336,14 @@ class AppRunner:
|
||||
|
||||
def moderation_for_inputs(
|
||||
self,
|
||||
*,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
query: str | None = None,
|
||||
message_id: str,
|
||||
) -> tuple[bool, dict, str]:
|
||||
) -> tuple[bool, Mapping[str, Any], str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
@@ -350,7 +359,7 @@ class AppRunner:
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
inputs=dict(inputs),
|
||||
query=query or "",
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager,
|
||||
@@ -390,9 +399,9 @@ class AppRunner:
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
) -> dict:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -91,7 +92,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get("conversation_id"):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
@@ -104,7 +105,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = ChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
|
||||
)
|
||||
|
||||
# always enable retriever resource in debugger mode
|
||||
@@ -146,7 +147,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
invoke_from=invoke_from,
|
||||
@@ -172,7 +173,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
@@ -216,6 +217,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = ChatAppRunner()
|
||||
|
||||
@@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -52,7 +52,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -83,7 +84,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -42,7 +42,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
config_dict = override_config_dict
|
||||
config_dict = override_config_dict or {}
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = CompletionAppConfig(
|
||||
|
||||
@@ -83,8 +83,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
|
||||
@@ -99,7 +97,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = CompletionAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
|
||||
)
|
||||
|
||||
# parse files
|
||||
@@ -132,11 +130,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
extras={},
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
@@ -157,7 +155,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
@@ -197,6 +195,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
try:
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
# chatbot app
|
||||
runner = CompletionAppRunner()
|
||||
@@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -293,7 +293,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
@@ -317,7 +317,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
|
||||
@@ -76,7 +76,7 @@ class CompletionAppRunner(AppRunner):
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
query=query or "",
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationError as e:
|
||||
@@ -122,7 +122,7 @@ class CompletionAppRunner(AppRunner):
|
||||
tenant_id=app_record.tenant_id,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
config=dataset_config,
|
||||
query=query,
|
||||
query=query or "",
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
hit_callback=hit_callback,
|
||||
|
||||
@@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = CompletionAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -51,7 +51,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -81,7 +82,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -2,11 +2,11 @@ import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from sqlalchemy import and_
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
@@ -42,7 +42,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
@@ -144,7 +144,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
:conversation conversation
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config)
|
||||
|
||||
# get from source
|
||||
end_user_id = None
|
||||
@@ -267,7 +267,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return introduction
|
||||
return introduction or ""
|
||||
|
||||
def _get_conversation(self, conversation_id: str):
|
||||
"""
|
||||
@@ -282,7 +282,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_message(self, message_id: str) -> Message:
|
||||
def _get_message(self, message_id: str) -> Optional[Message]:
|
||||
"""
|
||||
Get message by message id
|
||||
:param message_id: message id
|
||||
|
||||
@@ -116,7 +116,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
files=system_files,
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return blocking_response.to_dict()
|
||||
return dict(blocking_response.to_dict())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -36,7 +36,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -65,7 +66,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -24,6 +24,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
@@ -190,16 +191,15 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
inputs: Mapping[str, Any] | None = {}
|
||||
process_data: Mapping[str, Any] | None = {}
|
||||
outputs: Mapping[str, Any] | None = {}
|
||||
execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {}
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
@@ -289,7 +289,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
@@ -349,7 +349,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
|
||||
Reference in New Issue
Block a user