feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -66,6 +66,8 @@ class DatasetConfigManager:
dataset_configs = config.get("dataset_configs")
else:
dataset_configs = {"retrieval_model": "multiple"}
if dataset_configs is None:
return None
query_variable = config.get("dataset_query_variable")
if dataset_configs["retrieval_model"] == "single":

View File

@@ -94,7 +94,7 @@ class ModelConfigManager:
config["model"]["completion_params"]
)
return config, ["model"]
return dict(config), ["model"]
@classmethod
def validate_model_completion_params(cls, cp: dict) -> dict:

View File

@@ -7,10 +7,10 @@ class OpeningStatementConfigManager:
:param config: model config args
"""
# opening statement
opening_statement = config.get("opening_statement")
opening_statement = config.get("opening_statement", "")
# suggested questions
suggested_questions_list = config.get("suggested_questions")
suggested_questions_list = config.get("suggested_questions", [])
return opening_statement, suggested_questions_list

View File

@@ -29,6 +29,7 @@ from factories import file_factory
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -145,7 +146,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=file_objs,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
@@ -313,6 +314,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AdvancedChatAppRunner(

View File

@@ -5,6 +5,7 @@ import queue
import re
import threading
from collections.abc import Iterable
from typing import Optional
from core.app.entities.queue_entities import (
MessageQueueMessage,
@@ -15,6 +16,7 @@ from core.app.entities.queue_entities import (
WorkflowQueueMessage,
)
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.model_runtime.entities.model_entities import ModelType
@@ -71,8 +73,9 @@ class AppGeneratorTTSPublisher:
if not voice or voice not in values:
self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2
self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start()
self._last_audio_event: Optional[AudioTrunk] = None
# FIXME better way to handle this threading.start
threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
@@ -92,10 +95,21 @@ class AppGeneratorTTSPublisher:
future_queue.put(futures_result)
break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
self.msg_text += message.event.chunk.delta.message.content
message_content = message.event.chunk.delta.message.content
if not message_content:
continue
if isinstance(message_content, str):
self.msg_text += message_content
elif isinstance(message_content, list):
for content in message_content:
if not isinstance(content, TextPromptMessageContent):
continue
self.msg_text += content.data
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent):
if message.event.outputs is None:
continue
self.msg_text += message.event.outputs.get("output", "")
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
@@ -121,11 +135,10 @@ class AppGeneratorTTSPublisher:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:
self.executor.shutdown(wait=False)
return self.last_message
return self._last_audio_event
audio = self._audio_queue.get_nowait()
if audio and audio.status == "finish":
self.executor.shutdown(wait=False)
self._runtime_thread = None
if audio:
self._last_audio_event = audio
return audio

View File

@@ -109,18 +109,18 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
ConversationVariable.conversation_id == self.conversation.id,
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
db_conversation_variables = session.scalars(stmt).all()
if not db_conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
db_conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
session.add_all(db_conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
conversation_variables = [item.to_variable() for item in db_conversation_variables]
session.commit()

View File

@@ -2,6 +2,7 @@ import json
import logging
import time
from collections.abc import Generator, Mapping
from threading import Thread
from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@@ -64,6 +65,7 @@ from models.enums import CreatedByRole
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
)
@@ -81,6 +83,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None
def __init__(
self,
@@ -131,7 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
def process(self):
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
:return:
@@ -262,8 +265,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return:
"""
# init fake graph runtime state
graph_runtime_state = None
workflow_run = None
graph_runtime_state: Optional[GraphRuntimeState] = None
workflow_run: Optional[WorkflowRun] = None
for queue_message in self._queue_manager.listen():
event = queue_message.event
@@ -315,14 +318,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
response = self._workflow_node_start_to_stream_response(
response_start = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
if response_start:
yield response_start
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
@@ -330,18 +333,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
response = self._workflow_node_finish_to_stream_response(
response_finish = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
if response_finish:
yield response_finish
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
response_finish = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@@ -609,7 +612,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
task_id=self._application_generate_entity.task_id,
id=self._message.id,
files=self._recorded_files,
metadata=extras.get("metadata", {}),
)
def _handle_output_moderation_chunk(self, text: str) -> bool:

View File

@@ -61,7 +61,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict
config_dict = override_config_dict or {}
app_mode = AppMode.value_of(app_model.mode)
app_config = AgentChatAppConfig(

View File

@@ -23,6 +23,7 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -97,7 +98,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
@@ -153,7 +154,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=file_objs,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
@@ -180,7 +181,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
@@ -199,8 +200,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user=user,
stream=streaming,
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
# FIXME: Type hinting issue here, ignore it for now, will fix it later
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
def _generate_worker(
self,
@@ -224,6 +225,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AgentChatAppRunner()

View File

@@ -173,6 +173,8 @@ class AgentChatAppRunner(AppRunner):
return
agent_entity = app_config.agent
if not agent_entity:
raise ValueError("Agent entity not found")
# load tool variables
tool_conversation_variables = self._load_tool_variables(
@@ -200,14 +202,21 @@ class AgentChatAppRunner(AppRunner):
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not model_schema or not model_schema.features:
raise ValueError("Model schema not found")
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
message = db.session.query(Message).filter(Message.id == message.id).first()
conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = db.session.query(Message).filter(Message.id == message.id).first()
if message_result is None:
raise ValueError("Message not found")
db.session.close()
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
@@ -225,12 +234,12 @@ class AgentChatAppRunner(AppRunner):
runner = runner_cls(
tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity,
conversation=conversation,
conversation=conversation_result,
app_config=app_config,
model_config=application_generate_entity.model_conf,
config=agent_entity,
queue_manager=queue_manager,
message=message,
message=message_result,
user_id=application_generate_entity.user_id,
memory=memory,
prompt_messages=prompt_message,
@@ -257,7 +266,7 @@ class AgentChatAppRunner(AppRunner):
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables = (
tool_variables: ToolConversationVariables | None = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == conversation_id,

View File

@@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -51,8 +51,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
def convert_stream_full_response( # type: ignore[override]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -82,8 +83,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
def convert_stream_simple_response( # type: ignore[override]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
"""
Convert stream simple response.

View File

@@ -50,7 +50,7 @@ class AppQueueManager:
# wait for APP_MAX_EXECUTION_TIME seconds to stop listen
listen_timeout = dify_config.APP_MAX_EXECUTION_TIME
start_time = time.time()
last_ping_time = 0
last_ping_time: int | float = 0
while True:
try:
message = self._q.get(timeout=1)

View File

@@ -1,5 +1,5 @@
import time
from collections.abc import Generator, Mapping
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
@@ -36,8 +36,8 @@ class AppRunner:
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["File"],
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
) -> int:
"""
@@ -64,7 +64,7 @@ class AppRunner:
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
@@ -85,7 +85,7 @@ class AppRunner:
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
@@ -111,7 +111,7 @@ class AppRunner:
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
@@ -136,8 +136,8 @@ class AppRunner:
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["File"],
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
@@ -156,6 +156,7 @@ class AppRunner:
"""
# get prompt without memory and context
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform]
prompt_transform = SimplePromptTransform()
prompt_messages, stop = prompt_transform.get_prompt(
app_mode=AppMode.value_of(app_record.mode),
@@ -171,8 +172,11 @@ class AppRunner:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
model_mode = ModelMode.value_of(model_config.mode)
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
if not advanced_completion_prompt_template:
raise InvokeBadRequestError("Advanced completion prompt template is required.")
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
if advanced_completion_prompt_template.role_prefix:
@@ -181,6 +185,8 @@ class AppRunner:
assistant=advanced_completion_prompt_template.role_prefix.assistant,
)
else:
if not prompt_template_entity.advanced_chat_prompt_template:
raise InvokeBadRequestError("Advanced chat prompt template is required.")
prompt_template = []
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
@@ -246,7 +252,7 @@ class AppRunner:
def _handle_invoke_result(
self,
invoke_result: Union[LLMResult, Generator],
invoke_result: Union[LLMResult, Generator[Any, None, None]],
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
@@ -259,10 +265,12 @@ class AppRunner:
:param agent: agent
:return:
"""
if not stream:
if not stream and isinstance(invoke_result, LLMResult):
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
else:
elif stream and isinstance(invoke_result, Generator):
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct(
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
@@ -291,8 +299,8 @@ class AppRunner:
:param agent: agent
:return:
"""
model = None
prompt_messages = []
model: str = ""
prompt_messages: list[PromptMessage] = []
text = ""
usage = None
for result in invoke_result:
@@ -328,13 +336,14 @@ class AppRunner:
def moderation_for_inputs(
self,
*,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
query: str | None = None,
message_id: str,
) -> tuple[bool, dict, str]:
) -> tuple[bool, Mapping[str, Any], str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
@@ -350,7 +359,7 @@ class AppRunner:
app_id=app_id,
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
inputs=dict(inputs),
query=query or "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
@@ -390,9 +399,9 @@ class AppRunner:
tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
inputs: Mapping[str, Any],
query: str,
) -> dict:
) -> Mapping[str, Any]:
"""
Fill in variable inputs from external data tools if exists.

View File

@@ -24,6 +24,7 @@ from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.model import App, EndUser
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -91,7 +92,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
@@ -104,7 +105,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# validate config
override_model_config_dict = ChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config")
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
)
# always enable retriever resource in debugger mode
@@ -146,7 +147,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=file_objs,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
invoke_from=invoke_from,
@@ -172,7 +173,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
@@ -216,6 +217,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = ChatAppRunner()

View File

@@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -52,7 +52,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -83,7 +84,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.

View File

@@ -42,7 +42,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict
config_dict = override_config_dict or {}
app_mode = AppMode.value_of(app_model.mode)
app_config = CompletionAppConfig(

View File

@@ -83,8 +83,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {}
# get conversation
conversation = None
@@ -99,7 +97,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# validate config
override_model_config_dict = CompletionAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config")
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
)
# parse files
@@ -132,11 +130,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=file_objs,
files=list(file_objs),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
extras={},
trace_manager=trace_manager,
)
@@ -157,7 +155,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"message_id": message.id,
@@ -197,6 +195,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
try:
# get message
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError()
# chatbot app
runner = CompletionAppRunner()
@@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
"""
Generate App response.
@@ -293,7 +293,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
model_conf=ModelConfigConverter.convert(app_config),
inputs=message.inputs,
query=message.query,
files=file_objs,
files=list(file_objs),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
@@ -317,7 +317,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"message_id": message.id,

View File

@@ -76,7 +76,7 @@ class CompletionAppRunner(AppRunner):
tenant_id=app_config.tenant_id,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
query=query or "",
message_id=message.id,
)
except ModerationError as e:
@@ -122,7 +122,7 @@ class CompletionAppRunner(AppRunner):
tenant_id=app_record.tenant_id,
model_config=application_generate_entity.model_conf,
config=dataset_config,
query=query,
query=query or "",
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback,

View File

@@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -51,7 +51,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
cls,
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -81,7 +82,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
cls,
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.

View File

@@ -2,11 +2,11 @@ import json
import logging
from collections.abc import Generator
from datetime import UTC, datetime
from typing import Optional, Union
from typing import Optional, Union, cast
from sqlalchemy import and_
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import (
@@ -42,7 +42,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
],
queue_manager: AppQueueManager,
conversation: Conversation,
@@ -144,7 +144,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:conversation conversation
:return:
"""
app_config = application_generate_entity.app_config
app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config)
# get from source
end_user_id = None
@@ -267,7 +267,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
except KeyError:
pass
return introduction
return introduction or ""
def _get_conversation(self, conversation_id: str):
"""
@@ -282,7 +282,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return conversation
def _get_message(self, message_id: str) -> Message:
def _get_message(self, message_id: str) -> Optional[Message]:
"""
Get message by message id
:param message_id: message id

View File

@@ -116,7 +116,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
files=system_files,
files=list(system_files),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,

View File

@@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return blocking_response.to_dict()
return dict(blocking_response.to_dict())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -36,7 +36,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
cls,
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -65,7 +66,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
cls,
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.

View File

@@ -24,6 +24,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
@@ -190,16 +191,15 @@ class WorkflowBasedAppRunner(AppRunner):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.route_node_state.node_run_result
inputs: Mapping[str, Any] | None = {}
process_data: Mapping[str, Any] | None = {}
outputs: Mapping[str, Any] | None = {}
execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {}
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
else:
inputs = {}
process_data = {}
outputs = {}
execution_metadata = {}
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
@@ -289,7 +289,7 @@ class WorkflowBasedAppRunner(AppRunner):
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
@@ -349,7 +349,7 @@ class WorkflowBasedAppRunner(AppRunner):
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata

View File

@@ -5,7 +5,7 @@ from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from constants import UUID_NIL
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
@@ -79,7 +79,7 @@ class AppGenerateEntity(BaseModel):
task_id: str
# app config
app_config: AppConfig
app_config: Any
file_upload_config: Optional[FileUploadConfig] = None
inputs: Mapping[str, Any]

View File

@@ -308,7 +308,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: Optional[str] = None
"""single iteration duration map"""

View File

@@ -70,7 +70,7 @@ class StreamResponse(BaseModel):
event: StreamEvent
task_id: str
def to_dict(self) -> dict:
def to_dict(self):
return jsonable_encoder(self)
@@ -474,8 +474,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
title: str
created_at: int
extras: dict = {}
metadata: dict = {}
inputs: dict = {}
metadata: Mapping = {}
inputs: Mapping = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
@@ -526,15 +526,15 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Optional[dict] = None
outputs: Optional[Mapping] = None
created_at: int
extras: Optional[dict] = None
inputs: Optional[dict] = None
inputs: Optional[Mapping] = None
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
elapsed_time: float
total_tokens: int
execution_metadata: Optional[dict] = None
execution_metadata: Optional[Mapping] = None
finished_at: int
steps: int
parallel_id: Optional[str] = None
@@ -628,7 +628,7 @@ class AppBlockingResponse(BaseModel):
task_id: str
def to_dict(self) -> dict:
def to_dict(self):
return jsonable_encoder(self)

View File

@@ -58,7 +58,7 @@ class AnnotationReplyFeature:
query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]}
)
if documents:
if documents and documents[0].metadata:
annotation_id = documents[0].metadata["annotation_id"]
score = documents[0].metadata["score"]
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)

View File

@@ -17,7 +17,7 @@ class RateLimit:
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict = {}
_instance_dict: dict[str, "RateLimit"] = {}
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:

View File

@@ -62,6 +62,7 @@ class BasedGenerateTaskPipeline:
"""
logger.debug("error: %s", event.error)
e = event.error
err: Exception
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
@@ -130,6 +131,7 @@ class BasedGenerateTaskPipeline:
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
queue_manager=self._queue_manager,
)
return None
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
"""

View File

@@ -2,6 +2,7 @@ import json
import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Optional, Union, cast
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@@ -103,7 +104,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
)
)
self._conversation_name_generate_thread = None
self._conversation_name_generate_thread: Optional[Thread] = None
def process(
self,
@@ -123,7 +124,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation, self._application_generate_entity.query
self._conversation, self._application_generate_entity.query or ""
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
@@ -146,7 +147,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation.mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@@ -154,7 +155,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
id=self._message.id,
mode=self._conversation.mode,
message_id=self._message.id,
answer=self._task_state.llm_result.message.content,
answer=cast(str, self._task_state.llm_result.message.content),
created_at=int(self._message.created_at.timestamp()),
**extras,
),
@@ -167,7 +168,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
mode=self._conversation.mode,
conversation_id=self._conversation.id,
message_id=self._message.id,
answer=self._task_state.llm_result.message.content,
answer=cast(str, self._task_state.llm_result.message.content),
created_at=int(self._message.created_at.timestamp()),
**extras,
),
@@ -252,7 +253,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@@ -269,13 +270,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
if event.llm_result:
self._task_state.llm_result = event.llm_result
else:
self._handle_stop(event)
# handle output moderation
output_moderation_answer = self._handle_output_moderation_when_task_finished(
self._task_state.llm_result.message.content
cast(str, self._task_state.llm_result.message.content)
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
@@ -292,7 +294,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if annotation:
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
yield self._agent_thought_to_stream_response(event)
agent_thought_response = self._agent_thought_to_stream_response(event)
if agent_thought_response is not None:
yield agent_thought_response
elif isinstance(event, QueueMessageFileEvent):
response = self._message_file_to_stream_response(event)
if response:
@@ -307,16 +311,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
continue
self._task_state.llm_result.message.content += delta_text
current_content = cast(str, self._task_state.llm_result.message.content)
current_content += cast(str, delta_text)
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
yield self._message_to_stream_response(delta_text, self._message.id)
yield self._message_to_stream_response(cast(str, delta_text), self._message.id)
else:
yield self._agent_message_to_stream_response(delta_text, self._message.id)
yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
@@ -336,8 +342,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
llm_result = self._task_state.llm_result
usage = llm_result.usage
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
message = db.session.query(Message).filter(Message.id == self._message.id).first()
if not message:
raise Exception(f"Message {self._message.id} not found")
self._message = message
conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
if not conversation:
raise Exception(f"Conversation {self._conversation.id} not found")
self._conversation = conversation
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode, self._task_state.llm_result.prompt_messages
@@ -346,7 +358,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
self._message.answer = (
PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
if llm_result.message.content
else ""
)
@@ -374,6 +386,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
and hasattr(self._application_generate_entity, "conversation_id")
and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
)
@@ -420,7 +433,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
extras["metadata"] = self._task_state.metadata
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
task_id=self._application_generate_entity.task_id,
id=self._message.id,
metadata=extras.get("metadata", {}),
)
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
@@ -440,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
:param event: agent thought event
:return:
"""
agent_thought: MessageAgentThought = (
agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
)
db.session.refresh(agent_thought)

View File

@@ -128,7 +128,7 @@ class MessageCycleManage:
"""
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
if message_file:
if message_file and message_file.url is not None:
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension

View File

@@ -93,7 +93,7 @@ class WorkflowCycleManage:
)
# handle special values
inputs = WorkflowEntry.handle_special_values(inputs)
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
with Session(db.engine, expire_on_commit=False) as session:
@@ -192,7 +192,7 @@ class WorkflowCycleManage:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
outputs = WorkflowEntry.handle_special_values(outputs)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
workflow_run.outputs = json.dumps(outputs or {})
@@ -500,7 +500,7 @@ class WorkflowCycleManage:
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
inputs=workflow_run.inputs_dict,
inputs=dict(workflow_run.inputs_dict or {}),
created_at=int(workflow_run.created_at.timestamp()),
),
)
@@ -545,7 +545,7 @@ class WorkflowCycleManage:
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
status=workflow_run.status,
outputs=workflow_run.outputs_dict,
outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
@@ -553,7 +553,7 @@ class WorkflowCycleManage:
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
exceptions_count=workflow_run.exceptions_count,
),
)
@@ -655,7 +655,7 @@ class WorkflowCycleManage:
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
@@ -838,7 +838,7 @@ class WorkflowCycleManage:
),
)
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
@@ -851,9 +851,11 @@ class WorkflowCycleManage:
# Remove None
files = [file for file in files if file]
# Flatten list
files = [file for sublist in files for file in sublist]
# Flatten the list of sequences into a single list of mappings
flattened_files = [file for sublist in files if sublist for file in sublist]
return files
# Convert to tuple to match Sequence type
return tuple(flattened_files)
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
"""
@@ -891,6 +893,8 @@ class WorkflowCycleManage:
elif isinstance(value, File):
return value.to_dict()
return None
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run