feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,9 +1,14 @@
from typing import Optional
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolCredentialsOption,
ToolDescription,
ToolIdentity,
ToolProviderCredentials,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
@@ -64,21 +69,18 @@ class ApiToolProviderController(ToolProviderController):
pass
else:
raise ValueError(f"invalid auth type {auth_type}")
user_name = db_provider.user.name if db_provider.user_id else ""
user_name = db_provider.user.name if db_provider.user_id and db_provider.user is not None else ""
return ApiToolProviderController(
**{
"identity": {
"author": user_name,
"name": db_provider.name,
"label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
"icon": db_provider.icon,
},
"credentials_schema": credentials_schema,
"provider_id": db_provider.id or "",
}
identity=ToolProviderIdentity(
author=user_name,
name=db_provider.name,
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
icon=db_provider.icon,
),
credentials_schema=credentials_schema,
provider_id=db_provider.id or "",
tools=None,
)
@property
@@ -93,24 +95,22 @@ class ApiToolProviderController(ToolProviderController):
:return: the tool
"""
return ApiTool(
**{
"api_bundle": tool_bundle,
"identity": {
"author": tool_bundle.author,
"name": tool_bundle.operation_id,
"label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
"icon": self.identity.icon,
"provider": self.provider_id,
},
"description": {
"human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
"llm": tool_bundle.summary or "",
},
"parameters": tool_bundle.parameters or [],
}
api_bundle=tool_bundle,
identity=ToolIdentity(
author=tool_bundle.author,
name=tool_bundle.operation_id or "",
label=I18nObject(en_US=tool_bundle.operation_id, zh_Hans=tool_bundle.operation_id),
icon=self.identity.icon if self.identity else None,
provider=self.provider_id,
),
description=ToolDescription(
human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
llm=tool_bundle.summary or "",
),
parameters=tool_bundle.parameters or [],
)
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[Tool]:
"""
load bundled tools
@@ -121,7 +121,7 @@ class ApiToolProviderController(ToolProviderController):
return self.tools
def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
"""
fetch tools from database
@@ -131,6 +131,8 @@ class ApiToolProviderController(ToolProviderController):
"""
if self.tools is not None:
return self.tools
if self.identity is None:
return None
tools: list[Tool] = []
@@ -151,7 +153,7 @@ class ApiToolProviderController(ToolProviderController):
self.tools = tools
return tools
def get_tool(self, tool_name: str) -> ApiTool:
def get_tool(self, tool_name: str) -> Tool:
"""
get tool by name
@@ -161,7 +163,9 @@ class ApiToolProviderController(ToolProviderController):
if self.tools is None:
self.get_tools()
for tool in self.tools:
for tool in self.tools or []:
if tool.identity is None:
continue
if tool.identity.name == tool_name:
return tool

View File

@@ -1,9 +1,10 @@
import logging
from typing import Any
from typing import Any, Optional
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.tool import Tool
from extensions.ext_database import db
from models.model import App, AppModelConfig
@@ -20,10 +21,10 @@ class AppToolProviderEntity(ToolProviderController):
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
pass
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
pass
def get_tools(self, user_id: str) -> list[Tool]:
def get_tools(self, user_id: str = "", tenant_id: str = "") -> list[Tool]:
db_tools: list[PublishedAppTool] = (
db.session.query(PublishedAppTool)
.filter(
@@ -38,7 +39,7 @@ class AppToolProviderEntity(ToolProviderController):
tools: list[Tool] = []
for db_tool in db_tools:
tool = {
tool: dict[str, Any] = {
"identity": {
"author": db_tool.author,
"name": db_tool.tool_name,
@@ -52,7 +53,7 @@ class AppToolProviderEntity(ToolProviderController):
"parameters": [],
}
# get app from db
app: App = db_tool.app
app: Optional[App] = db_tool.app
if not app:
logger.error(f"app {db_tool.app_id} not found")
@@ -79,6 +80,7 @@ class AppToolProviderEntity(ToolProviderController):
type=ToolParameter.ToolParameterType.STRING,
required=required,
default=default,
placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
elif form_type == "select":
@@ -92,6 +94,7 @@ class AppToolProviderEntity(ToolProviderController):
type=ToolParameter.ToolParameterType.SELECT,
required=required,
default=default,
placeholder=I18nObject(en_US="", zh_Hans=""),
options=[
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in options
@@ -99,5 +102,5 @@ class AppToolProviderEntity(ToolProviderController):
)
)
tools.append(Tool(**tool))
tools.append(ApiTool(**tool))
return tools

View File

@@ -5,7 +5,7 @@ from core.tools.entities.api_entities import UserToolProvider
class BuiltinToolProviderSort:
_position = {}
_position: dict[str, int] = {}
@classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:

View File

@@ -4,7 +4,7 @@ from hmac import new as hmac_new
from json import loads as json_loads
from threading import Lock
from time import sleep, time
from typing import Any
from typing import Any, Union
from httpx import get, post
from requests import get as requests_get
@@ -21,23 +21,25 @@ class AIPPTGenerateToolAdapter:
"""
_api_base_url = URL("https://co.aippt.cn/api")
_api_token_cache = {}
_style_cache = {}
_api_token_cache: dict[str, dict[str, Union[str, float]]] = {}
_style_cache: dict[str, dict[str, Union[list[dict[str, Any]], float]]] = {}
_api_token_cache_lock = Lock()
_style_cache_lock = Lock()
_api_token_cache_lock: Lock = Lock()
_style_cache_lock: Lock = Lock()
_task = {}
_task: dict[str, Any] = {}
_task_type_map = {
"auto": 1,
"markdown": 7,
}
_tool: BuiltinTool
_tool: BuiltinTool | None
def __init__(self, tool: BuiltinTool = None):
def __init__(self, tool: BuiltinTool | None = None):
self._tool = tool
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invokes the AIPPT generate tool with the given user ID and tool parameters.
@@ -68,8 +70,8 @@ class AIPPTGenerateToolAdapter:
)
# get suit
color: str = tool_parameters.get("color")
style: str = tool_parameters.get("style")
color: str = tool_parameters.get("color", "")
style: str = tool_parameters.get("style", "")
if color == "__default__":
color_id = ""
@@ -226,7 +228,7 @@ class AIPPTGenerateToolAdapter:
return ""
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
def _generate_ppt(self, task_id: str, suit_id: int, user_id: str) -> tuple[str, str]:
"""
Generate a ppt
@@ -362,7 +364,9 @@ class AIPPTGenerateToolAdapter:
).decode("utf-8")
@classmethod
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
def _get_styles(
cls, credentials: dict[str, str], user_id: str
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""
Get styles
"""
@@ -415,7 +419,7 @@ class AIPPTGenerateToolAdapter:
return colors, styles
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
def get_styles(self, user_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""
Get styles
@@ -507,7 +511,9 @@ class AIPPTGenerateTool(BuiltinTool):
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters)
def get_runtime_parameters(self) -> list[ToolParameter]:

View File

@@ -1,7 +1,7 @@
import logging
from typing import Any, Optional
import arxiv
import arxiv # type: ignore
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage

View File

@@ -11,19 +11,21 @@ from services.model_provider_service import ModelProviderService
class TTSTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
provider, model = tool_parameters.get("model").split("#")
voice = tool_parameters.get(f"voice#{provider}#{model}")
provider, model = tool_parameters.get("model", "").split("#")
voice = tool_parameters.get(f"voice#{provider}#{model}", "")
model_manager = ModelManager()
if not self.runtime:
raise ValueError("Runtime is required")
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
provider=provider,
model_type=ModelType.TTS,
model=model,
)
tts = model_instance.invoke_tts(
content_text=tool_parameters.get("text"),
content_text=tool_parameters.get("text", ""),
user=user_id,
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
voice=voice,
)
buffer = io.BytesIO()
@@ -41,8 +43,11 @@ class TTSTool(BuiltinTool):
]
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
if not self.runtime:
raise ValueError("Runtime is required")
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
tid: str = self.runtime.tenant_id or ""
models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts")
items = []
for provider_model in models:
provider = provider_model.provider
@@ -62,6 +67,8 @@ class TTSTool(BuiltinTool):
ToolParameter(
name=f"voice#{provider}#{model}",
label=I18nObject(en_US=f"Voice of {model}({provider})"),
human_description=I18nObject(en_US=f"Select a voice for {model} model"),
placeholder=I18nObject(en_US="Select a voice"),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
options=[
@@ -83,6 +90,7 @@ class TTSTool(BuiltinTool):
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"),
options=options,
),
)

View File

@@ -2,8 +2,8 @@ import json
import logging
from typing import Any, Union
import boto3
from botocore.exceptions import BotoCoreError
import boto3 # type: ignore
from botocore.exceptions import BotoCoreError # type: ignore
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
import boto3
import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -2,7 +2,7 @@ import json
import logging
from typing import Any, Union
import boto3
import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -2,7 +2,7 @@ import json
import operator
from typing import Any, Union
import boto3
import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -10,8 +10,8 @@ from core.tools.tool.builtin_tool import BuiltinTool
class SageMakerReRankTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint: str = None
topk: int = None
sagemaker_endpoint: str | None = None
topk: int | None = None
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
inputs = [query_input] * len(docs)

View File

@@ -2,7 +2,7 @@ import json
from enum import Enum
from typing import Any, Optional, Union
import boto3
import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -17,7 +17,7 @@ class TTSModelType(Enum):
class SageMakerTTSTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint: str = None
sagemaker_endpoint: str | None = None
s3_client: Any = None
comprehend_client: Any = None

View File

@@ -1,6 +1,6 @@
from typing import Any, Union
from zhipuai import ZhipuAI
from zhipuai import ZhipuAI # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
from typing import Any, Union
import httpx
from zhipuai import ZhipuAI
from zhipuai import ZhipuAI # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import random
from typing import Any, Union
from zhipuai import ZhipuAI
from zhipuai import ZhipuAI # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -7,18 +7,22 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class SearchRecordsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
app_token = tool_parameters.get("app_token")
table_id = tool_parameters.get("table_id")
table_name = tool_parameters.get("table_name")
view_id = tool_parameters.get("view_id")
field_names = tool_parameters.get("field_names")
sort = tool_parameters.get("sort")
filters = tool_parameters.get("filter")
page_token = tool_parameters.get("page_token")
app_token = tool_parameters.get("app_token", "")
table_id = tool_parameters.get("table_id", "")
table_name = tool_parameters.get("table_name", "")
view_id = tool_parameters.get("view_id", "")
field_names = tool_parameters.get("field_names", "")
sort = tool_parameters.get("sort", "")
filters = tool_parameters.get("filter", "")
page_token = tool_parameters.get("page_token", "")
automatic_fields = tool_parameters.get("automatic_fields", False)
user_id_type = tool_parameters.get("user_id_type", "open_id")
page_size = tool_parameters.get("page_size", 20)

View File

@@ -7,14 +7,18 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class UpdateRecordsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
app_token = tool_parameters.get("app_token")
table_id = tool_parameters.get("table_id")
table_name = tool_parameters.get("table_name")
records = tool_parameters.get("records")
app_token = tool_parameters.get("app_token", "")
table_id = tool_parameters.get("table_id", "")
table_name = tool_parameters.get("table_name", "")
records = tool_parameters.get("records", "")
user_id_type = tool_parameters.get("user_id_type", "open_id")
res = client.update_records(app_token, table_id, table_name, records, user_id_type)

View File

@@ -7,12 +7,16 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class AddEventAttendeesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
event_id = tool_parameters.get("event_id")
attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email")
event_id = tool_parameters.get("event_id", "")
attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email", "")
need_notification = tool_parameters.get("need_notification", True)
res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification)

View File

@@ -7,11 +7,15 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class DeleteEventTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
event_id = tool_parameters.get("event_id")
event_id = tool_parameters.get("event_id", "")
need_notification = tool_parameters.get("need_notification", True)
res = client.delete_event(event_id, need_notification)

View File

@@ -7,8 +7,12 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class GetPrimaryCalendarTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
user_id_type = tool_parameters.get("user_id_type", "open_id")

View File

@@ -7,14 +7,18 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class ListEventsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
start_time = tool_parameters.get("start_time")
end_time = tool_parameters.get("end_time")
page_token = tool_parameters.get("page_token")
page_size = tool_parameters.get("page_size")
start_time = tool_parameters.get("start_time", "")
end_time = tool_parameters.get("end_time", "")
page_token = tool_parameters.get("page_token", "")
page_size = tool_parameters.get("page_size", 50)
res = client.list_events(start_time, end_time, page_token, page_size)

View File

@@ -7,16 +7,20 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class UpdateEventTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
event_id = tool_parameters.get("event_id")
summary = tool_parameters.get("summary")
description = tool_parameters.get("description")
event_id = tool_parameters.get("event_id", "")
summary = tool_parameters.get("summary", "")
description = tool_parameters.get("description", "")
need_notification = tool_parameters.get("need_notification", True)
start_time = tool_parameters.get("start_time")
end_time = tool_parameters.get("end_time")
start_time = tool_parameters.get("start_time", "")
end_time = tool_parameters.get("end_time", "")
auto_record = tool_parameters.get("auto_record", False)
res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record)

View File

@@ -7,13 +7,17 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class CreateDocumentTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
title = tool_parameters.get("title")
content = tool_parameters.get("content")
folder_token = tool_parameters.get("folder_token")
title = tool_parameters.get("title", "")
content = tool_parameters.get("content", "")
folder_token = tool_parameters.get("folder_token", "")
res = client.create_document(title, content, folder_token)
return self.create_json_message(res)

View File

@@ -7,11 +7,15 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class ListDocumentBlockTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
document_id = tool_parameters.get("document_id")
document_id = tool_parameters.get("document_id", "")
page_token = tool_parameters.get("page_token", "")
user_id_type = tool_parameters.get("user_id_type", "open_id")
page_size = tool_parameters.get("page_size", 500)

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
from jsonpath_ng import parse
from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
from jsonpath_ng import parse
from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
from jsonpath_ng import parse
from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
from jsonpath_ng import parse
from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import logging
from typing import Any, Union
import numexpr as ne
import numexpr as ne # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,4 +1,4 @@
from novita_client import (
from novita_client import ( # type: ignore
Txt2ImgV3Embedding,
Txt2ImgV3HiresFix,
Txt2ImgV3LoRA,

View File

@@ -2,7 +2,7 @@ from base64 import b64decode
from copy import deepcopy
from typing import Any, Union
from novita_client import (
from novita_client import ( # type: ignore
NovitaClient,
)

View File

@@ -2,7 +2,7 @@ from base64 import b64decode
from copy import deepcopy
from typing import Any, Union
from novita_client import (
from novita_client import ( # type: ignore
NovitaClient,
)

View File

@@ -13,7 +13,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from pydub import AudioSegment
from pydub import AudioSegment # type: ignore
class PodcastAudioGeneratorTool(BuiltinTool):

View File

@@ -2,10 +2,10 @@ import io
import logging
from typing import Any, Union
from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q
from qrcode.image.base import BaseImage
from qrcode.image.pure import PyPNGImage
from qrcode.main import QRCode
from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q # type: ignore
from qrcode.image.base import BaseImage # type: ignore
from qrcode.image.pure import PyPNGImage # type: ignore
from qrcode.main import QRCode # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
from typing import Any, Union
from urllib.parse import parse_qs, urlparse
from youtube_transcript_api import YouTubeTranscriptApi
from youtube_transcript_api import YouTubeTranscriptApi # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -37,7 +37,7 @@ class TwilioAPIWrapper(BaseModel):
def set_validator(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
try:
from twilio.rest import Client
from twilio.rest import Client # type: ignore
except ImportError:
raise ImportError("Could not import twilio python package. Please install it with `pip install twilio`.")
account_sid = values.get("account_sid")

View File

@@ -1,7 +1,7 @@
from typing import Any
from twilio.base.exceptions import TwilioRestException
from twilio.rest import Client
from twilio.base.exceptions import TwilioRestException # type: ignore
from twilio.rest import Client # type: ignore
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController

View File

@@ -1,6 +1,6 @@
from typing import Any, Union
from vanna.remote import VannaDefault
from vanna.remote import VannaDefault # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolProviderCredentialValidationError
@@ -14,6 +14,9 @@ class VannaTool(BuiltinTool):
"""
invoke tools
"""
# Ensure runtime and credentials
if not self.runtime or not self.runtime.credentials:
raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
api_key = self.runtime.credentials.get("api_key", None)
if not api_key:
raise ToolProviderCredentialValidationError("Please input api key")

View File

@@ -1,6 +1,6 @@
from typing import Any, Optional, Union
import wikipedia
import wikipedia # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -3,7 +3,7 @@ from typing import Any, Union
import pandas as pd
from requests.exceptions import HTTPError, ReadTimeout
from yfinance import download
from yfinance import download # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,6 +1,6 @@
from typing import Any, Union
import yfinance
import yfinance # type: ignore
from requests.exceptions import HTTPError, ReadTimeout
from core.tools.entities.tool_entities import ToolInvokeMessage

View File

@@ -1,7 +1,7 @@
from typing import Any, Union
from requests.exceptions import HTTPError, ReadTimeout
from yfinance import Ticker
from yfinance import Ticker # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Any, Union
from googleapiclient.discovery import build
from googleapiclient.discovery import build # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,6 +1,6 @@
from abc import abstractmethod
from os import listdir, path
from typing import Any
from typing import Any, Optional
from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
@@ -50,6 +50,8 @@ class BuiltinToolProviderController(ToolProviderController):
"""
if self.tools:
return self.tools
if not self.identity:
return []
provider = self.identity.name
tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools")
@@ -86,7 +88,7 @@ class BuiltinToolProviderController(ToolProviderController):
return self.credentials_schema.copy()
def get_tools(self) -> list[Tool]:
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
"""
returns a list of tools that the provider can provide
@@ -94,11 +96,14 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self._get_builtin_tools()
def get_tool(self, tool_name: str) -> Tool:
def get_tool(self, tool_name: str) -> Optional[Tool]:
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
tools = self.get_tools()
if tools is None:
raise ValueError("tools not found")
return next((t for t in tools if t.identity and t.identity.name == tool_name), None)
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
"""
@@ -107,10 +112,13 @@ class BuiltinToolProviderController(ToolProviderController):
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
tools = self.get_tools()
if tools is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool.parameters
return tool.parameters or []
@property
def need_credentials(self) -> bool:
@@ -144,6 +152,8 @@ class BuiltinToolProviderController(ToolProviderController):
"""
returns the labels of the provider
"""
if self.identity is None:
return []
return self.identity.tags or []
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
@@ -159,56 +169,56 @@ class BuiltinToolProviderController(ToolProviderController):
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate:
raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}")
for parameter_name in tool_parameters:
if parameter_name not in tool_parameters_need_to_validate:
raise ToolParameterValidationError(f"parameter {parameter_name} not found in tool {tool_name}")
# check type
parameter_schema = tool_parameters_need_to_validate[parameter]
parameter_schema = tool_parameters_need_to_validate[parameter_name]
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f"parameter {parameter} should be string")
if not isinstance(tool_parameters[parameter_name], str):
raise ToolParameterValidationError(f"parameter {parameter_name} should be string")
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], int | float):
raise ToolParameterValidationError(f"parameter {parameter} should be number")
if not isinstance(tool_parameters[parameter_name], int | float):
raise ToolParameterValidationError(f"parameter {parameter_name} should be number")
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
if parameter_schema.min is not None and tool_parameters[parameter_name] < parameter_schema.min:
raise ToolParameterValidationError(
f"parameter {parameter} should be greater than {parameter_schema.min}"
f"parameter {parameter_name} should be greater than {parameter_schema.min}"
)
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
if parameter_schema.max is not None and tool_parameters[parameter_name] > parameter_schema.max:
raise ToolParameterValidationError(
f"parameter {parameter} should be less than {parameter_schema.max}"
f"parameter {parameter_name} should be less than {parameter_schema.max}"
)
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool):
raise ToolParameterValidationError(f"parameter {parameter} should be boolean")
if not isinstance(tool_parameters[parameter_name], bool):
raise ToolParameterValidationError(f"parameter {parameter_name} should be boolean")
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f"parameter {parameter} should be string")
if not isinstance(tool_parameters[parameter_name], str):
raise ToolParameterValidationError(f"parameter {parameter_name} should be string")
options = parameter_schema.options
if not isinstance(options, list):
raise ToolParameterValidationError(f"parameter {parameter} options should be list")
raise ToolParameterValidationError(f"parameter {parameter_name} options should be list")
if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}")
if tool_parameters[parameter_name] not in [x.value for x in options]:
raise ToolParameterValidationError(f"parameter {parameter_name} should be one of {options}")
tool_parameters_need_to_validate.pop(parameter)
tool_parameters_need_to_validate.pop(parameter_name)
for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter]
for parameter_name in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter_name]
if parameter_schema.required:
raise ToolParameterValidationError(f"parameter {parameter} is required")
raise ToolParameterValidationError(f"parameter {parameter_name} is required")
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
default_value = parameter_schema.type.cast_value(parameter_schema.default)
tool_parameters[parameter] = default_value
tool_parameters[parameter_name] = default_value
def validate_credentials(self, credentials: dict[str, Any]) -> None:
"""

View File

@@ -24,10 +24,12 @@ class ToolProviderController(BaseModel, ABC):
:return: the credentials schema
"""
if self.credentials_schema is None:
return {}
return self.credentials_schema.copy()
@abstractmethod
def get_tools(self) -> list[Tool]:
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
"""
returns a list of tools that the provider can provide
@@ -36,7 +38,7 @@ class ToolProviderController(BaseModel, ABC):
pass
@abstractmethod
def get_tool(self, tool_name: str) -> Tool:
def get_tool(self, tool_name: str) -> Optional[Tool]:
"""
returns a tool that the provider can provide
@@ -51,10 +53,13 @@ class ToolProviderController(BaseModel, ABC):
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
tools = self.get_tools()
if tools is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool.parameters
return tool.parameters or []
@property
def provider_type(self) -> ToolProviderType:
@@ -78,55 +83,55 @@ class ToolProviderController(BaseModel, ABC):
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate:
raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}")
for tool_parameter in tool_parameters:
if tool_parameter not in tool_parameters_need_to_validate:
raise ToolParameterValidationError(f"parameter {tool_parameter} not found in tool {tool_name}")
# check type
parameter_schema = tool_parameters_need_to_validate[parameter]
parameter_schema = tool_parameters_need_to_validate[tool_parameter]
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f"parameter {parameter} should be string")
if not isinstance(tool_parameters[tool_parameter], str):
raise ToolParameterValidationError(f"parameter {tool_parameter} should be string")
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], int | float):
raise ToolParameterValidationError(f"parameter {parameter} should be number")
if not isinstance(tool_parameters[tool_parameter], int | float):
raise ToolParameterValidationError(f"parameter {tool_parameter} should be number")
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
if parameter_schema.min is not None and tool_parameters[tool_parameter] < parameter_schema.min:
raise ToolParameterValidationError(
f"parameter {parameter} should be greater than {parameter_schema.min}"
f"parameter {tool_parameter} should be greater than {parameter_schema.min}"
)
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
if parameter_schema.max is not None and tool_parameters[tool_parameter] > parameter_schema.max:
raise ToolParameterValidationError(
f"parameter {parameter} should be less than {parameter_schema.max}"
f"parameter {tool_parameter} should be less than {parameter_schema.max}"
)
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool):
raise ToolParameterValidationError(f"parameter {parameter} should be boolean")
if not isinstance(tool_parameters[tool_parameter], bool):
raise ToolParameterValidationError(f"parameter {tool_parameter} should be boolean")
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f"parameter {parameter} should be string")
if not isinstance(tool_parameters[tool_parameter], str):
raise ToolParameterValidationError(f"parameter {tool_parameter} should be string")
options = parameter_schema.options
if not isinstance(options, list):
raise ToolParameterValidationError(f"parameter {parameter} options should be list")
raise ToolParameterValidationError(f"parameter {tool_parameter} options should be list")
if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}")
if tool_parameters[tool_parameter] not in [x.value for x in options]:
raise ToolParameterValidationError(f"parameter {tool_parameter} should be one of {options}")
tool_parameters_need_to_validate.pop(parameter)
tool_parameters_need_to_validate.pop(tool_parameter)
for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter]
for tool_parameter_validate in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[tool_parameter_validate]
if parameter_schema.required:
raise ToolParameterValidationError(f"parameter {parameter} is required")
raise ToolParameterValidationError(f"parameter {tool_parameter_validate} is required")
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default)
tool_parameters[tool_parameter_validate] = parameter_schema.type.cast_value(parameter_schema.default)
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
@@ -144,6 +149,8 @@ class ToolProviderController(BaseModel, ABC):
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
if self.identity is None:
raise ValueError("identity is not set")
raise ToolProviderCredentialValidationError(
f"credential {credential_name} not found in provider {self.identity.name}"
)

View File

@@ -11,6 +11,7 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool
from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from extensions.ext_database import db
@@ -116,6 +117,7 @@ class WorkflowToolProviderController(ToolProviderController):
llm_description=parameter.description,
required=variable.required,
options=options,
placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
elif features.file_upload:
@@ -128,6 +130,7 @@ class WorkflowToolProviderController(ToolProviderController):
llm_description=parameter.description,
required=False,
form=parameter.form,
placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
else:
@@ -157,7 +160,7 @@ class WorkflowToolProviderController(ToolProviderController):
label=db_provider.label,
)
def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
"""
fetch tools from database
@@ -168,7 +171,7 @@ class WorkflowToolProviderController(ToolProviderController):
if self.tools is not None:
return self.tools
db_providers: WorkflowToolProvider = (
db_providers: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(
WorkflowToolProvider.tenant_id == tenant_id,
@@ -179,12 +182,14 @@ class WorkflowToolProviderController(ToolProviderController):
if not db_providers:
return []
if not db_providers.app:
raise ValueError("app not found")
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
return self.tools
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
def get_tool(self, tool_name: str) -> Optional[Tool]:
"""
get tool by name
@@ -195,6 +200,8 @@ class WorkflowToolProviderController(ToolProviderController):
return None
for tool in self.tools:
if tool.identity is None:
continue
if tool.identity.name == tool_name:
return tool