feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -29,6 +29,7 @@ from factories import file_factory
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -145,7 +146,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=file_objs,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
@@ -313,6 +314,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AdvancedChatAppRunner(

View File

@@ -5,6 +5,7 @@ import queue
import re
import threading
from collections.abc import Iterable
from typing import Optional
from core.app.entities.queue_entities import (
MessageQueueMessage,
@@ -15,6 +16,7 @@ from core.app.entities.queue_entities import (
WorkflowQueueMessage,
)
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.model_runtime.entities.model_entities import ModelType
@@ -71,8 +73,9 @@ class AppGeneratorTTSPublisher:
if not voice or voice not in values:
self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2
self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start()
self._last_audio_event: Optional[AudioTrunk] = None
# FIXME better way to handle this threading.start
threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
@@ -92,10 +95,21 @@ class AppGeneratorTTSPublisher:
future_queue.put(futures_result)
break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
self.msg_text += message.event.chunk.delta.message.content
message_content = message.event.chunk.delta.message.content
if not message_content:
continue
if isinstance(message_content, str):
self.msg_text += message_content
elif isinstance(message_content, list):
for content in message_content:
if not isinstance(content, TextPromptMessageContent):
continue
self.msg_text += content.data
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent):
if message.event.outputs is None:
continue
self.msg_text += message.event.outputs.get("output", "")
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
@@ -121,11 +135,10 @@ class AppGeneratorTTSPublisher:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:
self.executor.shutdown(wait=False)
return self.last_message
return self._last_audio_event
audio = self._audio_queue.get_nowait()
if audio and audio.status == "finish":
self.executor.shutdown(wait=False)
self._runtime_thread = None
if audio:
self._last_audio_event = audio
return audio

View File

@@ -109,18 +109,18 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
ConversationVariable.conversation_id == self.conversation.id,
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
db_conversation_variables = session.scalars(stmt).all()
if not db_conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
db_conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
session.add_all(db_conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
conversation_variables = [item.to_variable() for item in db_conversation_variables]
session.commit()

View File

@@ -2,6 +2,7 @@ import json
import logging
import time
from collections.abc import Generator, Mapping
from threading import Thread
from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@@ -64,6 +65,7 @@ from models.enums import CreatedByRole
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
)
@@ -81,6 +83,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None
def __init__(
self,
@@ -131,7 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
def process(self):
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
:return:
@@ -262,8 +265,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return:
"""
# init fake graph runtime state
graph_runtime_state = None
workflow_run = None
graph_runtime_state: Optional[GraphRuntimeState] = None
workflow_run: Optional[WorkflowRun] = None
for queue_message in self._queue_manager.listen():
event = queue_message.event
@@ -315,14 +318,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
response = self._workflow_node_start_to_stream_response(
response_start = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
if response_start:
yield response_start
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
@@ -330,18 +333,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
response = self._workflow_node_finish_to_stream_response(
response_finish = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
if response_finish:
yield response_finish
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
response_finish = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@@ -609,7 +612,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
task_id=self._application_generate_entity.task_id,
id=self._message.id,
files=self._recorded_files,
metadata=extras.get("metadata", {}),
)
def _handle_output_moderation_chunk(self, text: str) -> bool:

View File

@@ -61,7 +61,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict
config_dict = override_config_dict or {}
app_mode = AppMode.value_of(app_model.mode)
app_config = AgentChatAppConfig(

View File

@@ -23,6 +23,7 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -97,7 +98,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
@@ -153,7 +154,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=file_objs,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
@@ -180,7 +181,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
@@ -199,8 +200,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user=user,
stream=streaming,
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
# FIXME: Type hinting issue here, ignore it for now, will fix it later
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
def _generate_worker(
self,
@@ -224,6 +225,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AgentChatAppRunner()

View File

@@ -173,6 +173,8 @@ class AgentChatAppRunner(AppRunner):
return
agent_entity = app_config.agent
if not agent_entity:
raise ValueError("Agent entity not found")
# load tool variables
tool_conversation_variables = self._load_tool_variables(
@@ -200,14 +202,21 @@ class AgentChatAppRunner(AppRunner):
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not model_schema or not model_schema.features:
raise ValueError("Model schema not found")
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
message = db.session.query(Message).filter(Message.id == message.id).first()
conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = db.session.query(Message).filter(Message.id == message.id).first()
if message_result is None:
raise ValueError("Message not found")
db.session.close()
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
@@ -225,12 +234,12 @@ class AgentChatAppRunner(AppRunner):
runner = runner_cls(
tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity,
conversation=conversation,
conversation=conversation_result,
app_config=app_config,
model_config=application_generate_entity.model_conf,
config=agent_entity,
queue_manager=queue_manager,
message=message,
message=message_result,
user_id=application_generate_entity.user_id,
memory=memory,
prompt_messages=prompt_message,
@@ -257,7 +266,7 @@ class AgentChatAppRunner(AppRunner):
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables = (
tool_variables: ToolConversationVariables | None = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == conversation_id,

View File

@@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -51,8 +51,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
def convert_stream_full_response( # type: ignore[override]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -82,8 +83,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
def convert_stream_simple_response( # type: ignore[override]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
"""
Convert stream simple response.

View File

@@ -50,7 +50,7 @@ class AppQueueManager:
# wait for APP_MAX_EXECUTION_TIME seconds to stop listen
listen_timeout = dify_config.APP_MAX_EXECUTION_TIME
start_time = time.time()
last_ping_time = 0
last_ping_time: int | float = 0
while True:
try:
message = self._q.get(timeout=1)

View File

@@ -1,5 +1,5 @@
import time
from collections.abc import Generator, Mapping
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
@@ -36,8 +36,8 @@ class AppRunner:
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["File"],
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
) -> int:
"""
@@ -64,7 +64,7 @@ class AppRunner:
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
@@ -85,7 +85,7 @@ class AppRunner:
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
@@ -111,7 +111,7 @@ class AppRunner:
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
@@ -136,8 +136,8 @@ class AppRunner:
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["File"],
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
@@ -156,6 +156,7 @@ class AppRunner:
"""
# get prompt without memory and context
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform]
prompt_transform = SimplePromptTransform()
prompt_messages, stop = prompt_transform.get_prompt(
app_mode=AppMode.value_of(app_record.mode),
@@ -171,8 +172,11 @@ class AppRunner:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
model_mode = ModelMode.value_of(model_config.mode)
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
if not advanced_completion_prompt_template:
raise InvokeBadRequestError("Advanced completion prompt template is required.")
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
if advanced_completion_prompt_template.role_prefix:
@@ -181,6 +185,8 @@ class AppRunner:
assistant=advanced_completion_prompt_template.role_prefix.assistant,
)
else:
if not prompt_template_entity.advanced_chat_prompt_template:
raise InvokeBadRequestError("Advanced chat prompt template is required.")
prompt_template = []
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
@@ -246,7 +252,7 @@ class AppRunner:
def _handle_invoke_result(
self,
invoke_result: Union[LLMResult, Generator],
invoke_result: Union[LLMResult, Generator[Any, None, None]],
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
@@ -259,10 +265,12 @@ class AppRunner:
:param agent: agent
:return:
"""
if not stream:
if not stream and isinstance(invoke_result, LLMResult):
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
else:
elif stream and isinstance(invoke_result, Generator):
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct(
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
@@ -291,8 +299,8 @@ class AppRunner:
:param agent: agent
:return:
"""
model = None
prompt_messages = []
model: str = ""
prompt_messages: list[PromptMessage] = []
text = ""
usage = None
for result in invoke_result:
@@ -328,13 +336,14 @@ class AppRunner:
def moderation_for_inputs(
self,
*,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
query: str | None = None,
message_id: str,
) -> tuple[bool, dict, str]:
) -> tuple[bool, Mapping[str, Any], str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
@@ -350,7 +359,7 @@ class AppRunner:
app_id=app_id,
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
inputs=dict(inputs),
query=query or "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
@@ -390,9 +399,9 @@ class AppRunner:
tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
inputs: Mapping[str, Any],
query: str,
) -> dict:
) -> Mapping[str, Any]:
"""
Fill in variable inputs from external data tools if exists.

View File

@@ -24,6 +24,7 @@ from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.model import App, EndUser
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -91,7 +92,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
@@ -104,7 +105,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# validate config
override_model_config_dict = ChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config")
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
)
# always enable retriever resource in debugger mode
@@ -146,7 +147,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=file_objs,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
invoke_from=invoke_from,
@@ -172,7 +173,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
@@ -216,6 +217,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = ChatAppRunner()

View File

@@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -52,7 +52,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -83,7 +84,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.

View File

@@ -42,7 +42,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict
config_dict = override_config_dict or {}
app_mode = AppMode.value_of(app_model.mode)
app_config = CompletionAppConfig(

View File

@@ -83,8 +83,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {}
# get conversation
conversation = None
@@ -99,7 +97,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# validate config
override_model_config_dict = CompletionAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config")
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
)
# parse files
@@ -132,11 +130,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=file_objs,
files=list(file_objs),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
extras={},
trace_manager=trace_manager,
)
@@ -157,7 +155,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"message_id": message.id,
@@ -197,6 +195,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
try:
# get message
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError()
# chatbot app
runner = CompletionAppRunner()
@@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
"""
Generate App response.
@@ -293,7 +293,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
model_conf=ModelConfigConverter.convert(app_config),
inputs=message.inputs,
query=message.query,
files=file_objs,
files=list(file_objs),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
@@ -317,7 +317,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"message_id": message.id,

View File

@@ -76,7 +76,7 @@ class CompletionAppRunner(AppRunner):
tenant_id=app_config.tenant_id,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
query=query or "",
message_id=message.id,
)
except ModerationError as e:
@@ -122,7 +122,7 @@ class CompletionAppRunner(AppRunner):
tenant_id=app_record.tenant_id,
model_config=application_generate_entity.model_conf,
config=dataset_config,
query=query,
query=query or "",
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback,

View File

@@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -51,7 +51,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
cls,
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -81,7 +82,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
cls,
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.

View File

@@ -2,11 +2,11 @@ import json
import logging
from collections.abc import Generator
from datetime import UTC, datetime
from typing import Optional, Union
from typing import Optional, Union, cast
from sqlalchemy import and_
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import (
@@ -42,7 +42,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
],
queue_manager: AppQueueManager,
conversation: Conversation,
@@ -144,7 +144,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:conversation conversation
:return:
"""
app_config = application_generate_entity.app_config
app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config)
# get from source
end_user_id = None
@@ -267,7 +267,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
except KeyError:
pass
return introduction
return introduction or ""
def _get_conversation(self, conversation_id: str):
"""
@@ -282,7 +282,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return conversation
def _get_message(self, message_id: str) -> Message:
def _get_message(self, message_id: str) -> Optional[Message]:
"""
Get message by message id
:param message_id: message id

View File

@@ -116,7 +116,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
files=system_files,
files=list(system_files),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,

View File

@@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return blocking_response.to_dict()
return dict(blocking_response.to_dict())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -36,7 +36,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
cls,
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -65,7 +66,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
cls,
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.

View File

@@ -24,6 +24,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
@@ -190,16 +191,15 @@ class WorkflowBasedAppRunner(AppRunner):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.route_node_state.node_run_result
inputs: Mapping[str, Any] | None = {}
process_data: Mapping[str, Any] | None = {}
outputs: Mapping[str, Any] | None = {}
execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {}
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
else:
inputs = {}
process_data = {}
outputs = {}
execution_metadata = {}
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
@@ -289,7 +289,7 @@ class WorkflowBasedAppRunner(AppRunner):
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
@@ -349,7 +349,7 @@ class WorkflowBasedAppRunner(AppRunner):
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata