mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-23 17:55:46 +08:00
feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from .tool_node import ToolNode
|
||||
|
||||
__all__ = ["ToolNode"]
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
@@ -51,7 +51,4 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
raise ValueError("value must be a string, int, float, or bool")
|
||||
return typ
|
||||
|
||||
"""
|
||||
Tool Node Schema
|
||||
"""
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
|
||||
@@ -1,24 +1,28 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from os import path
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.file.models import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from models import ToolFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
class ToolNode(BaseNode[ToolNodeData]):
|
||||
"""
|
||||
Tool Node
|
||||
"""
|
||||
@@ -27,37 +31,38 @@ class ToolNode(BaseNode):
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
|
||||
node_data = cast(ToolNodeData, self.node_data)
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {"provider_type": node_data.provider_type, "provider_id": node_data.provider_id}
|
||||
tool_info = {
|
||||
"provider_type": self.node_data.provider_type,
|
||||
"provider_id": self.node_data.provider_id,
|
||||
}
|
||||
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
},
|
||||
error=f"Failed to get tool runtime: {str(e)}",
|
||||
)
|
||||
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
)
|
||||
|
||||
@@ -74,7 +79,9 @@ class ToolNode(BaseNode):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
},
|
||||
error=f"Failed to invoke tool: {str(e)}",
|
||||
)
|
||||
|
||||
@@ -83,8 +90,14 @@ class ToolNode(BaseNode):
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": plain_text, "files": files, "json": json},
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
outputs={
|
||||
"text": plain_text,
|
||||
"files": files,
|
||||
"json": json,
|
||||
},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
|
||||
@@ -116,29 +129,25 @@ class ToolNode(BaseNode):
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)]
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
raise ValueError(f"variable {tool_input.value} not exists")
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
# TODO: check if the variable exists in the variable pool
|
||||
parameter_value = variable_pool.get(tool_input.value).value
|
||||
else:
|
||||
segment_group = parser.convert_template(
|
||||
template=str(tool_input.value),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
result[parameter_name] = parameter_value
|
||||
raise ValueError(f"unknown tool input type '{tool_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar], list[dict]]:
|
||||
def _convert_tool_messages(
|
||||
self,
|
||||
messages: list[ToolInvokeMessage],
|
||||
):
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
@@ -156,50 +165,86 @@ class ToolNode(BaseNode):
|
||||
|
||||
return plain_text, files, json
|
||||
|
||||
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
|
||||
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[File]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
url = response.message
|
||||
ext = path.splitext(url)[1]
|
||||
mimetype = response.meta.get("mime_type", "image/jpeg")
|
||||
filename = response.save_as or url.split("/")[-1]
|
||||
url = str(response.message) if response.message else None
|
||||
ext = path.splitext(url)[1] if url else ".bin"
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = url.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
result.append(
|
||||
FileVar(
|
||||
File(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
url=url,
|
||||
related_id=tool_file_id,
|
||||
filename=filename,
|
||||
remote_url=url,
|
||||
related_id=tool_file.id,
|
||||
filename=tool_file.name,
|
||||
extension=ext,
|
||||
mime_type=mimetype,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
tool_file_id = response.message.split("/")[-1].split(".")[0]
|
||||
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||
result.append(
|
||||
FileVar(
|
||||
File(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=response.save_as,
|
||||
related_id=tool_file.id,
|
||||
filename=tool_file.name,
|
||||
extension=path.splitext(response.save_as)[1],
|
||||
mime_type=response.meta.get("mime_type", "application/octet-stream"),
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
pass # TODO:
|
||||
url = str(response.message)
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
tool_file_id = url.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||
if "." in url:
|
||||
extension = "." + url.split("/")[-1].split(".")[1]
|
||||
else:
|
||||
extension = ".bin"
|
||||
file = File(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType(response.save_as),
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
filename=tool_file.name,
|
||||
related_id=tool_file.id,
|
||||
extension=extension,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
result.append(file)
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert response.meta is not None
|
||||
result.append(response.meta["file"])
|
||||
|
||||
return result
|
||||
|
||||
@@ -218,12 +263,16 @@ class ToolNode(BaseNode):
|
||||
]
|
||||
)
|
||||
|
||||
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
|
||||
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]):
|
||||
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ToolNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -236,7 +285,7 @@ class ToolNode(BaseNode):
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
|
||||
Reference in New Issue
Block a user