mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 02:46:52 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -66,6 +66,8 @@ class DatasetConfigManager:
|
||||
dataset_configs = config.get("dataset_configs")
|
||||
else:
|
||||
dataset_configs = {"retrieval_model": "multiple"}
|
||||
if dataset_configs is None:
|
||||
return None
|
||||
query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if dataset_configs["retrieval_model"] == "single":
|
||||
|
||||
@@ -94,7 +94,7 @@ class ModelConfigManager:
|
||||
config["model"]["completion_params"]
|
||||
)
|
||||
|
||||
return config, ["model"]
|
||||
return dict(config), ["model"]
|
||||
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict) -> dict:
|
||||
|
||||
@@ -7,10 +7,10 @@ class OpeningStatementConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
# opening statement
|
||||
opening_statement = config.get("opening_statement")
|
||||
opening_statement = config.get("opening_statement", "")
|
||||
|
||||
# suggested questions
|
||||
suggested_questions_list = config.get("suggested_questions")
|
||||
suggested_questions_list = config.get("suggested_questions", [])
|
||||
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -145,7 +146,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@@ -313,6 +314,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner(
|
||||
|
||||
@@ -5,6 +5,7 @@ import queue
|
||||
import re
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
@@ -15,6 +16,7 @@ from core.app.entities.queue_entities import (
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
@@ -71,8 +73,9 @@ class AppGeneratorTTSPublisher:
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get("value")
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self._last_audio_event: Optional[AudioTrunk] = None
|
||||
# FIXME better way to handle this threading.start
|
||||
threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
|
||||
@@ -92,10 +95,21 @@ class AppGeneratorTTSPublisher:
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
self.msg_text += message.event.chunk.delta.message.content
|
||||
message_content = message.event.chunk.delta.message.content
|
||||
if not message_content:
|
||||
continue
|
||||
if isinstance(message_content, str):
|
||||
self.msg_text += message_content
|
||||
elif isinstance(message_content, list):
|
||||
for content in message_content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
self.msg_text += content.data
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||
if message.event.outputs is None:
|
||||
continue
|
||||
self.msg_text += message.event.outputs.get("output", "")
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
@@ -121,11 +135,10 @@ class AppGeneratorTTSPublisher:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False)
|
||||
return self.last_message
|
||||
return self._last_audio_event
|
||||
audio = self._audio_queue.get_nowait()
|
||||
if audio and audio.status == "finish":
|
||||
self.executor.shutdown(wait=False)
|
||||
self._runtime_thread = None
|
||||
if audio:
|
||||
self._last_audio_event = audio
|
||||
return audio
|
||||
|
||||
@@ -109,18 +109,18 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
ConversationVariable.conversation_id == self.conversation.id,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
db_conversation_variables = session.scalars(stmt).all()
|
||||
if not db_conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
db_conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
session.add_all(db_conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
conversation_variables = [item.to_variable() for item in db_conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from threading import Thread
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@@ -64,6 +65,7 @@ from models.enums import CreatedByRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
@@ -81,6 +83,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_user: Union[Account, EndUser]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
_conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -131,7 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._conversation_name_generate_thread = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
|
||||
def process(self):
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@@ -262,8 +265,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:return:
|
||||
"""
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
workflow_run: Optional[WorkflowRun] = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
@@ -315,14 +318,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
response_start = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
if response_start:
|
||||
yield response_start
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
@@ -330,18 +333,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
response_finish = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
if response_finish:
|
||||
yield response_finish
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
response_finish = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
@@ -609,7 +612,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
files=self._recorded_files,
|
||||
metadata=extras.get("metadata", {}),
|
||||
)
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
|
||||
@@ -61,7 +61,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
config_dict = override_config_dict
|
||||
config_dict = override_config_dict or {}
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = AgentChatAppConfig(
|
||||
|
||||
@@ -23,6 +23,7 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -97,7 +98,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get("conversation_id"):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
@@ -153,7 +154,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@@ -180,7 +181,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
@@ -199,8 +200,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
# FIXME: Type hinting issue here, ignore it for now, will fix it later
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
@@ -224,6 +225,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AgentChatAppRunner()
|
||||
|
||||
@@ -173,6 +173,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
return
|
||||
|
||||
agent_entity = app_config.agent
|
||||
if not agent_entity:
|
||||
raise ValueError("Agent entity not found")
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(
|
||||
@@ -200,14 +202,21 @@ class AgentChatAppRunner(AppRunner):
|
||||
# change function call strategy based on LLM model
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if not model_schema or not model_schema.features:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
|
||||
message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
message_result = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
@@ -225,12 +234,12 @@ class AgentChatAppRunner(AppRunner):
|
||||
runner = runner_cls(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
conversation=conversation_result,
|
||||
app_config=app_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
message=message_result,
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
@@ -257,7 +266,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables = (
|
||||
tool_variables: ToolConversationVariables | None = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
|
||||
@@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -51,8 +51,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
def convert_stream_full_response( # type: ignore[override]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -82,8 +83,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
def convert_stream_simple_response( # type: ignore[override]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -50,7 +50,7 @@ class AppQueueManager:
|
||||
# wait for APP_MAX_EXECUTION_TIME seconds to stop listen
|
||||
listen_timeout = dify_config.APP_MAX_EXECUTION_TIME
|
||||
start_time = time.time()
|
||||
last_ping_time = 0
|
||||
last_ping_time: int | float = 0
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
@@ -36,8 +36,8 @@ class AppRunner:
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["File"],
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -64,7 +64,7 @@ class AppRunner:
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
@@ -85,7 +85,7 @@ class AppRunner:
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise InvokeBadRequestError(
|
||||
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
@@ -111,7 +111,7 @@ class AppRunner:
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
@@ -136,8 +136,8 @@ class AppRunner:
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["File"],
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
@@ -156,6 +156,7 @@ class AppRunner:
|
||||
"""
|
||||
# get prompt without memory and context
|
||||
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform]
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_messages, stop = prompt_transform.get_prompt(
|
||||
app_mode=AppMode.value_of(app_record.mode),
|
||||
@@ -171,8 +172,11 @@ class AppRunner:
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
|
||||
if not advanced_completion_prompt_template:
|
||||
raise InvokeBadRequestError("Advanced completion prompt template is required.")
|
||||
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
|
||||
|
||||
if advanced_completion_prompt_template.role_prefix:
|
||||
@@ -181,6 +185,8 @@ class AppRunner:
|
||||
assistant=advanced_completion_prompt_template.role_prefix.assistant,
|
||||
)
|
||||
else:
|
||||
if not prompt_template_entity.advanced_chat_prompt_template:
|
||||
raise InvokeBadRequestError("Advanced chat prompt template is required.")
|
||||
prompt_template = []
|
||||
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
|
||||
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
|
||||
@@ -246,7 +252,7 @@ class AppRunner:
|
||||
|
||||
def _handle_invoke_result(
|
||||
self,
|
||||
invoke_result: Union[LLMResult, Generator],
|
||||
invoke_result: Union[LLMResult, Generator[Any, None, None]],
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False,
|
||||
@@ -259,10 +265,12 @@ class AppRunner:
|
||||
:param agent: agent
|
||||
:return:
|
||||
"""
|
||||
if not stream:
|
||||
if not stream and isinstance(invoke_result, LLMResult):
|
||||
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
else:
|
||||
elif stream and isinstance(invoke_result, Generator):
|
||||
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
else:
|
||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||
|
||||
def _handle_invoke_result_direct(
|
||||
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
|
||||
@@ -291,8 +299,8 @@ class AppRunner:
|
||||
:param agent: agent
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
model: str = ""
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
text = ""
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
@@ -328,13 +336,14 @@ class AppRunner:
|
||||
|
||||
def moderation_for_inputs(
|
||||
self,
|
||||
*,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
query: str | None = None,
|
||||
message_id: str,
|
||||
) -> tuple[bool, dict, str]:
|
||||
) -> tuple[bool, Mapping[str, Any], str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
@@ -350,7 +359,7 @@ class AppRunner:
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
inputs=dict(inputs),
|
||||
query=query or "",
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager,
|
||||
@@ -390,9 +399,9 @@ class AppRunner:
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
) -> dict:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -91,7 +92,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get("conversation_id"):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
@@ -104,7 +105,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = ChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
|
||||
)
|
||||
|
||||
# always enable retriever resource in debugger mode
|
||||
@@ -146,7 +147,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
invoke_from=invoke_from,
|
||||
@@ -172,7 +173,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
@@ -216,6 +217,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = ChatAppRunner()
|
||||
|
||||
@@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -52,7 +52,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -83,7 +84,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -42,7 +42,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
config_dict = override_config_dict
|
||||
config_dict = override_config_dict or {}
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = CompletionAppConfig(
|
||||
|
||||
@@ -83,8 +83,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
|
||||
@@ -99,7 +97,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = CompletionAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
|
||||
)
|
||||
|
||||
# parse files
|
||||
@@ -132,11 +130,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
extras={},
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
@@ -157,7 +155,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
@@ -197,6 +195,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
try:
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
# chatbot app
|
||||
runner = CompletionAppRunner()
|
||||
@@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -293,7 +293,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
@@ -317,7 +317,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
|
||||
@@ -76,7 +76,7 @@ class CompletionAppRunner(AppRunner):
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
query=query or "",
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationError as e:
|
||||
@@ -122,7 +122,7 @@ class CompletionAppRunner(AppRunner):
|
||||
tenant_id=app_record.tenant_id,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
config=dataset_config,
|
||||
query=query,
|
||||
query=query or "",
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
hit_callback=hit_callback,
|
||||
|
||||
@@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = CompletionAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -51,7 +51,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -81,7 +82,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -2,11 +2,11 @@ import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from sqlalchemy import and_
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
@@ -42,7 +42,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
@@ -144,7 +144,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
:conversation conversation
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config)
|
||||
|
||||
# get from source
|
||||
end_user_id = None
|
||||
@@ -267,7 +267,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return introduction
|
||||
return introduction or ""
|
||||
|
||||
def _get_conversation(self, conversation_id: str):
|
||||
"""
|
||||
@@ -282,7 +282,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_message(self, message_id: str) -> Message:
|
||||
def _get_message(self, message_id: str) -> Optional[Message]:
|
||||
"""
|
||||
Get message by message id
|
||||
:param message_id: message id
|
||||
|
||||
@@ -116,7 +116,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
files=system_files,
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return blocking_response.to_dict()
|
||||
return dict(blocking_response.to_dict())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -36,7 +36,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -65,7 +66,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -24,6 +24,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
@@ -190,16 +191,15 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
inputs: Mapping[str, Any] | None = {}
|
||||
process_data: Mapping[str, Any] | None = {}
|
||||
outputs: Mapping[str, Any] | None = {}
|
||||
execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {}
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
@@ -289,7 +289,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
@@ -349,7 +349,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
@@ -79,7 +79,7 @@ class AppGenerateEntity(BaseModel):
|
||||
task_id: str
|
||||
|
||||
# app config
|
||||
app_config: AppConfig
|
||||
app_config: Any
|
||||
file_upload_config: Optional[FileUploadConfig] = None
|
||||
|
||||
inputs: Mapping[str, Any]
|
||||
|
||||
@@ -308,7 +308,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
"""single iteration duration map"""
|
||||
|
||||
@@ -70,7 +70,7 @@ class StreamResponse(BaseModel):
|
||||
event: StreamEvent
|
||||
task_id: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self):
|
||||
return jsonable_encoder(self)
|
||||
|
||||
|
||||
@@ -474,8 +474,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = {}
|
||||
metadata: dict = {}
|
||||
inputs: dict = {}
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
@@ -526,15 +526,15 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Optional[dict] = None
|
||||
outputs: Optional[Mapping] = None
|
||||
created_at: int
|
||||
extras: Optional[dict] = None
|
||||
inputs: Optional[dict] = None
|
||||
inputs: Optional[Mapping] = None
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
execution_metadata: Optional[dict] = None
|
||||
execution_metadata: Optional[Mapping] = None
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: Optional[str] = None
|
||||
@@ -628,7 +628,7 @@ class AppBlockingResponse(BaseModel):
|
||||
|
||||
task_id: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self):
|
||||
return jsonable_encoder(self)
|
||||
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ class AnnotationReplyFeature:
|
||||
query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]}
|
||||
)
|
||||
|
||||
if documents:
|
||||
if documents and documents[0].metadata:
|
||||
annotation_id = documents[0].metadata["annotation_id"]
|
||||
score = documents[0].metadata["score"]
|
||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||
|
||||
@@ -17,7 +17,7 @@ class RateLimit:
|
||||
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
|
||||
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
|
||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict = {}
|
||||
_instance_dict: dict[str, "RateLimit"] = {}
|
||||
|
||||
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
|
||||
@@ -62,6 +62,7 @@ class BasedGenerateTaskPipeline:
|
||||
"""
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
err: Exception
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
@@ -130,6 +131,7 @@ class BasedGenerateTaskPipeline:
|
||||
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
|
||||
queue_manager=self._queue_manager,
|
||||
)
|
||||
return None
|
||||
|
||||
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@@ -103,7 +104,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
)
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
@@ -123,7 +124,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation, self._application_generate_entity.query
|
||||
self._conversation, self._application_generate_entity.query or ""
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
@@ -146,7 +147,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
if self._conversation.mode == AppMode.COMPLETION.value:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -154,7 +155,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
id=self._message.id,
|
||||
mode=self._conversation.mode,
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras,
|
||||
),
|
||||
@@ -167,7 +168,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
mode=self._conversation.mode,
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras,
|
||||
),
|
||||
@@ -252,7 +253,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
|
||||
self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
@@ -269,13 +270,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
if event.llm_result:
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
self._handle_stop(event)
|
||||
|
||||
# handle output moderation
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(
|
||||
self._task_state.llm_result.message.content
|
||||
cast(str, self._task_state.llm_result.message.content)
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
@@ -292,7 +294,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if annotation:
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
yield self._agent_thought_to_stream_response(event)
|
||||
agent_thought_response = self._agent_thought_to_stream_response(event)
|
||||
if agent_thought_response is not None:
|
||||
yield agent_thought_response
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
response = self._message_file_to_stream_response(event)
|
||||
if response:
|
||||
@@ -307,16 +311,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.llm_result.message.content += delta_text
|
||||
current_content = cast(str, self._task_state.llm_result.message.content)
|
||||
current_content += cast(str, delta_text)
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
yield self._message_to_stream_response(delta_text, self._message.id)
|
||||
yield self._message_to_stream_response(cast(str, delta_text), self._message.id)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(delta_text, self._message.id)
|
||||
yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
@@ -336,8 +342,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
llm_result = self._task_state.llm_result
|
||||
usage = llm_result.usage
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
if not message:
|
||||
raise Exception(f"Message {self._message.id} not found")
|
||||
self._message = message
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
if not conversation:
|
||||
raise Exception(f"Conversation {self._conversation.id} not found")
|
||||
self._conversation = conversation
|
||||
|
||||
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode, self._task_state.llm_result.prompt_messages
|
||||
@@ -346,7 +358,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer = (
|
||||
PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
|
||||
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
|
||||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
@@ -374,6 +386,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
|
||||
and hasattr(self._application_generate_entity, "conversation_id")
|
||||
and self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras,
|
||||
)
|
||||
@@ -420,7 +433,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
metadata=extras.get("metadata", {}),
|
||||
)
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
@@ -440,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
:param event: agent thought event
|
||||
:return:
|
||||
"""
|
||||
agent_thought: MessageAgentThought = (
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
|
||||
@@ -128,7 +128,7 @@ class MessageCycleManage:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
|
||||
|
||||
if message_file:
|
||||
if message_file and message_file.url is not None:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
|
||||
@@ -93,7 +93,7 @@ class WorkflowCycleManage:
|
||||
)
|
||||
|
||||
# handle special values
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
# init workflow run
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@@ -192,7 +192,7 @@ class WorkflowCycleManage:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
|
||||
workflow_run.outputs = json.dumps(outputs or {})
|
||||
@@ -500,7 +500,7 @@ class WorkflowCycleManage:
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=workflow_run.inputs_dict,
|
||||
inputs=dict(workflow_run.inputs_dict or {}),
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
),
|
||||
)
|
||||
@@ -545,7 +545,7 @@ class WorkflowCycleManage:
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
status=workflow_run.status,
|
||||
outputs=workflow_run.outputs_dict,
|
||||
outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=workflow_run.elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
@@ -553,7 +553,7 @@ class WorkflowCycleManage:
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
|
||||
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
|
||||
exceptions_count=workflow_run.exceptions_count,
|
||||
),
|
||||
)
|
||||
@@ -655,7 +655,7 @@ class WorkflowCycleManage:
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
@@ -838,7 +838,7 @@ class WorkflowCycleManage:
|
||||
),
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
@@ -851,9 +851,11 @@ class WorkflowCycleManage:
|
||||
# Remove None
|
||||
files = [file for file in files if file]
|
||||
# Flatten list
|
||||
files = [file for sublist in files for file in sublist]
|
||||
# Flatten the list of sequences into a single list of mappings
|
||||
flattened_files = [file for sublist in files if sublist for file in sublist]
|
||||
|
||||
return files
|
||||
# Convert to tuple to match Sequence type
|
||||
return tuple(flattened_files)
|
||||
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
@@ -891,6 +893,8 @@ class WorkflowCycleManage:
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
|
||||
Reference in New Issue
Block a user