mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -20,10 +20,9 @@ from core.tools.entities.tool_entities import (
|
||||
ToolRuntimeVariablePool,
|
||||
)
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class Tool(BaseModel, ABC):
|
||||
@@ -63,8 +62,12 @@ class Tool(BaseModel, ABC):
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
|
||||
class VariableKey(Enum):
|
||||
class VariableKey(str, Enum):
|
||||
IMAGE = "image"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
CUSTOM = "custom"
|
||||
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
|
||||
"""
|
||||
@@ -221,9 +224,7 @@ class Tool(BaseModel, ABC):
|
||||
result = deepcopy(tool_parameters)
|
||||
for parameter in self.parameters or []:
|
||||
if parameter.name in tool_parameters:
|
||||
result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(
|
||||
tool_parameters[parameter.name], parameter.type
|
||||
)
|
||||
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
|
||||
|
||||
return result
|
||||
|
||||
@@ -295,10 +296,8 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as)
|
||||
|
||||
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.FILE_VAR, message="", meta={"file_var": file_var}, save_as=""
|
||||
)
|
||||
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="")
|
||||
|
||||
def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileVar
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
@@ -45,11 +45,13 @@ class WorkflowTool(Tool):
|
||||
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
|
||||
|
||||
# transform the tool parameters
|
||||
tool_parameters, files = self._transform_args(tool_parameters)
|
||||
tool_parameters, files = self._transform_args(tool_parameters=tool_parameters)
|
||||
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
assert self.runtime is not None
|
||||
assert self.runtime.invoke_from is not None
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
@@ -74,7 +76,7 @@ class WorkflowTool(Tool):
|
||||
else:
|
||||
outputs, files = self._extract_files(outputs)
|
||||
for file in files:
|
||||
result.append(self.create_file_var_message(file))
|
||||
result.append(self.create_file_message(file))
|
||||
|
||||
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
|
||||
result.append(self.create_json_message(outputs))
|
||||
@@ -154,22 +156,22 @@ class WorkflowTool(Tool):
|
||||
parameters_result = {}
|
||||
files = []
|
||||
for parameter in parameter_rules:
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES:
|
||||
file = tool_parameters.get(parameter.name)
|
||||
if file:
|
||||
try:
|
||||
file_var_list = [FileVar(**f) for f in file]
|
||||
for file_var in file_var_list:
|
||||
file_dict = {
|
||||
"transfer_method": file_var.transfer_method.value,
|
||||
"type": file_var.type.value,
|
||||
file_var_list = [File.model_validate(f) for f in file]
|
||||
for file in file_var_list:
|
||||
file_dict: dict[str, str | None] = {
|
||||
"transfer_method": file.transfer_method.value,
|
||||
"type": file.type.value,
|
||||
}
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_var.related_id
|
||||
elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = file_var.related_id
|
||||
elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
file_dict["url"] = file_var.preview_url
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file.related_id
|
||||
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = file.related_id
|
||||
elif file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
file_dict["url"] = file.generate_url()
|
||||
|
||||
files.append(file_dict)
|
||||
except Exception as e:
|
||||
@@ -179,7 +181,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
return parameters_result, files
|
||||
|
||||
def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
|
||||
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
|
||||
"""
|
||||
extract files from the result
|
||||
|
||||
@@ -190,17 +192,13 @@ class WorkflowTool(Tool):
|
||||
result = {}
|
||||
for key, value in outputs.items():
|
||||
if isinstance(value, list):
|
||||
has_file = False
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get("__variant") == "FileVar":
|
||||
try:
|
||||
files.append(FileVar(**item))
|
||||
has_file = True
|
||||
except Exception as e:
|
||||
pass
|
||||
if has_file:
|
||||
continue
|
||||
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
file = File.model_validate(item)
|
||||
files.append(file)
|
||||
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
file = File.model_validate(value)
|
||||
files.append(file)
|
||||
|
||||
result[key] = value
|
||||
|
||||
return result, files
|
||||
|
||||
Reference in New Issue
Block a user