mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 11:26:52 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
@@ -53,6 +52,7 @@ logger = logging.getLogger(__name__)
|
||||
class BaseAgentRunner(AppRunner):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
conversation: Conversation,
|
||||
@@ -66,7 +66,7 @@ class BaseAgentRunner(AppRunner):
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
@@ -117,7 +117,7 @@ class BaseAgentRunner(AppRunner):
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
self.query = None
|
||||
self.query: Optional[str] = ""
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def _repack_app_generate_entity(
|
||||
@@ -145,7 +145,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.description.llm,
|
||||
description=tool_entity.description.llm if tool_entity.description else "",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@@ -167,7 +167,7 @@ class BaseAgentRunner(AppRunner):
|
||||
continue
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
@@ -187,8 +187,8 @@ class BaseAgentRunner(AppRunner):
|
||||
convert dataset retriever tool to prompt message tool
|
||||
"""
|
||||
prompt_tool = PromptMessageTool(
|
||||
name=tool.identity.name,
|
||||
description=tool.description.llm,
|
||||
name=tool.identity.name if tool.identity else "unknown",
|
||||
description=tool.description.llm if tool.description else "",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@@ -210,14 +210,14 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
|
||||
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
|
||||
"""
|
||||
Init tools
|
||||
"""
|
||||
tool_instances = {}
|
||||
prompt_messages_tools = []
|
||||
|
||||
for tool in self.app_config.agent.tools if self.app_config.agent else []:
|
||||
for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
@@ -234,7 +234,8 @@ class BaseAgentRunner(AppRunner):
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
if dataset_tool.identity is not None:
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
return tool_instances, prompt_messages_tools
|
||||
|
||||
@@ -258,7 +259,7 @@ class BaseAgentRunner(AppRunner):
|
||||
continue
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
@@ -322,16 +323,21 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
observation: Union[str, dict, None],
|
||||
tool_invoke_meta: Union[str, dict, None],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None,
|
||||
) -> MessageAgentThought:
|
||||
llm_usage: LLMUsage | None = None,
|
||||
):
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
queried_thought = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
)
|
||||
if not queried_thought:
|
||||
raise ValueError(f"Agent thought {agent_thought.id} not found")
|
||||
agent_thought = queried_thought
|
||||
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
@@ -404,7 +410,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables = (
|
||||
queried_variables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
@@ -412,6 +418,11 @@ class BaseAgentRunner(AppRunner):
|
||||
.first()
|
||||
)
|
||||
|
||||
if not queried_variables:
|
||||
return
|
||||
|
||||
db_variables = queried_variables
|
||||
|
||||
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
@@ -421,7 +432,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
Organize agent history
|
||||
"""
|
||||
result = []
|
||||
result: list[PromptMessage] = []
|
||||
# check if there is a system message in the beginning of the conversation
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
@@ -12,6 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@@ -26,18 +27,18 @@ from models.model import Message
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage] = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] = None
|
||||
_instruction: str = None
|
||||
_query: str = None
|
||||
_prompt_messages_tools: list[PromptMessage] = None
|
||||
_historic_prompt_messages: list[PromptMessage] | None = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] | None = None
|
||||
_instruction: str = "" # FIXME this must be str for now
|
||||
_query: str | None = None
|
||||
_prompt_messages_tools: list[PromptMessageTool] = []
|
||||
|
||||
def run(
|
||||
self,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
inputs: Mapping[str, str],
|
||||
) -> Generator:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
@@ -57,19 +58,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {"usage": None}
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
@@ -90,7 +91,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
# the last iteration, remove all tools
|
||||
self._prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
message_file_ids: list[str] = []
|
||||
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
@@ -105,7 +106,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
chunks = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
tools=[],
|
||||
@@ -115,11 +116,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
if not isinstance(chunks, Generator):
|
||||
raise ValueError("Expected streaming response from LLM")
|
||||
|
||||
# check llm result
|
||||
if not chunks:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict = {}
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
@@ -139,25 +143,30 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
if scratchpad.agent_response is not None:
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
scratchpad.agent_response += chunk
|
||||
scratchpad.thought += chunk
|
||||
if scratchpad.agent_response is not None:
|
||||
scratchpad.agent_response += chunk
|
||||
if scratchpad.thought is not None:
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
if scratchpad.thought is not None:
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
if self._agent_scratchpad is not None:
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
if usage_dict["usage"] is not None:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||
|
||||
@@ -166,9 +175,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else "",
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
thought=scratchpad.thought or "",
|
||||
observation="",
|
||||
answer=scratchpad.agent_response,
|
||||
answer=scratchpad.agent_response or "",
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
@@ -209,7 +218,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought,
|
||||
thought=scratchpad.thought or "",
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
@@ -247,8 +256,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
if self.variables_pool is not None and self.db_variables_pool is not None:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@@ -307,8 +316,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
if save_as is not None and self.variables_pool:
|
||||
# FIXME the save_as type is confusing, it should be a string or not
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
@@ -325,7 +335,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
fill in inputs from external data tools
|
||||
"""
|
||||
@@ -376,11 +386,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
current_scratchpad: AgentScratchpadUnit = None
|
||||
current_scratchpad: AgentScratchpadUnit | None = None
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
if not isinstance(message.content, str | None):
|
||||
raise NotImplementedError("expected str type")
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
@@ -399,8 +411,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
except:
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
if not current_scratchpad:
|
||||
continue
|
||||
if isinstance(message.content, str):
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
@@ -19,7 +19,12 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
if not self.app_config.agent:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if not prompt_entity:
|
||||
raise ValueError("Agent prompt configuration is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
@@ -75,6 +80,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
|
||||
@@ -2,7 +2,12 @@ import json
|
||||
from typing import Optional
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
@@ -11,7 +16,11 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
"""
|
||||
Organize instruction prompt
|
||||
"""
|
||||
if self.app_config.agent is None:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if prompt_entity is None:
|
||||
raise ValueError("prompt entity is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
@@ -33,7 +42,13 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
if isinstance(message, UserPromptMessage):
|
||||
historic_prompt += f"Question: {message.content}\n\n"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
if isinstance(message.content, str):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
historic_prompt += content.data
|
||||
|
||||
return historic_prompt
|
||||
|
||||
@@ -50,7 +65,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ""
|
||||
for unit in agent_scratchpad:
|
||||
for unit in agent_scratchpad or []:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
|
||||
@@ -78,5 +78,5 @@ class AgentEntity(BaseModel):
|
||||
model: str
|
||||
strategy: Strategy
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: list[AgentToolEntity] = None
|
||||
tools: list[AgentToolEntity] | None = None
|
||||
max_iteration: int = 5
|
||||
|
||||
@@ -40,6 +40,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
@@ -49,7 +51,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage = {"usage": None}
|
||||
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
@@ -75,7 +77,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
message_file_ids: list[str] = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
@@ -105,7 +107,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if self.stream_tool_call:
|
||||
if self.stream_tool_call and isinstance(chunks, Generator):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
@@ -116,7 +118,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||
tool_calls.extend(self.extract_tool_calls(chunk) or [])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps(
|
||||
@@ -131,19 +133,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
for content in chunk.delta.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += chunk.delta.message.content
|
||||
response += str(chunk.delta.message.content)
|
||||
|
||||
if chunk.delta.usage:
|
||||
increase_usage(llm_usage, chunk.delta.usage)
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
result: LLMResult = chunks
|
||||
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
|
||||
result = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result))
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps(
|
||||
@@ -162,7 +164,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
for content in result.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += result.message.content
|
||||
response += str(result.message.content)
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ""
|
||||
@@ -181,6 +183,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
usage=result.usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
|
||||
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
@@ -243,7 +247,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
if self.variables_pool:
|
||||
self.variables_pool.set_file(
|
||||
tool_name=tool_call_name, value=message_file_id, name=save_as
|
||||
)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
@@ -263,7 +270,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
if tool_response["tool_response"] is not None:
|
||||
self._current_thoughts.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response["tool_response"],
|
||||
content=str(tool_response["tool_response"]),
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
@@ -273,9 +280,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
thought="",
|
||||
tool_invoke_meta={
|
||||
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
|
||||
},
|
||||
@@ -283,7 +290,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_response["tool_call_name"]: tool_response["tool_response"]
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
answer=None,
|
||||
answer="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
@@ -296,7 +303,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
if self.variables_pool and self.db_variables_pool:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@@ -389,9 +397,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages
|
||||
return prompt_messages or []
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
@@ -449,7 +457,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def _organize_prompt_messages(self):
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query, [])
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
|
||||
@@ -38,7 +38,7 @@ class CotAgentOutputParser:
|
||||
except:
|
||||
return json_str or ""
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
@@ -67,15 +67,15 @@ class CotAgentOutputParser:
|
||||
for response in llm_response:
|
||||
if response.delta.usage:
|
||||
usage_dict["usage"] = response.delta.usage
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
response_content = response.delta.message.content
|
||||
if not isinstance(response_content, str):
|
||||
continue
|
||||
|
||||
# stream
|
||||
index = 0
|
||||
while index < len(response):
|
||||
while index < len(response_content):
|
||||
steps = 1
|
||||
delta = response[index : index + steps]
|
||||
delta = response_content[index : index + steps]
|
||||
yield_delta = False
|
||||
|
||||
if delta == "`":
|
||||
|
||||
@@ -66,6 +66,8 @@ class DatasetConfigManager:
|
||||
dataset_configs = config.get("dataset_configs")
|
||||
else:
|
||||
dataset_configs = {"retrieval_model": "multiple"}
|
||||
if dataset_configs is None:
|
||||
return None
|
||||
query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if dataset_configs["retrieval_model"] == "single":
|
||||
|
||||
@@ -94,7 +94,7 @@ class ModelConfigManager:
|
||||
config["model"]["completion_params"]
|
||||
)
|
||||
|
||||
return config, ["model"]
|
||||
return dict(config), ["model"]
|
||||
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict) -> dict:
|
||||
|
||||
@@ -7,10 +7,10 @@ class OpeningStatementConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
# opening statement
|
||||
opening_statement = config.get("opening_statement")
|
||||
opening_statement = config.get("opening_statement", "")
|
||||
|
||||
# suggested questions
|
||||
suggested_questions_list = config.get("suggested_questions")
|
||||
suggested_questions_list = config.get("suggested_questions", [])
|
||||
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -145,7 +146,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@@ -313,6 +314,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner(
|
||||
|
||||
@@ -5,6 +5,7 @@ import queue
|
||||
import re
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
@@ -15,6 +16,7 @@ from core.app.entities.queue_entities import (
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
@@ -71,8 +73,9 @@ class AppGeneratorTTSPublisher:
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get("value")
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self._last_audio_event: Optional[AudioTrunk] = None
|
||||
# FIXME better way to handle this threading.start
|
||||
threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
|
||||
@@ -92,10 +95,21 @@ class AppGeneratorTTSPublisher:
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
self.msg_text += message.event.chunk.delta.message.content
|
||||
message_content = message.event.chunk.delta.message.content
|
||||
if not message_content:
|
||||
continue
|
||||
if isinstance(message_content, str):
|
||||
self.msg_text += message_content
|
||||
elif isinstance(message_content, list):
|
||||
for content in message_content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
self.msg_text += content.data
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||
if message.event.outputs is None:
|
||||
continue
|
||||
self.msg_text += message.event.outputs.get("output", "")
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
@@ -121,11 +135,10 @@ class AppGeneratorTTSPublisher:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False)
|
||||
return self.last_message
|
||||
return self._last_audio_event
|
||||
audio = self._audio_queue.get_nowait()
|
||||
if audio and audio.status == "finish":
|
||||
self.executor.shutdown(wait=False)
|
||||
self._runtime_thread = None
|
||||
if audio:
|
||||
self._last_audio_event = audio
|
||||
return audio
|
||||
|
||||
@@ -109,18 +109,18 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
ConversationVariable.conversation_id == self.conversation.id,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
db_conversation_variables = session.scalars(stmt).all()
|
||||
if not db_conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
db_conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
session.add_all(db_conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
conversation_variables = [item.to_variable() for item in db_conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from threading import Thread
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@@ -64,6 +65,7 @@ from models.enums import CreatedByRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
@@ -81,6 +83,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_user: Union[Account, EndUser]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
_conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -131,7 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._conversation_name_generate_thread = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
|
||||
def process(self):
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@@ -262,8 +265,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:return:
|
||||
"""
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
workflow_run: Optional[WorkflowRun] = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
@@ -315,14 +318,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
response_start = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
if response_start:
|
||||
yield response_start
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
@@ -330,18 +333,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
response_finish = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
if response_finish:
|
||||
yield response_finish
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
response_finish = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
@@ -609,7 +612,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
files=self._recorded_files,
|
||||
metadata=extras.get("metadata", {}),
|
||||
)
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
|
||||
@@ -61,7 +61,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
config_dict = override_config_dict
|
||||
config_dict = override_config_dict or {}
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = AgentChatAppConfig(
|
||||
|
||||
@@ -23,6 +23,7 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -97,7 +98,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get("conversation_id"):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
@@ -153,7 +154,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@@ -180,7 +181,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
@@ -199,8 +200,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
# FIXME: Type hinting issue here, ignore it for now, will fix it later
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
@@ -224,6 +225,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AgentChatAppRunner()
|
||||
|
||||
@@ -173,6 +173,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
return
|
||||
|
||||
agent_entity = app_config.agent
|
||||
if not agent_entity:
|
||||
raise ValueError("Agent entity not found")
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(
|
||||
@@ -200,14 +202,21 @@ class AgentChatAppRunner(AppRunner):
|
||||
# change function call strategy based on LLM model
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if not model_schema or not model_schema.features:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
|
||||
message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
message_result = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
@@ -225,12 +234,12 @@ class AgentChatAppRunner(AppRunner):
|
||||
runner = runner_cls(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
conversation=conversation_result,
|
||||
app_config=app_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
message=message_result,
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
@@ -257,7 +266,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables = (
|
||||
tool_variables: ToolConversationVariables | None = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
|
||||
@@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -51,8 +51,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
def convert_stream_full_response( # type: ignore[override]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -82,8 +83,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
def convert_stream_simple_response( # type: ignore[override]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -50,7 +50,7 @@ class AppQueueManager:
|
||||
# wait for APP_MAX_EXECUTION_TIME seconds to stop listen
|
||||
listen_timeout = dify_config.APP_MAX_EXECUTION_TIME
|
||||
start_time = time.time()
|
||||
last_ping_time = 0
|
||||
last_ping_time: int | float = 0
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
@@ -36,8 +36,8 @@ class AppRunner:
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["File"],
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -64,7 +64,7 @@ class AppRunner:
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
@@ -85,7 +85,7 @@ class AppRunner:
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise InvokeBadRequestError(
|
||||
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
@@ -111,7 +111,7 @@ class AppRunner:
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
@@ -136,8 +136,8 @@ class AppRunner:
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["File"],
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
@@ -156,6 +156,7 @@ class AppRunner:
|
||||
"""
|
||||
# get prompt without memory and context
|
||||
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform]
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_messages, stop = prompt_transform.get_prompt(
|
||||
app_mode=AppMode.value_of(app_record.mode),
|
||||
@@ -171,8 +172,11 @@ class AppRunner:
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
|
||||
if not advanced_completion_prompt_template:
|
||||
raise InvokeBadRequestError("Advanced completion prompt template is required.")
|
||||
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
|
||||
|
||||
if advanced_completion_prompt_template.role_prefix:
|
||||
@@ -181,6 +185,8 @@ class AppRunner:
|
||||
assistant=advanced_completion_prompt_template.role_prefix.assistant,
|
||||
)
|
||||
else:
|
||||
if not prompt_template_entity.advanced_chat_prompt_template:
|
||||
raise InvokeBadRequestError("Advanced chat prompt template is required.")
|
||||
prompt_template = []
|
||||
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
|
||||
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
|
||||
@@ -246,7 +252,7 @@ class AppRunner:
|
||||
|
||||
def _handle_invoke_result(
|
||||
self,
|
||||
invoke_result: Union[LLMResult, Generator],
|
||||
invoke_result: Union[LLMResult, Generator[Any, None, None]],
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False,
|
||||
@@ -259,10 +265,12 @@ class AppRunner:
|
||||
:param agent: agent
|
||||
:return:
|
||||
"""
|
||||
if not stream:
|
||||
if not stream and isinstance(invoke_result, LLMResult):
|
||||
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
else:
|
||||
elif stream and isinstance(invoke_result, Generator):
|
||||
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
else:
|
||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||
|
||||
def _handle_invoke_result_direct(
|
||||
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
|
||||
@@ -291,8 +299,8 @@ class AppRunner:
|
||||
:param agent: agent
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
model: str = ""
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
text = ""
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
@@ -328,13 +336,14 @@ class AppRunner:
|
||||
|
||||
def moderation_for_inputs(
|
||||
self,
|
||||
*,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
query: str | None = None,
|
||||
message_id: str,
|
||||
) -> tuple[bool, dict, str]:
|
||||
) -> tuple[bool, Mapping[str, Any], str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
@@ -350,7 +359,7 @@ class AppRunner:
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
inputs=dict(inputs),
|
||||
query=query or "",
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager,
|
||||
@@ -390,9 +399,9 @@ class AppRunner:
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
) -> dict:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -91,7 +92,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get("conversation_id"):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
@@ -104,7 +105,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = ChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
|
||||
)
|
||||
|
||||
# always enable retriever resource in debugger mode
|
||||
@@ -146,7 +147,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
invoke_from=invoke_from,
|
||||
@@ -172,7 +173,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
@@ -216,6 +217,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = ChatAppRunner()
|
||||
|
||||
@@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -52,7 +52,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -83,7 +84,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -42,7 +42,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
config_dict = override_config_dict
|
||||
config_dict = override_config_dict or {}
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = CompletionAppConfig(
|
||||
|
||||
@@ -83,8 +83,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
|
||||
@@ -99,7 +97,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = CompletionAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
|
||||
)
|
||||
|
||||
# parse files
|
||||
@@ -132,11 +130,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
extras={},
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
@@ -157,7 +155,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
@@ -197,6 +195,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
try:
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
# chatbot app
|
||||
runner = CompletionAppRunner()
|
||||
@@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -293,7 +293,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
files=file_objs,
|
||||
files=list(file_objs),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
@@ -317,7 +317,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
|
||||
@@ -76,7 +76,7 @@ class CompletionAppRunner(AppRunner):
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
query=query or "",
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationError as e:
|
||||
@@ -122,7 +122,7 @@ class CompletionAppRunner(AppRunner):
|
||||
tenant_id=app_record.tenant_id,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
config=dataset_config,
|
||||
query=query,
|
||||
query=query or "",
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
hit_callback=hit_callback,
|
||||
|
||||
@@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = CompletionAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -51,7 +51,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -81,7 +82,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -2,11 +2,11 @@ import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from sqlalchemy import and_
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
@@ -42,7 +42,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
@@ -144,7 +144,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
:conversation conversation
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config)
|
||||
|
||||
# get from source
|
||||
end_user_id = None
|
||||
@@ -267,7 +267,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return introduction
|
||||
return introduction or ""
|
||||
|
||||
def _get_conversation(self, conversation_id: str):
|
||||
"""
|
||||
@@ -282,7 +282,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_message(self, message_id: str) -> Message:
|
||||
def _get_message(self, message_id: str) -> Optional[Message]:
|
||||
"""
|
||||
Get message by message id
|
||||
:param message_id: message id
|
||||
|
||||
@@ -116,7 +116,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
files=system_files,
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return blocking_response.to_dict()
|
||||
return dict(blocking_response.to_dict())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -36,7 +36,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -65,7 +66,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
|
||||
cls,
|
||||
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -24,6 +24,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
@@ -190,16 +191,15 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
inputs: Mapping[str, Any] | None = {}
|
||||
process_data: Mapping[str, Any] | None = {}
|
||||
outputs: Mapping[str, Any] | None = {}
|
||||
execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {}
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
@@ -289,7 +289,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
@@ -349,7 +349,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
@@ -79,7 +79,7 @@ class AppGenerateEntity(BaseModel):
|
||||
task_id: str
|
||||
|
||||
# app config
|
||||
app_config: AppConfig
|
||||
app_config: Any
|
||||
file_upload_config: Optional[FileUploadConfig] = None
|
||||
|
||||
inputs: Mapping[str, Any]
|
||||
|
||||
@@ -308,7 +308,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
"""single iteration duration map"""
|
||||
|
||||
@@ -70,7 +70,7 @@ class StreamResponse(BaseModel):
|
||||
event: StreamEvent
|
||||
task_id: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self):
|
||||
return jsonable_encoder(self)
|
||||
|
||||
|
||||
@@ -474,8 +474,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = {}
|
||||
metadata: dict = {}
|
||||
inputs: dict = {}
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
@@ -526,15 +526,15 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Optional[dict] = None
|
||||
outputs: Optional[Mapping] = None
|
||||
created_at: int
|
||||
extras: Optional[dict] = None
|
||||
inputs: Optional[dict] = None
|
||||
inputs: Optional[Mapping] = None
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
execution_metadata: Optional[dict] = None
|
||||
execution_metadata: Optional[Mapping] = None
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: Optional[str] = None
|
||||
@@ -628,7 +628,7 @@ class AppBlockingResponse(BaseModel):
|
||||
|
||||
task_id: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self):
|
||||
return jsonable_encoder(self)
|
||||
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ class AnnotationReplyFeature:
|
||||
query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]}
|
||||
)
|
||||
|
||||
if documents:
|
||||
if documents and documents[0].metadata:
|
||||
annotation_id = documents[0].metadata["annotation_id"]
|
||||
score = documents[0].metadata["score"]
|
||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||
|
||||
@@ -17,7 +17,7 @@ class RateLimit:
|
||||
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
|
||||
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
|
||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict = {}
|
||||
_instance_dict: dict[str, "RateLimit"] = {}
|
||||
|
||||
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
|
||||
@@ -62,6 +62,7 @@ class BasedGenerateTaskPipeline:
|
||||
"""
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
err: Exception
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
@@ -130,6 +131,7 @@ class BasedGenerateTaskPipeline:
|
||||
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
|
||||
queue_manager=self._queue_manager,
|
||||
)
|
||||
return None
|
||||
|
||||
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@@ -103,7 +104,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
)
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
@@ -123,7 +124,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation, self._application_generate_entity.query
|
||||
self._conversation, self._application_generate_entity.query or ""
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
@@ -146,7 +147,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
if self._conversation.mode == AppMode.COMPLETION.value:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -154,7 +155,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
id=self._message.id,
|
||||
mode=self._conversation.mode,
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras,
|
||||
),
|
||||
@@ -167,7 +168,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
mode=self._conversation.mode,
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras,
|
||||
),
|
||||
@@ -252,7 +253,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
|
||||
self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
@@ -269,13 +270,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
if event.llm_result:
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
self._handle_stop(event)
|
||||
|
||||
# handle output moderation
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(
|
||||
self._task_state.llm_result.message.content
|
||||
cast(str, self._task_state.llm_result.message.content)
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
@@ -292,7 +294,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if annotation:
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
yield self._agent_thought_to_stream_response(event)
|
||||
agent_thought_response = self._agent_thought_to_stream_response(event)
|
||||
if agent_thought_response is not None:
|
||||
yield agent_thought_response
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
response = self._message_file_to_stream_response(event)
|
||||
if response:
|
||||
@@ -307,16 +311,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.llm_result.message.content += delta_text
|
||||
current_content = cast(str, self._task_state.llm_result.message.content)
|
||||
current_content += cast(str, delta_text)
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
yield self._message_to_stream_response(delta_text, self._message.id)
|
||||
yield self._message_to_stream_response(cast(str, delta_text), self._message.id)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(delta_text, self._message.id)
|
||||
yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
@@ -336,8 +342,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
llm_result = self._task_state.llm_result
|
||||
usage = llm_result.usage
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
if not message:
|
||||
raise Exception(f"Message {self._message.id} not found")
|
||||
self._message = message
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
if not conversation:
|
||||
raise Exception(f"Conversation {self._conversation.id} not found")
|
||||
self._conversation = conversation
|
||||
|
||||
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode, self._task_state.llm_result.prompt_messages
|
||||
@@ -346,7 +358,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer = (
|
||||
PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
|
||||
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
|
||||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
@@ -374,6 +386,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
|
||||
and hasattr(self._application_generate_entity, "conversation_id")
|
||||
and self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras,
|
||||
)
|
||||
@@ -420,7 +433,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
metadata=extras.get("metadata", {}),
|
||||
)
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
@@ -440,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
:param event: agent thought event
|
||||
:return:
|
||||
"""
|
||||
agent_thought: MessageAgentThought = (
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
|
||||
@@ -128,7 +128,7 @@ class MessageCycleManage:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
|
||||
|
||||
if message_file:
|
||||
if message_file and message_file.url is not None:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
|
||||
@@ -93,7 +93,7 @@ class WorkflowCycleManage:
|
||||
)
|
||||
|
||||
# handle special values
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
# init workflow run
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@@ -192,7 +192,7 @@ class WorkflowCycleManage:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
|
||||
workflow_run.outputs = json.dumps(outputs or {})
|
||||
@@ -500,7 +500,7 @@ class WorkflowCycleManage:
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=workflow_run.inputs_dict,
|
||||
inputs=dict(workflow_run.inputs_dict or {}),
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
),
|
||||
)
|
||||
@@ -545,7 +545,7 @@ class WorkflowCycleManage:
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
status=workflow_run.status,
|
||||
outputs=workflow_run.outputs_dict,
|
||||
outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=workflow_run.elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
@@ -553,7 +553,7 @@ class WorkflowCycleManage:
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
|
||||
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
|
||||
exceptions_count=workflow_run.exceptions_count,
|
||||
),
|
||||
)
|
||||
@@ -655,7 +655,7 @@ class WorkflowCycleManage:
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
@@ -838,7 +838,7 @@ class WorkflowCycleManage:
|
||||
),
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
@@ -851,9 +851,11 @@ class WorkflowCycleManage:
|
||||
# Remove None
|
||||
files = [file for file in files if file]
|
||||
# Flatten list
|
||||
files = [file for sublist in files for file in sublist]
|
||||
# Flatten the list of sequences into a single list of mappings
|
||||
flattened_files = [file for sublist in files if sublist for file in sublist]
|
||||
|
||||
return files
|
||||
# Convert to tuple to match Sequence type
|
||||
return tuple(flattened_files)
|
||||
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
@@ -891,6 +893,8 @@ class WorkflowCycleManage:
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
|
||||
@@ -57,7 +57,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Mapping[str, Any],
|
||||
tool_outputs: Sequence[ToolInvokeMessage],
|
||||
tool_outputs: Sequence[ToolInvokeMessage] | str,
|
||||
message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
|
||||
@@ -40,17 +40,18 @@ class DatasetIndexToolCallbackHandler:
|
||||
def on_tool_end(self, documents: list[Document]) -> None:
|
||||
"""Handle tool end."""
|
||||
for document in documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
if document.metadata is not None:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
db.session.commit()
|
||||
db.session.commit()
|
||||
|
||||
def return_retriever_resource_info(self, resource: list):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
@@ -72,7 +73,7 @@ class DefaultModelProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
supported_model_types: list[ModelType]
|
||||
supported_model_types: Sequence[ModelType] = []
|
||||
|
||||
|
||||
class DefaultModelEntity(BaseModel):
|
||||
|
||||
@@ -40,7 +40,7 @@ from models.provider import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
original_provider_configurate_methods = {}
|
||||
original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
|
||||
|
||||
|
||||
class ProviderConfiguration(BaseModel):
|
||||
@@ -99,7 +99,8 @@ class ProviderConfiguration(BaseModel):
|
||||
continue
|
||||
|
||||
restrict_models = quota_configuration.restrict_models
|
||||
|
||||
if self.system_configuration.credentials is None:
|
||||
return None
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_models:
|
||||
for restrict_model in restrict_models:
|
||||
@@ -124,7 +125,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return credentials
|
||||
|
||||
def get_system_configuration_status(self) -> SystemConfigurationStatus:
|
||||
def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
|
||||
"""
|
||||
Get system configuration status.
|
||||
:return:
|
||||
@@ -136,6 +137,8 @@ class ProviderConfiguration(BaseModel):
|
||||
current_quota_configuration = next(
|
||||
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
|
||||
)
|
||||
if current_quota_configuration is None:
|
||||
return None
|
||||
|
||||
return (
|
||||
SystemConfigurationStatus.ACTIVE
|
||||
@@ -150,7 +153,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
|
||||
|
||||
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
def get_custom_credentials(self, obfuscated: bool = False):
|
||||
"""
|
||||
Get custom credentials.
|
||||
|
||||
@@ -172,7 +175,7 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@@ -324,7 +327,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def custom_model_credentials_validate(
|
||||
self, model_type: ModelType, model: str, credentials: dict
|
||||
) -> tuple[ProviderModel, dict]:
|
||||
) -> tuple[Optional[ProviderModel], dict]:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
@@ -740,10 +743,10 @@ class ProviderConfiguration(BaseModel):
|
||||
if model_type:
|
||||
model_types.append(model_type)
|
||||
else:
|
||||
model_types = provider_instance.get_provider_schema().supported_model_types
|
||||
model_types = list(provider_instance.get_provider_schema().supported_model_types)
|
||||
|
||||
# Group model settings by model type and model
|
||||
model_setting_map = defaultdict(dict)
|
||||
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
|
||||
for model_setting in self.model_settings:
|
||||
model_setting_map[model_setting.model_type][model_setting.model] = model_setting
|
||||
|
||||
@@ -822,54 +825,57 @@ class ProviderConfiguration(BaseModel):
|
||||
]:
|
||||
# only customizable model
|
||||
for restrict_model in restrict_models:
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_model.base_model_name:
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
if self.system_configuration.credentials is not None:
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_model.base_model_name:
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
|
||||
try:
|
||||
custom_model_schema = provider_instance.get_model_instance(
|
||||
restrict_model.model_type
|
||||
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
|
||||
except Exception as ex:
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
continue
|
||||
try:
|
||||
custom_model_schema = provider_instance.get_model_instance(
|
||||
restrict_model.model_type
|
||||
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
|
||||
except Exception as ex:
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
continue
|
||||
|
||||
if not custom_model_schema:
|
||||
continue
|
||||
if not custom_model_schema:
|
||||
continue
|
||||
|
||||
if custom_model_schema.model_type not in model_types:
|
||||
continue
|
||||
if custom_model_schema.model_type not in model_types:
|
||||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
if (
|
||||
custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
|
||||
):
|
||||
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
status = ModelStatus.ACTIVE
|
||||
if (
|
||||
custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
|
||||
):
|
||||
model_setting = model_setting_map[custom_model_schema.model_type][
|
||||
custom_model_schema.model
|
||||
]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=custom_model_schema.model,
|
||||
label=custom_model_schema.label,
|
||||
model_type=custom_model_schema.model_type,
|
||||
features=custom_model_schema.features,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=status,
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=custom_model_schema.model,
|
||||
label=custom_model_schema.label,
|
||||
model_type=custom_model_schema.model_type,
|
||||
features=custom_model_schema.features,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# if llm name not in restricted llm list, remove it
|
||||
restrict_model_names = [rm.model for rm in restrict_models]
|
||||
for m in provider_models:
|
||||
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
|
||||
m.status = ModelStatus.NO_PERMISSION
|
||||
for model in provider_models:
|
||||
if model.model_type == ModelType.LLM and m.model not in restrict_model_names:
|
||||
model.status = ModelStatus.NO_PERMISSION
|
||||
elif not quota_configuration.is_valid:
|
||||
m.status = ModelStatus.QUOTA_EXCEEDED
|
||||
model.status = ModelStatus.QUOTA_EXCEEDED
|
||||
|
||||
return provider_models
|
||||
|
||||
@@ -1043,7 +1049,7 @@ class ProviderConfigurations(BaseModel):
|
||||
return iter(self.configurations)
|
||||
|
||||
def values(self) -> Iterator[ProviderConfiguration]:
|
||||
return self.configurations.values()
|
||||
return iter(self.configurations.values())
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.configurations.get(key, default)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
|
||||
from configs import dify_config
|
||||
@@ -5,7 +7,7 @@ from models.api_based_extension import APIBasedExtensionPoint
|
||||
|
||||
|
||||
class APIBasedExtensionRequestor:
|
||||
timeout: (int, int) = (5, 60)
|
||||
timeout: tuple[int, int] = (5, 60)
|
||||
"""timeout for request connect and read"""
|
||||
|
||||
def __init__(self, api_endpoint: str, api_key: str) -> None:
|
||||
@@ -51,4 +53,4 @@ class APIBasedExtensionRequestor:
|
||||
"request error, status_code: {}, content: {}".format(response.status_code, response.text[:100])
|
||||
)
|
||||
|
||||
return response.json()
|
||||
return cast(dict, response.json())
|
||||
|
||||
@@ -38,8 +38,8 @@ class Extensible:
|
||||
|
||||
@classmethod
|
||||
def scan_extensions(cls):
|
||||
extensions: list[ModuleExtension] = []
|
||||
position_map = {}
|
||||
extensions = []
|
||||
position_map: dict[str, int] = {}
|
||||
|
||||
# get the path of the current class
|
||||
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
|
||||
@@ -58,7 +58,8 @@ class Extensible:
|
||||
# is builtin extension, builtin extension
|
||||
# in the front-end page and business logic, there are special treatments.
|
||||
builtin = False
|
||||
position = None
|
||||
# default position is 0 can not be None for sort_to_dict_by_position_map
|
||||
position = 0
|
||||
if "__builtin__" in file_names:
|
||||
builtin = True
|
||||
|
||||
@@ -89,7 +90,7 @@ class Extensible:
|
||||
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
|
||||
continue
|
||||
|
||||
json_data = {}
|
||||
json_data: dict[str, Any] = {}
|
||||
if not builtin:
|
||||
if "schema.json" not in file_names:
|
||||
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from core.extension.extensible import ExtensionModule, ModuleExtension
|
||||
from typing import cast
|
||||
|
||||
from core.extension.extensible import Extensible, ExtensionModule, ModuleExtension
|
||||
from core.external_data_tool.base import ExternalDataTool
|
||||
from core.moderation.base import Moderation
|
||||
|
||||
@@ -10,7 +12,8 @@ class Extension:
|
||||
|
||||
def init(self):
|
||||
for module, module_class in self.module_classes.items():
|
||||
self.__module_extensions[module.value] = module_class.scan_extensions()
|
||||
m = cast(Extensible, module_class)
|
||||
self.__module_extensions[module.value] = m.scan_extensions()
|
||||
|
||||
def module_extensions(self, module: str) -> list[ModuleExtension]:
|
||||
module_extensions = self.__module_extensions.get(module)
|
||||
@@ -35,7 +38,8 @@ class Extension:
|
||||
|
||||
def extension_class(self, module: ExtensionModule, extension_name: str) -> type:
|
||||
module_extension = self.module_extension(module, extension_name)
|
||||
return module_extension.extension_class
|
||||
t: type = module_extension.extension_class
|
||||
return t
|
||||
|
||||
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
|
||||
module_extension = self.module_extension(module, extension_name)
|
||||
|
||||
@@ -48,7 +48,10 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
:return: the tool query result
|
||||
"""
|
||||
# get params from config
|
||||
if not self.config:
|
||||
raise ValueError("config is required, config: {}".format(self.config))
|
||||
api_based_extension_id = self.config.get("api_based_extension_id")
|
||||
assert api_based_extension_id is not None, "api_based_extension_id is required"
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = (
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import concurrent
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
@@ -17,9 +17,9 @@ class ExternalDataFetch:
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
) -> dict:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
@@ -30,13 +30,14 @@ class ExternalDataFetch:
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
results = {}
|
||||
results: dict[str, Any] = {}
|
||||
inputs = dict(inputs)
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = {}
|
||||
for tool in external_data_tools:
|
||||
future = executor.submit(
|
||||
future: Future[tuple[str | None, str | None]] = executor.submit(
|
||||
self._query_external_data_tool,
|
||||
current_app._get_current_object(),
|
||||
current_app._get_current_object(), # type: ignore
|
||||
tenant_id,
|
||||
app_id,
|
||||
tool,
|
||||
@@ -46,9 +47,10 @@ class ExternalDataFetch:
|
||||
|
||||
futures[future] = tool
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
for future in as_completed(futures):
|
||||
tool_variable, result = future.result()
|
||||
results[tool_variable] = result
|
||||
if tool_variable is not None:
|
||||
results[tool_variable] = result
|
||||
|
||||
inputs.update(results)
|
||||
return inputs
|
||||
@@ -59,7 +61,7 @@ class ExternalDataFetch:
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tool: ExternalDataVariableEntity,
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.extension.extensible import ExtensionModule
|
||||
from extensions.ext_code_based_extension import code_based_extension
|
||||
@@ -23,9 +24,10 @@ class ExternalDataToolFactory:
|
||||
"""
|
||||
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
extension_class.validate_config(tenant_id, config)
|
||||
# FIXME mypy issue here, figure out how to fix it
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
||||
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||
def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
@@ -33,4 +35,4 @@ class ExternalDataToolFactory:
|
||||
:param query: the query of chat app
|
||||
:return: the tool query result
|
||||
"""
|
||||
return self.__extension_instance.query(inputs, query)
|
||||
return cast(str, self.__extension_instance.query(inputs, query))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
from collections.abc import Mapping
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
@@ -55,7 +56,7 @@ def to_prompt_message_content(
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_class_map = {
|
||||
prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = {
|
||||
FileType.IMAGE: ImagePromptMessageContent,
|
||||
FileType.AUDIO: AudioPromptMessageContent,
|
||||
FileType.VIDEO: VideoPromptMessageContent,
|
||||
@@ -63,7 +64,7 @@ def to_prompt_message_content(
|
||||
}
|
||||
|
||||
try:
|
||||
return prompt_class_map[f.type](**params)
|
||||
return prompt_class_map[f.type].model_validate(params)
|
||||
except KeyError:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
@@ -9,4 +9,4 @@ tool_file_manager: dict[str, Any] = {"manager": None}
|
||||
class ToolFileParser:
|
||||
@staticmethod
|
||||
def get_tool_file_manager() -> "ToolFileManager":
|
||||
return tool_file_manager["manager"]
|
||||
return cast("ToolFileManager", tool_file_manager["manager"])
|
||||
|
||||
@@ -38,7 +38,7 @@ class CodeLanguage(StrEnum):
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
dependencies_cache = {}
|
||||
dependencies_cache: dict[str, str] = {}
|
||||
dependencies_cache_lock = Lock()
|
||||
|
||||
code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = {
|
||||
@@ -103,19 +103,19 @@ class CodeExecutor:
|
||||
)
|
||||
|
||||
try:
|
||||
response = response.json()
|
||||
response_data = response.json()
|
||||
except:
|
||||
raise CodeExecutionError("Failed to parse response")
|
||||
|
||||
if (code := response.get("code")) != 0:
|
||||
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}")
|
||||
if (code := response_data.get("code")) != 0:
|
||||
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
|
||||
|
||||
response = CodeExecutionResponse(**response)
|
||||
response_code = CodeExecutionResponse(**response_data)
|
||||
|
||||
if response.data.error:
|
||||
raise CodeExecutionError(response.data.error)
|
||||
if response_code.data.error:
|
||||
raise CodeExecutionError(response_code.data.error)
|
||||
|
||||
return response.data.stdout or ""
|
||||
return response_code.data.stdout or ""
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]):
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
|
||||
|
||||
class Jinja2Formatter:
|
||||
@classmethod
|
||||
def format(cls, template: str, inputs: dict) -> str:
|
||||
def format(cls, template: str, inputs: Mapping[str, str]) -> str:
|
||||
"""
|
||||
Format template
|
||||
:param template: template
|
||||
@@ -11,5 +13,4 @@ class Jinja2Formatter:
|
||||
:return:
|
||||
"""
|
||||
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
|
||||
|
||||
return result["result"]
|
||||
return str(result.get("result", ""))
|
||||
|
||||
@@ -29,8 +29,7 @@ class TemplateTransformer(ABC):
|
||||
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
|
||||
if not result:
|
||||
raise ValueError("Failed to parse result")
|
||||
result = result.group(1)
|
||||
return result
|
||||
return result.group(1)
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> Mapping[str, Any]:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, capacity: int):
|
||||
self.cache = OrderedDict()
|
||||
self.cache: OrderedDict[Any, Any] = OrderedDict()
|
||||
self.capacity = capacity
|
||||
|
||||
def get(self, key: Any) -> Any:
|
||||
|
||||
@@ -30,7 +30,7 @@ class ProviderCredentialsCache:
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
return cached_provider_credentials
|
||||
return dict(cached_provider_credentials)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
|
||||
provider_name = model_config.provider
|
||||
if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
|
||||
hosting_openai_config = hosting_configuration.provider_map["openai"]
|
||||
assert hosting_openai_config is not None
|
||||
|
||||
# 2000 text per chunk
|
||||
length = 2000
|
||||
@@ -34,8 +35,9 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
|
||||
|
||||
try:
|
||||
model_type_instance = OpenAIModerationModel()
|
||||
# FIXME, for type hint using assert or raise ValueError is better here?
|
||||
moderation_result = model_type_instance.invoke(
|
||||
model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk
|
||||
model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk
|
||||
)
|
||||
|
||||
if moderation_result is True:
|
||||
|
||||
@@ -14,12 +14,13 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
|
||||
if existed_spec:
|
||||
spec = existed_spec
|
||||
if not spec.loader:
|
||||
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
|
||||
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
|
||||
else:
|
||||
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path)
|
||||
# FIXME: mypy does not support the type of spec.loader
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore
|
||||
if not spec or not spec.loader:
|
||||
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
|
||||
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
|
||||
if use_lazy_loader:
|
||||
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
|
||||
spec.loader = importlib.util.LazyLoader(spec.loader)
|
||||
@@ -29,7 +30,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to load module {module_name} from script file '{py_file_path}'")
|
||||
logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'")
|
||||
raise e
|
||||
|
||||
|
||||
@@ -57,6 +58,6 @@ def load_single_subclass_from_source(
|
||||
case 1:
|
||||
return subclasses[0]
|
||||
case 0:
|
||||
raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}")
|
||||
raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path!r}")
|
||||
case _:
|
||||
raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}")
|
||||
raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path!r}")
|
||||
|
||||
@@ -33,7 +33,7 @@ class ToolParameterCache:
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
return cached_tool_parameter
|
||||
return dict(cached_tool_parameter)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class ToolProviderCredentialsCache:
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
return cached_provider_credentials
|
||||
return dict(cached_provider_credentials)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class HostedModerationConfig(BaseModel):
|
||||
|
||||
class HostingConfiguration:
|
||||
provider_map: dict[str, HostingProvider] = {}
|
||||
moderation_config: HostedModerationConfig = None
|
||||
moderation_config: Optional[HostedModerationConfig] = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
@@ -67,7 +67,7 @@ class HostingConfiguration:
|
||||
"base_model_name": "gpt-35-turbo",
|
||||
}
|
||||
|
||||
quotas = []
|
||||
quotas: list[HostingQuota] = []
|
||||
hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
@@ -123,7 +123,7 @@ class HostingConfiguration:
|
||||
|
||||
def init_openai(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas = []
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
|
||||
@@ -157,7 +157,7 @@ class HostingConfiguration:
|
||||
@staticmethod
|
||||
def init_anthropic() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
quotas = []
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
|
||||
@@ -187,7 +187,7 @@ class HostingConfiguration:
|
||||
def init_minimax() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if dify_config.HOSTED_MINIMAX_ENABLED:
|
||||
quotas = [FreeHostingQuota()]
|
||||
quotas: list[HostingQuota] = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
@@ -205,7 +205,7 @@ class HostingConfiguration:
|
||||
def init_spark() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if dify_config.HOSTED_SPARK_ENABLED:
|
||||
quotas = [FreeHostingQuota()]
|
||||
quotas: list[HostingQuota] = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
@@ -223,7 +223,7 @@ class HostingConfiguration:
|
||||
def init_zhipuai() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if dify_config.HOSTED_ZHIPUAI_ENABLED:
|
||||
quotas = [FreeHostingQuota()]
|
||||
quotas: list[HostingQuota] = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
|
||||
@@ -6,10 +6,10 @@ import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask_login import current_user
|
||||
from flask_login import current_user # type: ignore
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from configs import dify_config
|
||||
@@ -62,6 +62,8 @@ class IndexingRunner:
|
||||
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
@@ -120,6 +122,8 @@ class IndexingRunner:
|
||||
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
@@ -254,7 +258,7 @@ class IndexingRunner:
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
preview_texts = []
|
||||
preview_texts: list[str] = []
|
||||
total_segments = 0
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
@@ -285,7 +289,8 @@ class IndexingRunner:
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
if image_file:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Delete image_files failed while indexing_estimate, \
|
||||
@@ -379,8 +384,9 @@ class IndexingRunner:
|
||||
# replace doc id to document model id
|
||||
text_docs = cast(list[Document], text_docs)
|
||||
for text_doc in text_docs:
|
||||
text_doc.metadata["document_id"] = dataset_document.id
|
||||
text_doc.metadata["dataset_id"] = dataset_document.dataset_id
|
||||
if text_doc.metadata is not None:
|
||||
text_doc.metadata["document_id"] = dataset_document.id
|
||||
text_doc.metadata["dataset_id"] = dataset_document.dataset_id
|
||||
|
||||
return text_docs
|
||||
|
||||
@@ -400,6 +406,7 @@ class IndexingRunner:
|
||||
"""
|
||||
Get the NodeParser object according to the processing rule.
|
||||
"""
|
||||
character_splitter: TextSplitter
|
||||
if processing_rule.mode == "custom":
|
||||
# The user-defined segmentation rule
|
||||
rules = json.loads(processing_rule.rules)
|
||||
@@ -426,9 +433,10 @@ class IndexingRunner:
|
||||
)
|
||||
else:
|
||||
# Automatic segmentation
|
||||
automatic_rules: dict[str, Any] = dict(DatasetProcessRule.AUTOMATIC_RULES["segmentation"])
|
||||
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"],
|
||||
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"],
|
||||
chunk_size=automatic_rules["max_tokens"],
|
||||
chunk_overlap=automatic_rules["chunk_overlap"],
|
||||
separators=["\n\n", "。", ". ", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
)
|
||||
@@ -497,8 +505,8 @@ class IndexingRunner:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
all_documents = []
|
||||
all_qa_documents = []
|
||||
all_documents: list[Document] = []
|
||||
all_qa_documents: list[Document] = []
|
||||
for text_doc in text_docs:
|
||||
# document clean
|
||||
document_text = self._document_clean(text_doc.page_content, processing_rule)
|
||||
@@ -509,10 +517,11 @@ class IndexingRunner:
|
||||
split_documents = []
|
||||
for document_node in documents:
|
||||
if document_node.page_content.strip():
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata["doc_id"] = doc_id
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
if document_node.metadata is not None:
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata["doc_id"] = doc_id
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
document_node.page_content = remove_leading_symbols(page_content)
|
||||
@@ -529,7 +538,7 @@ class IndexingRunner:
|
||||
document_format_thread = threading.Thread(
|
||||
target=self.format_qa_document,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"tenant_id": tenant_id,
|
||||
"document_node": doc,
|
||||
"all_qa_documents": all_qa_documents,
|
||||
@@ -557,11 +566,12 @@ class IndexingRunner:
|
||||
qa_document = Document(
|
||||
page_content=result["question"], metadata=document_node.metadata.model_copy()
|
||||
)
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result["question"])
|
||||
qa_document.metadata["answer"] = result["answer"]
|
||||
qa_document.metadata["doc_id"] = doc_id
|
||||
qa_document.metadata["doc_hash"] = hash
|
||||
if qa_document.metadata is not None:
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result["question"])
|
||||
qa_document.metadata["answer"] = result["answer"]
|
||||
qa_document.metadata["doc_id"] = doc_id
|
||||
qa_document.metadata["doc_hash"] = hash
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
@@ -575,7 +585,7 @@ class IndexingRunner:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
all_documents = []
|
||||
all_documents: list[Document] = []
|
||||
for text_doc in text_docs:
|
||||
# document clean
|
||||
document_text = self._document_clean(text_doc.page_content, processing_rule)
|
||||
@@ -588,11 +598,11 @@ class IndexingRunner:
|
||||
for document in documents:
|
||||
if document.page_content is None or not document.page_content.strip():
|
||||
continue
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document.page_content)
|
||||
|
||||
document.metadata["doc_id"] = doc_id
|
||||
document.metadata["doc_hash"] = hash
|
||||
if document.metadata is not None:
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document.page_content)
|
||||
document.metadata["doc_id"] = doc_id
|
||||
document.metadata["doc_hash"] = hash
|
||||
|
||||
split_documents.append(document)
|
||||
|
||||
@@ -648,7 +658,7 @@ class IndexingRunner:
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
target=self._process_keyword_index,
|
||||
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
|
||||
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
|
||||
)
|
||||
create_keyword_thread.start()
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
@@ -659,7 +669,7 @@ class IndexingRunner:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self._process_chunk,
|
||||
current_app._get_current_object(),
|
||||
current_app._get_current_object(), # type: ignore
|
||||
index_processor,
|
||||
chunk_documents,
|
||||
dataset,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
@@ -13,6 +13,7 @@ from core.llm_generator.prompts import (
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
@@ -44,10 +45,13 @@ class LLMGenerator:
|
||||
prompts = [UserPromptMessage(content=prompt)]
|
||||
|
||||
with measure_time() as timer:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
|
||||
),
|
||||
)
|
||||
answer = response.message.content
|
||||
answer = cast(str, response.message.content)
|
||||
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
|
||||
if cleaned_answer is None:
|
||||
return ""
|
||||
@@ -94,11 +98,16 @@ class LLMGenerator:
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"max_tokens": 256, "temperature": 0},
|
||||
stream=False,
|
||||
),
|
||||
)
|
||||
|
||||
questions = output_parser.parse(response.message.content)
|
||||
questions = output_parser.parse(cast(str, response.message.content))
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception as e:
|
||||
@@ -138,11 +147,14 @@ class LLMGenerator:
|
||||
)
|
||||
|
||||
try:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
rule_config["prompt"] = response.message.content
|
||||
rule_config["prompt"] = cast(str, response.message.content)
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
@@ -178,15 +190,18 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider") if model_config else None,
|
||||
model=model_config.get("name") if model_config else None,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
|
||||
try:
|
||||
try:
|
||||
# the first step to generate the task prompt
|
||||
prompt_content = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
prompt_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
@@ -195,8 +210,10 @@ class LLMGenerator:
|
||||
|
||||
return rule_config
|
||||
|
||||
rule_config["prompt"] = prompt_content.message.content
|
||||
rule_config["prompt"] = cast(str, prompt_content.message.content)
|
||||
|
||||
if not isinstance(prompt_content.message.content, str):
|
||||
raise NotImplementedError("prompt content is not a string")
|
||||
parameter_generate_prompt = parameter_template.format(
|
||||
inputs={
|
||||
"INPUT_TEXT": prompt_content.message.content,
|
||||
@@ -216,19 +233,25 @@ class LLMGenerator:
|
||||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
parameter_content = model_instance.invoke_llm(
|
||||
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
|
||||
parameter_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate variables"
|
||||
|
||||
try:
|
||||
statement_content = model_instance.invoke_llm(
|
||||
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
|
||||
statement_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
rule_config["opening_statement"] = statement_content.message.content
|
||||
rule_config["opening_statement"] = cast(str, statement_content.message.content)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate conversation opener"
|
||||
@@ -267,19 +290,22 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider") if model_config else None,
|
||||
model=model_config.get("name") if model_config else None,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
|
||||
|
||||
try:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
generated_code = response.message.content
|
||||
generated_code = cast(str, response.message.content)
|
||||
return {"code": generated_code, "language": code_language, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
@@ -303,9 +329,14 @@ class LLMGenerator:
|
||||
|
||||
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
|
||||
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"temperature": 0.01, "max_tokens": 2000},
|
||||
stream=False,
|
||||
),
|
||||
)
|
||||
|
||||
answer = response.message.content
|
||||
answer = cast(str, response.message.content)
|
||||
return answer.strip()
|
||||
|
||||
@@ -68,7 +68,7 @@ class TokenBufferMemory:
|
||||
|
||||
messages = list(reversed(thread_messages))
|
||||
|
||||
prompt_messages = []
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
|
||||
@@ -124,17 +124,20 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
@@ -151,12 +154,15 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
@@ -174,13 +180,16 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
return cast(
|
||||
TextEmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
),
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
|
||||
@@ -194,11 +203,14 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_rerank(
|
||||
@@ -223,15 +235,18 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
|
||||
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
|
||||
@@ -246,12 +261,15 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
|
||||
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user,
|
||||
return cast(
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
@@ -266,12 +284,15 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
|
||||
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user,
|
||||
return cast(
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
|
||||
@@ -288,17 +309,20 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
return cast(
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
),
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
|
||||
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any:
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
@@ -309,7 +333,7 @@ class ModelInstance:
|
||||
if not self.load_balancing_manager:
|
||||
return function(*args, **kwargs)
|
||||
|
||||
last_exception = None
|
||||
last_exception: Union[InvokeRateLimitError, InvokeAuthorizationError, InvokeConnectionError, None] = None
|
||||
while True:
|
||||
lb_config = self.load_balancing_manager.fetch_next()
|
||||
if not lb_config:
|
||||
@@ -463,7 +487,7 @@ class LBModelManager:
|
||||
if real_index > max_index:
|
||||
real_index = 0
|
||||
|
||||
config = self._load_balancing_configs[real_index]
|
||||
config: ModelLoadBalancingConfiguration = self._load_balancing_configs[real_index]
|
||||
|
||||
if self.in_cooldown(config):
|
||||
cooldown_load_balancing_configs.append(config)
|
||||
@@ -507,8 +531,7 @@ class LBModelManager:
|
||||
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
|
||||
)
|
||||
|
||||
res = redis_client.exists(cooldown_cache_key)
|
||||
res = cast(bool, res)
|
||||
res: bool = redis_client.exists(cooldown_cache_key)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
@@ -20,7 +21,7 @@ class LoggingCallback(Callback):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
@@ -76,7 +77,7 @@ class LoggingCallback(Callback):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
):
|
||||
@@ -94,7 +95,7 @@ class LoggingCallback(Callback):
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
sys.stdout.write(chunk.delta.message.content)
|
||||
sys.stdout.write(cast(str, chunk.delta.message.content))
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_after_invoke(
|
||||
@@ -106,7 +107,7 @@ class LoggingCallback(Callback):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
@@ -147,7 +148,7 @@ class LoggingCallback(Callback):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
@@ -89,7 +89,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
|
||||
@computed_field(return_type=str)
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import decimal
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@@ -36,7 +35,7 @@ class AIModel(ABC):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
@@ -214,7 +213,7 @@ class AIModel(ABC):
|
||||
|
||||
return model_schemas
|
||||
|
||||
def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
|
||||
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get model schema by model name and credentials
|
||||
|
||||
@@ -236,9 +235,7 @@ class AIModel(ABC):
|
||||
|
||||
return None
|
||||
|
||||
def get_customizable_model_schema_from_credentials(
|
||||
self, model: str, credentials: Mapping
|
||||
) -> Optional[AIModelEntity]:
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
@@ -248,7 +245,7 @@ class AIModel(ABC):
|
||||
"""
|
||||
return self._get_customizable_model_schema(model, credentials)
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema and fill in the template
|
||||
"""
|
||||
@@ -301,7 +298,7 @@ class AIModel(ABC):
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
@@ -291,12 +291,12 @@ if you are not sure about the structure.
|
||||
content = piece.delta.message.content
|
||||
piece.delta.message.content = ""
|
||||
yield piece
|
||||
piece = content
|
||||
content_piece = content
|
||||
else:
|
||||
yield piece
|
||||
continue
|
||||
new_piece: str = ""
|
||||
for char in piece:
|
||||
for char in content_piece:
|
||||
char = str(char)
|
||||
if state == "normal":
|
||||
if char == "`":
|
||||
@@ -350,7 +350,7 @@ if you are not sure about the structure.
|
||||
piece.delta.message.content = ""
|
||||
# Yield a piece with cleared content before processing it to maintain the generator structure
|
||||
yield piece
|
||||
piece = content
|
||||
content_piece = content
|
||||
else:
|
||||
# Yield pieces without content directly
|
||||
yield piece
|
||||
@@ -360,7 +360,7 @@ if you are not sure about the structure.
|
||||
continue
|
||||
|
||||
new_piece: str = ""
|
||||
for char in piece:
|
||||
for char in content_piece:
|
||||
if state == "search_start":
|
||||
if char == "`":
|
||||
backtick_count += 1
|
||||
@@ -535,7 +535,7 @@ if you are not sure about the structure.
|
||||
|
||||
return []
|
||||
|
||||
def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
|
||||
def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
|
||||
"""
|
||||
Get model mode
|
||||
|
||||
|
||||
@@ -104,9 +104,10 @@ class ModelProvider(ABC):
|
||||
mod = import_module_from_source(
|
||||
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
|
||||
)
|
||||
# FIXME "type" has no attribute "__abstractmethods__" ignore it for now fix it later
|
||||
model_class = next(
|
||||
filter(
|
||||
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
|
||||
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, # type: ignore
|
||||
get_subclasses_from_module(mod, AIModel),
|
||||
),
|
||||
None,
|
||||
|
||||
@@ -89,7 +89,8 @@ class TextEmbeddingModel(AIModel):
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
return content_size
|
||||
|
||||
return 1000
|
||||
|
||||
@@ -104,6 +105,7 @@ class TextEmbeddingModel(AIModel):
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
return max_chunks
|
||||
|
||||
return 1
|
||||
|
||||
@@ -2,9 +2,9 @@ from os.path import abspath, dirname, join
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||
|
||||
_tokenizer = None
|
||||
_tokenizer: Any = None
|
||||
_lock = Lock()
|
||||
|
||||
|
||||
|
||||
@@ -127,7 +127,8 @@ class TTSModel(AIModel):
|
||||
if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support audio type")
|
||||
|
||||
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
||||
audio_type: str = model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
||||
return audio_type
|
||||
|
||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
@@ -138,8 +139,9 @@ class TTSModel(AIModel):
|
||||
|
||||
if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support word limit")
|
||||
world_limit: int = model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||
|
||||
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||
return world_limit
|
||||
|
||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
@@ -150,8 +152,9 @@ class TTSModel(AIModel):
|
||||
|
||||
if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support max workers")
|
||||
workers_limit: int = model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
return workers_limit
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
|
||||
|
||||
@@ -64,10 +64,12 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
|
||||
if not ai_model_entity:
|
||||
return None
|
||||
return ai_model_entity.entity
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> Optional[AzureBaseModel]:
|
||||
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
|
||||
@@ -114,6 +114,8 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
|
||||
if not ai_model_entity:
|
||||
return None
|
||||
return ai_model_entity.entity
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -6,9 +6,9 @@ from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
# 3rd import
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import (
|
||||
import boto3 # type: ignore
|
||||
from botocore.config import Config # type: ignore
|
||||
from botocore.exceptions import ( # type: ignore
|
||||
ClientError,
|
||||
EndpointConnectionError,
|
||||
NoRegionError,
|
||||
|
||||
@@ -44,7 +44,7 @@ class CohereRerankModel(RerankModel):
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=docs)
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
|
||||
@@ -62,7 +62,7 @@ class CohereRerankModel(RerankModel):
|
||||
# format document
|
||||
rerank_document = RerankDocument(
|
||||
index=result.index,
|
||||
text=result.document.text,
|
||||
text=result.document.text if result.document else "",
|
||||
score=result.relevance_score,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import openai
|
||||
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@@ -13,7 +11,7 @@ from core.model_runtime.errors.invoke import (
|
||||
|
||||
|
||||
class _CommonFireworks:
|
||||
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
|
||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Transform credentials to kwargs for model instance
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -93,7 +92,7 @@ class FireworksTextEmbeddingModel(_CommonFireworks, TextEmbeddingModel):
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dashscope.common.error import (
|
||||
from dashscope.common.error import ( # type: ignore
|
||||
AuthenticationError,
|
||||
InvalidParameter,
|
||||
RequestFailure,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -51,7 +51,7 @@ class GiteeAIRerankModel(RerankModel):
|
||||
base_url = base_url.removesuffix("/")
|
||||
|
||||
try:
|
||||
body = {"model": model, "query": query, "documents": docs}
|
||||
body: dict[str, Any] = {"model": model, "query": query, "documents": docs}
|
||||
if top_n is not None:
|
||||
body["top_n"] = top_n
|
||||
response = httpx.post(
|
||||
|
||||
@@ -24,7 +24,7 @@ class GiteeAIEmbeddingModel(OAICompatEmbeddingModel):
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict, model: str) -> None:
|
||||
def _add_custom_parameters(credentials: dict, model: Optional[str]) -> None:
|
||||
if model is None:
|
||||
model = "bge-m3"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
@@ -13,9 +13,10 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
|
||||
Model class for OpenAI text2speech model.
|
||||
"""
|
||||
|
||||
# FIXME this Any return will be better type
|
||||
def _invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
) -> any:
|
||||
) -> Any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
@@ -47,7 +48,8 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
|
||||
# FIXME this Any return will be better type
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
:param model: model name
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.generativeai as genai
|
||||
import google.generativeai as genai # type: ignore
|
||||
import requests
|
||||
from google.api_core import exceptions
|
||||
from google.generativeai.types import ContentType, File, GenerateContentResponse
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from huggingface_hub.utils import BadRequestError, HfHubHTTPError
|
||||
from huggingface_hub.utils import BadRequestError, HfHubHTTPError # type: ignore
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
from huggingface_hub.utils import BadRequestError
|
||||
from huggingface_hub import InferenceClient # type: ignore
|
||||
from huggingface_hub.hf_api import HfApi # type: ignore
|
||||
from huggingface_hub.utils import BadRequestError # type: ignore
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import HfApi, InferenceClient
|
||||
from huggingface_hub import HfApi, InferenceClient # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
@@ -3,11 +3,11 @@ import logging
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.exception import TencentCloudSDKException
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
from tencentcloud.common import credential # type: ignore
|
||||
from tencentcloud.common.exception import TencentCloudSDKException # type: ignore
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
@@ -305,7 +305,7 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message_text = f"{tool_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
message_text = content if isinstance(content, str) else ""
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
@@ -3,11 +3,11 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.exception import TencentCloudSDKException
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
from tencentcloud.common import credential # type: ignore
|
||||
from tencentcloud.common.exception import TencentCloudSDKException # type: ignore
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from os.path import abspath, dirname, join
|
||||
from threading import Lock
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
class JinaTokenizer:
|
||||
_tokenizer = None
|
||||
_tokenizer: AutoTokenizer | None = None
|
||||
_lock = Lock()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -40,7 +40,7 @@ class MinimaxChatCompletion:
|
||||
|
||||
url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}"
|
||||
|
||||
extra_kwargs = {}
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
|
||||
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
|
||||
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
|
||||
@@ -117,19 +117,19 @@ class MinimaxChatCompletion:
|
||||
"""
|
||||
handle chat generate response
|
||||
"""
|
||||
response = response.json()
|
||||
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
|
||||
code = response["base_resp"]["status_code"]
|
||||
msg = response["base_resp"]["status_msg"]
|
||||
response_data = response.json()
|
||||
if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0:
|
||||
code = response_data["base_resp"]["status_code"]
|
||||
msg = response_data["base_resp"]["status_msg"]
|
||||
self._handle_error(code, msg)
|
||||
|
||||
message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
message.usage = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": response["usage"]["total_tokens"],
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response_data["usage"]["total_tokens"],
|
||||
"total_tokens": response_data["usage"]["total_tokens"],
|
||||
}
|
||||
message.stop_reason = response["choices"][0]["finish_reason"]
|
||||
message.stop_reason = response_data["choices"][0]["finish_reason"]
|
||||
return message
|
||||
|
||||
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
|
||||
@@ -139,10 +139,10 @@ class MinimaxChatCompletion:
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line: str = line.decode("utf-8")
|
||||
if line.startswith("data: "):
|
||||
line = line[6:].strip()
|
||||
data = loads(line)
|
||||
line_str: str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
line_str = line_str[6:].strip()
|
||||
data = loads(line_str)
|
||||
|
||||
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
|
||||
code = data["base_resp"]["status_code"]
|
||||
@@ -162,5 +162,5 @@ class MinimaxChatCompletion:
|
||||
continue
|
||||
|
||||
for choice in choices:
|
||||
message = choice["delta"]
|
||||
yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
message_choice = choice["delta"]
|
||||
yield MinimaxMessage(content=message_choice, role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
|
||||
@@ -41,7 +41,7 @@ class MinimaxChatCompletionPro:
|
||||
|
||||
url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}"
|
||||
|
||||
extra_kwargs = {}
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
|
||||
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
|
||||
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
|
||||
@@ -122,19 +122,19 @@ class MinimaxChatCompletionPro:
|
||||
"""
|
||||
handle chat generate response
|
||||
"""
|
||||
response = response.json()
|
||||
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
|
||||
code = response["base_resp"]["status_code"]
|
||||
msg = response["base_resp"]["status_msg"]
|
||||
response_data = response.json()
|
||||
if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0:
|
||||
code = response_data["base_resp"]["status_code"]
|
||||
msg = response_data["base_resp"]["status_msg"]
|
||||
self._handle_error(code, msg)
|
||||
|
||||
message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
message.usage = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": response["usage"]["total_tokens"],
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response_data["usage"]["total_tokens"],
|
||||
"total_tokens": response_data["usage"]["total_tokens"],
|
||||
}
|
||||
message.stop_reason = response["choices"][0]["finish_reason"]
|
||||
message.stop_reason = response_data["choices"][0]["finish_reason"]
|
||||
return message
|
||||
|
||||
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
|
||||
@@ -144,10 +144,10 @@ class MinimaxChatCompletionPro:
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line: str = line.decode("utf-8")
|
||||
if line.startswith("data: "):
|
||||
line = line[6:].strip()
|
||||
data = loads(line)
|
||||
line_str: str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
line_str = line_str[6:].strip()
|
||||
data = loads(line_str)
|
||||
|
||||
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
|
||||
code = data["base_resp"]["status_code"]
|
||||
|
||||
@@ -11,9 +11,9 @@ class MinimaxMessage:
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: dict[str, int] = None
|
||||
usage: dict[str, int] | None = None
|
||||
stop_reason: str = ""
|
||||
function_call: dict[str, Any] = None
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
|
||||
|
||||
@@ -2,8 +2,8 @@ import time
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from nomic import embed
|
||||
from nomic import login as nomic_login
|
||||
from nomic import embed # type: ignore
|
||||
from nomic import login as nomic_login # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
|
||||
@@ -5,8 +5,8 @@ import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import oci
|
||||
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse
|
||||
import oci # type: ignore
|
||||
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse # type: ignore
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
|
||||
@@ -4,7 +4,7 @@ import time
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import oci
|
||||
import oci # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
|
||||
@@ -61,6 +61,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
endpoint_url = credentials.get("base_url")
|
||||
assert endpoint_url is not None, "Base URL is required for Ollama API"
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import openai
|
||||
from httpx import Timeout
|
||||
|
||||
@@ -14,7 +12,7 @@ from core.model_runtime.errors.invoke import (
|
||||
|
||||
|
||||
class _CommonOpenAI:
|
||||
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
|
||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Transform credentials to kwargs for model instance
|
||||
|
||||
|
||||
@@ -93,7 +93,8 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
|
||||
max_characters_per_chunk: int = model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
|
||||
return max_characters_per_chunk
|
||||
|
||||
return 2000
|
||||
|
||||
@@ -108,6 +109,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
return max_chunks
|
||||
|
||||
return 1
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@@ -9,7 +8,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: Mapping) -> None:
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user