Feat/new login (#8120)

Co-authored-by: douxc <douxc512@gmail.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
This commit is contained in:
Joe
2024-10-21 10:03:40 +08:00
committed by GitHub
parent 2c0eaaec3d
commit 4fd2743efa
24 changed files with 1027 additions and 292 deletions

View File

@@ -189,23 +189,39 @@ def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Resp
class TokenManager:
@classmethod
def generate_token(cls, account: Account, token_type: str, additional_data: Optional[dict] = None) -> str:
old_token = cls._get_current_token_for_account(account.id, token_type)
if old_token:
if isinstance(old_token, bytes):
old_token = old_token.decode("utf-8")
cls.revoke_token(old_token, token_type)
def generate_token(
cls,
token_type: str,
account: Optional[Account] = None,
email: Optional[str] = None,
additional_data: Optional[dict] = None,
) -> str:
if account is None and email is None:
raise ValueError("Account or email must be provided")
account_id = account.id if account else None
account_email = account.email if account else email
if account_id:
old_token = cls._get_current_token_for_account(account_id, token_type)
if old_token:
if isinstance(old_token, bytes):
old_token = old_token.decode("utf-8")
cls.revoke_token(old_token, token_type)
token = str(uuid.uuid4())
token_data = {"account_id": account.id, "email": account.email, "token_type": token_type}
token_data = {"account_id": account_id, "email": account_email, "token_type": token_type}
if additional_data:
token_data.update(additional_data)
expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"]
token_key = cls._get_token_key(token, token_type)
redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data))
expiry_time = int(expiry_hours * 60 * 60)
redis_client.setex(token_key, expiry_time, json.dumps(token_data))
if account_id:
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
return token
@classmethod
@@ -234,9 +250,12 @@ class TokenManager:
return current_token
@classmethod
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_hours: int):
def _set_current_token_for_account(
cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float]
):
key = cls._get_account_token_key(account_id, token_type)
redis_client.setex(key, expiry_hours * 60 * 60, token)
expiry_time = int(expiry_hours * 60 * 60)
redis_client.setex(key, expiry_time, token)
@classmethod
def _get_account_token_key(cls, account_id: str, token_type: str) -> str:

View File

@@ -1,5 +1,6 @@
import urllib.parse
from dataclasses import dataclass
from typing import Optional
import requests
@@ -40,12 +41,14 @@ class GitHubOAuth(OAuth):
_USER_INFO_URL = "https://api.github.com/user"
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
def get_authorization_url(self):
def get_authorization_url(self, invite_token: Optional[str] = None):
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": "user:email", # Request only basic user information
}
if invite_token:
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
@@ -90,13 +93,15 @@ class GoogleOAuth(OAuth):
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def get_authorization_url(self):
def get_authorization_url(self, invite_token: Optional[str] = None):
params = {
"client_id": self.client_id,
"response_type": "code",
"redirect_uri": self.redirect_uri,
"scope": "openid email",
}
if invite_token:
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):

View File

@@ -13,7 +13,7 @@ def valid_password(password):
if re.match(pattern, password) is not None:
return password
raise ValueError("Not a valid password.")
raise ValueError("Password must contain letters and numbers, and the length must be greater than 8.")
def hash_password(password_str, salt_byte):