mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 11:26:52 +08:00
FEAT: NEW WORKFLOW ENGINE (#3160)
Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
@@ -11,6 +11,7 @@ class ToolProviderType(Enum):
|
||||
Enum class for tool provider
|
||||
"""
|
||||
BUILT_IN = "built-in"
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
APP_BASED = "app-based"
|
||||
API_BASED = "api-based"
|
||||
|
||||
@@ -161,6 +162,8 @@ class ToolIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
provider: str = Field(..., description="The provider of the tool")
|
||||
icon: Optional[str] = None
|
||||
|
||||
class ToolCredentialsOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
@@ -326,4 +329,33 @@ class ModelToolProviderConfiguration(BaseModel):
|
||||
"""
|
||||
provider: str = Field(..., description="The provider of the model tool")
|
||||
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
|
||||
class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
Tool invoke meta
|
||||
"""
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: Optional[str] = None
|
||||
tool_config: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> 'ToolInvokeMeta':
|
||||
"""
|
||||
Get an empty instance of ToolInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None, tool_config={})
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> 'ToolInvokeMeta':
|
||||
"""
|
||||
Get an instance of ToolInvokeMeta with error
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'time_cost': self.time_cost,
|
||||
'error': self.error,
|
||||
'tool_config': self.tool_config,
|
||||
}
|
||||
@@ -8,6 +8,13 @@ from core.tools.entities.tool_entities import ToolProviderCredentials
|
||||
from core.tools.tool.tool import ToolParameter
|
||||
|
||||
|
||||
class UserTool(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]]
|
||||
|
||||
class UserToolProvider(BaseModel):
|
||||
class ProviderType(Enum):
|
||||
BUILTIN = "builtin"
|
||||
@@ -22,9 +29,11 @@ class UserToolProvider(BaseModel):
|
||||
icon: str
|
||||
label: I18nObject # label
|
||||
type: ProviderType
|
||||
team_credentials: dict = None
|
||||
masked_credentials: dict = None
|
||||
original_credentials: dict = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
tools: list[UserTool] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
@@ -35,17 +44,11 @@ class UserToolProvider(BaseModel):
|
||||
'icon': self.icon,
|
||||
'label': self.label.to_dict(),
|
||||
'type': self.type.value,
|
||||
'team_credentials': self.team_credentials,
|
||||
'team_credentials': self.masked_credentials,
|
||||
'is_team_authorization': self.is_team_authorization,
|
||||
'allow_delete': self.allow_delete
|
||||
'allow_delete': self.allow_delete,
|
||||
'tools': self.tools
|
||||
}
|
||||
|
||||
class UserToolProviderCredentials(BaseModel):
|
||||
credentials: dict[str, ToolProviderCredentials]
|
||||
|
||||
class UserTool(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]]
|
||||
credentials: dict[str, ToolProviderCredentials]
|
||||
@@ -1,3 +1,6 @@
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
|
||||
class ToolProviderNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
@@ -17,4 +20,7 @@ class ToolInvokeError(ValueError):
|
||||
pass
|
||||
|
||||
class ToolApiSchemaError(ValueError):
|
||||
pass
|
||||
pass
|
||||
|
||||
class ToolEngineInvokeError(Exception):
|
||||
meta: ToolInvokeMeta
|
||||
@@ -16,6 +16,8 @@ from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class ApiBasedToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
|
||||
credentials_schema = {
|
||||
@@ -89,9 +91,10 @@ class ApiBasedToolProviderController(ToolProviderController):
|
||||
'en_US': db_provider.description,
|
||||
'zh_Hans': db_provider.description
|
||||
},
|
||||
'icon': db_provider.icon
|
||||
'icon': db_provider.icon,
|
||||
},
|
||||
'credentials_schema': credentials_schema
|
||||
'credentials_schema': credentials_schema,
|
||||
'provider_id': db_provider.id or '',
|
||||
})
|
||||
|
||||
@property
|
||||
@@ -120,7 +123,8 @@ class ApiBasedToolProviderController(ToolProviderController):
|
||||
'en_US': tool_bundle.operation_id,
|
||||
'zh_Hans': tool_bundle.operation_id
|
||||
},
|
||||
'icon': tool_bundle.icon if tool_bundle.icon else ''
|
||||
'icon': self.identity.icon,
|
||||
'provider': self.provider_id,
|
||||
},
|
||||
'description': {
|
||||
'human': {
|
||||
|
||||
@@ -189,7 +189,7 @@ class StableDiffusionTool(BuiltinTool):
|
||||
if not base_url:
|
||||
return []
|
||||
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
|
||||
response = get(url=api_url, timeout=10)
|
||||
response = get(url=api_url, timeout=(2, 10))
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
else:
|
||||
|
||||
@@ -49,6 +49,7 @@ parameters:
|
||||
zh_Hans: Lora
|
||||
pt_BR: Lora
|
||||
form: form
|
||||
default: ""
|
||||
- name: steps
|
||||
type: number
|
||||
required: false
|
||||
|
||||
@@ -68,6 +68,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'builtin', provider, 'tools', f'{tool_name}.py'),
|
||||
parent_type=BuiltinTool)
|
||||
tool["identity"]["provider"] = provider
|
||||
tools.append(assistant_tool_class(**tool))
|
||||
|
||||
self.tools = tools
|
||||
@@ -126,7 +127,8 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
|
||||
:return: whether the provider needs credentials
|
||||
"""
|
||||
return self.credentials_schema is not None and len(self.credentials_schema) != 0
|
||||
return self.credentials_schema is not None and \
|
||||
len(self.credentials_schema) != 0
|
||||
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
|
||||
@@ -8,7 +8,8 @@ import requests
|
||||
|
||||
import core.helper.ssrf_proxy as ssrf_proxy
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
@@ -34,7 +35,7 @@ class ApiTool(Tool):
|
||||
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
|
||||
runtime=Tool.Runtime(**meta)
|
||||
)
|
||||
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
|
||||
"""
|
||||
validate the credentials for Api tool
|
||||
@@ -49,6 +50,9 @@ class ApiTool(Tool):
|
||||
# validate response
|
||||
return self.validate_and_parse_response(response)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return UserToolProvider.ProviderType.API
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
headers = {}
|
||||
credentials = self.runtime.credentials or {}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.model.tool_model_manager import ToolModelManager
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.web_reader_tool import get_url
|
||||
@@ -40,6 +42,9 @@ class BuiltinTool(Tool):
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return UserToolProvider.ProviderType.BUILTIN
|
||||
|
||||
def get_max_tokens(self) -> int:
|
||||
"""
|
||||
get max tokens
|
||||
|
||||
@@ -2,11 +2,18 @@ from typing import Any
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
|
||||
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
|
||||
@@ -30,7 +37,7 @@ class DatasetRetrieverTool(Tool):
|
||||
if retrieve_config is None:
|
||||
return []
|
||||
|
||||
feature = DatasetRetrievalFeature()
|
||||
feature = DatasetRetrieval()
|
||||
|
||||
# save original retrieve strategy, and set retrieve strategy to SINGLE
|
||||
# Agent only support SINGLE mode
|
||||
@@ -52,7 +59,7 @@ class DatasetRetrieverTool(Tool):
|
||||
for langchain_tool in langchain_tools:
|
||||
tool = DatasetRetrieverTool(
|
||||
langchain_tool=langchain_tool,
|
||||
identity=ToolIdentity(author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
||||
identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
||||
parameters=[],
|
||||
is_team_authorization=True,
|
||||
description=ToolDescription(
|
||||
@@ -76,6 +83,9 @@ class DatasetRetrieverTool(Tool):
|
||||
required=True,
|
||||
default=''),
|
||||
]
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.DATASET_RETRIEVAL
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage
|
||||
from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
VISION_PROMPT = """## Image Recognition Task
|
||||
@@ -79,6 +79,9 @@ class ModelTool(Tool):
|
||||
"""
|
||||
pass
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -2,14 +2,14 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
ToolRuntimeImageVariable,
|
||||
ToolRuntimeVariable,
|
||||
ToolRuntimeVariablePool,
|
||||
@@ -22,8 +22,13 @@ class Tool(BaseModel, ABC):
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
description: ToolDescription = None
|
||||
is_team_authorization: bool = False
|
||||
agent_callback: Optional[DifyAgentCallbackHandler] = None
|
||||
use_callback: bool = False
|
||||
|
||||
@validator('parameters', pre=True, always=True)
|
||||
def set_parameters(cls, v, values):
|
||||
if not v:
|
||||
return []
|
||||
|
||||
return v
|
||||
|
||||
class Runtime(BaseModel):
|
||||
"""
|
||||
@@ -45,15 +50,10 @@ class Tool(BaseModel, ABC):
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
|
||||
if not self.agent_callback:
|
||||
self.use_callback = False
|
||||
else:
|
||||
self.use_callback = True
|
||||
|
||||
class VARIABLE_KEY(Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
def fork_tool_runtime(self, meta: dict[str, Any], agent_callback: DifyAgentCallbackHandler = None) -> 'Tool':
|
||||
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@@ -65,9 +65,16 @@ class Tool(BaseModel, ABC):
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
runtime=Tool.Runtime(**meta),
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
get the tool provider type
|
||||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
|
||||
def load_variables(self, variables: ToolRuntimeVariablePool):
|
||||
"""
|
||||
load variables from database
|
||||
@@ -174,50 +181,22 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
return result
|
||||
|
||||
def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]:
|
||||
# check if tool_parameters is a string
|
||||
if isinstance(tool_parameters, str):
|
||||
# check if this tool has only one parameter
|
||||
parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM]
|
||||
if parameters and len(parameters) == 1:
|
||||
tool_parameters = {
|
||||
parameters[0].name: tool_parameters
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
# update tool_parameters
|
||||
if self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
# hit callback
|
||||
if self.use_callback:
|
||||
self.agent_callback.on_tool_start(
|
||||
tool_name=self.identity.name,
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
# try parse tool parameters into the correct type
|
||||
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
|
||||
|
||||
try:
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
)
|
||||
except Exception as e:
|
||||
if self.use_callback:
|
||||
self.agent_callback.on_tool_error(e)
|
||||
raise e
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
)
|
||||
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
|
||||
# hit callback
|
||||
if self.use_callback:
|
||||
self.agent_callback.on_tool_end(
|
||||
tool_name=self.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=self._convert_tool_response_to_str(result)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
@@ -242,6 +221,31 @@ class Tool(BaseModel, ABC):
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
return result
|
||||
|
||||
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform tool parameters type
|
||||
"""
|
||||
for parameter in self.parameters:
|
||||
if parameter.name in tool_parameters:
|
||||
if parameter.type in [
|
||||
ToolParameter.ToolParameterType.SECRET_INPUT,
|
||||
ToolParameter.ToolParameterType.STRING,
|
||||
ToolParameter.ToolParameterType.SELECT,
|
||||
] and not isinstance(tool_parameters[parameter.name], str):
|
||||
tool_parameters[parameter.name] = str(tool_parameters[parameter.name])
|
||||
elif parameter.type == ToolParameter.ToolParameterType.NUMBER \
|
||||
and not isinstance(tool_parameters[parameter.name], int | float):
|
||||
if isinstance(tool_parameters[parameter.name], str):
|
||||
try:
|
||||
tool_parameters[parameter.name] = int(tool_parameters[parameter.name])
|
||||
except ValueError:
|
||||
tool_parameters[parameter.name] = float(tool_parameters[parameter.name])
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
if not isinstance(tool_parameters[parameter.name], bool):
|
||||
tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
|
||||
|
||||
return tool_parameters
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
273
api/core/tools/tool_engine.py
Normal file
273
api/core/tools/tool_engine.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
|
||||
from core.tools.errors import (
|
||||
ToolEngineInvokeError,
|
||||
ToolInvokeError,
|
||||
ToolNotFoundError,
|
||||
ToolNotSupportedError,
|
||||
ToolParameterValidationError,
|
||||
ToolProviderCredentialValidationError,
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
|
||||
class ToolEngine:
|
||||
"""
|
||||
Tool runtime engine take care of the tool executions.
|
||||
"""
|
||||
@staticmethod
|
||||
def agent_invoke(tool: Tool, tool_parameters: Union[str, dict],
|
||||
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
|
||||
agent_tool_callback: DifyAgentCallbackHandler) \
|
||||
-> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
|
||||
"""
|
||||
Agent invokes the tool with the given arguments.
|
||||
"""
|
||||
# check if arguments is a string
|
||||
if isinstance(tool_parameters, str):
|
||||
# check if this tool has only one parameter
|
||||
parameters = [
|
||||
parameter for parameter in tool.parameters
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM
|
||||
]
|
||||
if parameters and len(parameters) == 1:
|
||||
tool_parameters = {
|
||||
parameters[0].name: tool_parameters
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
# invoke the tool
|
||||
try:
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_start(
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
|
||||
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=response,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=message.conversation_id
|
||||
)
|
||||
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = ToolEngine._extract_tool_response_binary(response)
|
||||
# create message file
|
||||
message_files = ToolEngine._create_message_files(
|
||||
tool_messages=binary_files,
|
||||
agent_message=message,
|
||||
invoke_from=invoke_from,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
plain_text = ToolEngine._convert_tool_response_to_str(response)
|
||||
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_end(
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=plain_text
|
||||
)
|
||||
|
||||
# transform tool invoke message to get LLM friendly message
|
||||
return plain_text, message_files, meta
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = "Please check your tool provider credentials"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
error_response = f"there is not a tool named {tool.identity.name}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except (
|
||||
ToolParameterValidationError
|
||||
) as e:
|
||||
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolEngineInvokeError as e:
|
||||
meta = e.args[0]
|
||||
error_response = f"tool invoke error: {meta.error}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
return error_response, [], meta
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
|
||||
return error_response, [], ToolInvokeMeta.error_instance(error_response)
|
||||
|
||||
@staticmethod
|
||||
def workflow_invoke(tool: Tool, tool_parameters: dict,
|
||||
user_id: str, workflow_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler) \
|
||||
-> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
"""
|
||||
try:
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_tool_start(
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_tool_end(
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=response
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
workflow_tool_callback.on_tool_error(e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \
|
||||
-> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke the tool with the given arguments.
|
||||
"""
|
||||
started_at = datetime.now(timezone.utc)
|
||||
meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={
|
||||
'tool_name': tool.identity.name,
|
||||
'tool_provider': tool.identity.provider,
|
||||
'tool_provider_type': tool.tool_provider_type().value,
|
||||
'tool_parameters': deepcopy(tool.runtime.runtime_parameters),
|
||||
'tool_icon': tool.identity.icon
|
||||
})
|
||||
try:
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
except Exception as e:
|
||||
meta.error = str(e)
|
||||
raise ToolEngineInvokeError(meta)
|
||||
finally:
|
||||
ended_at = datetime.now(timezone.utc)
|
||||
meta.time_cost = (ended_at - started_at).total_seconds()
|
||||
|
||||
return meta, response
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
result = ''
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += response.message
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and 'mime_type' in response.meta:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _create_message_files(
|
||||
tool_messages: list[ToolInvokeMessageBinary],
|
||||
agent_message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
user_id: str
|
||||
) -> list[tuple[MessageFile, bool]]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
:param messages: messages
|
||||
:return: message files, should save as variable
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in tool_messages:
|
||||
file_type = 'bin'
|
||||
if 'image' in message.mimetype:
|
||||
file_type = 'image'
|
||||
elif 'video' in message.mimetype:
|
||||
file_type = 'video'
|
||||
elif 'audio' in message.mimetype:
|
||||
file_type = 'audio'
|
||||
elif 'text' in message.mimetype:
|
||||
file_type = 'text'
|
||||
elif 'pdf' in message.mimetype:
|
||||
file_type = 'pdf'
|
||||
elif 'zip' in message.mimetype:
|
||||
file_type = 'archive'
|
||||
# ...
|
||||
|
||||
message_file = MessageFile(
|
||||
message_id=agent_message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE.value,
|
||||
belongs_to='assistant',
|
||||
url=message.url,
|
||||
upload_file_id=None,
|
||||
created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
|
||||
created_by=user_id,
|
||||
)
|
||||
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
result.append((
|
||||
message_file,
|
||||
message.save_as
|
||||
))
|
||||
|
||||
db.session.close()
|
||||
|
||||
return result
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import current_app
|
||||
@@ -19,18 +19,19 @@ from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolFileManager:
|
||||
@staticmethod
|
||||
def sign_file(file_id: str, extension: str) -> str:
|
||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
"""
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
file_preview_url = f'{base_url}/files/tools/{file_id}{extension}'
|
||||
file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}'
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
@@ -55,10 +56,10 @@ class ToolFileManager:
|
||||
return current_time - int(timestamp) <= 300 # expired after 5 minutes
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_raw(user_id: str, tenant_id: str,
|
||||
conversation_id: str, file_binary: bytes,
|
||||
mimetype: str
|
||||
) -> ToolFile:
|
||||
def create_file_by_raw(user_id: str, tenant_id: str,
|
||||
conversation_id: Optional[str], file_binary: bytes,
|
||||
mimetype: str
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
"""
|
||||
@@ -74,11 +75,11 @@ class ToolFileManager:
|
||||
db.session.commit()
|
||||
|
||||
return tool_file
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_url(user_id: str, tenant_id: str,
|
||||
conversation_id: str, file_url: str,
|
||||
) -> ToolFile:
|
||||
def create_file_by_url(user_id: str, tenant_id: str,
|
||||
conversation_id: str, file_url: str,
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
"""
|
||||
@@ -93,26 +94,26 @@ class ToolFileManager:
|
||||
storage.save(filename, blob)
|
||||
|
||||
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
|
||||
conversation_id=conversation_id, file_key=filename,
|
||||
conversation_id=conversation_id, file_key=filename,
|
||||
mimetype=mimetype, original_url=file_url)
|
||||
|
||||
|
||||
db.session.add(tool_file)
|
||||
db.session.commit()
|
||||
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_key(user_id: str, tenant_id: str,
|
||||
conversation_id: str, file_key: str,
|
||||
mimetype: str
|
||||
) -> ToolFile:
|
||||
def create_file_by_key(user_id: str, tenant_id: str,
|
||||
conversation_id: str, file_key: str,
|
||||
mimetype: str
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
"""
|
||||
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
|
||||
conversation_id=conversation_id, file_key=file_key, mimetype=mimetype)
|
||||
return tool_file
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
@@ -132,7 +133,7 @@ class ToolFileManager:
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
@@ -161,25 +162,16 @@ class ToolFileManager:
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_file_generator_by_message_file_id(id: str) -> Union[tuple[Generator, str], None]:
|
||||
def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
:param tool_file_id: the id of the tool file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile = db.session.query(MessageFile).filter(
|
||||
MessageFile.id == id,
|
||||
).first()
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split('/')[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split('.')[0]
|
||||
|
||||
tool_file: ToolFile = db.session.query(ToolFile).filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
).first()
|
||||
@@ -190,7 +182,8 @@ class ToolFileManager:
|
||||
generator = storage.load_stream(tool_file.file_key)
|
||||
|
||||
return generator, tool_file.mimetype
|
||||
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
from core.file.tool_file_parser import tool_file_manager
|
||||
|
||||
|
||||
@@ -1,105 +1,68 @@
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import Any, Union
|
||||
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from flask import current_app
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools import *
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.constant import DEFAULT_PROVIDERS
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
)
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
|
||||
from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.provider.model_tool_provider import ModelToolProviderController
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.configuration import (
|
||||
ModelToolConfigurationManager,
|
||||
ToolConfigurationManager,
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.encoder import serialize_base_model_dict
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider
|
||||
from services.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_builtin_providers = {}
|
||||
_builtin_tools_labels = {}
|
||||
|
||||
class ToolManager:
|
||||
@staticmethod
|
||||
def invoke(
|
||||
provider: str,
|
||||
tool_id: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
invoke the assistant
|
||||
_builtin_provider_lock = Lock()
|
||||
_builtin_providers = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels = {}
|
||||
|
||||
:param provider: the name of the provider
|
||||
:param tool_id: the id of the tool
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param tool_parameters: the parameters of the tool
|
||||
:param credentials: the credentials of the tool
|
||||
:param prompt_messages: the prompt messages that the tool can use
|
||||
|
||||
:return: the messages that the tool wants to send to the user
|
||||
"""
|
||||
provider_entity: ToolProviderController = None
|
||||
if provider == DEFAULT_PROVIDERS.API_BASED:
|
||||
provider_entity = ApiBasedToolProviderController()
|
||||
elif provider == DEFAULT_PROVIDERS.APP_BASED:
|
||||
provider_entity = AppBasedToolProviderEntity()
|
||||
|
||||
if provider_entity is None:
|
||||
# fetch the provider from .provider.builtin
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=ToolProviderController)
|
||||
provider_entity = provider_class()
|
||||
|
||||
return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_provider(provider: str) -> BuiltinToolProviderController:
|
||||
global _builtin_providers
|
||||
@classmethod
|
||||
def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
"""
|
||||
get the builtin provider
|
||||
|
||||
:param provider: the name of the provider
|
||||
:return: the provider
|
||||
"""
|
||||
if len(_builtin_providers) == 0:
|
||||
if len(cls._builtin_providers) == 0:
|
||||
# init the builtin providers
|
||||
ToolManager.list_builtin_providers()
|
||||
cls.load_builtin_providers_cache()
|
||||
|
||||
if provider not in _builtin_providers:
|
||||
if provider not in cls._builtin_providers:
|
||||
raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
|
||||
|
||||
return _builtin_providers[provider]
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool:
|
||||
|
||||
return cls._builtin_providers[provider]
|
||||
|
||||
@classmethod
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
|
||||
"""
|
||||
get the builtin tool
|
||||
|
||||
@@ -108,14 +71,14 @@ class ToolManager:
|
||||
|
||||
:return: the provider, the tool
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider)
|
||||
provider_controller = cls.get_builtin_provider(provider)
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
|
||||
return tool
|
||||
|
||||
@staticmethod
|
||||
def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
|
||||
@classmethod
|
||||
def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
"""
|
||||
get the tool
|
||||
|
||||
@@ -126,20 +89,19 @@ class ToolManager:
|
||||
:return: the tool
|
||||
"""
|
||||
if provider_type == 'builtin':
|
||||
return ToolManager.get_builtin_tool(provider_id, tool_name)
|
||||
return cls.get_builtin_tool(provider_id, tool_name)
|
||||
elif provider_type == 'api':
|
||||
if tenant_id is None:
|
||||
raise ValueError('tenant id is required for api provider')
|
||||
api_provider, _ = ToolManager.get_api_provider_controller(tenant_id, provider_id)
|
||||
api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
return api_provider.get_tool(tool_name)
|
||||
elif provider_type == 'app':
|
||||
raise NotImplementedError('app provider not implemented')
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||
|
||||
@staticmethod
|
||||
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str,
|
||||
agent_callback: DifyAgentCallbackHandler = None) \
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
@@ -151,15 +113,15 @@ class ToolManager:
|
||||
:return: the tool
|
||||
"""
|
||||
if provider_type == 'builtin':
|
||||
builtin_tool = ToolManager.get_builtin_tool(provider_name, tool_name)
|
||||
builtin_tool = cls.get_builtin_tool(provider_name, tool_name)
|
||||
|
||||
# check if the builtin tool need credentials
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
provider_controller = cls.get_builtin_provider(provider_name)
|
||||
if not provider_controller.need_credentials:
|
||||
return builtin_tool.fork_tool_runtime(meta={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': {},
|
||||
}, agent_callback=agent_callback)
|
||||
})
|
||||
|
||||
# get credentials
|
||||
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
@@ -169,10 +131,10 @@ class ToolManager:
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
|
||||
|
||||
|
||||
# decrypt the credentials
|
||||
credentials = builtin_provider.credentials
|
||||
controller = ToolManager.get_builtin_provider(provider_name)
|
||||
controller = cls.get_builtin_provider(provider_name)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
@@ -181,13 +143,13 @@ class ToolManager:
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': decrypted_credentials,
|
||||
'runtime_parameters': {}
|
||||
}, agent_callback=agent_callback)
|
||||
})
|
||||
|
||||
elif provider_type == 'api':
|
||||
if tenant_id is None:
|
||||
raise ValueError('tenant id is required for api provider')
|
||||
|
||||
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
|
||||
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name)
|
||||
|
||||
# decrypt the credentials
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
|
||||
@@ -201,7 +163,7 @@ class ToolManager:
|
||||
if tenant_id is None:
|
||||
raise ValueError('tenant id is required for model provider')
|
||||
# get model provider
|
||||
model_provider = ToolManager.get_model_provider(tenant_id, provider_name)
|
||||
model_provider = cls.get_model_provider(tenant_id, provider_name)
|
||||
|
||||
# get tool
|
||||
model_tool = model_provider.get_tool(tool_name)
|
||||
@@ -215,59 +177,68 @@ class ToolManager:
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||
|
||||
@staticmethod
|
||||
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
|
||||
@classmethod
|
||||
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
|
||||
"""
|
||||
init runtime parameter
|
||||
"""
|
||||
parameter_value = parameters.get(parameter_rule.name)
|
||||
if not parameter_value:
|
||||
# get default value
|
||||
parameter_value = parameter_rule.default
|
||||
if not parameter_value and parameter_rule.required:
|
||||
raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
|
||||
|
||||
if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter_rule.options))
|
||||
if parameter_value not in options:
|
||||
raise ValueError(
|
||||
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
|
||||
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
# check if tool parameter is integer
|
||||
if isinstance(parameter_value, int):
|
||||
parameter_value = parameter_value
|
||||
elif isinstance(parameter_value, float):
|
||||
parameter_value = parameter_value
|
||||
elif isinstance(parameter_value, str):
|
||||
if '.' in parameter_value:
|
||||
parameter_value = float(parameter_value)
|
||||
else:
|
||||
parameter_value = int(parameter_value)
|
||||
elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
parameter_value = bool(parameter_value)
|
||||
elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT,
|
||||
ToolParameter.ToolParameterType.STRING]:
|
||||
parameter_value = str(parameter_value)
|
||||
elif parameter_rule.type == ToolParameter.ToolParameterType:
|
||||
parameter_value = str(parameter_value)
|
||||
except Exception as e:
|
||||
raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type")
|
||||
|
||||
return parameter_value
|
||||
|
||||
@classmethod
|
||||
def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
"""
|
||||
tool_entity = ToolManager.get_tool_runtime(
|
||||
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name,
|
||||
tool_entity = cls.get_tool_runtime(
|
||||
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id,
|
||||
tool_name=agent_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# get tool parameter from form
|
||||
tool_parameter_config = agent_tool.tool_parameters.get(parameter.name)
|
||||
if not tool_parameter_config:
|
||||
# get default value
|
||||
tool_parameter_config = parameter.default
|
||||
if not tool_parameter_config and parameter.required:
|
||||
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
|
||||
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter.options))
|
||||
if tool_parameter_config not in options:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
|
||||
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
# check if tool parameter is integer
|
||||
if isinstance(tool_parameter_config, int):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, float):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, str):
|
||||
if '.' in tool_parameter_config:
|
||||
tool_parameter_config = float(tool_parameter_config)
|
||||
else:
|
||||
tool_parameter_config = int(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
tool_parameter_config = bool(tool_parameter_config)
|
||||
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
|
||||
|
||||
# save tool parameter to tool entity memory
|
||||
runtime_parameters[parameter.name] = tool_parameter_config
|
||||
|
||||
value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
# decrypt runtime parameters
|
||||
encryption_manager = ToolParameterConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
@@ -280,8 +251,42 @@ class ToolManager:
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
|
||||
@classmethod
|
||||
def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity):
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
"""
|
||||
tool_entity = cls.get_tool_runtime(
|
||||
provider_type=workflow_tool.provider_type,
|
||||
provider_name=workflow_tool.provider_id,
|
||||
tool_name=workflow_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
|
||||
for parameter in parameters:
|
||||
# save tool parameter to tool entity memory
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
# decrypt runtime parameters
|
||||
encryption_manager = ToolParameterConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
tool_runtime=tool_entity,
|
||||
provider_name=workflow_tool.provider_id,
|
||||
provider_type=workflow_tool.provider_type,
|
||||
)
|
||||
|
||||
if runtime_parameters:
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
@classmethod
|
||||
def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]:
|
||||
"""
|
||||
get the absolute path of the icon of the builtin provider
|
||||
|
||||
@@ -290,28 +295,39 @@ class ToolManager:
|
||||
:return: the absolute path of the icon, the mime type of the icon
|
||||
"""
|
||||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider)
|
||||
provider_controller = cls.get_builtin_provider(provider)
|
||||
|
||||
absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', provider_controller.identity.icon)
|
||||
absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets',
|
||||
provider_controller.identity.icon)
|
||||
# check if the icon exists
|
||||
if not path.exists(absolute_path):
|
||||
raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found')
|
||||
|
||||
|
||||
# get the mime type
|
||||
mime_type, _ = mimetypes.guess_type(absolute_path)
|
||||
mime_type = mime_type or 'application/octet-stream'
|
||||
|
||||
return absolute_path, mime_type
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_providers() -> list[BuiltinToolProviderController]:
|
||||
global _builtin_providers
|
||||
|
||||
@classmethod
|
||||
def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
|
||||
# use cache first
|
||||
if len(_builtin_providers) > 0:
|
||||
return list(_builtin_providers.values())
|
||||
if cls._builtin_providers_loaded:
|
||||
yield from list(cls._builtin_providers.values())
|
||||
return
|
||||
|
||||
builtin_providers: list[BuiltinToolProviderController] = []
|
||||
with cls._builtin_provider_lock:
|
||||
if cls._builtin_providers_loaded:
|
||||
yield from list(cls._builtin_providers.values())
|
||||
return
|
||||
|
||||
yield from cls._list_builtin_providers()
|
||||
|
||||
@classmethod
|
||||
def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
|
||||
"""
|
||||
list all the builtin providers
|
||||
"""
|
||||
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
|
||||
if provider.startswith('__'):
|
||||
continue
|
||||
@@ -321,48 +337,61 @@ class ToolManager:
|
||||
continue
|
||||
|
||||
# init provider
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
builtin_providers.append(provider_class())
|
||||
try:
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
provider: BuiltinToolProviderController = provider_class()
|
||||
cls._builtin_providers[provider.identity.name] = provider
|
||||
for tool in provider.get_tools():
|
||||
cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
|
||||
yield provider
|
||||
|
||||
# cache the builtin providers
|
||||
for provider in builtin_providers:
|
||||
_builtin_providers[provider.identity.name] = provider
|
||||
for tool in provider.get_tools():
|
||||
_builtin_tools_labels[tool.identity.name] = tool.identity.label
|
||||
except Exception as e:
|
||||
logger.error(f'load builtin provider {provider} error: {e}')
|
||||
continue
|
||||
# set builtin providers loaded
|
||||
cls._builtin_providers_loaded = True
|
||||
|
||||
return builtin_providers
|
||||
|
||||
@staticmethod
|
||||
def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]:
|
||||
"""
|
||||
list all the model providers
|
||||
@classmethod
|
||||
def load_builtin_providers_cache(cls):
|
||||
for _ in cls.list_builtin_providers():
|
||||
pass
|
||||
|
||||
:return: the list of the model providers
|
||||
"""
|
||||
tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
|
||||
# get configurations
|
||||
model_configurations = ModelToolConfigurationManager.get_all_configuration()
|
||||
# get all providers
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id).values()
|
||||
# get model providers
|
||||
model_providers: list[ModelToolProviderController] = []
|
||||
for configuration in configurations:
|
||||
# all the model tool should be configurated
|
||||
if configuration.provider.provider not in model_configurations:
|
||||
continue
|
||||
if not ModelToolProviderController.is_configuration_valid(configuration):
|
||||
continue
|
||||
model_providers.append(ModelToolProviderController.from_db(configuration))
|
||||
@classmethod
|
||||
def clear_builtin_providers_cache(cls):
|
||||
cls._builtin_providers = {}
|
||||
cls._builtin_providers_loaded = False
|
||||
|
||||
return model_providers
|
||||
|
||||
@staticmethod
|
||||
def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController:
|
||||
# @classmethod
|
||||
# def list_model_providers(cls, tenant_id: str = None) -> list[ModelToolProviderController]:
|
||||
# """
|
||||
# list all the model providers
|
||||
|
||||
# :return: the list of the model providers
|
||||
# """
|
||||
# tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
|
||||
# # get configurations
|
||||
# model_configurations = ModelToolConfigurationManager.get_all_configuration()
|
||||
# # get all providers
|
||||
# provider_manager = ProviderManager()
|
||||
# configurations = provider_manager.get_configurations(tenant_id).values()
|
||||
# # get model providers
|
||||
# model_providers: list[ModelToolProviderController] = []
|
||||
# for configuration in configurations:
|
||||
# # all the model tool should be configurated
|
||||
# if configuration.provider.provider not in model_configurations:
|
||||
# continue
|
||||
# if not ModelToolProviderController.is_configuration_valid(configuration):
|
||||
# continue
|
||||
# model_providers.append(ModelToolProviderController.from_db(configuration))
|
||||
|
||||
# return model_providers
|
||||
|
||||
@classmethod
|
||||
def get_model_provider(cls, tenant_id: str, provider_name: str) -> ModelToolProviderController:
|
||||
"""
|
||||
get the model provider
|
||||
|
||||
@@ -376,11 +405,11 @@ class ToolManager:
|
||||
configuration = configurations.get(provider_name)
|
||||
if configuration is None:
|
||||
raise ToolProviderNotFoundError(f'model provider {provider_name} not found')
|
||||
|
||||
|
||||
return ModelToolProviderController.from_db(configuration)
|
||||
|
||||
@staticmethod
|
||||
def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
|
||||
@classmethod
|
||||
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
|
||||
"""
|
||||
get the tool label
|
||||
|
||||
@@ -388,148 +417,71 @@ class ToolManager:
|
||||
|
||||
:return: the label of the tool
|
||||
"""
|
||||
global _builtin_tools_labels
|
||||
if len(_builtin_tools_labels) == 0:
|
||||
cls._builtin_tools_labels
|
||||
if len(cls._builtin_tools_labels) == 0:
|
||||
# init the builtin providers
|
||||
ToolManager.list_builtin_providers()
|
||||
cls.load_builtin_providers_cache()
|
||||
|
||||
if tool_name not in _builtin_tools_labels:
|
||||
if tool_name not in cls._builtin_tools_labels:
|
||||
return None
|
||||
|
||||
return _builtin_tools_labels[tool_name]
|
||||
|
||||
@staticmethod
|
||||
def user_list_providers(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
) -> list[UserToolProvider]:
|
||||
|
||||
return cls._builtin_tools_labels[tool_name]
|
||||
|
||||
@classmethod
|
||||
def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
result_providers: dict[str, UserToolProvider] = {}
|
||||
|
||||
# get builtin providers
|
||||
builtin_providers = ToolManager.list_builtin_providers()
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
result_providers[provider.identity.name] = UserToolProvider(
|
||||
id=provider.identity.name,
|
||||
author=provider.identity.author,
|
||||
name=provider.identity.name,
|
||||
description=I18nObject(
|
||||
en_US=provider.identity.description.en_US,
|
||||
zh_Hans=provider.identity.description.zh_Hans,
|
||||
),
|
||||
icon=provider.identity.icon,
|
||||
label=I18nObject(
|
||||
en_US=provider.identity.label.en_US,
|
||||
zh_Hans=provider.identity.label.zh_Hans,
|
||||
),
|
||||
type=UserToolProvider.ProviderType.BUILTIN,
|
||||
team_credentials={},
|
||||
is_team_authorization=False,
|
||||
)
|
||||
|
||||
# get credentials schema
|
||||
schema = provider.get_credentials_schema()
|
||||
for name, value in schema.items():
|
||||
result_providers[provider.identity.name].team_credentials[name] = \
|
||||
ToolProviderCredentials.CredentialsType.default(value.type)
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider.need_credentials:
|
||||
result_providers[provider.identity.name].is_team_authorization = True
|
||||
result_providers[provider.identity.name].allow_delete = False
|
||||
|
||||
builtin_providers = cls.list_builtin_providers()
|
||||
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
|
||||
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
for db_builtin_provider in db_builtin_providers:
|
||||
# add provider into providers
|
||||
credentials = db_builtin_provider.credentials
|
||||
provider_name = db_builtin_provider.provider
|
||||
result_providers[provider_name].is_team_authorization = True
|
||||
|
||||
# package builtin tool provider controller
|
||||
controller = ToolManager.get_builtin_provider(provider_name)
|
||||
find_db_builtin_provider = lambda provider: next(
|
||||
(x for x in db_builtin_providers if x.provider == provider),
|
||||
None
|
||||
)
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
||||
|
||||
result_providers[provider_name].team_credentials = masked_credentials
|
||||
|
||||
# get model tool providers
|
||||
model_providers = ToolManager.list_model_providers(tenant_id=tenant_id)
|
||||
# append model providers
|
||||
for provider in model_providers:
|
||||
result_providers[f'model_provider.{provider.identity.name}'] = UserToolProvider(
|
||||
id=provider.identity.name,
|
||||
author=provider.identity.author,
|
||||
name=provider.identity.name,
|
||||
description=I18nObject(
|
||||
en_US=provider.identity.description.en_US,
|
||||
zh_Hans=provider.identity.description.zh_Hans,
|
||||
),
|
||||
icon=provider.identity.icon,
|
||||
label=I18nObject(
|
||||
en_US=provider.identity.label.en_US,
|
||||
zh_Hans=provider.identity.label.zh_Hans,
|
||||
),
|
||||
type=UserToolProvider.ProviderType.MODEL,
|
||||
team_credentials={},
|
||||
is_team_authorization=provider.is_active,
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.identity.name),
|
||||
decrypt_credentials=False
|
||||
)
|
||||
|
||||
result_providers[provider.identity.name] = user_provider
|
||||
|
||||
# # get model tool providers
|
||||
# model_providers = cls.list_model_providers(tenant_id=tenant_id)
|
||||
# # append model providers
|
||||
# for provider in model_providers:
|
||||
# user_provider = ToolTransformService.model_provider_to_user_provider(
|
||||
# db_provider=provider,
|
||||
# )
|
||||
# result_providers[f'model_provider.{provider.identity.name}'] = user_provider
|
||||
|
||||
# get db api providers
|
||||
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
|
||||
filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
|
||||
for db_api_provider in db_api_providers:
|
||||
username = 'Anonymous'
|
||||
try:
|
||||
username = db_api_provider.user.name
|
||||
except Exception as e:
|
||||
logger.error(f'failed to get user name for api provider {db_api_provider.id}: {str(e)}')
|
||||
# add provider into providers
|
||||
credentials = db_api_provider.credentials
|
||||
provider_name = db_api_provider.name
|
||||
result_providers[provider_name] = UserToolProvider(
|
||||
id=db_api_provider.id,
|
||||
author=username,
|
||||
name=db_api_provider.name,
|
||||
description=I18nObject(
|
||||
en_US=db_api_provider.description,
|
||||
zh_Hans=db_api_provider.description,
|
||||
),
|
||||
icon=db_api_provider.icon,
|
||||
label=I18nObject(
|
||||
en_US=db_api_provider.name,
|
||||
zh_Hans=db_api_provider.name,
|
||||
),
|
||||
type=UserToolProvider.ProviderType.API,
|
||||
team_credentials={},
|
||||
is_team_authorization=True,
|
||||
)
|
||||
|
||||
# package tool provider controller
|
||||
controller = ApiBasedToolProviderController.from_db(
|
||||
provider_controller = ToolTransformService.api_provider_to_controller(
|
||||
db_provider=db_api_provider,
|
||||
auth_type=ApiProviderAuthType.API_KEY if db_api_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
||||
)
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
||||
|
||||
result_providers[provider_name].team_credentials = masked_credentials
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=db_api_provider,
|
||||
decrypt_credentials=False
|
||||
)
|
||||
result_providers[db_api_provider.name] = user_provider
|
||||
|
||||
return BuiltinToolProviderSort.sort(list(result_providers.values()))
|
||||
|
||||
@staticmethod
|
||||
def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[ApiBasedToolProviderController, dict[str, Any]]:
|
||||
|
||||
@classmethod
|
||||
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
|
||||
ApiBasedToolProviderController, dict[str, Any]]:
|
||||
"""
|
||||
get the api provider
|
||||
|
||||
@@ -544,16 +496,18 @@ class ToolManager:
|
||||
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
|
||||
|
||||
|
||||
controller = ApiBasedToolProviderController.from_db(
|
||||
provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
||||
provider,
|
||||
ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else
|
||||
ApiProviderAuthType.NONE
|
||||
)
|
||||
controller.load_bundled_tools(provider.tools)
|
||||
|
||||
return controller, provider.credentials
|
||||
|
||||
@staticmethod
|
||||
def user_get_api_provider(provider: str, tenant_id: str) -> dict:
|
||||
|
||||
@classmethod
|
||||
def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
|
||||
"""
|
||||
get api provider
|
||||
"""
|
||||
@@ -567,7 +521,7 @@ class ToolManager:
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider}')
|
||||
|
||||
|
||||
try:
|
||||
credentials = json.loads(provider.credentials_str) or {}
|
||||
except:
|
||||
@@ -591,7 +545,7 @@ class ToolManager:
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
return json.loads(serialize_base_model_dict({
|
||||
return jsonable_encoder({
|
||||
'schema_type': provider.schema_type,
|
||||
'schema': provider.schema,
|
||||
'tools': provider.tools,
|
||||
@@ -599,4 +553,38 @@ class ToolManager:
|
||||
'description': provider.description,
|
||||
'credentials': masked_credentials,
|
||||
'privacy_policy': provider.privacy_policy
|
||||
}))
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
:param tenant_id: the id of the tenant
|
||||
:param provider_type: the type of the provider
|
||||
:param provider_id: the id of the provider
|
||||
:return:
|
||||
"""
|
||||
provider_type = provider_type
|
||||
provider_id = provider_id
|
||||
if provider_type == 'builtin':
|
||||
return (current_app.config.get("CONSOLE_API_URL")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
+ provider_id
|
||||
+ "/icon")
|
||||
elif provider_type == 'api':
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.id == provider_id
|
||||
)
|
||||
return json.loads(provider.icon)
|
||||
except:
|
||||
return {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
ToolManager.load_builtin_providers_cache()
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -25,7 +26,7 @@ class ToolConfigurationManager(BaseModel):
|
||||
"""
|
||||
deep copy credentials
|
||||
"""
|
||||
return {key: value for key, value in credentials.items()}
|
||||
return deepcopy(credentials)
|
||||
|
||||
def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def serialize_base_model_array(l: list[BaseModel]) -> str:
|
||||
class _BaseModel(BaseModel):
|
||||
__root__: list[BaseModel]
|
||||
|
||||
"""
|
||||
{"__root__": [BaseModel, BaseModel, ...]}
|
||||
"""
|
||||
return _BaseModel(__root__=l).json()
|
||||
|
||||
def serialize_base_model_dict(b: dict) -> str:
|
||||
class _BaseModel(BaseModel):
|
||||
__root__: dict
|
||||
|
||||
"""
|
||||
{"__root__": {BaseModel}}
|
||||
"""
|
||||
return _BaseModel(__root__=b).json()
|
||||
85
api/core/tools/utils/message_transformer.py
Normal file
85
api/core/tools/utils/message_transformer.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
from mimetypes import guess_extension
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@staticmethod
|
||||
def transform_tool_invoke_messages(messages: list[ToolInvokeMessage],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Transform tool message and handle file download
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in messages:
|
||||
if message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result.append(message)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result.append(message)
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# try to download image
|
||||
try:
|
||||
file = ToolFileManager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_url=message.message
|
||||
)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
save_as=message.save_as,
|
||||
))
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
mimetype = message.meta.get('mime_type', 'octet/stream')
|
||||
# if message is str, encode it to bytes
|
||||
if isinstance(message.message, str):
|
||||
message.message = message.message.encode('utf-8')
|
||||
|
||||
file = ToolFileManager.create_file_by_raw(
|
||||
user_id=user_id, tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message,
|
||||
mimetype=mimetype
|
||||
)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
|
||||
|
||||
# check if file is image
|
||||
if 'image' in mimetype:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(message)
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user