FEAT: NEW WORKFLOW ENGINE (#3160)

Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
takatost
2024-04-08 18:51:46 +08:00
committed by GitHub
parent 2fb9850af5
commit 7753ba2d37
1161 changed files with 103836 additions and 10327 deletions

View File

@@ -11,6 +11,7 @@ class ToolProviderType(Enum):
Enum class for tool provider
"""
BUILT_IN = "built-in"
DATASET_RETRIEVAL = "dataset-retrieval"
APP_BASED = "app-based"
API_BASED = "api-based"
@@ -161,6 +162,8 @@ class ToolIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
provider: str = Field(..., description="The provider of the tool")
icon: Optional[str] = None
class ToolCredentialsOption(BaseModel):
value: str = Field(..., description="The value of the option")
@@ -326,4 +329,33 @@ class ModelToolProviderConfiguration(BaseModel):
"""
provider: str = Field(..., description="The provider of the model tool")
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
label: I18nObject = Field(..., description="The label of the model tool")
label: I18nObject = Field(..., description="The label of the model tool")
class ToolInvokeMeta(BaseModel):
"""
Tool invoke meta
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: Optional[str] = None
tool_config: Optional[dict] = None
@classmethod
def empty(cls) -> 'ToolInvokeMeta':
"""
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> 'ToolInvokeMeta':
"""
Get an instance of ToolInvokeMeta with error
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict:
return {
'time_cost': self.time_cost,
'error': self.error,
'tool_config': self.tool_config,
}

View File

@@ -8,6 +8,13 @@ from core.tools.entities.tool_entities import ToolProviderCredentials
from core.tools.tool.tool import ToolParameter
class UserTool(BaseModel):
author: str
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: Optional[list[ToolParameter]]
class UserToolProvider(BaseModel):
class ProviderType(Enum):
BUILTIN = "builtin"
@@ -22,9 +29,11 @@ class UserToolProvider(BaseModel):
icon: str
label: I18nObject # label
type: ProviderType
team_credentials: dict = None
masked_credentials: dict = None
original_credentials: dict = None
is_team_authorization: bool = False
allow_delete: bool = True
tools: list[UserTool] = None
def to_dict(self) -> dict:
return {
@@ -35,17 +44,11 @@ class UserToolProvider(BaseModel):
'icon': self.icon,
'label': self.label.to_dict(),
'type': self.type.value,
'team_credentials': self.team_credentials,
'team_credentials': self.masked_credentials,
'is_team_authorization': self.is_team_authorization,
'allow_delete': self.allow_delete
'allow_delete': self.allow_delete,
'tools': self.tools
}
class UserToolProviderCredentials(BaseModel):
credentials: dict[str, ToolProviderCredentials]
class UserTool(BaseModel):
author: str
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: Optional[list[ToolParameter]]
credentials: dict[str, ToolProviderCredentials]

View File

@@ -1,3 +1,6 @@
from core.tools.entities.tool_entities import ToolInvokeMeta
class ToolProviderNotFoundError(ValueError):
pass
@@ -17,4 +20,7 @@ class ToolInvokeError(ValueError):
pass
class ToolApiSchemaError(ValueError):
pass
pass
class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta

View File

@@ -16,6 +16,8 @@ from models.tools import ApiToolProvider
class ApiBasedToolProviderController(ToolProviderController):
provider_id: str
@staticmethod
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
credentials_schema = {
@@ -89,9 +91,10 @@ class ApiBasedToolProviderController(ToolProviderController):
'en_US': db_provider.description,
'zh_Hans': db_provider.description
},
'icon': db_provider.icon
'icon': db_provider.icon,
},
'credentials_schema': credentials_schema
'credentials_schema': credentials_schema,
'provider_id': db_provider.id or '',
})
@property
@@ -120,7 +123,8 @@ class ApiBasedToolProviderController(ToolProviderController):
'en_US': tool_bundle.operation_id,
'zh_Hans': tool_bundle.operation_id
},
'icon': tool_bundle.icon if tool_bundle.icon else ''
'icon': self.identity.icon,
'provider': self.provider_id,
},
'description': {
'human': {

View File

@@ -189,7 +189,7 @@ class StableDiffusionTool(BuiltinTool):
if not base_url:
return []
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
response = get(url=api_url, timeout=10)
response = get(url=api_url, timeout=(2, 10))
if response.status_code != 200:
return []
else:

View File

@@ -49,6 +49,7 @@ parameters:
zh_Hans: Lora
pt_BR: Lora
form: form
default: ""
- name: steps
type: number
required: false

View File

@@ -68,6 +68,7 @@ class BuiltinToolProviderController(ToolProviderController):
script_path=path.join(path.dirname(path.realpath(__file__)),
'builtin', provider, 'tools', f'{tool_name}.py'),
parent_type=BuiltinTool)
tool["identity"]["provider"] = provider
tools.append(assistant_tool_class(**tool))
self.tools = tools
@@ -126,7 +127,8 @@ class BuiltinToolProviderController(ToolProviderController):
:return: whether the provider needs credentials
"""
return self.credentials_schema is not None and len(self.credentials_schema) != 0
return self.credentials_schema is not None and \
len(self.credentials_schema) != 0
@property
def app_type(self) -> ToolProviderType:

View File

@@ -8,7 +8,8 @@ import requests
import core.helper.ssrf_proxy as ssrf_proxy
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.entities.user_entities import UserToolProvider
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool
@@ -34,7 +35,7 @@ class ApiTool(Tool):
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
runtime=Tool.Runtime(**meta)
)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
"""
validate the credentials for Api tool
@@ -49,6 +50,9 @@ class ApiTool(Tool):
# validate response
return self.validate_and_parse_response(response)
def tool_provider_type(self) -> ToolProviderType:
return UserToolProvider.ProviderType.API
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
headers = {}
credentials = self.runtime.credentials or {}

View File

@@ -1,6 +1,8 @@
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.user_entities import UserToolProvider
from core.tools.model.tool_model_manager import ToolModelManager
from core.tools.tool.tool import Tool
from core.tools.utils.web_reader_tool import get_url
@@ -40,6 +42,9 @@ class BuiltinTool(Tool):
prompt_messages=prompt_messages,
)
def tool_provider_type(self) -> ToolProviderType:
return UserToolProvider.ProviderType.BUILTIN
def get_max_tokens(self) -> int:
"""
get max tokens

View File

@@ -2,11 +2,18 @@ from typing import Any
from langchain.tools import BaseTool
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.tool.tool import Tool
@@ -30,7 +37,7 @@ class DatasetRetrieverTool(Tool):
if retrieve_config is None:
return []
feature = DatasetRetrievalFeature()
feature = DatasetRetrieval()
# save original retrieve strategy, and set retrieve strategy to SINGLE
# Agent only support SINGLE mode
@@ -52,7 +59,7 @@ class DatasetRetrieverTool(Tool):
for langchain_tool in langchain_tools:
tool = DatasetRetrieverTool(
langchain_tool=langchain_tool,
identity=ToolIdentity(author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
parameters=[],
is_team_authorization=True,
description=ToolDescription(
@@ -76,6 +83,9 @@ class DatasetRetrieverTool(Tool):
required=True,
default=''),
]
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.DATASET_RETRIEVAL
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""

View File

@@ -11,7 +11,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage
from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage, ToolProviderType
from core.tools.tool.tool import Tool
VISION_PROMPT = """## Image Recognition Task
@@ -79,6 +79,9 @@ class ModelTool(Tool):
"""
pass
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.BUILT_IN
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
"""

View File

@@ -2,14 +2,14 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Optional, Union
from pydantic import BaseModel
from pydantic import BaseModel, validator
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
ToolRuntimeImageVariable,
ToolRuntimeVariable,
ToolRuntimeVariablePool,
@@ -22,8 +22,13 @@ class Tool(BaseModel, ABC):
parameters: Optional[list[ToolParameter]] = None
description: ToolDescription = None
is_team_authorization: bool = False
agent_callback: Optional[DifyAgentCallbackHandler] = None
use_callback: bool = False
@validator('parameters', pre=True, always=True)
def set_parameters(cls, v, values):
if not v:
return []
return v
class Runtime(BaseModel):
"""
@@ -45,15 +50,10 @@ class Tool(BaseModel, ABC):
def __init__(self, **data: Any):
super().__init__(**data)
if not self.agent_callback:
self.use_callback = False
else:
self.use_callback = True
class VARIABLE_KEY(Enum):
IMAGE = 'image'
def fork_tool_runtime(self, meta: dict[str, Any], agent_callback: DifyAgentCallbackHandler = None) -> 'Tool':
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
@@ -65,9 +65,16 @@ class Tool(BaseModel, ABC):
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
runtime=Tool.Runtime(**meta),
agent_callback=agent_callback
)
@abstractmethod
def tool_provider_type(self) -> ToolProviderType:
"""
get the tool provider type
:return: the tool provider type
"""
def load_variables(self, variables: ToolRuntimeVariablePool):
"""
load variables from database
@@ -174,50 +181,22 @@ class Tool(BaseModel, ABC):
return result
def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]:
# check if tool_parameters is a string
if isinstance(tool_parameters, str):
# check if this tool has only one parameter
parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM]
if parameters and len(parameters) == 1:
tool_parameters = {
parameters[0].name: tool_parameters
}
else:
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
# update tool_parameters
if self.runtime.runtime_parameters:
tool_parameters.update(self.runtime.runtime_parameters)
# hit callback
if self.use_callback:
self.agent_callback.on_tool_start(
tool_name=self.identity.name,
tool_inputs=tool_parameters
)
# try parse tool parameters into the correct type
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
try:
result = self._invoke(
user_id=user_id,
tool_parameters=tool_parameters,
)
except Exception as e:
if self.use_callback:
self.agent_callback.on_tool_error(e)
raise e
result = self._invoke(
user_id=user_id,
tool_parameters=tool_parameters,
)
if not isinstance(result, list):
result = [result]
# hit callback
if self.use_callback:
self.agent_callback.on_tool_end(
tool_name=self.identity.name,
tool_inputs=tool_parameters,
tool_outputs=self._convert_tool_response_to_str(result)
)
return result
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
@@ -242,6 +221,31 @@ class Tool(BaseModel, ABC):
result += f"tool response: {response.message}."
return result
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
"""
Transform tool parameters type
"""
for parameter in self.parameters:
if parameter.name in tool_parameters:
if parameter.type in [
ToolParameter.ToolParameterType.SECRET_INPUT,
ToolParameter.ToolParameterType.STRING,
ToolParameter.ToolParameterType.SELECT,
] and not isinstance(tool_parameters[parameter.name], str):
tool_parameters[parameter.name] = str(tool_parameters[parameter.name])
elif parameter.type == ToolParameter.ToolParameterType.NUMBER \
and not isinstance(tool_parameters[parameter.name], int | float):
if isinstance(tool_parameters[parameter.name], str):
try:
tool_parameters[parameter.name] = int(tool_parameters[parameter.name])
except ValueError:
tool_parameters[parameter.name] = float(tool_parameters[parameter.name])
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter.name], bool):
tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
return tool_parameters
@abstractmethod
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:

View File

@@ -0,0 +1,273 @@
from copy import deepcopy
from datetime import datetime, timezone
from typing import Union
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
from core.tools.errors import (
ToolEngineInvokeError,
ToolInvokeError,
ToolNotFoundError,
ToolNotSupportedError,
ToolParameterValidationError,
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.tool.tool import Tool
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from extensions.ext_database import db
from models.model import Message, MessageFile
class ToolEngine:
"""
Tool runtime engine take care of the tool executions.
"""
@staticmethod
def agent_invoke(tool: Tool, tool_parameters: Union[str, dict],
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler) \
-> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
"""
Agent invokes the tool with the given arguments.
"""
# check if arguments is a string
if isinstance(tool_parameters, str):
# check if this tool has only one parameter
parameters = [
parameter for parameter in tool.parameters
if parameter.form == ToolParameter.ToolParameterForm.LLM
]
if parameters and len(parameters) == 1:
tool_parameters = {
parameters[0].name: tool_parameters
}
else:
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
# invoke the tool
try:
# hit the callback handler
agent_tool_callback.on_tool_start(
tool_name=tool.identity.name,
tool_inputs=tool_parameters
)
meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=response,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=message.conversation_id
)
# extract binary data from tool invoke message
binary_files = ToolEngine._extract_tool_response_binary(response)
# create message file
message_files = ToolEngine._create_message_files(
tool_messages=binary_files,
agent_message=message,
invoke_from=invoke_from,
user_id=user_id
)
plain_text = ToolEngine._convert_tool_response_to_str(response)
# hit the callback handler
agent_tool_callback.on_tool_end(
tool_name=tool.identity.name,
tool_inputs=tool_parameters,
tool_outputs=plain_text
)
# transform tool invoke message to get LLM friendly message
return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e:
error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e)
except (
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
) as e:
error_response = f"there is not a tool named {tool.identity.name}"
agent_tool_callback.on_tool_error(e)
except (
ToolParameterValidationError
) as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
agent_tool_callback.on_tool_error(e)
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
except ToolEngineInvokeError as e:
meta = e.args[0]
error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e)
return error_response, [], meta
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
return error_response, [], ToolInvokeMeta.error_instance(error_response)
@staticmethod
def workflow_invoke(tool: Tool, tool_parameters: dict,
user_id: str, workflow_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler) \
-> list[ToolInvokeMessage]:
"""
Workflow invokes the tool with the given arguments.
"""
try:
# hit the callback handler
workflow_tool_callback.on_tool_start(
tool_name=tool.identity.name,
tool_inputs=tool_parameters
)
response = tool.invoke(user_id, tool_parameters)
# hit the callback handler
workflow_tool_callback.on_tool_end(
tool_name=tool.identity.name,
tool_inputs=tool_parameters,
tool_outputs=response
)
return response
except Exception as e:
workflow_tool_callback.on_tool_error(e)
raise e
@staticmethod
def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \
-> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]:
"""
Invoke the tool with the given arguments.
"""
started_at = datetime.now(timezone.utc)
meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={
'tool_name': tool.identity.name,
'tool_provider': tool.identity.provider,
'tool_provider_type': tool.tool_provider_type().value,
'tool_parameters': deepcopy(tool.runtime.runtime_parameters),
'tool_icon': tool.identity.icon
})
try:
response = tool.invoke(user_id, tool_parameters)
except Exception as e:
meta.error = str(e)
raise ToolEngineInvokeError(meta)
finally:
ended_at = datetime.now(timezone.utc)
meta.time_cost = (ended_at - started_at).total_seconds()
return meta, response
@staticmethod
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
"""
Handle tool response
"""
result = ''
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please tell user to check it."
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
else:
result += f"tool response: {response.message}."
return result
@staticmethod
def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream'),
url=response.message,
save_as=response.save_as,
))
elif response.type == ToolInvokeMessage.MessageType.BLOB:
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream'),
url=response.message,
save_as=response.save_as,
))
elif response.type == ToolInvokeMessage.MessageType.LINK:
# check if there is a mime type in meta
if response.meta and 'mime_type' in response.meta:
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
url=response.message,
save_as=response.save_as,
))
return result
@staticmethod
def _create_message_files(
tool_messages: list[ToolInvokeMessageBinary],
agent_message: Message,
invoke_from: InvokeFrom,
user_id: str
) -> list[tuple[MessageFile, bool]]:
"""
Create message file
:param messages: messages
:return: message files, should save as variable
"""
result = []
for message in tool_messages:
file_type = 'bin'
if 'image' in message.mimetype:
file_type = 'image'
elif 'video' in message.mimetype:
file_type = 'video'
elif 'audio' in message.mimetype:
file_type = 'audio'
elif 'text' in message.mimetype:
file_type = 'text'
elif 'pdf' in message.mimetype:
file_type = 'pdf'
elif 'zip' in message.mimetype:
file_type = 'archive'
# ...
message_file = MessageFile(
message_id=agent_message.id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE.value,
belongs_to='assistant',
url=message.url,
upload_file_id=None,
created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
created_by=user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
result.append((
message_file,
message.save_as
))
db.session.close()
return result

View File

@@ -6,7 +6,7 @@ import os
import time
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Union
from typing import Optional, Union
from uuid import uuid4
from flask import current_app
@@ -19,18 +19,19 @@ from models.tools import ToolFile
logger = logging.getLogger(__name__)
class ToolFileManager:
@staticmethod
def sign_file(file_id: str, extension: str) -> str:
def sign_file(tool_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url
"""
base_url = current_app.config.get('FILES_URL')
file_preview_url = f'{base_url}/files/tools/{file_id}{extension}'
file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}'
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
@@ -55,10 +56,10 @@ class ToolFileManager:
return current_time - int(timestamp) <= 300 # expired after 5 minutes
@staticmethod
def create_file_by_raw(user_id: str, tenant_id: str,
conversation_id: str, file_binary: bytes,
mimetype: str
) -> ToolFile:
def create_file_by_raw(user_id: str, tenant_id: str,
conversation_id: Optional[str], file_binary: bytes,
mimetype: str
) -> ToolFile:
"""
create file
"""
@@ -74,11 +75,11 @@ class ToolFileManager:
db.session.commit()
return tool_file
@staticmethod
def create_file_by_url(user_id: str, tenant_id: str,
conversation_id: str, file_url: str,
) -> ToolFile:
def create_file_by_url(user_id: str, tenant_id: str,
conversation_id: str, file_url: str,
) -> ToolFile:
"""
create file
"""
@@ -93,26 +94,26 @@ class ToolFileManager:
storage.save(filename, blob)
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
conversation_id=conversation_id, file_key=filename,
conversation_id=conversation_id, file_key=filename,
mimetype=mimetype, original_url=file_url)
db.session.add(tool_file)
db.session.commit()
return tool_file
@staticmethod
def create_file_by_key(user_id: str, tenant_id: str,
conversation_id: str, file_key: str,
mimetype: str
) -> ToolFile:
def create_file_by_key(user_id: str, tenant_id: str,
conversation_id: str, file_key: str,
mimetype: str
) -> ToolFile:
"""
create file
"""
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
conversation_id=conversation_id, file_key=file_key, mimetype=mimetype)
return tool_file
@staticmethod
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
"""
@@ -132,7 +133,7 @@ class ToolFileManager:
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
@staticmethod
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
"""
@@ -161,25 +162,16 @@ class ToolFileManager:
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
@staticmethod
def get_file_generator_by_message_file_id(id: str) -> Union[tuple[Generator, str], None]:
def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]:
"""
get file binary
:param id: the id of the file
:param tool_file_id: the id of the tool file
:return: the binary of the file, mime type
"""
message_file: MessageFile = db.session.query(MessageFile).filter(
MessageFile.id == id,
).first()
# get tool file id
tool_file_id = message_file.url.split('/')[-1]
# trim extension
tool_file_id = tool_file_id.split('.')[0]
tool_file: ToolFile = db.session.query(ToolFile).filter(
ToolFile.id == tool_file_id,
).first()
@@ -190,7 +182,8 @@ class ToolFileManager:
generator = storage.load_stream(tool_file.file_key)
return generator, tool_file.mimetype
# init tool_file_parser
from core.file.tool_file_parser import tool_file_manager

View File

@@ -1,105 +1,68 @@
import json
import logging
import mimetypes
from collections.abc import Generator
from os import listdir, path
from threading import Lock
from typing import Any, Union
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.entities.application_entities import AgentToolEntity
from core.model_runtime.entities.message_entities import PromptMessage
from flask import current_app
from core.agent.entities import AgentToolEntity
from core.model_runtime.utils.encoders import jsonable_encoder
from core.provider_manager import ProviderManager
from core.tools import *
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constant import DEFAULT_PROVIDERS
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeMessage,
ToolParameter,
ToolProviderCredentials,
)
from core.tools.entities.user_entities import UserToolProvider
from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.provider.model_tool_provider import ModelToolProviderController
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.utils.configuration import (
ModelToolConfigurationManager,
ToolConfigurationManager,
ToolParameterConfigurationManager,
)
from core.tools.utils.encoder import serialize_base_model_dict
from core.utils.module_import_helper import load_single_subclass_from_source
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider
from services.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
_builtin_providers = {}
_builtin_tools_labels = {}
class ToolManager:
@staticmethod
def invoke(
provider: str,
tool_id: str,
tool_name: str,
tool_parameters: dict[str, Any],
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
) -> list[ToolInvokeMessage]:
"""
invoke the assistant
_builtin_provider_lock = Lock()
_builtin_providers = {}
_builtin_providers_loaded = False
_builtin_tools_labels = {}
:param provider: the name of the provider
:param tool_id: the id of the tool
:param tool_name: the name of the tool, defined in `get_tools`
:param tool_parameters: the parameters of the tool
:param credentials: the credentials of the tool
:param prompt_messages: the prompt messages that the tool can use
:return: the messages that the tool wants to send to the user
"""
provider_entity: ToolProviderController = None
if provider == DEFAULT_PROVIDERS.API_BASED:
provider_entity = ApiBasedToolProviderController()
elif provider == DEFAULT_PROVIDERS.APP_BASED:
provider_entity = AppBasedToolProviderEntity()
if provider_entity is None:
# fetch the provider from .provider.builtin
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py'),
parent_type=ToolProviderController)
provider_entity = provider_class()
return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
@staticmethod
def get_builtin_provider(provider: str) -> BuiltinToolProviderController:
global _builtin_providers
@classmethod
def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
"""
get the builtin provider
:param provider: the name of the provider
:return: the provider
"""
if len(_builtin_providers) == 0:
if len(cls._builtin_providers) == 0:
# init the builtin providers
ToolManager.list_builtin_providers()
cls.load_builtin_providers_cache()
if provider not in _builtin_providers:
if provider not in cls._builtin_providers:
raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
return _builtin_providers[provider]
@staticmethod
def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool:
return cls._builtin_providers[provider]
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
"""
get the builtin tool
@@ -108,14 +71,14 @@ class ToolManager:
:return: the provider, the tool
"""
provider_controller = ToolManager.get_builtin_provider(provider)
provider_controller = cls.get_builtin_provider(provider)
tool = provider_controller.get_tool(tool_name)
return tool
@staticmethod
def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
-> Union[BuiltinTool, ApiTool]:
@classmethod
def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
-> Union[BuiltinTool, ApiTool]:
"""
get the tool
@@ -126,20 +89,19 @@ class ToolManager:
:return: the tool
"""
if provider_type == 'builtin':
return ToolManager.get_builtin_tool(provider_id, tool_name)
return cls.get_builtin_tool(provider_id, tool_name)
elif provider_type == 'api':
if tenant_id is None:
raise ValueError('tenant id is required for api provider')
api_provider, _ = ToolManager.get_api_provider_controller(tenant_id, provider_id)
api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id)
return api_provider.get_tool(tool_name)
elif provider_type == 'app':
raise NotImplementedError('app provider not implemented')
else:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str,
agent_callback: DifyAgentCallbackHandler = None) \
@classmethod
def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \
-> Union[BuiltinTool, ApiTool]:
"""
get the tool runtime
@@ -151,15 +113,15 @@ class ToolManager:
:return: the tool
"""
if provider_type == 'builtin':
builtin_tool = ToolManager.get_builtin_tool(provider_name, tool_name)
builtin_tool = cls.get_builtin_tool(provider_name, tool_name)
# check if the builtin tool need credentials
provider_controller = ToolManager.get_builtin_provider(provider_name)
provider_controller = cls.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(meta={
'tenant_id': tenant_id,
'credentials': {},
}, agent_callback=agent_callback)
})
# get credentials
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
@@ -169,10 +131,10 @@ class ToolManager:
if builtin_provider is None:
raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
# decrypt the credentials
credentials = builtin_provider.credentials
controller = ToolManager.get_builtin_provider(provider_name)
controller = cls.get_builtin_provider(provider_name)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
@@ -181,13 +143,13 @@ class ToolManager:
'tenant_id': tenant_id,
'credentials': decrypted_credentials,
'runtime_parameters': {}
}, agent_callback=agent_callback)
})
elif provider_type == 'api':
if tenant_id is None:
raise ValueError('tenant id is required for api provider')
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name)
# decrypt the credentials
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
@@ -201,7 +163,7 @@ class ToolManager:
if tenant_id is None:
raise ValueError('tenant id is required for model provider')
# get model provider
model_provider = ToolManager.get_model_provider(tenant_id, provider_name)
model_provider = cls.get_model_provider(tenant_id, provider_name)
# get tool
model_tool = model_provider.get_tool(tool_name)
@@ -215,59 +177,68 @@ class ToolManager:
else:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
@classmethod
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
"""
init runtime parameter
"""
parameter_value = parameters.get(parameter_rule.name)
if not parameter_value:
# get default value
parameter_value = parameter_rule.default
if not parameter_value and parameter_rule.required:
raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter_rule.options))
if parameter_value not in options:
raise ValueError(
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
# convert tool parameter config to correct type
try:
if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(parameter_value, int):
parameter_value = parameter_value
elif isinstance(parameter_value, float):
parameter_value = parameter_value
elif isinstance(parameter_value, str):
if '.' in parameter_value:
parameter_value = float(parameter_value)
else:
parameter_value = int(parameter_value)
elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN:
parameter_value = bool(parameter_value)
elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT,
ToolParameter.ToolParameterType.STRING]:
parameter_value = str(parameter_value)
elif parameter_rule.type == ToolParameter.ToolParameterType:
parameter_value = str(parameter_value)
except Exception as e:
raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type")
return parameter_value
@classmethod
def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> Tool:
"""
get the agent tool runtime
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name,
tool_entity = cls.get_tool_runtime(
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id,
tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
agent_callback=agent_callback
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# get tool parameter from form
tool_parameter_config = agent_tool.tool_parameters.get(parameter.name)
if not tool_parameter_config:
# get default value
tool_parameter_config = parameter.default
if not tool_parameter_config and parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
if parameter.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter.options))
if tool_parameter_config not in options:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
# convert tool parameter config to correct type
try:
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(tool_parameter_config, int):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, float):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, str):
if '.' in tool_parameter_config:
tool_parameter_config = float(tool_parameter_config)
else:
tool_parameter_config = int(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
tool_parameter_config = bool(tool_parameter_config)
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
tool_parameter_config = str(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType:
tool_parameter_config = str(tool_parameter_config)
except Exception as e:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
# save tool parameter to tool entity memory
runtime_parameters[parameter.name] = tool_parameter_config
value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
runtime_parameters[parameter.name] = value
# decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id,
@@ -280,8 +251,42 @@ class ToolManager:
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@staticmethod
def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
@classmethod
def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity):
"""
get the workflow tool runtime
"""
tool_entity = cls.get_tool_runtime(
provider_type=workflow_tool.provider_type,
provider_name=workflow_tool.provider_id,
tool_name=workflow_tool.tool_name,
tenant_id=tenant_id,
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
# save tool parameter to tool entity memory
if parameter.form == ToolParameter.ToolParameterForm.FORM:
value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
runtime_parameters[parameter.name] = value
# decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id,
tool_runtime=tool_entity,
provider_name=workflow_tool.provider_id,
provider_type=workflow_tool.provider_type,
)
if runtime_parameters:
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@classmethod
def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]:
"""
get the absolute path of the icon of the builtin provider
@@ -290,28 +295,39 @@ class ToolManager:
:return: the absolute path of the icon, the mime type of the icon
"""
# get provider
provider_controller = ToolManager.get_builtin_provider(provider)
provider_controller = cls.get_builtin_provider(provider)
absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', provider_controller.identity.icon)
absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets',
provider_controller.identity.icon)
# check if the icon exists
if not path.exists(absolute_path):
raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found')
# get the mime type
mime_type, _ = mimetypes.guess_type(absolute_path)
mime_type = mime_type or 'application/octet-stream'
return absolute_path, mime_type
@staticmethod
def list_builtin_providers() -> list[BuiltinToolProviderController]:
global _builtin_providers
@classmethod
def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
# use cache first
if len(_builtin_providers) > 0:
return list(_builtin_providers.values())
if cls._builtin_providers_loaded:
yield from list(cls._builtin_providers.values())
return
builtin_providers: list[BuiltinToolProviderController] = []
with cls._builtin_provider_lock:
if cls._builtin_providers_loaded:
yield from list(cls._builtin_providers.values())
return
yield from cls._list_builtin_providers()
@classmethod
def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
"""
list all the builtin providers
"""
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
if provider.startswith('__'):
continue
@@ -321,48 +337,61 @@ class ToolManager:
continue
# init provider
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController)
builtin_providers.append(provider_class())
try:
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController)
provider: BuiltinToolProviderController = provider_class()
cls._builtin_providers[provider.identity.name] = provider
for tool in provider.get_tools():
cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
yield provider
# cache the builtin providers
for provider in builtin_providers:
_builtin_providers[provider.identity.name] = provider
for tool in provider.get_tools():
_builtin_tools_labels[tool.identity.name] = tool.identity.label
except Exception as e:
logger.error(f'load builtin provider {provider} error: {e}')
continue
# set builtin providers loaded
cls._builtin_providers_loaded = True
return builtin_providers
@staticmethod
def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]:
"""
list all the model providers
@classmethod
def load_builtin_providers_cache(cls):
for _ in cls.list_builtin_providers():
pass
:return: the list of the model providers
"""
tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
# get configurations
model_configurations = ModelToolConfigurationManager.get_all_configuration()
# get all providers
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id).values()
# get model providers
model_providers: list[ModelToolProviderController] = []
for configuration in configurations:
# all the model tool should be configurated
if configuration.provider.provider not in model_configurations:
continue
if not ModelToolProviderController.is_configuration_valid(configuration):
continue
model_providers.append(ModelToolProviderController.from_db(configuration))
@classmethod
def clear_builtin_providers_cache(cls):
cls._builtin_providers = {}
cls._builtin_providers_loaded = False
return model_providers
@staticmethod
def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController:
# @classmethod
# def list_model_providers(cls, tenant_id: str = None) -> list[ModelToolProviderController]:
# """
# list all the model providers
# :return: the list of the model providers
# """
# tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
# # get configurations
# model_configurations = ModelToolConfigurationManager.get_all_configuration()
# # get all providers
# provider_manager = ProviderManager()
# configurations = provider_manager.get_configurations(tenant_id).values()
# # get model providers
# model_providers: list[ModelToolProviderController] = []
# for configuration in configurations:
# # all the model tool should be configurated
# if configuration.provider.provider not in model_configurations:
# continue
# if not ModelToolProviderController.is_configuration_valid(configuration):
# continue
# model_providers.append(ModelToolProviderController.from_db(configuration))
# return model_providers
@classmethod
def get_model_provider(cls, tenant_id: str, provider_name: str) -> ModelToolProviderController:
"""
get the model provider
@@ -376,11 +405,11 @@ class ToolManager:
configuration = configurations.get(provider_name)
if configuration is None:
raise ToolProviderNotFoundError(f'model provider {provider_name} not found')
return ModelToolProviderController.from_db(configuration)
@staticmethod
def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
@classmethod
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
"""
get the tool label
@@ -388,148 +417,71 @@ class ToolManager:
:return: the label of the tool
"""
global _builtin_tools_labels
if len(_builtin_tools_labels) == 0:
cls._builtin_tools_labels
if len(cls._builtin_tools_labels) == 0:
# init the builtin providers
ToolManager.list_builtin_providers()
cls.load_builtin_providers_cache()
if tool_name not in _builtin_tools_labels:
if tool_name not in cls._builtin_tools_labels:
return None
return _builtin_tools_labels[tool_name]
@staticmethod
def user_list_providers(
user_id: str,
tenant_id: str,
) -> list[UserToolProvider]:
return cls._builtin_tools_labels[tool_name]
@classmethod
def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
result_providers: dict[str, UserToolProvider] = {}
# get builtin providers
builtin_providers = ToolManager.list_builtin_providers()
# append builtin providers
for provider in builtin_providers:
result_providers[provider.identity.name] = UserToolProvider(
id=provider.identity.name,
author=provider.identity.author,
name=provider.identity.name,
description=I18nObject(
en_US=provider.identity.description.en_US,
zh_Hans=provider.identity.description.zh_Hans,
),
icon=provider.identity.icon,
label=I18nObject(
en_US=provider.identity.label.en_US,
zh_Hans=provider.identity.label.zh_Hans,
),
type=UserToolProvider.ProviderType.BUILTIN,
team_credentials={},
is_team_authorization=False,
)
# get credentials schema
schema = provider.get_credentials_schema()
for name, value in schema.items():
result_providers[provider.identity.name].team_credentials[name] = \
ToolProviderCredentials.CredentialsType.default(value.type)
# check if the provider need credentials
if not provider.need_credentials:
result_providers[provider.identity.name].is_team_authorization = True
result_providers[provider.identity.name].allow_delete = False
builtin_providers = cls.list_builtin_providers()
# get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
for db_builtin_provider in db_builtin_providers:
# add provider into providers
credentials = db_builtin_provider.credentials
provider_name = db_builtin_provider.provider
result_providers[provider_name].is_team_authorization = True
# package builtin tool provider controller
controller = ToolManager.get_builtin_provider(provider_name)
find_db_builtin_provider = lambda provider: next(
(x for x in db_builtin_providers if x.provider == provider),
None
)
# init tool configuration
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
result_providers[provider_name].team_credentials = masked_credentials
# get model tool providers
model_providers = ToolManager.list_model_providers(tenant_id=tenant_id)
# append model providers
for provider in model_providers:
result_providers[f'model_provider.{provider.identity.name}'] = UserToolProvider(
id=provider.identity.name,
author=provider.identity.author,
name=provider.identity.name,
description=I18nObject(
en_US=provider.identity.description.en_US,
zh_Hans=provider.identity.description.zh_Hans,
),
icon=provider.identity.icon,
label=I18nObject(
en_US=provider.identity.label.en_US,
zh_Hans=provider.identity.label.zh_Hans,
),
type=UserToolProvider.ProviderType.MODEL,
team_credentials={},
is_team_authorization=provider.is_active,
# append builtin providers
for provider in builtin_providers:
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
db_provider=find_db_builtin_provider(provider.identity.name),
decrypt_credentials=False
)
result_providers[provider.identity.name] = user_provider
# # get model tool providers
# model_providers = cls.list_model_providers(tenant_id=tenant_id)
# # append model providers
# for provider in model_providers:
# user_provider = ToolTransformService.model_provider_to_user_provider(
# db_provider=provider,
# )
# result_providers[f'model_provider.{provider.identity.name}'] = user_provider
# get db api providers
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
filter(ApiToolProvider.tenant_id == tenant_id).all()
for db_api_provider in db_api_providers:
username = 'Anonymous'
try:
username = db_api_provider.user.name
except Exception as e:
logger.error(f'failed to get user name for api provider {db_api_provider.id}: {str(e)}')
# add provider into providers
credentials = db_api_provider.credentials
provider_name = db_api_provider.name
result_providers[provider_name] = UserToolProvider(
id=db_api_provider.id,
author=username,
name=db_api_provider.name,
description=I18nObject(
en_US=db_api_provider.description,
zh_Hans=db_api_provider.description,
),
icon=db_api_provider.icon,
label=I18nObject(
en_US=db_api_provider.name,
zh_Hans=db_api_provider.name,
),
type=UserToolProvider.ProviderType.API,
team_credentials={},
is_team_authorization=True,
)
# package tool provider controller
controller = ApiBasedToolProviderController.from_db(
provider_controller = ToolTransformService.api_provider_to_controller(
db_provider=db_api_provider,
auth_type=ApiProviderAuthType.API_KEY if db_api_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
)
# init tool configuration
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
result_providers[provider_name].team_credentials = masked_credentials
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=db_api_provider,
decrypt_credentials=False
)
result_providers[db_api_provider.name] = user_provider
return BuiltinToolProviderSort.sort(list(result_providers.values()))
@staticmethod
def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[ApiBasedToolProviderController, dict[str, Any]]:
@classmethod
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
ApiBasedToolProviderController, dict[str, Any]]:
"""
get the api provider
@@ -544,16 +496,18 @@ class ToolManager:
if provider is None:
raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
controller = ApiBasedToolProviderController.from_db(
provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
provider,
ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else
ApiProviderAuthType.NONE
)
controller.load_bundled_tools(provider.tools)
return controller, provider.credentials
@staticmethod
def user_get_api_provider(provider: str, tenant_id: str) -> dict:
@classmethod
def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
"""
get api provider
"""
@@ -567,7 +521,7 @@ class ToolManager:
if provider is None:
raise ValueError(f'you have not added provider {provider}')
try:
credentials = json.loads(provider.credentials_str) or {}
except:
@@ -591,7 +545,7 @@ class ToolManager:
"content": "\ud83d\ude01"
}
return json.loads(serialize_base_model_dict({
return jsonable_encoder({
'schema_type': provider.schema_type,
'schema': provider.schema,
'tools': provider.tools,
@@ -599,4 +553,38 @@ class ToolManager:
'description': provider.description,
'credentials': masked_credentials,
'privacy_policy': provider.privacy_policy
}))
})
@classmethod
def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]:
"""
get the tool icon
:param tenant_id: the id of the tenant
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:return:
"""
provider_type = provider_type
provider_id = provider_id
if provider_type == 'builtin':
return (current_app.config.get("CONSOLE_API_URL")
+ "/console/api/workspaces/current/tool-provider/builtin/"
+ provider_id
+ "/icon")
elif provider_type == 'api':
try:
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.id == provider_id
)
return json.loads(provider.icon)
except:
return {
"background": "#252525",
"content": "\ud83d\ude01"
}
else:
raise ValueError(f"provider type {provider_type} not found")
ToolManager.load_builtin_providers_cache()

View File

@@ -1,4 +1,5 @@
import os
from copy import deepcopy
from typing import Any, Union
from pydantic import BaseModel
@@ -25,7 +26,7 @@ class ToolConfigurationManager(BaseModel):
"""
deep copy credentials
"""
return {key: value for key, value in credentials.items()}
return deepcopy(credentials)
def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
"""

View File

@@ -1,21 +0,0 @@
from pydantic import BaseModel
def serialize_base_model_array(l: list[BaseModel]) -> str:
class _BaseModel(BaseModel):
__root__: list[BaseModel]
"""
{"__root__": [BaseModel, BaseModel, ...]}
"""
return _BaseModel(__root__=l).json()
def serialize_base_model_dict(b: dict) -> str:
class _BaseModel(BaseModel):
__root__: dict
"""
{"__root__": {BaseModel}}
"""
return _BaseModel(__root__=b).json()

View File

@@ -0,0 +1,85 @@
import logging
from mimetypes import guess_extension
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__)
class ToolFileMessageTransformer:
@staticmethod
def transform_tool_invoke_messages(messages: list[ToolInvokeMessage],
user_id: str,
tenant_id: str,
conversation_id: str) -> list[ToolInvokeMessage]:
"""
Transform tool message and handle file download
"""
result = []
for message in messages:
if message.type == ToolInvokeMessage.MessageType.TEXT:
result.append(message)
elif message.type == ToolInvokeMessage.MessageType.LINK:
result.append(message)
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
# try to download image
try:
file = ToolFileManager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_url=message.message
)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
except Exception as e:
logger.exception(e)
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
meta=message.meta.copy() if message.meta is not None else {},
save_as=message.save_as,
))
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
mimetype = message.meta.get('mime_type', 'octet/stream')
# if message is str, encode it to bytes
if isinstance(message.message, str):
message.message = message.message.encode('utf-8')
file = ToolFileManager.create_file_by_raw(
user_id=user_id, tenant_id=tenant_id,
conversation_id=conversation_id,
file_binary=message.message,
mimetype=mimetype
)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
# check if file is image
if 'image' in mimetype:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
else:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
else:
result.append(message)
return result