feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,8 +1,9 @@
import re
import sys
from typing import Any
from flask import current_app, got_request_exception
from flask_restful import Api, http_status_message
from flask_restful import Api, http_status_message # type: ignore
from werkzeug.datastructures import Headers
from werkzeug.exceptions import HTTPException
@@ -84,7 +85,7 @@ class ExternalApi(Api):
# record the exception in the logs when we have a server error of status code: 500
if status_code and status_code >= 500:
exc_info = sys.exc_info()
exc_info: Any = sys.exc_info()
if exc_info[1] is None:
exc_info = None
current_app.log_exception(exc_info)
@@ -100,7 +101,7 @@ class ExternalApi(Api):
resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype)
elif status_code == 400:
if isinstance(data.get("message"), dict):
param_key, param_value = list(data.get("message").items())[0]
param_key, param_value = list(data.get("message", {}).items())[0]
data = {"code": "invalid_param", "message": param_value, "params": param_key}
else:
if "code" not in data:

View File

@@ -23,7 +23,7 @@ from hashlib import sha1
import Crypto.Hash.SHA1
import Crypto.Util.number
import gmpy2
import gmpy2 # type: ignore
from Crypto import Random
from Crypto.Signature.pss import MGF1
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
@@ -191,12 +191,12 @@ class PKCS1OAepCipher:
# Step 3g
one_pos = hLen + db[hLen:].find(b"\x01")
lHash1 = db[:hLen]
invalid = bord(y) | int(one_pos < hLen)
invalid = bord(y) | int(one_pos < hLen) # type: ignore
hash_compare = strxor(lHash1, lHash)
for x in hash_compare:
invalid |= bord(x)
invalid |= bord(x) # type: ignore
for x in db[hLen:one_pos]:
invalid |= bord(x)
invalid |= bord(x) # type: ignore
if invalid != 0:
raise ValueError("Incorrect decryption.")
# Step 4

View File

@@ -13,7 +13,7 @@ from typing import Any, Optional, Union, cast
from zoneinfo import available_timezones
from flask import Response, stream_with_context
from flask_restful import fields
from flask_restful import fields # type: ignore
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
@@ -248,13 +248,13 @@ class TokenManager:
if token_data_json is None:
logging.warning(f"{token_type} token {token} not found with key {key}")
return None
token_data = json.loads(token_data_json)
token_data: Optional[dict[str, Any]] = json.loads(token_data_json)
return token_data
@classmethod
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]:
key = cls._get_account_token_key(account_id, token_type)
current_token = redis_client.get(key)
current_token: Optional[str] = redis_client.get(key)
return current_token
@classmethod

View File

@@ -10,6 +10,7 @@ def parse_json_markdown(json_string: str) -> dict:
ends = ["```", "``", "`", "}"]
end_index = -1
start_index = 0
parsed: dict = {}
for s in starts:
start_index = json_string.find(s)
if start_index != -1:

View File

@@ -1,8 +1,9 @@
from functools import wraps
from typing import Any
from flask import current_app, g, has_request_context, request
from flask_login import user_logged_in
from flask_login.config import EXEMPT_METHODS
from flask_login import user_logged_in # type: ignore
from flask_login.config import EXEMPT_METHODS # type: ignore
from werkzeug.exceptions import Unauthorized
from werkzeug.local import LocalProxy
@@ -12,7 +13,7 @@ from models.account import Account, Tenant, TenantAccountJoin
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
current_user = LocalProxy(lambda: _get_user())
current_user: Any = LocalProxy(lambda: _get_user())
def login_required(func):
@@ -79,12 +80,12 @@ def login_required(func):
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account)
user_logged_in.send(current_app._get_current_object(), user=_get_user())
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass
elif not current_user.is_authenticated:
return current_app.login_manager.unauthorized()
return current_app.login_manager.unauthorized() # type: ignore
# flask 1.x compatibility
# current_app.ensure_sync is only available in Flask >= 2.0
@@ -98,7 +99,7 @@ def login_required(func):
def _get_user():
if has_request_context():
if "_login_user" not in g:
current_app.login_manager._load_user()
current_app.login_manager._load_user() # type: ignore
return g._login_user

View File

@@ -77,9 +77,9 @@ class GitHubOAuth(OAuth):
email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json()
primary_email = next((email for email in email_info if email["primary"] == True), None)
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
return {**user_info, "email": primary_email["email"]}
return {**user_info, "email": primary_email.get("email", "")}
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
email = raw_info.get("email")
@@ -130,4 +130,4 @@ class GoogleOAuth(OAuth):
return response.json()
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"])
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])

View File

@@ -1,8 +1,9 @@
import datetime
import urllib.parse
from typing import Any
import requests
from flask_login import current_user
from flask_login import current_user # type: ignore
from extensions.ext_database import db
from models.source import DataSourceOauthBinding
@@ -226,7 +227,7 @@ class NotionOAuth(OAuthDataSource):
has_more = True
while has_more:
data = {
data: dict[str, Any] = {
"filter": {"value": "page", "property": "object"},
**({"start_cursor": next_cursor} if next_cursor else {}),
}
@@ -281,7 +282,7 @@ class NotionOAuth(OAuthDataSource):
has_more = True
while has_more:
data = {
data: dict[str, Any] = {
"filter": {"value": "database", "property": "object"},
**({"start_cursor": next_cursor} if next_cursor else {}),
}

View File

@@ -9,8 +9,8 @@ def apply_gevent_threading_patch():
:return:
"""
if not dify_config.DEBUG:
from gevent import monkey
from grpc.experimental import gevent as grpc_gevent
from gevent import monkey # type: ignore
from grpc.experimental import gevent as grpc_gevent # type: ignore
# gevent
monkey.patch_all()