mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-23 17:55:46 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,9 +1,14 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolCredentialsOption,
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
@@ -64,21 +69,18 @@ class ApiToolProviderController(ToolProviderController):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"invalid auth type {auth_type}")
|
||||
|
||||
user_name = db_provider.user.name if db_provider.user_id else ""
|
||||
|
||||
user_name = db_provider.user.name if db_provider.user_id and db_provider.user is not None else ""
|
||||
return ApiToolProviderController(
|
||||
**{
|
||||
"identity": {
|
||||
"author": user_name,
|
||||
"name": db_provider.name,
|
||||
"label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
|
||||
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
|
||||
"icon": db_provider.icon,
|
||||
},
|
||||
"credentials_schema": credentials_schema,
|
||||
"provider_id": db_provider.id or "",
|
||||
}
|
||||
identity=ToolProviderIdentity(
|
||||
author=user_name,
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
credentials_schema=credentials_schema,
|
||||
provider_id=db_provider.id or "",
|
||||
tools=None,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -93,24 +95,22 @@ class ApiToolProviderController(ToolProviderController):
|
||||
:return: the tool
|
||||
"""
|
||||
return ApiTool(
|
||||
**{
|
||||
"api_bundle": tool_bundle,
|
||||
"identity": {
|
||||
"author": tool_bundle.author,
|
||||
"name": tool_bundle.operation_id,
|
||||
"label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
|
||||
"icon": self.identity.icon,
|
||||
"provider": self.provider_id,
|
||||
},
|
||||
"description": {
|
||||
"human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
|
||||
"llm": tool_bundle.summary or "",
|
||||
},
|
||||
"parameters": tool_bundle.parameters or [],
|
||||
}
|
||||
api_bundle=tool_bundle,
|
||||
identity=ToolIdentity(
|
||||
author=tool_bundle.author,
|
||||
name=tool_bundle.operation_id or "",
|
||||
label=I18nObject(en_US=tool_bundle.operation_id, zh_Hans=tool_bundle.operation_id),
|
||||
icon=self.identity.icon if self.identity else None,
|
||||
provider=self.provider_id,
|
||||
),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
|
||||
llm=tool_bundle.summary or "",
|
||||
),
|
||||
parameters=tool_bundle.parameters or [],
|
||||
)
|
||||
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[Tool]:
|
||||
"""
|
||||
load bundled tools
|
||||
|
||||
@@ -121,7 +121,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
|
||||
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
@@ -131,6 +131,8 @@ class ApiToolProviderController(ToolProviderController):
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
if self.identity is None:
|
||||
return None
|
||||
|
||||
tools: list[Tool] = []
|
||||
|
||||
@@ -151,7 +153,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> ApiTool:
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
@@ -161,7 +163,9 @@ class ApiToolProviderController(ToolProviderController):
|
||||
if self.tools is None:
|
||||
self.get_tools()
|
||||
|
||||
for tool in self.tools:
|
||||
for tool in self.tools or []:
|
||||
if tool.identity is None:
|
||||
continue
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
@@ -20,10 +21,10 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def get_tools(self, user_id: str) -> list[Tool]:
|
||||
def get_tools(self, user_id: str = "", tenant_id: str = "") -> list[Tool]:
|
||||
db_tools: list[PublishedAppTool] = (
|
||||
db.session.query(PublishedAppTool)
|
||||
.filter(
|
||||
@@ -38,7 +39,7 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
tools: list[Tool] = []
|
||||
|
||||
for db_tool in db_tools:
|
||||
tool = {
|
||||
tool: dict[str, Any] = {
|
||||
"identity": {
|
||||
"author": db_tool.author,
|
||||
"name": db_tool.tool_name,
|
||||
@@ -52,7 +53,7 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
"parameters": [],
|
||||
}
|
||||
# get app from db
|
||||
app: App = db_tool.app
|
||||
app: Optional[App] = db_tool.app
|
||||
|
||||
if not app:
|
||||
logger.error(f"app {db_tool.app_id} not found")
|
||||
@@ -79,6 +80,7 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=required,
|
||||
default=default,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
)
|
||||
elif form_type == "select":
|
||||
@@ -92,6 +94,7 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
required=required,
|
||||
default=default,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
options=[
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
@@ -99,5 +102,5 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
)
|
||||
)
|
||||
|
||||
tools.append(Tool(**tool))
|
||||
tools.append(ApiTool(**tool))
|
||||
return tools
|
||||
|
||||
@@ -5,7 +5,7 @@ from core.tools.entities.api_entities import UserToolProvider
|
||||
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
_position = {}
|
||||
_position: dict[str, int] = {}
|
||||
|
||||
@classmethod
|
||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||
|
||||
@@ -4,7 +4,7 @@ from hmac import new as hmac_new
|
||||
from json import loads as json_loads
|
||||
from threading import Lock
|
||||
from time import sleep, time
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
from httpx import get, post
|
||||
from requests import get as requests_get
|
||||
@@ -21,23 +21,25 @@ class AIPPTGenerateToolAdapter:
|
||||
"""
|
||||
|
||||
_api_base_url = URL("https://co.aippt.cn/api")
|
||||
_api_token_cache = {}
|
||||
_style_cache = {}
|
||||
_api_token_cache: dict[str, dict[str, Union[str, float]]] = {}
|
||||
_style_cache: dict[str, dict[str, Union[list[dict[str, Any]], float]]] = {}
|
||||
|
||||
_api_token_cache_lock = Lock()
|
||||
_style_cache_lock = Lock()
|
||||
_api_token_cache_lock: Lock = Lock()
|
||||
_style_cache_lock: Lock = Lock()
|
||||
|
||||
_task = {}
|
||||
_task: dict[str, Any] = {}
|
||||
_task_type_map = {
|
||||
"auto": 1,
|
||||
"markdown": 7,
|
||||
}
|
||||
_tool: BuiltinTool
|
||||
_tool: BuiltinTool | None
|
||||
|
||||
def __init__(self, tool: BuiltinTool = None):
|
||||
def __init__(self, tool: BuiltinTool | None = None):
|
||||
self._tool = tool
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invokes the AIPPT generate tool with the given user ID and tool parameters.
|
||||
|
||||
@@ -68,8 +70,8 @@ class AIPPTGenerateToolAdapter:
|
||||
)
|
||||
|
||||
# get suit
|
||||
color: str = tool_parameters.get("color")
|
||||
style: str = tool_parameters.get("style")
|
||||
color: str = tool_parameters.get("color", "")
|
||||
style: str = tool_parameters.get("style", "")
|
||||
|
||||
if color == "__default__":
|
||||
color_id = ""
|
||||
@@ -226,7 +228,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
return ""
|
||||
|
||||
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
|
||||
def _generate_ppt(self, task_id: str, suit_id: int, user_id: str) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a ppt
|
||||
|
||||
@@ -362,7 +364,9 @@ class AIPPTGenerateToolAdapter:
|
||||
).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
def _get_styles(
|
||||
cls, credentials: dict[str, str], user_id: str
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""
|
||||
Get styles
|
||||
"""
|
||||
@@ -415,7 +419,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
return colors, styles
|
||||
|
||||
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
def get_styles(self, user_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""
|
||||
Get styles
|
||||
|
||||
@@ -507,7 +511,9 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters)
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import arxiv
|
||||
import arxiv # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -11,19 +11,21 @@ from services.model_provider_service import ModelProviderService
|
||||
|
||||
class TTSTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
provider, model = tool_parameters.get("model").split("#")
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}")
|
||||
provider, model = tool_parameters.get("model", "").split("#")
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}", "")
|
||||
model_manager = ModelManager()
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
provider=provider,
|
||||
model_type=ModelType.TTS,
|
||||
model=model,
|
||||
)
|
||||
tts = model_instance.invoke_tts(
|
||||
content_text=tool_parameters.get("text"),
|
||||
content_text=tool_parameters.get("text", ""),
|
||||
user=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
voice=voice,
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
@@ -41,8 +43,11 @@ class TTSTool(BuiltinTool):
|
||||
]
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
|
||||
tid: str = self.runtime.tenant_id or ""
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts")
|
||||
items = []
|
||||
for provider_model in models:
|
||||
provider = provider_model.provider
|
||||
@@ -62,6 +67,8 @@ class TTSTool(BuiltinTool):
|
||||
ToolParameter(
|
||||
name=f"voice#{provider}#{model}",
|
||||
label=I18nObject(en_US=f"Voice of {model}({provider})"),
|
||||
human_description=I18nObject(en_US=f"Select a voice for {model} model"),
|
||||
placeholder=I18nObject(en_US="Select a voice"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
options=[
|
||||
@@ -83,6 +90,7 @@ class TTSTool(BuiltinTool):
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=True,
|
||||
placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"),
|
||||
options=options,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@ import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError
|
||||
import boto3 # type: ignore
|
||||
from botocore.exceptions import BotoCoreError # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import operator
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
@@ -10,8 +10,8 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
class SageMakerReRankTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint: str = None
|
||||
topk: int = None
|
||||
sagemaker_endpoint: str | None = None
|
||||
topk: int | None = None
|
||||
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
|
||||
inputs = [query_input] * len(docs)
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
@@ -17,7 +17,7 @@ class TTSModelType(Enum):
|
||||
|
||||
class SageMakerTTSTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint: str = None
|
||||
sagemaker_endpoint: str | None = None
|
||||
s3_client: Any = None
|
||||
comprehend_client: Any = None
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai import ZhipuAI # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import httpx
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai import ZhipuAI # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import random
|
||||
from typing import Any, Union
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai import ZhipuAI # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -7,18 +7,22 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class SearchRecordsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
app_token = tool_parameters.get("app_token")
|
||||
table_id = tool_parameters.get("table_id")
|
||||
table_name = tool_parameters.get("table_name")
|
||||
view_id = tool_parameters.get("view_id")
|
||||
field_names = tool_parameters.get("field_names")
|
||||
sort = tool_parameters.get("sort")
|
||||
filters = tool_parameters.get("filter")
|
||||
page_token = tool_parameters.get("page_token")
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
table_id = tool_parameters.get("table_id", "")
|
||||
table_name = tool_parameters.get("table_name", "")
|
||||
view_id = tool_parameters.get("view_id", "")
|
||||
field_names = tool_parameters.get("field_names", "")
|
||||
sort = tool_parameters.get("sort", "")
|
||||
filters = tool_parameters.get("filter", "")
|
||||
page_token = tool_parameters.get("page_token", "")
|
||||
automatic_fields = tool_parameters.get("automatic_fields", False)
|
||||
user_id_type = tool_parameters.get("user_id_type", "open_id")
|
||||
page_size = tool_parameters.get("page_size", 20)
|
||||
|
||||
@@ -7,14 +7,18 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class UpdateRecordsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
app_token = tool_parameters.get("app_token")
|
||||
table_id = tool_parameters.get("table_id")
|
||||
table_name = tool_parameters.get("table_name")
|
||||
records = tool_parameters.get("records")
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
table_id = tool_parameters.get("table_id", "")
|
||||
table_name = tool_parameters.get("table_name", "")
|
||||
records = tool_parameters.get("records", "")
|
||||
user_id_type = tool_parameters.get("user_id_type", "open_id")
|
||||
|
||||
res = client.update_records(app_token, table_id, table_name, records, user_id_type)
|
||||
|
||||
@@ -7,12 +7,16 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class AddEventAttendeesTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
event_id = tool_parameters.get("event_id")
|
||||
attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email")
|
||||
event_id = tool_parameters.get("event_id", "")
|
||||
attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email", "")
|
||||
need_notification = tool_parameters.get("need_notification", True)
|
||||
|
||||
res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification)
|
||||
|
||||
@@ -7,11 +7,15 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class DeleteEventTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
event_id = tool_parameters.get("event_id")
|
||||
event_id = tool_parameters.get("event_id", "")
|
||||
need_notification = tool_parameters.get("need_notification", True)
|
||||
|
||||
res = client.delete_event(event_id, need_notification)
|
||||
|
||||
@@ -7,8 +7,12 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class GetPrimaryCalendarTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
user_id_type = tool_parameters.get("user_id_type", "open_id")
|
||||
|
||||
@@ -7,14 +7,18 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class ListEventsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
start_time = tool_parameters.get("start_time")
|
||||
end_time = tool_parameters.get("end_time")
|
||||
page_token = tool_parameters.get("page_token")
|
||||
page_size = tool_parameters.get("page_size")
|
||||
start_time = tool_parameters.get("start_time", "")
|
||||
end_time = tool_parameters.get("end_time", "")
|
||||
page_token = tool_parameters.get("page_token", "")
|
||||
page_size = tool_parameters.get("page_size", 50)
|
||||
|
||||
res = client.list_events(start_time, end_time, page_token, page_size)
|
||||
|
||||
|
||||
@@ -7,16 +7,20 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class UpdateEventTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
event_id = tool_parameters.get("event_id")
|
||||
summary = tool_parameters.get("summary")
|
||||
description = tool_parameters.get("description")
|
||||
event_id = tool_parameters.get("event_id", "")
|
||||
summary = tool_parameters.get("summary", "")
|
||||
description = tool_parameters.get("description", "")
|
||||
need_notification = tool_parameters.get("need_notification", True)
|
||||
start_time = tool_parameters.get("start_time")
|
||||
end_time = tool_parameters.get("end_time")
|
||||
start_time = tool_parameters.get("start_time", "")
|
||||
end_time = tool_parameters.get("end_time", "")
|
||||
auto_record = tool_parameters.get("auto_record", False)
|
||||
|
||||
res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record)
|
||||
|
||||
@@ -7,13 +7,17 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class CreateDocumentTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
title = tool_parameters.get("title")
|
||||
content = tool_parameters.get("content")
|
||||
folder_token = tool_parameters.get("folder_token")
|
||||
title = tool_parameters.get("title", "")
|
||||
content = tool_parameters.get("content", "")
|
||||
folder_token = tool_parameters.get("folder_token", "")
|
||||
|
||||
res = client.create_document(title, content, folder_token)
|
||||
return self.create_json_message(res)
|
||||
|
||||
@@ -7,11 +7,15 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class ListDocumentBlockTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
document_id = tool_parameters.get("document_id")
|
||||
document_id = tool_parameters.get("document_id", "")
|
||||
page_token = tool_parameters.get("page_token", "")
|
||||
user_id_type = tool_parameters.get("user_id_type", "open_id")
|
||||
page_size = tool_parameters.get("page_size", 500)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from jsonpath_ng import parse
|
||||
from jsonpath_ng import parse # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from jsonpath_ng import parse
|
||||
from jsonpath_ng import parse # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from jsonpath_ng import parse
|
||||
from jsonpath_ng import parse # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from jsonpath_ng import parse
|
||||
from jsonpath_ng import parse # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import numexpr as ne
|
||||
import numexpr as ne # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from novita_client import (
|
||||
from novita_client import ( # type: ignore
|
||||
Txt2ImgV3Embedding,
|
||||
Txt2ImgV3HiresFix,
|
||||
Txt2ImgV3LoRA,
|
||||
|
||||
@@ -2,7 +2,7 @@ from base64 import b64decode
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from novita_client import (
|
||||
from novita_client import ( # type: ignore
|
||||
NovitaClient,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from base64 import b64decode
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from novita_client import (
|
||||
from novita_client import ( # type: ignore
|
||||
NovitaClient,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
from pydub import AudioSegment
|
||||
from pydub import AudioSegment # type: ignore
|
||||
|
||||
|
||||
class PodcastAudioGeneratorTool(BuiltinTool):
|
||||
|
||||
@@ -2,10 +2,10 @@ import io
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q
|
||||
from qrcode.image.base import BaseImage
|
||||
from qrcode.image.pure import PyPNGImage
|
||||
from qrcode.main import QRCode
|
||||
from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q # type: ignore
|
||||
from qrcode.image.base import BaseImage # type: ignore
|
||||
from qrcode.image.pure import PyPNGImage # type: ignore
|
||||
from qrcode.main import QRCode # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
from youtube_transcript_api import YouTubeTranscriptApi # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -37,7 +37,7 @@ class TwilioAPIWrapper(BaseModel):
|
||||
def set_validator(cls, values: dict) -> dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
from twilio.rest import Client
|
||||
from twilio.rest import Client # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError("Could not import twilio python package. Please install it with `pip install twilio`.")
|
||||
account_sid = values.get("account_sid")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from twilio.base.exceptions import TwilioRestException
|
||||
from twilio.rest import Client
|
||||
from twilio.base.exceptions import TwilioRestException # type: ignore
|
||||
from twilio.rest import Client # type: ignore
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from vanna.remote import VannaDefault
|
||||
from vanna.remote import VannaDefault # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
@@ -14,6 +14,9 @@ class VannaTool(BuiltinTool):
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
# Ensure runtime and credentials
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
|
||||
api_key = self.runtime.credentials.get("api_key", None)
|
||||
if not api_key:
|
||||
raise ToolProviderCredentialValidationError("Please input api key")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import wikipedia
|
||||
import wikipedia # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Union
|
||||
|
||||
import pandas as pd
|
||||
from requests.exceptions import HTTPError, ReadTimeout
|
||||
from yfinance import download
|
||||
from yfinance import download # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import yfinance
|
||||
import yfinance # type: ignore
|
||||
from requests.exceptions import HTTPError, ReadTimeout
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from requests.exceptions import HTTPError, ReadTimeout
|
||||
from yfinance import Ticker
|
||||
from yfinance import Ticker # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
|
||||
@@ -50,6 +50,8 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
if self.tools:
|
||||
return self.tools
|
||||
if not self.identity:
|
||||
return []
|
||||
|
||||
provider = self.identity.name
|
||||
tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools")
|
||||
@@ -86,7 +88,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
|
||||
return self.credentials_schema.copy()
|
||||
|
||||
def get_tools(self) -> list[Tool]:
|
||||
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
@@ -94,11 +96,14 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self._get_builtin_tools()
|
||||
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
def get_tool(self, tool_name: str) -> Optional[Tool]:
|
||||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
tools = self.get_tools()
|
||||
if tools is None:
|
||||
raise ValueError("tools not found")
|
||||
return next((t for t in tools if t.identity and t.identity.name == tool_name), None)
|
||||
|
||||
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
|
||||
"""
|
||||
@@ -107,10 +112,13 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:return: list of parameters
|
||||
"""
|
||||
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
tools = self.get_tools()
|
||||
if tools is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
return tool.parameters
|
||||
return tool.parameters or []
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
@@ -144,6 +152,8 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
returns the labels of the provider
|
||||
"""
|
||||
if self.identity is None:
|
||||
return []
|
||||
return self.identity.tags or []
|
||||
|
||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
@@ -159,56 +169,56 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
for parameter in tool_parameters_schema:
|
||||
tool_parameters_need_to_validate[parameter.name] = parameter
|
||||
|
||||
for parameter in tool_parameters:
|
||||
if parameter not in tool_parameters_need_to_validate:
|
||||
raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}")
|
||||
for parameter_name in tool_parameters:
|
||||
if parameter_name not in tool_parameters_need_to_validate:
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} not found in tool {tool_name}")
|
||||
|
||||
# check type
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter_name]
|
||||
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be string")
|
||||
if not isinstance(tool_parameters[parameter_name], str):
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} should be string")
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
if not isinstance(tool_parameters[parameter], int | float):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be number")
|
||||
if not isinstance(tool_parameters[parameter_name], int | float):
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} should be number")
|
||||
|
||||
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
|
||||
if parameter_schema.min is not None and tool_parameters[parameter_name] < parameter_schema.min:
|
||||
raise ToolParameterValidationError(
|
||||
f"parameter {parameter} should be greater than {parameter_schema.min}"
|
||||
f"parameter {parameter_name} should be greater than {parameter_schema.min}"
|
||||
)
|
||||
|
||||
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
|
||||
if parameter_schema.max is not None and tool_parameters[parameter_name] > parameter_schema.max:
|
||||
raise ToolParameterValidationError(
|
||||
f"parameter {parameter} should be less than {parameter_schema.max}"
|
||||
f"parameter {parameter_name} should be less than {parameter_schema.max}"
|
||||
)
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
if not isinstance(tool_parameters[parameter], bool):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be boolean")
|
||||
if not isinstance(tool_parameters[parameter_name], bool):
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} should be boolean")
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be string")
|
||||
if not isinstance(tool_parameters[parameter_name], str):
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} should be string")
|
||||
|
||||
options = parameter_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} options should be list")
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} options should be list")
|
||||
|
||||
if tool_parameters[parameter] not in [x.value for x in options]:
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}")
|
||||
if tool_parameters[parameter_name] not in [x.value for x in options]:
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} should be one of {options}")
|
||||
|
||||
tool_parameters_need_to_validate.pop(parameter)
|
||||
tool_parameters_need_to_validate.pop(parameter_name)
|
||||
|
||||
for parameter in tool_parameters_need_to_validate:
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
for parameter_name in tool_parameters_need_to_validate:
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter_name]
|
||||
if parameter_schema.required:
|
||||
raise ToolParameterValidationError(f"parameter {parameter} is required")
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} is required")
|
||||
|
||||
# the parameter is not set currently, set the default value if needed
|
||||
if parameter_schema.default is not None:
|
||||
default_value = parameter_schema.type.cast_value(parameter_schema.default)
|
||||
tool_parameters[parameter] = default_value
|
||||
tool_parameters[parameter_name] = default_value
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
|
||||
@@ -24,10 +24,12 @@ class ToolProviderController(BaseModel, ABC):
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
if self.credentials_schema is None:
|
||||
return {}
|
||||
return self.credentials_schema.copy()
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> list[Tool]:
|
||||
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
@@ -36,7 +38,7 @@ class ToolProviderController(BaseModel, ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
def get_tool(self, tool_name: str) -> Optional[Tool]:
|
||||
"""
|
||||
returns a tool that the provider can provide
|
||||
|
||||
@@ -51,10 +53,13 @@ class ToolProviderController(BaseModel, ABC):
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:return: list of parameters
|
||||
"""
|
||||
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
tools = self.get_tools()
|
||||
if tools is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
return tool.parameters
|
||||
return tool.parameters or []
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
@@ -78,55 +83,55 @@ class ToolProviderController(BaseModel, ABC):
|
||||
for parameter in tool_parameters_schema:
|
||||
tool_parameters_need_to_validate[parameter.name] = parameter
|
||||
|
||||
for parameter in tool_parameters:
|
||||
if parameter not in tool_parameters_need_to_validate:
|
||||
raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}")
|
||||
for tool_parameter in tool_parameters:
|
||||
if tool_parameter not in tool_parameters_need_to_validate:
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} not found in tool {tool_name}")
|
||||
|
||||
# check type
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
parameter_schema = tool_parameters_need_to_validate[tool_parameter]
|
||||
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be string")
|
||||
if not isinstance(tool_parameters[tool_parameter], str):
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} should be string")
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
if not isinstance(tool_parameters[parameter], int | float):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be number")
|
||||
if not isinstance(tool_parameters[tool_parameter], int | float):
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} should be number")
|
||||
|
||||
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
|
||||
if parameter_schema.min is not None and tool_parameters[tool_parameter] < parameter_schema.min:
|
||||
raise ToolParameterValidationError(
|
||||
f"parameter {parameter} should be greater than {parameter_schema.min}"
|
||||
f"parameter {tool_parameter} should be greater than {parameter_schema.min}"
|
||||
)
|
||||
|
||||
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
|
||||
if parameter_schema.max is not None and tool_parameters[tool_parameter] > parameter_schema.max:
|
||||
raise ToolParameterValidationError(
|
||||
f"parameter {parameter} should be less than {parameter_schema.max}"
|
||||
f"parameter {tool_parameter} should be less than {parameter_schema.max}"
|
||||
)
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
if not isinstance(tool_parameters[parameter], bool):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be boolean")
|
||||
if not isinstance(tool_parameters[tool_parameter], bool):
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} should be boolean")
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be string")
|
||||
if not isinstance(tool_parameters[tool_parameter], str):
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} should be string")
|
||||
|
||||
options = parameter_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} options should be list")
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} options should be list")
|
||||
|
||||
if tool_parameters[parameter] not in [x.value for x in options]:
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}")
|
||||
if tool_parameters[tool_parameter] not in [x.value for x in options]:
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} should be one of {options}")
|
||||
|
||||
tool_parameters_need_to_validate.pop(parameter)
|
||||
tool_parameters_need_to_validate.pop(tool_parameter)
|
||||
|
||||
for parameter in tool_parameters_need_to_validate:
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
for tool_parameter_validate in tool_parameters_need_to_validate:
|
||||
parameter_schema = tool_parameters_need_to_validate[tool_parameter_validate]
|
||||
if parameter_schema.required:
|
||||
raise ToolParameterValidationError(f"parameter {parameter} is required")
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter_validate} is required")
|
||||
|
||||
# the parameter is not set currently, set the default value if needed
|
||||
if parameter_schema.default is not None:
|
||||
tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default)
|
||||
tool_parameters[tool_parameter_validate] = parameter_schema.type.cast_value(parameter_schema.default)
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
@@ -144,6 +149,8 @@ class ToolProviderController(BaseModel, ABC):
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
if self.identity is None:
|
||||
raise ValueError("identity is not set")
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.identity.name}"
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from core.tools.entities.tool_entities import (
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from extensions.ext_database import db
|
||||
@@ -116,6 +117,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
llm_description=parameter.description,
|
||||
required=variable.required,
|
||||
options=options,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
)
|
||||
elif features.file_upload:
|
||||
@@ -128,6 +130,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
llm_description=parameter.description,
|
||||
required=False,
|
||||
form=parameter.form,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -157,7 +160,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
label=db_provider.label,
|
||||
)
|
||||
|
||||
def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
|
||||
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
@@ -168,7 +171,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
db_providers: WorkflowToolProvider = (
|
||||
db_providers: Optional[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
@@ -179,12 +182,14 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
if not db_providers:
|
||||
return []
|
||||
if not db_providers.app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
|
||||
def get_tool(self, tool_name: str) -> Optional[Tool]:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
@@ -195,6 +200,8 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
return None
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.identity is None:
|
||||
continue
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
|
||||
Reference in New Issue
Block a user