feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -5,7 +5,7 @@ from core.tools.entities.api_entities import UserToolProvider
class BuiltinToolProviderSort:
_position = {}
_position: dict[str, int] = {}
@classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:

View File

@@ -4,7 +4,7 @@ from hmac import new as hmac_new
from json import loads as json_loads
from threading import Lock
from time import sleep, time
from typing import Any
from typing import Any, Union
from httpx import get, post
from requests import get as requests_get
@@ -21,23 +21,25 @@ class AIPPTGenerateToolAdapter:
"""
_api_base_url = URL("https://co.aippt.cn/api")
_api_token_cache = {}
_style_cache = {}
_api_token_cache: dict[str, dict[str, Union[str, float]]] = {}
_style_cache: dict[str, dict[str, Union[list[dict[str, Any]], float]]] = {}
_api_token_cache_lock = Lock()
_style_cache_lock = Lock()
_api_token_cache_lock: Lock = Lock()
_style_cache_lock: Lock = Lock()
_task = {}
_task: dict[str, Any] = {}
_task_type_map = {
"auto": 1,
"markdown": 7,
}
_tool: BuiltinTool
_tool: BuiltinTool | None
def __init__(self, tool: BuiltinTool = None):
def __init__(self, tool: BuiltinTool | None = None):
self._tool = tool
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invokes the AIPPT generate tool with the given user ID and tool parameters.
@@ -68,8 +70,8 @@ class AIPPTGenerateToolAdapter:
)
# get suit
color: str = tool_parameters.get("color")
style: str = tool_parameters.get("style")
color: str = tool_parameters.get("color", "")
style: str = tool_parameters.get("style", "")
if color == "__default__":
color_id = ""
@@ -226,7 +228,7 @@ class AIPPTGenerateToolAdapter:
return ""
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
def _generate_ppt(self, task_id: str, suit_id: int, user_id: str) -> tuple[str, str]:
"""
Generate a ppt
@@ -362,7 +364,9 @@ class AIPPTGenerateToolAdapter:
).decode("utf-8")
@classmethod
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
def _get_styles(
cls, credentials: dict[str, str], user_id: str
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""
Get styles
"""
@@ -415,7 +419,7 @@ class AIPPTGenerateToolAdapter:
return colors, styles
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
def get_styles(self, user_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""
Get styles
@@ -507,7 +511,9 @@ class AIPPTGenerateTool(BuiltinTool):
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters)
def get_runtime_parameters(self) -> list[ToolParameter]:

View File

@@ -1,7 +1,7 @@
import logging
from typing import Any, Optional
import arxiv
import arxiv # type: ignore
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage

View File

@@ -11,19 +11,21 @@ from services.model_provider_service import ModelProviderService
class TTSTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
provider, model = tool_parameters.get("model").split("#")
voice = tool_parameters.get(f"voice#{provider}#{model}")
provider, model = tool_parameters.get("model", "").split("#")
voice = tool_parameters.get(f"voice#{provider}#{model}", "")
model_manager = ModelManager()
if not self.runtime:
raise ValueError("Runtime is required")
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
provider=provider,
model_type=ModelType.TTS,
model=model,
)
tts = model_instance.invoke_tts(
content_text=tool_parameters.get("text"),
content_text=tool_parameters.get("text", ""),
user=user_id,
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
voice=voice,
)
buffer = io.BytesIO()
@@ -41,8 +43,11 @@ class TTSTool(BuiltinTool):
]
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
if not self.runtime:
raise ValueError("Runtime is required")
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
tid: str = self.runtime.tenant_id or ""
models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts")
items = []
for provider_model in models:
provider = provider_model.provider
@@ -62,6 +67,8 @@ class TTSTool(BuiltinTool):
ToolParameter(
name=f"voice#{provider}#{model}",
label=I18nObject(en_US=f"Voice of {model}({provider})"),
human_description=I18nObject(en_US=f"Select a voice for {model} model"),
placeholder=I18nObject(en_US="Select a voice"),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
options=[
@@ -83,6 +90,7 @@ class TTSTool(BuiltinTool):
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"),
options=options,
),
)

View File

@@ -2,8 +2,8 @@ import json
import logging
from typing import Any, Union
import boto3
from botocore.exceptions import BotoCoreError
import boto3 # type: ignore
from botocore.exceptions import BotoCoreError # type: ignore
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
import boto3
import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -2,7 +2,7 @@ import json
import logging
from typing import Any, Union
import boto3
import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -2,7 +2,7 @@ import json
import operator
from typing import Any, Union
import boto3
import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -10,8 +10,8 @@ from core.tools.tool.builtin_tool import BuiltinTool
class SageMakerReRankTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint: str = None
topk: int = None
sagemaker_endpoint: str | None = None
topk: int | None = None
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
inputs = [query_input] * len(docs)

View File

@@ -2,7 +2,7 @@ import json
from enum import Enum
from typing import Any, Optional, Union
import boto3
import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -17,7 +17,7 @@ class TTSModelType(Enum):
class SageMakerTTSTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint: str = None
sagemaker_endpoint: str | None = None
s3_client: Any = None
comprehend_client: Any = None

View File

@@ -1,6 +1,6 @@
from typing import Any, Union
from zhipuai import ZhipuAI
from zhipuai import ZhipuAI # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
from typing import Any, Union
import httpx
from zhipuai import ZhipuAI
from zhipuai import ZhipuAI # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import random
from typing import Any, Union
from zhipuai import ZhipuAI
from zhipuai import ZhipuAI # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -7,18 +7,22 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class SearchRecordsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
app_token = tool_parameters.get("app_token")
table_id = tool_parameters.get("table_id")
table_name = tool_parameters.get("table_name")
view_id = tool_parameters.get("view_id")
field_names = tool_parameters.get("field_names")
sort = tool_parameters.get("sort")
filters = tool_parameters.get("filter")
page_token = tool_parameters.get("page_token")
app_token = tool_parameters.get("app_token", "")
table_id = tool_parameters.get("table_id", "")
table_name = tool_parameters.get("table_name", "")
view_id = tool_parameters.get("view_id", "")
field_names = tool_parameters.get("field_names", "")
sort = tool_parameters.get("sort", "")
filters = tool_parameters.get("filter", "")
page_token = tool_parameters.get("page_token", "")
automatic_fields = tool_parameters.get("automatic_fields", False)
user_id_type = tool_parameters.get("user_id_type", "open_id")
page_size = tool_parameters.get("page_size", 20)

View File

@@ -7,14 +7,18 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class UpdateRecordsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
app_token = tool_parameters.get("app_token")
table_id = tool_parameters.get("table_id")
table_name = tool_parameters.get("table_name")
records = tool_parameters.get("records")
app_token = tool_parameters.get("app_token", "")
table_id = tool_parameters.get("table_id", "")
table_name = tool_parameters.get("table_name", "")
records = tool_parameters.get("records", "")
user_id_type = tool_parameters.get("user_id_type", "open_id")
res = client.update_records(app_token, table_id, table_name, records, user_id_type)

View File

@@ -7,12 +7,16 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class AddEventAttendeesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
event_id = tool_parameters.get("event_id")
attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email")
event_id = tool_parameters.get("event_id", "")
attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email", "")
need_notification = tool_parameters.get("need_notification", True)
res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification)

View File

@@ -7,11 +7,15 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class DeleteEventTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
event_id = tool_parameters.get("event_id")
event_id = tool_parameters.get("event_id", "")
need_notification = tool_parameters.get("need_notification", True)
res = client.delete_event(event_id, need_notification)

View File

@@ -7,8 +7,12 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class GetPrimaryCalendarTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
user_id_type = tool_parameters.get("user_id_type", "open_id")

View File

@@ -7,14 +7,18 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class ListEventsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
start_time = tool_parameters.get("start_time")
end_time = tool_parameters.get("end_time")
page_token = tool_parameters.get("page_token")
page_size = tool_parameters.get("page_size")
start_time = tool_parameters.get("start_time", "")
end_time = tool_parameters.get("end_time", "")
page_token = tool_parameters.get("page_token", "")
page_size = tool_parameters.get("page_size", 50)
res = client.list_events(start_time, end_time, page_token, page_size)

View File

@@ -7,16 +7,20 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class UpdateEventTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
event_id = tool_parameters.get("event_id")
summary = tool_parameters.get("summary")
description = tool_parameters.get("description")
event_id = tool_parameters.get("event_id", "")
summary = tool_parameters.get("summary", "")
description = tool_parameters.get("description", "")
need_notification = tool_parameters.get("need_notification", True)
start_time = tool_parameters.get("start_time")
end_time = tool_parameters.get("end_time")
start_time = tool_parameters.get("start_time", "")
end_time = tool_parameters.get("end_time", "")
auto_record = tool_parameters.get("auto_record", False)
res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record)

View File

@@ -7,13 +7,17 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class CreateDocumentTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
title = tool_parameters.get("title")
content = tool_parameters.get("content")
folder_token = tool_parameters.get("folder_token")
title = tool_parameters.get("title", "")
content = tool_parameters.get("content", "")
folder_token = tool_parameters.get("folder_token", "")
res = client.create_document(title, content, folder_token)
return self.create_json_message(res)

View File

@@ -7,11 +7,15 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class ListDocumentBlockTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
document_id = tool_parameters.get("document_id")
document_id = tool_parameters.get("document_id", "")
page_token = tool_parameters.get("page_token", "")
user_id_type = tool_parameters.get("user_id_type", "open_id")
page_size = tool_parameters.get("page_size", 500)

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
from jsonpath_ng import parse
from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
from jsonpath_ng import parse
from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
from jsonpath_ng import parse
from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
from jsonpath_ng import parse
from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import logging
from typing import Any, Union
import numexpr as ne
import numexpr as ne # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,4 +1,4 @@
from novita_client import (
from novita_client import ( # type: ignore
Txt2ImgV3Embedding,
Txt2ImgV3HiresFix,
Txt2ImgV3LoRA,

View File

@@ -2,7 +2,7 @@ from base64 import b64decode
from copy import deepcopy
from typing import Any, Union
from novita_client import (
from novita_client import ( # type: ignore
NovitaClient,
)

View File

@@ -2,7 +2,7 @@ from base64 import b64decode
from copy import deepcopy
from typing import Any, Union
from novita_client import (
from novita_client import ( # type: ignore
NovitaClient,
)

View File

@@ -13,7 +13,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from pydub import AudioSegment
from pydub import AudioSegment # type: ignore
class PodcastAudioGeneratorTool(BuiltinTool):

View File

@@ -2,10 +2,10 @@ import io
import logging
from typing import Any, Union
from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q
from qrcode.image.base import BaseImage
from qrcode.image.pure import PyPNGImage
from qrcode.main import QRCode
from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q # type: ignore
from qrcode.image.base import BaseImage # type: ignore
from qrcode.image.pure import PyPNGImage # type: ignore
from qrcode.main import QRCode # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
from typing import Any, Union
from urllib.parse import parse_qs, urlparse
from youtube_transcript_api import YouTubeTranscriptApi
from youtube_transcript_api import YouTubeTranscriptApi # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -37,7 +37,7 @@ class TwilioAPIWrapper(BaseModel):
def set_validator(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
try:
from twilio.rest import Client
from twilio.rest import Client # type: ignore
except ImportError:
raise ImportError("Could not import twilio python package. Please install it with `pip install twilio`.")
account_sid = values.get("account_sid")

View File

@@ -1,7 +1,7 @@
from typing import Any
from twilio.base.exceptions import TwilioRestException
from twilio.rest import Client
from twilio.base.exceptions import TwilioRestException # type: ignore
from twilio.rest import Client # type: ignore
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController

View File

@@ -1,6 +1,6 @@
from typing import Any, Union
from vanna.remote import VannaDefault
from vanna.remote import VannaDefault # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolProviderCredentialValidationError
@@ -14,6 +14,9 @@ class VannaTool(BuiltinTool):
"""
invoke tools
"""
# Ensure runtime and credentials
if not self.runtime or not self.runtime.credentials:
raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
api_key = self.runtime.credentials.get("api_key", None)
if not api_key:
raise ToolProviderCredentialValidationError("Please input api key")

View File

@@ -1,6 +1,6 @@
from typing import Any, Optional, Union
import wikipedia
import wikipedia # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -3,7 +3,7 @@ from typing import Any, Union
import pandas as pd
from requests.exceptions import HTTPError, ReadTimeout
from yfinance import download
from yfinance import download # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,6 +1,6 @@
from typing import Any, Union
import yfinance
import yfinance # type: ignore
from requests.exceptions import HTTPError, ReadTimeout
from core.tools.entities.tool_entities import ToolInvokeMessage

View File

@@ -1,7 +1,7 @@
from typing import Any, Union
from requests.exceptions import HTTPError, ReadTimeout
from yfinance import Ticker
from yfinance import Ticker # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Any, Union
from googleapiclient.discovery import build
from googleapiclient.discovery import build # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool