feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,3 +1,5 @@
from typing import cast
import requests
from configs import dify_config
@@ -5,7 +7,7 @@ from models.api_based_extension import APIBasedExtensionPoint
class APIBasedExtensionRequestor:
timeout: (int, int) = (5, 60)
timeout: tuple[int, int] = (5, 60)
"""timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str) -> None:
@@ -51,4 +53,4 @@ class APIBasedExtensionRequestor:
"request error, status_code: {}, content: {}".format(response.status_code, response.text[:100])
)
return response.json()
return cast(dict, response.json())

View File

@@ -38,8 +38,8 @@ class Extensible:
@classmethod
def scan_extensions(cls):
extensions: list[ModuleExtension] = []
position_map = {}
extensions = []
position_map: dict[str, int] = {}
# get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
@@ -58,7 +58,8 @@ class Extensible:
# is builtin extension, builtin extension
# in the front-end page and business logic, there are special treatments.
builtin = False
position = None
# default position is 0 can not be None for sort_to_dict_by_position_map
position = 0
if "__builtin__" in file_names:
builtin = True
@@ -89,7 +90,7 @@ class Extensible:
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
continue
json_data = {}
json_data: dict[str, Any] = {}
if not builtin:
if "schema.json" not in file_names:
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")

View File

@@ -1,4 +1,6 @@
from core.extension.extensible import ExtensionModule, ModuleExtension
from typing import cast
from core.extension.extensible import Extensible, ExtensionModule, ModuleExtension
from core.external_data_tool.base import ExternalDataTool
from core.moderation.base import Moderation
@@ -10,7 +12,8 @@ class Extension:
def init(self):
for module, module_class in self.module_classes.items():
self.__module_extensions[module.value] = module_class.scan_extensions()
m = cast(Extensible, module_class)
self.__module_extensions[module.value] = m.scan_extensions()
def module_extensions(self, module: str) -> list[ModuleExtension]:
module_extensions = self.__module_extensions.get(module)
@@ -35,7 +38,8 @@ class Extension:
def extension_class(self, module: ExtensionModule, extension_name: str) -> type:
module_extension = self.module_extension(module, extension_name)
return module_extension.extension_class
t: type = module_extension.extension_class
return t
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
module_extension = self.module_extension(module, extension_name)