feat/enhance the multi-modal support (#8818)

This commit is contained in:
-LAN-
2024-10-21 10:43:49 +08:00
committed by GitHub
parent 7a1d6fe509
commit e61752bd3a
267 changed files with 6263 additions and 3523 deletions

View File

@@ -0,0 +1,3 @@
from .tool_node import ToolNode
__all__ = ["ToolNode"]

View File

@@ -3,7 +3,7 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.nodes.base import BaseNodeData
class ToolEntity(BaseModel):
@@ -51,7 +51,4 @@ class ToolNodeData(BaseNodeData, ToolEntity):
raise ValueError("value must be a string, int, float, or bool")
return typ
"""
Tool Node Schema
"""
tool_parameters: dict[str, ToolInput]

View File

@@ -1,24 +1,28 @@
from collections.abc import Mapping, Sequence
from os import path
from typing import Any, cast
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.file.models import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models import WorkflowNodeExecutionStatus
from extensions.ext_database import db
from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus
class ToolNode(BaseNode):
class ToolNode(BaseNode[ToolNodeData]):
"""
Tool Node
"""
@@ -27,37 +31,38 @@ class ToolNode(BaseNode):
_node_type = NodeType.TOOL
def _run(self) -> NodeRunResult:
"""
Run the tool node
"""
node_data = cast(ToolNodeData, self.node_data)
# fetch tool icon
tool_info = {"provider_type": node_data.provider_type, "provider_id": node_data.provider_id}
tool_info = {
"provider_type": self.node_data.provider_type,
"provider_id": self.node_data.provider_id,
}
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to get tool runtime: {str(e)}",
)
# get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
node_data=self.node_data,
for_log=True,
)
@@ -74,7 +79,9 @@ class ToolNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to invoke tool: {str(e)}",
)
@@ -83,8 +90,14 @@ class ToolNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": plain_text, "files": files, "json": json},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
outputs={
"text": plain_text,
"files": files,
"json": json,
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
inputs=parameters_for_log,
)
@@ -116,29 +129,25 @@ class ToolNode(BaseNode):
if not parameter:
result[parameter_name] = None
continue
if parameter.type == ToolParameter.ToolParameterType.FILE:
result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)]
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ValueError(f"variable {tool_input.value} not exists")
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
# TODO: check if the variable exists in the variable pool
parameter_value = variable_pool.get(tool_input.value).value
else:
segment_group = parser.convert_template(
template=str(tool_input.value),
variable_pool=variable_pool,
)
parameter_value = segment_group.log if for_log else segment_group.text
result[parameter_name] = parameter_value
raise ValueError(f"unknown tool input type '{tool_input.type}'")
result[parameter_name] = parameter_value
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar], list[dict]]:
def _convert_tool_messages(
self,
messages: list[ToolInvokeMessage],
):
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
@@ -156,50 +165,86 @@ class ToolNode(BaseNode):
return plain_text, files, json
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[File]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
url = response.message
ext = path.splitext(url)[1]
mimetype = response.meta.get("mime_type", "image/jpeg")
filename = response.save_as or url.split("/")[-1]
url = str(response.message) if response.message else None
ext = path.splitext(url)[1] if url else ".bin"
tool_file_id = str(url).split("/")[-1].split(".")[0]
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
# get tool file id
tool_file_id = url.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
result.append(
FileVar(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=url,
related_id=tool_file_id,
filename=filename,
remote_url=url,
related_id=tool_file.id,
filename=tool_file.name,
extension=ext,
mime_type=mimetype,
mime_type=tool_file.mimetype,
size=tool_file.size,
)
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
tool_file_id = response.message.split("/")[-1].split(".")[0]
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
result.append(
FileVar(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=response.save_as,
related_id=tool_file.id,
filename=tool_file.name,
extension=path.splitext(response.save_as)[1],
mime_type=response.meta.get("mime_type", "application/octet-stream"),
mime_type=tool_file.mimetype,
size=tool_file.size,
)
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
pass # TODO:
url = str(response.message)
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = url.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
if "." in url:
extension = "." + url.split("/")[-1].split(".")[1]
else:
extension = ".bin"
file = File(
tenant_id=self.tenant_id,
type=FileType(response.save_as),
transfer_method=transfer_method,
remote_url=url,
filename=tool_file.name,
related_id=tool_file.id,
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
)
result.append(file)
elif response.type == ToolInvokeMessage.MessageType.FILE:
assert response.meta is not None
result.append(response.meta["file"])
return result
@@ -218,12 +263,16 @@ class ToolNode(BaseNode):
]
)
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]):
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -236,7 +285,7 @@ class ToolNode(BaseNode):
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == "mixed":
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":