feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,7 +1,6 @@
import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Optional, Union, cast
@@ -53,6 +52,7 @@ logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner):
def __init__(
self,
*,
tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
@@ -66,7 +66,7 @@ class BaseAgentRunner(AppRunner):
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance | None = None,
model_instance: ModelInstance,
) -> None:
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
@@ -117,7 +117,7 @@ class BaseAgentRunner(AppRunner):
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
self.query = None
self.query: Optional[str] = ""
self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity(
@@ -145,7 +145,7 @@ class BaseAgentRunner(AppRunner):
message_tool = PromptMessageTool(
name=tool.tool_name,
description=tool_entity.description.llm,
description=tool_entity.description.llm if tool_entity.description else "",
parameters={
"type": "object",
"properties": {},
@@ -167,7 +167,7 @@ class BaseAgentRunner(AppRunner):
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
enum = [option.value for option in parameter.options] if parameter.options else []
message_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
@@ -187,8 +187,8 @@ class BaseAgentRunner(AppRunner):
convert dataset retriever tool to prompt message tool
"""
prompt_tool = PromptMessageTool(
name=tool.identity.name,
description=tool.description.llm,
name=tool.identity.name if tool.identity else "unknown",
description=tool.description.llm if tool.description else "",
parameters={
"type": "object",
"properties": {},
@@ -210,14 +210,14 @@ class BaseAgentRunner(AppRunner):
return prompt_tool
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
"""
Init tools
"""
tool_instances = {}
prompt_messages_tools = []
for tool in self.app_config.agent.tools if self.app_config.agent else []:
for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
@@ -234,7 +234,8 @@ class BaseAgentRunner(AppRunner):
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
tool_instances[dataset_tool.identity.name] = dataset_tool
if dataset_tool.identity is not None:
tool_instances[dataset_tool.identity.name] = dataset_tool
return tool_instances, prompt_messages_tools
@@ -258,7 +259,7 @@ class BaseAgentRunner(AppRunner):
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
enum = [option.value for option in parameter.options] if parameter.options else []
prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
@@ -322,16 +323,21 @@ class BaseAgentRunner(AppRunner):
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
observation: Union[str, dict, None],
tool_invoke_meta: Union[str, dict, None],
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage = None,
) -> MessageAgentThought:
llm_usage: LLMUsage | None = None,
):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
queried_thought = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
)
if not queried_thought:
raise ValueError(f"Agent thought {agent_thought.id} not found")
agent_thought = queried_thought
if thought is not None:
agent_thought.thought = thought
@@ -404,7 +410,7 @@ class BaseAgentRunner(AppRunner):
"""
convert tool variables to db variables
"""
db_variables = (
queried_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
@@ -412,6 +418,11 @@ class BaseAgentRunner(AppRunner):
.first()
)
if not queried_variables:
return
db_variables = queried_variables
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
@@ -421,7 +432,7 @@ class BaseAgentRunner(AppRunner):
"""
Organize agent history
"""
result = []
result: list[PromptMessage] = []
# check if there is a system message in the beginning of the conversation
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):

View File

@@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Optional, Union
from collections.abc import Generator, Mapping
from typing import Any, Optional
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentScratchpadUnit
@@ -12,6 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
@@ -26,18 +27,18 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None
_query: str = None
_prompt_messages_tools: list[PromptMessage] = None
_historic_prompt_messages: list[PromptMessage] | None = None
_agent_scratchpad: list[AgentScratchpadUnit] | None = None
_instruction: str = "" # FIXME this must be str for now
_query: str | None = None
_prompt_messages_tools: list[PromptMessageTool] = []
def run(
self,
message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
inputs: Mapping[str, str],
) -> Generator:
"""
Run Cot agent application
"""
@@ -57,19 +58,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
# convert tools into ModelRuntime Tool format
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True
llm_usage = {"usage": None}
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = ""
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
@@ -90,7 +91,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# the last iteration, remove all tools
self._prompt_messages_tools = []
message_file_ids = []
message_file_ids: list[str] = []
agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
@@ -105,7 +106,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
chunks = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=[],
@@ -115,11 +116,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
callbacks=[],
)
if not isinstance(chunks, Generator):
raise ValueError("Expected streaming response from LLM")
# check llm result
if not chunks:
raise ValueError("failed to invoke llm")
usage_dict = {}
usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
@@ -139,25 +143,30 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
scratchpad.agent_response += json.dumps(chunk.model_dump())
if scratchpad.agent_response is not None:
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action
else:
scratchpad.agent_response += chunk
scratchpad.thought += chunk
if scratchpad.agent_response is not None:
scratchpad.agent_response += chunk
if scratchpad.thought is not None:
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
if scratchpad.thought is not None:
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
if self._agent_scratchpad is not None:
self._agent_scratchpad.append(scratchpad)
# get llm usage
if "usage" in usage_dict:
increase_usage(llm_usage, usage_dict["usage"])
if usage_dict["usage"] is not None:
increase_usage(llm_usage, usage_dict["usage"])
else:
usage_dict["usage"] = LLMUsage.empty_usage()
@@ -166,9 +175,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought,
thought=scratchpad.thought or "",
observation="",
answer=scratchpad.agent_response,
answer=scratchpad.agent_response or "",
messages_ids=[],
llm_usage=usage_dict["usage"],
)
@@ -209,7 +218,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought,
thought=scratchpad.thought or "",
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
@@ -247,8 +256,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
answer=final_answer,
messages_ids=[],
)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
if self.variables_pool is not None and self.db_variables_pool is not None:
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@@ -307,8 +316,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files
for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
if save_as is not None and self.variables_pool:
# FIXME the save_as type is confusing, it should be a string or not
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
# publish message file
self.queue_manager.publish(
@@ -325,7 +335,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
"""
fill in inputs from external data tools
"""
@@ -376,11 +386,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None
current_scratchpad: AgentScratchpadUnit | None = None
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
if not isinstance(message.content, str | None):
raise NotImplementedError("expected str type")
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
@@ -399,8 +411,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
except:
pass
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
if not current_scratchpad:
continue
if isinstance(message.content, str):
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))

View File

@@ -19,7 +19,12 @@ class CotChatAgentRunner(CotAgentRunner):
"""
Organize system prompt
"""
if not self.app_config.agent:
raise ValueError("Agent configuration is not set")
prompt_entity = self.app_config.agent.prompt
if not prompt_entity:
raise ValueError("Agent prompt configuration is not set")
first_prompt = prompt_entity.first_prompt
system_prompt = (
@@ -75,6 +80,7 @@ class CotChatAgentRunner(CotAgentRunner):
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content="")
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad:
if unit.is_final():
assistant_message.content += f"Final Answer: {unit.agent_response}"

View File

@@ -2,7 +2,12 @@ import json
from typing import Optional
from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -11,7 +16,11 @@ class CotCompletionAgentRunner(CotAgentRunner):
"""
Organize instruction prompt
"""
if self.app_config.agent is None:
raise ValueError("Agent configuration is not set")
prompt_entity = self.app_config.agent.prompt
if prompt_entity is None:
raise ValueError("prompt entity is not set")
first_prompt = prompt_entity.first_prompt
system_prompt = (
@@ -33,7 +42,13 @@ class CotCompletionAgentRunner(CotAgentRunner):
if isinstance(message, UserPromptMessage):
historic_prompt += f"Question: {message.content}\n\n"
elif isinstance(message, AssistantPromptMessage):
historic_prompt += message.content + "\n\n"
if isinstance(message.content, str):
historic_prompt += message.content + "\n\n"
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, TextPromptMessageContent):
continue
historic_prompt += content.data
return historic_prompt
@@ -50,7 +65,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
assistant_prompt = ""
for unit in agent_scratchpad:
for unit in agent_scratchpad or []:
if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}"
else:

View File

@@ -78,5 +78,5 @@ class AgentEntity(BaseModel):
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
tools: list[AgentToolEntity] = None
tools: list[AgentToolEntity] | None = None
max_iteration: int = 5

View File

@@ -40,6 +40,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
app_generate_entity = self.application_generate_entity
app_config = self.app_config
assert app_config is not None, "app_config is required"
assert app_config.agent is not None, "app_config.agent is required"
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
@@ -49,7 +51,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call
function_call_state = True
llm_usage = {"usage": None}
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
final_answer = ""
# get tracing instance
@@ -75,7 +77,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# the last iteration, remove all tools
prompt_messages_tools = []
message_file_ids = []
message_file_ids: list[str] = []
agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
@@ -105,7 +107,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
current_llm_usage = None
if self.stream_tool_call:
if self.stream_tool_call and isinstance(chunks, Generator):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
@@ -116,7 +118,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_calls.extend(self.extract_tool_calls(chunk) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
@@ -131,19 +133,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for content in chunk.delta.message.content:
response += content.data
else:
response += chunk.delta.message.content
response += str(chunk.delta.message.content)
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
yield chunk
else:
result: LLMResult = chunks
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
result = chunks
# check if there is any tool call
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
@@ -162,7 +164,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for content in result.message.content:
response += content.data
else:
response += result.message.content
response += str(result.message.content)
if not result.message.content:
result.message.content = ""
@@ -181,6 +183,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
usage=result.usage,
),
)
else:
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
@@ -243,7 +247,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# publish files
for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
if self.variables_pool:
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file_id, name=save_as
)
# publish message file
self.queue_manager.publish(
@@ -263,7 +270,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if tool_response["tool_response"] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=tool_response["tool_response"],
content=str(tool_response["tool_response"]),
tool_call_id=tool_call_id,
name=tool_call_name,
)
@@ -273,9 +280,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=None,
tool_input=None,
thought=None,
tool_name="",
tool_input="",
thought="",
tool_invoke_meta={
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
},
@@ -283,7 +290,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_response["tool_call_name"]: tool_response["tool_response"]
for tool_response in tool_responses
},
answer=None,
answer="",
messages_ids=message_file_ids,
)
self.queue_manager.publish(
@@ -296,7 +303,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1
self.update_db_variables(self.variables_pool, self.db_variables_pool)
if self.variables_pool and self.db_variables_pool:
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@@ -389,9 +397,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
return prompt_messages or []
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
@@ -449,7 +457,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, [])
query_prompt_messages = self._organize_user_query(self.query or "", [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,

View File

@@ -38,7 +38,7 @@ class CotAgentOutputParser:
except:
return json_str or ""
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
if not code_blocks:
return
@@ -67,15 +67,15 @@ class CotAgentOutputParser:
for response in llm_response:
if response.delta.usage:
usage_dict["usage"] = response.delta.usage
response = response.delta.message.content
if not isinstance(response, str):
response_content = response.delta.message.content
if not isinstance(response_content, str):
continue
# stream
index = 0
while index < len(response):
while index < len(response_content):
steps = 1
delta = response[index : index + steps]
delta = response_content[index : index + steps]
yield_delta = False
if delta == "`":

View File

@@ -66,6 +66,8 @@ class DatasetConfigManager:
dataset_configs = config.get("dataset_configs")
else:
dataset_configs = {"retrieval_model": "multiple"}
if dataset_configs is None:
return None
query_variable = config.get("dataset_query_variable")
if dataset_configs["retrieval_model"] == "single":

View File

@@ -94,7 +94,7 @@ class ModelConfigManager:
config["model"]["completion_params"]
)
return config, ["model"]
return dict(config), ["model"]
@classmethod
def validate_model_completion_params(cls, cp: dict) -> dict:

View File

@@ -7,10 +7,10 @@ class OpeningStatementConfigManager:
:param config: model config args
"""
# opening statement
opening_statement = config.get("opening_statement")
opening_statement = config.get("opening_statement", "")
# suggested questions
suggested_questions_list = config.get("suggested_questions")
suggested_questions_list = config.get("suggested_questions", [])
return opening_statement, suggested_questions_list

View File

@@ -29,6 +29,7 @@ from factories import file_factory
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -145,7 +146,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=file_objs,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
@@ -313,6 +314,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AdvancedChatAppRunner(

View File

@@ -5,6 +5,7 @@ import queue
import re
import threading
from collections.abc import Iterable
from typing import Optional
from core.app.entities.queue_entities import (
MessageQueueMessage,
@@ -15,6 +16,7 @@ from core.app.entities.queue_entities import (
WorkflowQueueMessage,
)
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.model_runtime.entities.model_entities import ModelType
@@ -71,8 +73,9 @@ class AppGeneratorTTSPublisher:
if not voice or voice not in values:
self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2
self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start()
self._last_audio_event: Optional[AudioTrunk] = None
# FIXME better way to handle this threading.start
threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
@@ -92,10 +95,21 @@ class AppGeneratorTTSPublisher:
future_queue.put(futures_result)
break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
self.msg_text += message.event.chunk.delta.message.content
message_content = message.event.chunk.delta.message.content
if not message_content:
continue
if isinstance(message_content, str):
self.msg_text += message_content
elif isinstance(message_content, list):
for content in message_content:
if not isinstance(content, TextPromptMessageContent):
continue
self.msg_text += content.data
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent):
if message.event.outputs is None:
continue
self.msg_text += message.event.outputs.get("output", "")
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
@@ -121,11 +135,10 @@ class AppGeneratorTTSPublisher:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:
self.executor.shutdown(wait=False)
return self.last_message
return self._last_audio_event
audio = self._audio_queue.get_nowait()
if audio and audio.status == "finish":
self.executor.shutdown(wait=False)
self._runtime_thread = None
if audio:
self._last_audio_event = audio
return audio

View File

@@ -109,18 +109,18 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
ConversationVariable.conversation_id == self.conversation.id,
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
db_conversation_variables = session.scalars(stmt).all()
if not db_conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
db_conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
session.add_all(db_conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
conversation_variables = [item.to_variable() for item in db_conversation_variables]
session.commit()

View File

@@ -2,6 +2,7 @@ import json
import logging
import time
from collections.abc import Generator, Mapping
from threading import Thread
from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@@ -64,6 +65,7 @@ from models.enums import CreatedByRole
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
)
@@ -81,6 +83,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None
def __init__(
self,
@@ -131,7 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
def process(self):
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
:return:
@@ -262,8 +265,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return:
"""
# init fake graph runtime state
graph_runtime_state = None
workflow_run = None
graph_runtime_state: Optional[GraphRuntimeState] = None
workflow_run: Optional[WorkflowRun] = None
for queue_message in self._queue_manager.listen():
event = queue_message.event
@@ -315,14 +318,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
response = self._workflow_node_start_to_stream_response(
response_start = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
if response_start:
yield response_start
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
@@ -330,18 +333,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
response = self._workflow_node_finish_to_stream_response(
response_finish = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
if response_finish:
yield response_finish
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
response_finish = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@@ -609,7 +612,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
task_id=self._application_generate_entity.task_id,
id=self._message.id,
files=self._recorded_files,
metadata=extras.get("metadata", {}),
)
def _handle_output_moderation_chunk(self, text: str) -> bool:

View File

@@ -61,7 +61,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict
config_dict = override_config_dict or {}
app_mode = AppMode.value_of(app_model.mode)
app_config = AgentChatAppConfig(

View File

@@ -23,6 +23,7 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -97,7 +98,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
@@ -153,7 +154,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=file_objs,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
@@ -180,7 +181,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
@@ -199,8 +200,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user=user,
stream=streaming,
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
# FIXME: Type hinting issue here, ignore it for now, will fix it later
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
def _generate_worker(
self,
@@ -224,6 +225,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AgentChatAppRunner()

View File

@@ -173,6 +173,8 @@ class AgentChatAppRunner(AppRunner):
return
agent_entity = app_config.agent
if not agent_entity:
raise ValueError("Agent entity not found")
# load tool variables
tool_conversation_variables = self._load_tool_variables(
@@ -200,14 +202,21 @@ class AgentChatAppRunner(AppRunner):
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not model_schema or not model_schema.features:
raise ValueError("Model schema not found")
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
message = db.session.query(Message).filter(Message.id == message.id).first()
conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = db.session.query(Message).filter(Message.id == message.id).first()
if message_result is None:
raise ValueError("Message not found")
db.session.close()
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
@@ -225,12 +234,12 @@ class AgentChatAppRunner(AppRunner):
runner = runner_cls(
tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity,
conversation=conversation,
conversation=conversation_result,
app_config=app_config,
model_config=application_generate_entity.model_conf,
config=agent_entity,
queue_manager=queue_manager,
message=message,
message=message_result,
user_id=application_generate_entity.user_id,
memory=memory,
prompt_messages=prompt_message,
@@ -257,7 +266,7 @@ class AgentChatAppRunner(AppRunner):
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables = (
tool_variables: ToolConversationVariables | None = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == conversation_id,

View File

@@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -51,8 +51,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
def convert_stream_full_response( # type: ignore[override]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -82,8 +83,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
def convert_stream_simple_response( # type: ignore[override]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
"""
Convert stream simple response.

View File

@@ -50,7 +50,7 @@ class AppQueueManager:
# wait for APP_MAX_EXECUTION_TIME seconds to stop listen
listen_timeout = dify_config.APP_MAX_EXECUTION_TIME
start_time = time.time()
last_ping_time = 0
last_ping_time: int | float = 0
while True:
try:
message = self._q.get(timeout=1)

View File

@@ -1,5 +1,5 @@
import time
from collections.abc import Generator, Mapping
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
@@ -36,8 +36,8 @@ class AppRunner:
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["File"],
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
) -> int:
"""
@@ -64,7 +64,7 @@ class AppRunner:
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
@@ -85,7 +85,7 @@ class AppRunner:
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
@@ -111,7 +111,7 @@ class AppRunner:
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
@@ -136,8 +136,8 @@ class AppRunner:
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["File"],
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
@@ -156,6 +156,7 @@ class AppRunner:
"""
# get prompt without memory and context
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform]
prompt_transform = SimplePromptTransform()
prompt_messages, stop = prompt_transform.get_prompt(
app_mode=AppMode.value_of(app_record.mode),
@@ -171,8 +172,11 @@ class AppRunner:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
model_mode = ModelMode.value_of(model_config.mode)
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
if not advanced_completion_prompt_template:
raise InvokeBadRequestError("Advanced completion prompt template is required.")
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
if advanced_completion_prompt_template.role_prefix:
@@ -181,6 +185,8 @@ class AppRunner:
assistant=advanced_completion_prompt_template.role_prefix.assistant,
)
else:
if not prompt_template_entity.advanced_chat_prompt_template:
raise InvokeBadRequestError("Advanced chat prompt template is required.")
prompt_template = []
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
@@ -246,7 +252,7 @@ class AppRunner:
def _handle_invoke_result(
self,
invoke_result: Union[LLMResult, Generator],
invoke_result: Union[LLMResult, Generator[Any, None, None]],
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
@@ -259,10 +265,12 @@ class AppRunner:
:param agent: agent
:return:
"""
if not stream:
if not stream and isinstance(invoke_result, LLMResult):
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
else:
elif stream and isinstance(invoke_result, Generator):
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct(
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
@@ -291,8 +299,8 @@ class AppRunner:
:param agent: agent
:return:
"""
model = None
prompt_messages = []
model: str = ""
prompt_messages: list[PromptMessage] = []
text = ""
usage = None
for result in invoke_result:
@@ -328,13 +336,14 @@ class AppRunner:
def moderation_for_inputs(
self,
*,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
query: str | None = None,
message_id: str,
) -> tuple[bool, dict, str]:
) -> tuple[bool, Mapping[str, Any], str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
@@ -350,7 +359,7 @@ class AppRunner:
app_id=app_id,
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
inputs=dict(inputs),
query=query or "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
@@ -390,9 +399,9 @@ class AppRunner:
tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
inputs: Mapping[str, Any],
query: str,
) -> dict:
) -> Mapping[str, Any]:
"""
Fill in variable inputs from external data tools if exists.

View File

@@ -24,6 +24,7 @@ from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.model import App, EndUser
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -91,7 +92,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
@@ -104,7 +105,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# validate config
override_model_config_dict = ChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config")
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
)
# always enable retriever resource in debugger mode
@@ -146,7 +147,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=file_objs,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
invoke_from=invoke_from,
@@ -172,7 +173,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
@@ -216,6 +217,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = ChatAppRunner()

View File

@@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -52,7 +52,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -83,7 +84,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.

View File

@@ -42,7 +42,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict
config_dict = override_config_dict or {}
app_mode = AppMode.value_of(app_model.mode)
app_config = CompletionAppConfig(

View File

@@ -83,8 +83,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {}
# get conversation
conversation = None
@@ -99,7 +97,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# validate config
override_model_config_dict = CompletionAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config")
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
)
# parse files
@@ -132,11 +130,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=file_objs,
files=list(file_objs),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
extras={},
trace_manager=trace_manager,
)
@@ -157,7 +155,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"message_id": message.id,
@@ -197,6 +195,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
try:
# get message
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError()
# chatbot app
runner = CompletionAppRunner()
@@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
"""
Generate App response.
@@ -293,7 +293,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
model_conf=ModelConfigConverter.convert(app_config),
inputs=message.inputs,
query=message.query,
files=file_objs,
files=list(file_objs),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
@@ -317,7 +317,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"message_id": message.id,

View File

@@ -76,7 +76,7 @@ class CompletionAppRunner(AppRunner):
tenant_id=app_config.tenant_id,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
query=query or "",
message_id=message.id,
)
except ModerationError as e:
@@ -122,7 +122,7 @@ class CompletionAppRunner(AppRunner):
tenant_id=app_record.tenant_id,
model_config=application_generate_entity.model_conf,
config=dataset_config,
query=query,
query=query or "",
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback,

View File

@@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -51,7 +51,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
cls,
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -81,7 +82,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
cls,
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.

View File

@@ -2,11 +2,11 @@ import json
import logging
from collections.abc import Generator
from datetime import UTC, datetime
from typing import Optional, Union
from typing import Optional, Union, cast
from sqlalchemy import and_
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import (
@@ -42,7 +42,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
],
queue_manager: AppQueueManager,
conversation: Conversation,
@@ -144,7 +144,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:conversation conversation
:return:
"""
app_config = application_generate_entity.app_config
app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config)
# get from source
end_user_id = None
@@ -267,7 +267,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
except KeyError:
pass
return introduction
return introduction or ""
def _get_conversation(self, conversation_id: str):
"""
@@ -282,7 +282,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return conversation
def _get_message(self, message_id: str) -> Message:
def _get_message(self, message_id: str) -> Optional[Message]:
"""
Get message by message id
:param message_id: message id

View File

@@ -116,7 +116,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
files=system_files,
files=list(system_files),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,

View File

@@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return blocking_response.to_dict()
return dict(blocking_response.to_dict())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -36,7 +36,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
cls,
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -65,7 +66,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
cls,
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.

View File

@@ -24,6 +24,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
@@ -190,16 +191,15 @@ class WorkflowBasedAppRunner(AppRunner):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.route_node_state.node_run_result
inputs: Mapping[str, Any] | None = {}
process_data: Mapping[str, Any] | None = {}
outputs: Mapping[str, Any] | None = {}
execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {}
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
else:
inputs = {}
process_data = {}
outputs = {}
execution_metadata = {}
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
@@ -289,7 +289,7 @@ class WorkflowBasedAppRunner(AppRunner):
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
@@ -349,7 +349,7 @@ class WorkflowBasedAppRunner(AppRunner):
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata

View File

@@ -5,7 +5,7 @@ from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from constants import UUID_NIL
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
@@ -79,7 +79,7 @@ class AppGenerateEntity(BaseModel):
task_id: str
# app config
app_config: AppConfig
app_config: Any
file_upload_config: Optional[FileUploadConfig] = None
inputs: Mapping[str, Any]

View File

@@ -308,7 +308,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: Optional[str] = None
"""single iteration duration map"""

View File

@@ -70,7 +70,7 @@ class StreamResponse(BaseModel):
event: StreamEvent
task_id: str
def to_dict(self) -> dict:
def to_dict(self):
return jsonable_encoder(self)
@@ -474,8 +474,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
title: str
created_at: int
extras: dict = {}
metadata: dict = {}
inputs: dict = {}
metadata: Mapping = {}
inputs: Mapping = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
@@ -526,15 +526,15 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Optional[dict] = None
outputs: Optional[Mapping] = None
created_at: int
extras: Optional[dict] = None
inputs: Optional[dict] = None
inputs: Optional[Mapping] = None
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
elapsed_time: float
total_tokens: int
execution_metadata: Optional[dict] = None
execution_metadata: Optional[Mapping] = None
finished_at: int
steps: int
parallel_id: Optional[str] = None
@@ -628,7 +628,7 @@ class AppBlockingResponse(BaseModel):
task_id: str
def to_dict(self) -> dict:
def to_dict(self):
return jsonable_encoder(self)

View File

@@ -58,7 +58,7 @@ class AnnotationReplyFeature:
query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]}
)
if documents:
if documents and documents[0].metadata:
annotation_id = documents[0].metadata["annotation_id"]
score = documents[0].metadata["score"]
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)

View File

@@ -17,7 +17,7 @@ class RateLimit:
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict = {}
_instance_dict: dict[str, "RateLimit"] = {}
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:

View File

@@ -62,6 +62,7 @@ class BasedGenerateTaskPipeline:
"""
logger.debug("error: %s", event.error)
e = event.error
err: Exception
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
@@ -130,6 +131,7 @@ class BasedGenerateTaskPipeline:
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
queue_manager=self._queue_manager,
)
return None
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
"""

View File

@@ -2,6 +2,7 @@ import json
import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Optional, Union, cast
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@@ -103,7 +104,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
)
)
self._conversation_name_generate_thread = None
self._conversation_name_generate_thread: Optional[Thread] = None
def process(
self,
@@ -123,7 +124,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation, self._application_generate_entity.query
self._conversation, self._application_generate_entity.query or ""
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
@@ -146,7 +147,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation.mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@@ -154,7 +155,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
id=self._message.id,
mode=self._conversation.mode,
message_id=self._message.id,
answer=self._task_state.llm_result.message.content,
answer=cast(str, self._task_state.llm_result.message.content),
created_at=int(self._message.created_at.timestamp()),
**extras,
),
@@ -167,7 +168,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
mode=self._conversation.mode,
conversation_id=self._conversation.id,
message_id=self._message.id,
answer=self._task_state.llm_result.message.content,
answer=cast(str, self._task_state.llm_result.message.content),
created_at=int(self._message.created_at.timestamp()),
**extras,
),
@@ -252,7 +253,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@@ -269,13 +270,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
if event.llm_result:
self._task_state.llm_result = event.llm_result
else:
self._handle_stop(event)
# handle output moderation
output_moderation_answer = self._handle_output_moderation_when_task_finished(
self._task_state.llm_result.message.content
cast(str, self._task_state.llm_result.message.content)
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
@@ -292,7 +294,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if annotation:
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
yield self._agent_thought_to_stream_response(event)
agent_thought_response = self._agent_thought_to_stream_response(event)
if agent_thought_response is not None:
yield agent_thought_response
elif isinstance(event, QueueMessageFileEvent):
response = self._message_file_to_stream_response(event)
if response:
@@ -307,16 +311,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
continue
self._task_state.llm_result.message.content += delta_text
current_content = cast(str, self._task_state.llm_result.message.content)
current_content += cast(str, delta_text)
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
yield self._message_to_stream_response(delta_text, self._message.id)
yield self._message_to_stream_response(cast(str, delta_text), self._message.id)
else:
yield self._agent_message_to_stream_response(delta_text, self._message.id)
yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
@@ -336,8 +342,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
llm_result = self._task_state.llm_result
usage = llm_result.usage
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
message = db.session.query(Message).filter(Message.id == self._message.id).first()
if not message:
raise Exception(f"Message {self._message.id} not found")
self._message = message
conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
if not conversation:
raise Exception(f"Conversation {self._conversation.id} not found")
self._conversation = conversation
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode, self._task_state.llm_result.prompt_messages
@@ -346,7 +358,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
self._message.answer = (
PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
if llm_result.message.content
else ""
)
@@ -374,6 +386,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
and hasattr(self._application_generate_entity, "conversation_id")
and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
)
@@ -420,7 +433,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
extras["metadata"] = self._task_state.metadata
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
task_id=self._application_generate_entity.task_id,
id=self._message.id,
metadata=extras.get("metadata", {}),
)
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
@@ -440,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
:param event: agent thought event
:return:
"""
agent_thought: MessageAgentThought = (
agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
)
db.session.refresh(agent_thought)

View File

@@ -128,7 +128,7 @@ class MessageCycleManage:
"""
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
if message_file:
if message_file and message_file.url is not None:
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension

View File

@@ -93,7 +93,7 @@ class WorkflowCycleManage:
)
# handle special values
inputs = WorkflowEntry.handle_special_values(inputs)
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
with Session(db.engine, expire_on_commit=False) as session:
@@ -192,7 +192,7 @@ class WorkflowCycleManage:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
outputs = WorkflowEntry.handle_special_values(outputs)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
workflow_run.outputs = json.dumps(outputs or {})
@@ -500,7 +500,7 @@ class WorkflowCycleManage:
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
inputs=workflow_run.inputs_dict,
inputs=dict(workflow_run.inputs_dict or {}),
created_at=int(workflow_run.created_at.timestamp()),
),
)
@@ -545,7 +545,7 @@ class WorkflowCycleManage:
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
status=workflow_run.status,
outputs=workflow_run.outputs_dict,
outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
@@ -553,7 +553,7 @@ class WorkflowCycleManage:
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
exceptions_count=workflow_run.exceptions_count,
),
)
@@ -655,7 +655,7 @@ class WorkflowCycleManage:
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
@@ -838,7 +838,7 @@ class WorkflowCycleManage:
),
)
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
@@ -851,9 +851,11 @@ class WorkflowCycleManage:
# Remove None
files = [file for file in files if file]
# Flatten list
files = [file for sublist in files for file in sublist]
# Flatten the list of sequences into a single list of mappings
flattened_files = [file for sublist in files if sublist for file in sublist]
return files
# Convert to tuple to match Sequence type
return tuple(flattened_files)
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
"""
@@ -891,6 +893,8 @@ class WorkflowCycleManage:
elif isinstance(value, File):
return value.to_dict()
return None
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run

View File

@@ -57,7 +57,7 @@ class DifyAgentCallbackHandler(BaseModel):
self,
tool_name: str,
tool_inputs: Mapping[str, Any],
tool_outputs: Sequence[ToolInvokeMessage],
tool_outputs: Sequence[ToolInvokeMessage] | str,
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None,

View File

@@ -40,17 +40,18 @@ class DatasetIndexToolCallbackHandler:
def on_tool_end(self, documents: list[Document]) -> None:
"""Handle tool end."""
for document in documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
if document.metadata is not None:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
# add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
db.session.commit()
db.session.commit()
def return_retriever_resource_info(self, resource: list):
"""Handle return_retriever_resource_info."""

View File

@@ -1,3 +1,4 @@
from collections.abc import Sequence
from enum import Enum
from typing import Optional
@@ -72,7 +73,7 @@ class DefaultModelProviderEntity(BaseModel):
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]
supported_model_types: Sequence[ModelType] = []
class DefaultModelEntity(BaseModel):

View File

@@ -40,7 +40,7 @@ from models.provider import (
logger = logging.getLogger(__name__)
original_provider_configurate_methods = {}
original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
class ProviderConfiguration(BaseModel):
@@ -99,7 +99,8 @@ class ProviderConfiguration(BaseModel):
continue
restrict_models = quota_configuration.restrict_models
if self.system_configuration.credentials is None:
return None
copy_credentials = self.system_configuration.credentials.copy()
if restrict_models:
for restrict_model in restrict_models:
@@ -124,7 +125,7 @@ class ProviderConfiguration(BaseModel):
return credentials
def get_system_configuration_status(self) -> SystemConfigurationStatus:
def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
"""
Get system configuration status.
:return:
@@ -136,6 +137,8 @@ class ProviderConfiguration(BaseModel):
current_quota_configuration = next(
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
)
if current_quota_configuration is None:
return None
return (
SystemConfigurationStatus.ACTIVE
@@ -150,7 +153,7 @@ class ProviderConfiguration(BaseModel):
"""
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
def get_custom_credentials(self, obfuscated: bool = False):
"""
Get custom credentials.
@@ -172,7 +175,7 @@ class ProviderConfiguration(BaseModel):
else [],
)
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
@@ -324,7 +327,7 @@ class ProviderConfiguration(BaseModel):
def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict
) -> tuple[ProviderModel, dict]:
) -> tuple[Optional[ProviderModel], dict]:
"""
Validate custom model credentials.
@@ -740,10 +743,10 @@ class ProviderConfiguration(BaseModel):
if model_type:
model_types.append(model_type)
else:
model_types = provider_instance.get_provider_schema().supported_model_types
model_types = list(provider_instance.get_provider_schema().supported_model_types)
# Group model settings by model type and model
model_setting_map = defaultdict(dict)
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
for model_setting in self.model_settings:
model_setting_map[model_setting.model_type][model_setting.model] = model_setting
@@ -822,54 +825,57 @@ class ProviderConfiguration(BaseModel):
]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials["base_model_name"] = restrict_model.base_model_name
if self.system_configuration.credentials is not None:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials["base_model_name"] = restrict_model.base_model_name
try:
custom_model_schema = provider_instance.get_model_instance(
restrict_model.model_type
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}")
continue
try:
custom_model_schema = provider_instance.get_model_instance(
restrict_model.model_type
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}")
continue
if not custom_model_schema:
continue
if not custom_model_schema:
continue
if custom_model_schema.model_type not in model_types:
continue
if custom_model_schema.model_type not in model_types:
continue
status = ModelStatus.ACTIVE
if (
custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
status = ModelStatus.ACTIVE
if (
custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
):
model_setting = model_setting_map[custom_model_schema.model_type][
custom_model_schema.model
]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status,
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status,
)
)
)
# if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models]
for m in provider_models:
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
m.status = ModelStatus.NO_PERMISSION
for model in provider_models:
if model.model_type == ModelType.LLM and m.model not in restrict_model_names:
model.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
m.status = ModelStatus.QUOTA_EXCEEDED
model.status = ModelStatus.QUOTA_EXCEEDED
return provider_models
@@ -1043,7 +1049,7 @@ class ProviderConfigurations(BaseModel):
return iter(self.configurations)
def values(self) -> Iterator[ProviderConfiguration]:
return self.configurations.values()
return iter(self.configurations.values())
def get(self, key, default=None):
return self.configurations.get(key, default)

View File

@@ -1,3 +1,5 @@
from typing import cast
import requests
from configs import dify_config
@@ -5,7 +7,7 @@ from models.api_based_extension import APIBasedExtensionPoint
class APIBasedExtensionRequestor:
timeout: (int, int) = (5, 60)
timeout: tuple[int, int] = (5, 60)
"""timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str) -> None:
@@ -51,4 +53,4 @@ class APIBasedExtensionRequestor:
"request error, status_code: {}, content: {}".format(response.status_code, response.text[:100])
)
return response.json()
return cast(dict, response.json())

View File

@@ -38,8 +38,8 @@ class Extensible:
@classmethod
def scan_extensions(cls):
extensions: list[ModuleExtension] = []
position_map = {}
extensions = []
position_map: dict[str, int] = {}
# get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
@@ -58,7 +58,8 @@ class Extensible:
# is builtin extension, builtin extension
# in the front-end page and business logic, there are special treatments.
builtin = False
position = None
# default position is 0 can not be None for sort_to_dict_by_position_map
position = 0
if "__builtin__" in file_names:
builtin = True
@@ -89,7 +90,7 @@ class Extensible:
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
continue
json_data = {}
json_data: dict[str, Any] = {}
if not builtin:
if "schema.json" not in file_names:
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")

View File

@@ -1,4 +1,6 @@
from core.extension.extensible import ExtensionModule, ModuleExtension
from typing import cast
from core.extension.extensible import Extensible, ExtensionModule, ModuleExtension
from core.external_data_tool.base import ExternalDataTool
from core.moderation.base import Moderation
@@ -10,7 +12,8 @@ class Extension:
def init(self):
for module, module_class in self.module_classes.items():
self.__module_extensions[module.value] = module_class.scan_extensions()
m = cast(Extensible, module_class)
self.__module_extensions[module.value] = m.scan_extensions()
def module_extensions(self, module: str) -> list[ModuleExtension]:
module_extensions = self.__module_extensions.get(module)
@@ -35,7 +38,8 @@ class Extension:
def extension_class(self, module: ExtensionModule, extension_name: str) -> type:
module_extension = self.module_extension(module, extension_name)
return module_extension.extension_class
t: type = module_extension.extension_class
return t
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
module_extension = self.module_extension(module, extension_name)

View File

@@ -48,7 +48,10 @@ class ApiExternalDataTool(ExternalDataTool):
:return: the tool query result
"""
# get params from config
if not self.config:
raise ValueError("config is required, config: {}".format(self.config))
api_based_extension_id = self.config.get("api_based_extension_id")
assert api_based_extension_id is not None, "api_based_extension_id is required"
# get api_based_extension
api_based_extension = (

View File

@@ -1,7 +1,7 @@
import concurrent
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from collections.abc import Mapping
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from typing import Any, Optional
from flask import Flask, current_app
@@ -17,9 +17,9 @@ class ExternalDataFetch:
tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
inputs: Mapping[str, Any],
query: str,
) -> dict:
) -> Mapping[str, Any]:
"""
Fill in variable inputs from external data tools if exists.
@@ -30,13 +30,14 @@ class ExternalDataFetch:
:param query: the query
:return: the filled inputs
"""
results = {}
results: dict[str, Any] = {}
inputs = dict(inputs)
with ThreadPoolExecutor() as executor:
futures = {}
for tool in external_data_tools:
future = executor.submit(
future: Future[tuple[str | None, str | None]] = executor.submit(
self._query_external_data_tool,
current_app._get_current_object(),
current_app._get_current_object(), # type: ignore
tenant_id,
app_id,
tool,
@@ -46,9 +47,10 @@ class ExternalDataFetch:
futures[future] = tool
for future in concurrent.futures.as_completed(futures):
for future in as_completed(futures):
tool_variable, result = future.result()
results[tool_variable] = result
if tool_variable is not None:
results[tool_variable] = result
inputs.update(results)
return inputs
@@ -59,7 +61,7 @@ class ExternalDataFetch:
tenant_id: str,
app_id: str,
external_data_tool: ExternalDataVariableEntity,
inputs: dict,
inputs: Mapping[str, Any],
query: str,
) -> tuple[Optional[str], Optional[str]]:
"""

View File

@@ -1,4 +1,5 @@
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional, cast
from core.extension.extensible import ExtensionModule
from extensions.ext_code_based_extension import code_based_extension
@@ -23,9 +24,10 @@ class ExternalDataToolFactory:
"""
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
extension_class.validate_config(tenant_id, config)
# FIXME mypy issue here, figure out how to fix it
extension_class.validate_config(tenant_id, config) # type: ignore
def query(self, inputs: dict, query: Optional[str] = None) -> str:
def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str:
"""
Query the external data tool.
@@ -33,4 +35,4 @@ class ExternalDataToolFactory:
:param query: the query of chat app
:return: the tool query result
"""
return self.__extension_instance.query(inputs, query)
return cast(str, self.__extension_instance.query(inputs, query))

View File

@@ -1,4 +1,5 @@
import base64
from collections.abc import Mapping
from configs import dify_config
from core.helper import ssrf_proxy
@@ -55,7 +56,7 @@ def to_prompt_message_content(
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_class_map = {
prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
@@ -63,7 +64,7 @@ def to_prompt_message_content(
}
try:
return prompt_class_map[f.type](**params)
return prompt_class_map[f.type].model_validate(params)
except KeyError:
raise ValueError(f"file type {f.type} is not supported")

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
if TYPE_CHECKING:
from core.tools.tool_file_manager import ToolFileManager
@@ -9,4 +9,4 @@ tool_file_manager: dict[str, Any] = {"manager": None}
class ToolFileParser:
@staticmethod
def get_tool_file_manager() -> "ToolFileManager":
return tool_file_manager["manager"]
return cast("ToolFileManager", tool_file_manager["manager"])

View File

@@ -38,7 +38,7 @@ class CodeLanguage(StrEnum):
class CodeExecutor:
dependencies_cache = {}
dependencies_cache: dict[str, str] = {}
dependencies_cache_lock = Lock()
code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = {
@@ -103,19 +103,19 @@ class CodeExecutor:
)
try:
response = response.json()
response_data = response.json()
except:
raise CodeExecutionError("Failed to parse response")
if (code := response.get("code")) != 0:
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}")
if (code := response_data.get("code")) != 0:
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
response = CodeExecutionResponse(**response)
response_code = CodeExecutionResponse(**response_data)
if response.data.error:
raise CodeExecutionError(response.data.error)
if response_code.data.error:
raise CodeExecutionError(response_code.data.error)
return response.data.stdout or ""
return response_code.data.stdout or ""
@classmethod
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]):

View File

@@ -1,9 +1,11 @@
from collections.abc import Mapping
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
class Jinja2Formatter:
@classmethod
def format(cls, template: str, inputs: dict) -> str:
def format(cls, template: str, inputs: Mapping[str, str]) -> str:
"""
Format template
:param template: template
@@ -11,5 +13,4 @@ class Jinja2Formatter:
:return:
"""
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
return result["result"]
return str(result.get("result", ""))

View File

@@ -29,8 +29,7 @@ class TemplateTransformer(ABC):
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
if not result:
raise ValueError("Failed to parse result")
result = result.group(1)
return result
return result.group(1)
@classmethod
def transform_response(cls, response: str) -> Mapping[str, Any]:

View File

@@ -4,7 +4,7 @@ from typing import Any
class LRUCache:
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.cache: OrderedDict[Any, Any] = OrderedDict()
self.capacity = capacity
def get(self, key: Any) -> Any:

View File

@@ -30,7 +30,7 @@ class ProviderCredentialsCache:
except JSONDecodeError:
return None
return cached_provider_credentials
return dict(cached_provider_credentials)
else:
return None

View File

@@ -22,6 +22,7 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
provider_name = model_config.provider
if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
hosting_openai_config = hosting_configuration.provider_map["openai"]
assert hosting_openai_config is not None
# 2000 text per chunk
length = 2000
@@ -34,8 +35,9 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
try:
model_type_instance = OpenAIModerationModel()
# FIXME, for type hint using assert or raise ValueError is better here?
moderation_result = model_type_instance.invoke(
model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk
model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk
)
if moderation_result is True:

View File

@@ -14,12 +14,13 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
if existed_spec:
spec = existed_spec
if not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(module_name, py_file_path)
# FIXME: mypy does not support the type of spec.loader
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore
if not spec or not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
if use_lazy_loader:
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
spec.loader = importlib.util.LazyLoader(spec.loader)
@@ -29,7 +30,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
spec.loader.exec_module(module)
return module
except Exception as e:
logging.exception(f"Failed to load module {module_name} from script file '{py_file_path}'")
logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'")
raise e
@@ -57,6 +58,6 @@ def load_single_subclass_from_source(
case 1:
return subclasses[0]
case 0:
raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}")
raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path!r}")
case _:
raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}")
raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path!r}")

View File

@@ -33,7 +33,7 @@ class ToolParameterCache:
except JSONDecodeError:
return None
return cached_tool_parameter
return dict(cached_tool_parameter)
else:
return None

View File

@@ -28,7 +28,7 @@ class ToolProviderCredentialsCache:
except JSONDecodeError:
return None
return cached_provider_credentials
return dict(cached_provider_credentials)
else:
return None

View File

@@ -42,7 +42,7 @@ class HostedModerationConfig(BaseModel):
class HostingConfiguration:
provider_map: dict[str, HostingProvider] = {}
moderation_config: HostedModerationConfig = None
moderation_config: Optional[HostedModerationConfig] = None
def init_app(self, app: Flask) -> None:
if dify_config.EDITION != "CLOUD":
@@ -67,7 +67,7 @@ class HostingConfiguration:
"base_model_name": "gpt-35-turbo",
}
quotas = []
quotas: list[HostingQuota] = []
hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
@@ -123,7 +123,7 @@ class HostingConfiguration:
def init_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas = []
quotas: list[HostingQuota] = []
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
@@ -157,7 +157,7 @@ class HostingConfiguration:
@staticmethod
def init_anthropic() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
quotas = []
quotas: list[HostingQuota] = []
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
@@ -187,7 +187,7 @@ class HostingConfiguration:
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if dify_config.HOSTED_MINIMAX_ENABLED:
quotas = [FreeHostingQuota()]
quotas: list[HostingQuota] = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
@@ -205,7 +205,7 @@ class HostingConfiguration:
def init_spark() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if dify_config.HOSTED_SPARK_ENABLED:
quotas = [FreeHostingQuota()]
quotas: list[HostingQuota] = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
@@ -223,7 +223,7 @@ class HostingConfiguration:
def init_zhipuai() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if dify_config.HOSTED_ZHIPUAI_ENABLED:
quotas = [FreeHostingQuota()]
quotas: list[HostingQuota] = [FreeHostingQuota()]
return HostingProvider(
enabled=True,

View File

@@ -6,10 +6,10 @@ import re
import threading
import time
import uuid
from typing import Optional, cast
from typing import Any, Optional, cast
from flask import Flask, current_app
from flask_login import current_user
from flask_login import current_user # type: ignore
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
@@ -62,6 +62,8 @@ class IndexingRunner:
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
@@ -120,6 +122,8 @@ class IndexingRunner:
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
@@ -254,7 +258,7 @@ class IndexingRunner:
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
preview_texts = []
preview_texts: list[str] = []
total_segments = 0
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
@@ -285,7 +289,8 @@ class IndexingRunner:
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
try:
storage.delete(image_file.key)
if image_file:
storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed while indexing_estimate, \
@@ -379,8 +384,9 @@ class IndexingRunner:
# replace doc id to document model id
text_docs = cast(list[Document], text_docs)
for text_doc in text_docs:
text_doc.metadata["document_id"] = dataset_document.id
text_doc.metadata["dataset_id"] = dataset_document.dataset_id
if text_doc.metadata is not None:
text_doc.metadata["document_id"] = dataset_document.id
text_doc.metadata["dataset_id"] = dataset_document.dataset_id
return text_docs
@@ -400,6 +406,7 @@ class IndexingRunner:
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule.mode == "custom":
# The user-defined segmentation rule
rules = json.loads(processing_rule.rules)
@@ -426,9 +433,10 @@ class IndexingRunner:
)
else:
# Automatic segmentation
automatic_rules: dict[str, Any] = dict(DatasetProcessRule.AUTOMATIC_RULES["segmentation"])
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"],
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"],
chunk_size=automatic_rules["max_tokens"],
chunk_overlap=automatic_rules["chunk_overlap"],
separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance,
)
@@ -497,8 +505,8 @@ class IndexingRunner:
"""
Split the text documents into nodes.
"""
all_documents = []
all_qa_documents = []
all_documents: list[Document] = []
all_qa_documents: list[Document] = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.page_content, processing_rule)
@@ -509,10 +517,11 @@ class IndexingRunner:
split_documents = []
for document_node in documents:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
if document_node.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
document_node.page_content = remove_leading_symbols(page_content)
@@ -529,7 +538,7 @@ class IndexingRunner:
document_format_thread = threading.Thread(
target=self.format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": tenant_id,
"document_node": doc,
"all_qa_documents": all_qa_documents,
@@ -557,11 +566,12 @@ class IndexingRunner:
qa_document = Document(
page_content=result["question"], metadata=document_node.metadata.model_copy()
)
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result["question"])
qa_document.metadata["answer"] = result["answer"]
qa_document.metadata["doc_id"] = doc_id
qa_document.metadata["doc_hash"] = hash
if qa_document.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result["question"])
qa_document.metadata["answer"] = result["answer"]
qa_document.metadata["doc_id"] = doc_id
qa_document.metadata["doc_hash"] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
@@ -575,7 +585,7 @@ class IndexingRunner:
"""
Split the text documents into nodes.
"""
all_documents = []
all_documents: list[Document] = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.page_content, processing_rule)
@@ -588,11 +598,11 @@ class IndexingRunner:
for document in documents:
if document.page_content is None or not document.page_content.strip():
continue
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata["doc_id"] = doc_id
document.metadata["doc_hash"] = hash
if document.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata["doc_id"] = doc_id
document.metadata["doc_hash"] = hash
split_documents.append(document)
@@ -648,7 +658,7 @@ class IndexingRunner:
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
)
create_keyword_thread.start()
if dataset.indexing_technique == "high_quality":
@@ -659,7 +669,7 @@ class IndexingRunner:
futures.append(
executor.submit(
self._process_chunk,
current_app._get_current_object(),
current_app._get_current_object(), # type: ignore
index_processor,
chunk_documents,
dataset,

View File

@@ -1,7 +1,7 @@
import json
import logging
import re
from typing import Optional
from typing import Optional, cast
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
@@ -13,6 +13,7 @@ from core.llm_generator.prompts import (
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
@@ -44,10 +45,13 @@ class LLMGenerator:
prompts = [UserPromptMessage(content=prompt)]
with measure_time() as timer:
response = model_instance.invoke_llm(
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
),
)
answer = response.message.content
answer = cast(str, response.message.content)
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
if cleaned_answer is None:
return ""
@@ -94,11 +98,16 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
try:
response = model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
),
)
questions = output_parser.parse(response.message.content)
questions = output_parser.parse(cast(str, response.message.content))
except InvokeError:
questions = []
except Exception as e:
@@ -138,11 +147,14 @@ class LLMGenerator:
)
try:
response = model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
),
)
rule_config["prompt"] = response.message.content
rule_config["prompt"] = cast(str, response.message.content)
except InvokeError as e:
error = str(e)
@@ -178,15 +190,18 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider") if model_config else None,
model=model_config.get("name") if model_config else None,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
try:
try:
# the first step to generate the task prompt
prompt_content = model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
prompt_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
),
)
except InvokeError as e:
error = str(e)
@@ -195,8 +210,10 @@ class LLMGenerator:
return rule_config
rule_config["prompt"] = prompt_content.message.content
rule_config["prompt"] = cast(str, prompt_content.message.content)
if not isinstance(prompt_content.message.content, str):
raise NotImplementedError("prompt content is not a string")
parameter_generate_prompt = parameter_template.format(
inputs={
"INPUT_TEXT": prompt_content.message.content,
@@ -216,19 +233,25 @@ class LLMGenerator:
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
try:
parameter_content = model_instance.invoke_llm(
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
parameter_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
),
)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
except InvokeError as e:
error = str(e)
error_step = "generate variables"
try:
statement_content = model_instance.invoke_llm(
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
statement_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
),
)
rule_config["opening_statement"] = statement_content.message.content
rule_config["opening_statement"] = cast(str, statement_content.message.content)
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
@@ -267,19 +290,22 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider") if model_config else None,
model=model_config.get("name") if model_config else None,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
try:
response = model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
),
)
generated_code = response.message.content
generated_code = cast(str, response.message.content)
return {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
@@ -303,9 +329,14 @@ class LLMGenerator:
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
response = model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False,
),
)
answer = response.message.content
answer = cast(str, response.message.content)
return answer.strip()

View File

@@ -68,7 +68,7 @@ class TokenBufferMemory:
messages = list(reversed(thread_messages))
prompt_messages = []
prompt_messages: list[PromptMessage] = []
for message in messages:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:

View File

@@ -124,17 +124,20 @@ class ModelInstance:
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
),
)
def get_llm_num_tokens(
@@ -151,12 +154,15 @@ class ModelInstance:
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools,
return cast(
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools,
),
)
def invoke_text_embedding(
@@ -174,13 +180,16 @@ class ModelInstance:
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
texts=texts,
user=user,
input_type=input_type,
return cast(
TextEmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
texts=texts,
user=user,
input_type=input_type,
),
)
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
@@ -194,11 +203,14 @@ class ModelInstance:
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
credentials=self.credentials,
texts=texts,
return cast(
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
credentials=self.credentials,
texts=texts,
),
)
def invoke_rerank(
@@ -223,15 +235,18 @@ class ModelInstance:
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
@@ -246,12 +261,15 @@ class ModelInstance:
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
text=text,
user=user,
return cast(
bool,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
text=text,
user=user,
),
)
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
@@ -266,12 +284,15 @@ class ModelInstance:
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
file=file,
user=user,
return cast(
str,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
file=file,
user=user,
),
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
@@ -288,17 +309,20 @@ class ModelInstance:
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
return cast(
Iterable[bytes],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
),
)
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any:
"""
Round-robin invoke
:param function: function to invoke
@@ -309,7 +333,7 @@ class ModelInstance:
if not self.load_balancing_manager:
return function(*args, **kwargs)
last_exception = None
last_exception: Union[InvokeRateLimitError, InvokeAuthorizationError, InvokeConnectionError, None] = None
while True:
lb_config = self.load_balancing_manager.fetch_next()
if not lb_config:
@@ -463,7 +487,7 @@ class LBModelManager:
if real_index > max_index:
real_index = 0
config = self._load_balancing_configs[real_index]
config: ModelLoadBalancingConfiguration = self._load_balancing_configs[real_index]
if self.in_cooldown(config):
cooldown_load_balancing_configs.append(config)
@@ -507,8 +531,7 @@ class LBModelManager:
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
)
res = redis_client.exists(cooldown_cache_key)
res = cast(bool, res)
res: bool = redis_client.exists(cooldown_cache_key)
return res
@staticmethod

View File

@@ -1,7 +1,8 @@
import json
import logging
import sys
from typing import Optional
from collections.abc import Sequence
from typing import Optional, cast
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@@ -20,7 +21,7 @@ class LoggingCallback(Callback):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@@ -76,7 +77,7 @@ class LoggingCallback(Callback):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
@@ -94,7 +95,7 @@ class LoggingCallback(Callback):
:param stream: is stream response
:param user: unique user id
"""
sys.stdout.write(chunk.delta.message.content)
sys.stdout.write(cast(str, chunk.delta.message.content))
sys.stdout.flush()
def on_after_invoke(
@@ -106,7 +107,7 @@ class LoggingCallback(Callback):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@@ -147,7 +148,7 @@ class LoggingCallback(Callback):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:

View File

@@ -3,7 +3,7 @@ from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Optional
from pydantic import BaseModel, Field, computed_field, field_validator
from pydantic import BaseModel, Field, field_validator
class PromptMessageRole(Enum):
@@ -89,7 +89,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
url: str = Field(default="", description="the url of multi-modal file")
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
@computed_field(return_type=str)
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"

View File

@@ -1,7 +1,6 @@
import decimal
import os
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Optional
from pydantic import ConfigDict
@@ -36,7 +35,7 @@ class AIModel(ABC):
model_config = ConfigDict(protected_namespaces=())
@abstractmethod
def validate_credentials(self, model: str, credentials: Mapping) -> None:
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
@@ -214,7 +213,7 @@ class AIModel(ABC):
return model_schemas
def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
"""
Get model schema by model name and credentials
@@ -236,9 +235,7 @@ class AIModel(ABC):
return None
def get_customizable_model_schema_from_credentials(
self, model: str, credentials: Mapping
) -> Optional[AIModelEntity]:
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema from credentials
@@ -248,7 +245,7 @@ class AIModel(ABC):
"""
return self._get_customizable_model_schema(model, credentials)
def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema and fill in the template
"""
@@ -301,7 +298,7 @@ class AIModel(ABC):
return schema
def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema

View File

@@ -2,7 +2,7 @@ import logging
import re
import time
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Generator, Sequence
from typing import Optional, Union
from pydantic import ConfigDict
@@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@@ -291,12 +291,12 @@ if you are not sure about the structure.
content = piece.delta.message.content
piece.delta.message.content = ""
yield piece
piece = content
content_piece = content
else:
yield piece
continue
new_piece: str = ""
for char in piece:
for char in content_piece:
char = str(char)
if state == "normal":
if char == "`":
@@ -350,7 +350,7 @@ if you are not sure about the structure.
piece.delta.message.content = ""
# Yield a piece with cleared content before processing it to maintain the generator structure
yield piece
piece = content
content_piece = content
else:
# Yield pieces without content directly
yield piece
@@ -360,7 +360,7 @@ if you are not sure about the structure.
continue
new_piece: str = ""
for char in piece:
for char in content_piece:
if state == "search_start":
if char == "`":
backtick_count += 1
@@ -535,7 +535,7 @@ if you are not sure about the structure.
return []
def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
"""
Get model mode

View File

@@ -104,9 +104,10 @@ class ModelProvider(ABC):
mod = import_module_from_source(
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
)
# FIXME "type" has no attribute "__abstractmethods__" ignore it for now fix it later
model_class = next(
filter(
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, # type: ignore
get_subclasses_from_module(mod, AIModel),
),
None,

View File

@@ -89,7 +89,8 @@ class TextEmbeddingModel(AIModel):
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
return content_size
return 1000
@@ -104,6 +105,7 @@ class TextEmbeddingModel(AIModel):
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
return max_chunks
return 1

View File

@@ -2,9 +2,9 @@ from os.path import abspath, dirname, join
from threading import Lock
from typing import Any
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
_tokenizer = None
_tokenizer: Any = None
_lock = Lock()

View File

@@ -127,7 +127,8 @@ class TTSModel(AIModel):
if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
raise ValueError("this model does not support audio type")
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
audio_type: str = model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
return audio_type
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
"""
@@ -138,8 +139,9 @@ class TTSModel(AIModel):
if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
raise ValueError("this model does not support word limit")
world_limit: int = model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
return world_limit
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
"""
@@ -150,8 +152,9 @@ class TTSModel(AIModel):
if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
raise ValueError("this model does not support max workers")
workers_limit: int = model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
return workers_limit
@staticmethod
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):

View File

@@ -64,10 +64,12 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
if not ai_model_entity:
return None
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
def _get_ai_model_entity(base_model_name: str, model: str) -> Optional[AzureBaseModel]:
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)

View File

@@ -114,6 +114,8 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
if not ai_model_entity:
return None
return ai_model_entity.entity
@staticmethod

View File

@@ -6,9 +6,9 @@ from collections.abc import Generator
from typing import Optional, Union, cast
# 3rd import
import boto3
from botocore.config import Config
from botocore.exceptions import (
import boto3 # type: ignore
from botocore.config import Config # type: ignore
from botocore.exceptions import ( # type: ignore
ClientError,
EndpointConnectionError,
NoRegionError,

View File

@@ -44,7 +44,7 @@ class CohereRerankModel(RerankModel):
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=docs)
return RerankResult(model=model, docs=[])
# initialize client
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
@@ -62,7 +62,7 @@ class CohereRerankModel(RerankModel):
# format document
rerank_document = RerankDocument(
index=result.index,
text=result.document.text,
text=result.document.text if result.document else "",
score=result.relevance_score,
)

View File

@@ -1,5 +1,3 @@
from collections.abc import Mapping
import openai
from core.model_runtime.errors.invoke import (
@@ -13,7 +11,7 @@ from core.model_runtime.errors.invoke import (
class _CommonFireworks:
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
def _to_credential_kwargs(self, credentials: dict) -> dict:
"""
Transform credentials to kwargs for model instance

View File

@@ -1,5 +1,4 @@
import time
from collections.abc import Mapping
from typing import Optional, Union
import numpy as np
@@ -93,7 +92,7 @@ class FireworksTextEmbeddingModel(_CommonFireworks, TextEmbeddingModel):
"""
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
def validate_credentials(self, model: str, credentials: Mapping) -> None:
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials

View File

@@ -1,4 +1,4 @@
from dashscope.common.error import (
from dashscope.common.error import ( # type: ignore
AuthenticationError,
InvalidParameter,
RequestFailure,

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional
import httpx
@@ -51,7 +51,7 @@ class GiteeAIRerankModel(RerankModel):
base_url = base_url.removesuffix("/")
try:
body = {"model": model, "query": query, "documents": docs}
body: dict[str, Any] = {"model": model, "query": query, "documents": docs}
if top_n is not None:
body["top_n"] = top_n
response = httpx.post(

View File

@@ -24,7 +24,7 @@ class GiteeAIEmbeddingModel(OAICompatEmbeddingModel):
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict, model: str) -> None:
def _add_custom_parameters(credentials: dict, model: Optional[str]) -> None:
if model is None:
model = "bge-m3"

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional
import requests
@@ -13,9 +13,10 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
Model class for OpenAI text2speech model.
"""
# FIXME this Any return will be better type
def _invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
) -> any:
) -> Any:
"""
_invoke text2speech model
@@ -47,7 +48,8 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
# FIXME this Any return will be better type
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any:
"""
_tts_invoke_streaming text2speech model
:param model: model name

View File

@@ -7,7 +7,7 @@ from collections.abc import Generator
from typing import Optional, Union
import google.ai.generativelanguage as glm
import google.generativeai as genai
import google.generativeai as genai # type: ignore
import requests
from google.api_core import exceptions
from google.generativeai.types import ContentType, File, GenerateContentResponse

View File

@@ -1,4 +1,4 @@
from huggingface_hub.utils import BadRequestError, HfHubHTTPError
from huggingface_hub.utils import BadRequestError, HfHubHTTPError # type: ignore
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError

View File

@@ -1,9 +1,9 @@
from collections.abc import Generator
from typing import Optional, Union
from huggingface_hub import InferenceClient
from huggingface_hub.hf_api import HfApi
from huggingface_hub.utils import BadRequestError
from huggingface_hub import InferenceClient # type: ignore
from huggingface_hub.hf_api import HfApi # type: ignore
from huggingface_hub.utils import BadRequestError # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE

View File

@@ -4,7 +4,7 @@ from typing import Optional
import numpy as np
import requests
from huggingface_hub import HfApi, InferenceClient
from huggingface_hub import HfApi, InferenceClient # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject

View File

@@ -3,11 +3,11 @@ import logging
from collections.abc import Generator
from typing import cast
from tencentcloud.common import credential
from tencentcloud.common.exception import TencentCloudSDKException
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
from tencentcloud.common import credential # type: ignore
from tencentcloud.common.exception import TencentCloudSDKException # type: ignore
from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore
from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
@@ -305,7 +305,7 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
elif isinstance(message, ToolPromptMessage):
message_text = f"{tool_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = content
message_text = content if isinstance(content, str) else ""
else:
raise ValueError(f"Got unknown type {message}")

View File

@@ -3,11 +3,11 @@ import logging
import time
from typing import Optional
from tencentcloud.common import credential
from tencentcloud.common.exception import TencentCloudSDKException
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
from tencentcloud.common import credential # type: ignore
from tencentcloud.common.exception import TencentCloudSDKException # type: ignore
from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore
from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType

View File

@@ -1,11 +1,11 @@
from os.path import abspath, dirname, join
from threading import Lock
from transformers import AutoTokenizer
from transformers import AutoTokenizer # type: ignore
class JinaTokenizer:
_tokenizer = None
_tokenizer: AutoTokenizer | None = None
_lock = Lock()
@classmethod

View File

@@ -40,7 +40,7 @@ class MinimaxChatCompletion:
url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}"
extra_kwargs = {}
extra_kwargs: dict[str, Any] = {}
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
@@ -117,19 +117,19 @@ class MinimaxChatCompletion:
"""
handle chat generate response
"""
response = response.json()
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
code = response["base_resp"]["status_code"]
msg = response["base_resp"]["status_msg"]
response_data = response.json()
if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0:
code = response_data["base_resp"]["status_code"]
msg = response_data["base_resp"]["status_msg"]
self._handle_error(code, msg)
message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
message.usage = {
"prompt_tokens": 0,
"completion_tokens": response["usage"]["total_tokens"],
"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response_data["usage"]["total_tokens"],
"total_tokens": response_data["usage"]["total_tokens"],
}
message.stop_reason = response["choices"][0]["finish_reason"]
message.stop_reason = response_data["choices"][0]["finish_reason"]
return message
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
@@ -139,10 +139,10 @@ class MinimaxChatCompletion:
for line in response.iter_lines():
if not line:
continue
line: str = line.decode("utf-8")
if line.startswith("data: "):
line = line[6:].strip()
data = loads(line)
line_str: str = line.decode("utf-8")
if line_str.startswith("data: "):
line_str = line_str[6:].strip()
data = loads(line_str)
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
code = data["base_resp"]["status_code"]
@@ -162,5 +162,5 @@ class MinimaxChatCompletion:
continue
for choice in choices:
message = choice["delta"]
yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value)
message_choice = choice["delta"]
yield MinimaxMessage(content=message_choice, role=MinimaxMessage.Role.ASSISTANT.value)

View File

@@ -41,7 +41,7 @@ class MinimaxChatCompletionPro:
url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}"
extra_kwargs = {}
extra_kwargs: dict[str, Any] = {}
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
@@ -122,19 +122,19 @@ class MinimaxChatCompletionPro:
"""
handle chat generate response
"""
response = response.json()
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
code = response["base_resp"]["status_code"]
msg = response["base_resp"]["status_msg"]
response_data = response.json()
if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0:
code = response_data["base_resp"]["status_code"]
msg = response_data["base_resp"]["status_msg"]
self._handle_error(code, msg)
message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
message.usage = {
"prompt_tokens": 0,
"completion_tokens": response["usage"]["total_tokens"],
"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response_data["usage"]["total_tokens"],
"total_tokens": response_data["usage"]["total_tokens"],
}
message.stop_reason = response["choices"][0]["finish_reason"]
message.stop_reason = response_data["choices"][0]["finish_reason"]
return message
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
@@ -144,10 +144,10 @@ class MinimaxChatCompletionPro:
for line in response.iter_lines():
if not line:
continue
line: str = line.decode("utf-8")
if line.startswith("data: "):
line = line[6:].strip()
data = loads(line)
line_str: str = line.decode("utf-8")
if line_str.startswith("data: "):
line_str = line_str[6:].strip()
data = loads(line_str)
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
code = data["base_resp"]["status_code"]

View File

@@ -11,9 +11,9 @@ class MinimaxMessage:
role: str = Role.USER.value
content: str
usage: dict[str, int] = None
usage: dict[str, int] | None = None
stop_reason: str = ""
function_call: dict[str, Any] = None
function_call: dict[str, Any] | None = None
def to_dict(self) -> dict[str, Any]:
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:

View File

@@ -2,8 +2,8 @@ import time
from functools import wraps
from typing import Optional
from nomic import embed
from nomic import login as nomic_login
from nomic import embed # type: ignore
from nomic import login as nomic_login # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType

View File

@@ -5,8 +5,8 @@ import logging
from collections.abc import Generator
from typing import Optional, Union
import oci
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse
import oci # type: ignore
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse # type: ignore
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (

View File

@@ -4,7 +4,7 @@ import time
from typing import Optional
import numpy as np
import oci
import oci # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType

View File

@@ -61,6 +61,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
headers = {"Content-Type": "application/json"}
endpoint_url = credentials.get("base_url")
assert endpoint_url is not None, "Base URL is required for Ollama API"
if not endpoint_url.endswith("/"):
endpoint_url += "/"

View File

@@ -1,5 +1,3 @@
from collections.abc import Mapping
import openai
from httpx import Timeout
@@ -14,7 +12,7 @@ from core.model_runtime.errors.invoke import (
class _CommonOpenAI:
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
def _to_credential_kwargs(self, credentials: dict) -> dict:
"""
Transform credentials to kwargs for model instance

View File

@@ -93,7 +93,8 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
max_characters_per_chunk: int = model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
return max_characters_per_chunk
return 2000
@@ -108,6 +109,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
return max_chunks
return 1

View File

@@ -1,5 +1,4 @@
import logging
from collections.abc import Mapping
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -9,7 +8,7 @@ logger = logging.getLogger(__name__)
class OpenAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: Mapping) -> None:
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception

Some files were not shown because too many files have changed in this diff Show More