Initial commit

This commit is contained in:
John Wang
2023-05-15 08:51:32 +08:00
commit db896255d6
744 changed files with 56028 additions and 0 deletions

1
api/libs/__init__.py Normal file
View File

@@ -0,0 +1 @@
# -*- coding:utf-8 -*-

82
api/libs/ecc_aes.py Normal file
View File

@@ -0,0 +1,82 @@
from Crypto.Cipher import AES
from Crypto.Hash import SHA256
from Crypto.PublicKey import ECC
from Crypto.Util.Padding import pad, unpad
class ECC_AES:
def __init__(self, curve='P-256'):
self.curve = curve
self._aes_key = None
self._private_key = None
def _derive_aes_key(self, ecc_key, nonce):
if not self._aes_key:
hasher = SHA256.new()
hasher.update(ecc_key.export_key(format='DER') + nonce.encode())
self._aes_key = hasher.digest()[:32]
return self._aes_key
def generate_key_pair(self):
private_key = ECC.generate(curve=self.curve)
public_key = private_key.public_key()
pem_private = private_key.export_key(format='PEM')
pem_public = public_key.export_key(format='PEM')
return pem_private, pem_public
def load_private_key(self, private_key_pem):
self._private_key = ECC.import_key(private_key_pem)
self._aes_key = None
def encrypt(self, text, nonce):
if not self._private_key:
raise ValueError("Private key not loaded")
# Generate AES key using ECC private key and nonce
aes_key = self._derive_aes_key(self._private_key, nonce)
# Encrypt data using AES key
cipher = AES.new(aes_key, AES.MODE_ECB)
padded_text = pad(text.encode(), AES.block_size)
ciphertext = cipher.encrypt(padded_text)
return ciphertext
def decrypt(self, ciphertext, nonce):
if not self._private_key:
raise ValueError("Private key not loaded")
# Generate AES key using ECC private key and nonce
aes_key = self._derive_aes_key(self._private_key, nonce)
# Decrypt data using AES key
cipher = AES.new(aes_key, AES.MODE_ECB)
padded_plaintext = cipher.decrypt(ciphertext)
plaintext = unpad(padded_plaintext, AES.block_size)
return plaintext.decode()
if __name__ == '__main__':
ecc_aes = ECC_AES()
# Generate key pairs for the user
private_key, public_key = ecc_aes.generate_key_pair()
ecc_aes.load_private_key(private_key)
nonce = "THIS-IS-USER-ID"
print(private_key)
# Encrypt a message
message = "Hello, this is a secret message!"
encrypted_message = ecc_aes.encrypt(message, nonce)
print(f"Encrypted message: {encrypted_message.hex()}")
# Decrypt the message
decrypted_message = ecc_aes.decrypt(encrypted_message, nonce)
print(f"Decrypted message: {decrypted_message}")
# Check if the original message and decrypted message are the same
assert message == decrypted_message, "Original message and decrypted message do not match"

17
api/libs/exception.py Normal file
View File

@@ -0,0 +1,17 @@
from typing import Optional
from werkzeug.exceptions import HTTPException
class BaseHTTPException(HTTPException):
error_code: str = 'unknown'
data: Optional[dict] = None
def __init__(self, description=None, response=None):
super().__init__(description, response)
self.data = {
"code": self.error_code,
"message": self.description,
"status": self.code,
}

115
api/libs/external_api.py Normal file
View File

@@ -0,0 +1,115 @@
import re
import sys
from flask import got_request_exception, current_app
from flask_restful import Api, http_status_message
from werkzeug.datastructures import Headers
from werkzeug.exceptions import HTTPException
class ExternalApi(Api):
def handle_error(self, e):
"""Error handler for the API transforms a raised exception into a Flask
response, with the appropriate HTTP status code and body.
:param e: the raised Exception object
:type e: Exception
"""
got_request_exception.send(current_app, exception=e)
headers = Headers()
if isinstance(e, HTTPException):
if e.response is not None:
resp = e.get_response()
return resp
status_code = e.code
default_data = {
'code': re.sub(r'(?<!^)(?=[A-Z])', '_', type(e).__name__).lower(),
'message': getattr(e, 'description', http_status_message(status_code)),
'status': status_code
}
headers = e.get_response().headers
elif isinstance(e, ValueError):
status_code = 400
default_data = {
'code': 'invalid_param',
'message': str(e),
'status': status_code
}
else:
status_code = 500
default_data = {
'message': http_status_message(status_code),
}
# Werkzeug exceptions generate a content-length header which is added
# to the response in addition to the actual content-length header
# https://github.com/flask-restful/flask-restful/issues/534
remove_headers = ('Content-Length',)
for header in remove_headers:
headers.pop(header, None)
data = getattr(e, 'data', default_data)
error_cls_name = type(e).__name__
if error_cls_name in self.errors:
custom_data = self.errors.get(error_cls_name, {})
custom_data = custom_data.copy()
status_code = custom_data.get('status', 500)
if 'message' in custom_data:
custom_data['message'] = custom_data['message'].format(
message=str(e.description if hasattr(e, 'description') else e)
)
data.update(custom_data)
# record the exception in the logs when we have a server error of status code: 500
if status_code and status_code >= 500:
exc_info = sys.exc_info()
if exc_info[1] is None:
exc_info = None
current_app.log_exception(exc_info)
if status_code == 406 and self.default_mediatype is None:
# if we are handling NotAcceptable (406), make sure that
# make_response uses a representation we support as the
# default mediatype (so that make_response doesn't throw
# another NotAcceptable error).
supported_mediatypes = list(self.representations.keys()) # only supported application/json
fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain"
data = {
'code': 'not_acceptable',
'message': data.get('message')
}
resp = self.make_response(
data,
status_code,
headers,
fallback_mediatype = fallback_mediatype
)
elif status_code == 400:
if isinstance(data.get('message'), dict):
param_key, param_value = list(data.get('message').items())[0]
data = {
'code': 'invalid_param',
'message': param_value,
'params': param_key
}
else:
if 'code' not in data:
data['code'] = 'unknown'
resp = self.make_response(data, status_code, headers)
else:
if 'code' not in data:
data['code'] = 'unknown'
resp = self.make_response(data, status_code, headers)
if status_code == 401:
resp = self.unauthorized(resp)
return resp

149
api/libs/helper.py Normal file
View File

@@ -0,0 +1,149 @@
# -*- coding:utf-8 -*-
import re
import subprocess
import uuid
from datetime import datetime
from zoneinfo import available_timezones
import random
import string
from flask_restful import fields
def run(script):
return subprocess.getstatusoutput('source /root/.bashrc && ' + script)
class TimestampField(fields.Raw):
def format(self, value):
return int(value.timestamp())
def email(email):
# Define a regex pattern for email addresses
pattern = r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$"
# Check if the email matches the pattern
if re.match(pattern, email) is not None:
return email
error = ('{email} is not a valid email.'
.format(email=email))
raise ValueError(error)
def uuid_value(value):
if value == '':
return str(value)
try:
uuid_obj = uuid.UUID(value)
return str(uuid_obj)
except ValueError:
error = ('{value} is not a valid uuid.'
.format(value=value))
raise ValueError(error)
def timestamp_value(timestamp):
try:
int_timestamp = int(timestamp)
if int_timestamp < 0:
raise ValueError
return int_timestamp
except ValueError:
error = ('{timestamp} is not a valid timestamp.'
.format(timestamp=timestamp))
raise ValueError(error)
class str_len(object):
""" Restrict input to an integer in a range (inclusive) """
def __init__(self, max_length, argument='argument'):
self.max_length = max_length
self.argument = argument
def __call__(self, value):
length = len(value)
if length > self.max_length:
error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}'
.format(arg=self.argument, val=value, length=self.max_length))
raise ValueError(error)
return value
class float_range(object):
""" Restrict input to an float in a range (inclusive) """
def __init__(self, low, high, argument='argument'):
self.low = low
self.high = high
self.argument = argument
def __call__(self, value):
value = _get_float(value)
if value < self.low or value > self.high:
error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'
.format(arg=self.argument, val=value, lo=self.low, hi=self.high))
raise ValueError(error)
return value
class datetime_string(object):
def __init__(self, format, argument='argument'):
self.format = format
self.argument = argument
def __call__(self, value):
try:
datetime.strptime(value, self.format)
except ValueError:
error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}'
.format(arg=self.argument, val=value, lo=self.format))
raise ValueError(error)
return value
def _get_float(value):
try:
return float(value)
except (TypeError, ValueError):
raise ValueError('{0} is not a valid float'.format(value))
def supported_language(lang):
if lang in ['en-US', 'zh-Hans']:
return lang
error = ('{lang} is not a valid language.'
.format(lang=lang))
raise ValueError(error)
def timezone(timezone_string):
if timezone_string and timezone_string in available_timezones():
return timezone_string
error = ('{timezone_string} is not a valid timezone.'
.format(timezone_string=timezone_string))
raise ValueError(error)
def generate_string(n):
letters_digits = string.ascii_letters + string.digits
result = ""
for i in range(n):
result += random.choice(letters_digits)
return result
def get_remote_ip(request):
if request.headers.get('CF-Connecting-IP'):
return request.headers.get('Cf-Connecting-Ip')
elif request.headers.getlist("X-Forwarded-For"):
return request.headers.getlist("X-Forwarded-For")[0]
else:
return request.remote_addr

View File

@@ -0,0 +1,7 @@
# -*- coding:utf-8 -*-
class InfiniteScrollPagination:
def __init__(self, data, limit, has_more):
self.data = data
self.limit = limit
self.has_more = has_more

136
api/libs/oauth.py Normal file
View File

@@ -0,0 +1,136 @@
import urllib.parse
from dataclasses import dataclass
import requests
@dataclass
class OAuthUserInfo:
id: str
name: str
email: str
class OAuth:
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self):
raise NotImplementedError()
def get_access_token(self, code: str):
raise NotImplementedError()
def get_raw_user_info(self, token: str):
raise NotImplementedError()
def get_user_info(self, token: str) -> OAuthUserInfo:
raw_info = self.get_raw_user_info(token)
return self._transform_user_info(raw_info)
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
raise NotImplementedError()
class GitHubOAuth(OAuth):
_AUTH_URL = 'https://github.com/login/oauth/authorize'
_TOKEN_URL = 'https://github.com/login/oauth/access_token'
_USER_INFO_URL = 'https://api.github.com/user'
_EMAIL_INFO_URL = 'https://api.github.com/user/emails'
def get_authorization_url(self):
params = {
'client_id': self.client_id,
'redirect_uri': self.redirect_uri,
'scope': 'user:email' # Request only basic user information
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': code,
'redirect_uri': self.redirect_uri
}
headers = {'Accept': 'application/json'}
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
access_token = response_json.get('access_token')
if not access_token:
raise ValueError(f"Error in GitHub OAuth: {response_json}")
return access_token
def get_raw_user_info(self, token: str):
headers = {'Authorization': f"token {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
user_info = response.json()
email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json()
primary_email = next((email for email in email_info if email['primary'] == True), None)
return {**user_info, 'email': primary_email['email']}
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
email = raw_info.get('email')
if not email:
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
return OAuthUserInfo(
id=str(raw_info['id']),
name=raw_info['name'],
email=email
)
class GoogleOAuth(OAuth):
_AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth'
_TOKEN_URL = 'https://oauth2.googleapis.com/token'
_USER_INFO_URL = 'https://www.googleapis.com/oauth2/v3/userinfo'
def get_authorization_url(self):
params = {
'client_id': self.client_id,
'response_type': 'code',
'redirect_uri': self.redirect_uri,
'scope': 'openid email'
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': code,
'grant_type': 'authorization_code',
'redirect_uri': self.redirect_uri
}
headers = {'Accept': 'application/json'}
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
access_token = response_json.get('access_token')
if not access_token:
raise ValueError(f"Error in Google OAuth: {response_json}")
return access_token
def get_raw_user_info(self, token: str):
headers = {'Authorization': f"Bearer {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
return response.json()
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
return OAuthUserInfo(
id=str(raw_info['sub']),
name=None,
email=raw_info['email']
)

26
api/libs/password.py Normal file
View File

@@ -0,0 +1,26 @@
# -*- coding:utf-8 -*-
import base64
import binascii
import hashlib
import re
password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$"
def valid_password(password):
# Define a regex pattern for password rules
pattern = password_pattern
# Check if the password matches the pattern
if re.match(pattern, password) is not None:
return password
raise ValueError('Not a valid password.')
def hash_password(password_str, salt_byte):
dk = hashlib.pbkdf2_hmac('sha256', password_str.encode('utf-8'), salt_byte, 10000)
return binascii.hexlify(dk)
def compare_password(password_str, password_hashed_base64, salt_base64):
# compare password for login
return hash_password(password_str, base64.b64decode(salt_base64)) == base64.b64decode(password_hashed_base64)

58
api/libs/rsa.py Normal file
View File

@@ -0,0 +1,58 @@
# -*- coding:utf-8 -*-
import hashlib
from Crypto.Cipher import PKCS1_OAEP
from Crypto.PublicKey import RSA
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
# TODO: PKCS1_OAEP is no longer recommended for new systems and protocols. It is recommended to migrate to PKCS1_PSS.
def generate_key_pair(tenant_id):
private_key = RSA.generate(2048)
public_key = private_key.publickey()
pem_private = private_key.export_key()
pem_public = public_key.export_key()
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
storage.save(filepath, pem_private)
return pem_public.decode()
def encrypt(text, public_key):
if isinstance(public_key, str):
public_key = public_key.encode()
rsa_key = RSA.import_key(public_key)
cipher = PKCS1_OAEP.new(rsa_key)
encrypted_text = cipher.encrypt(text.encode())
return encrypted_text
def decrypt(encrypted_text, tenant_id):
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
cache_key = 'tenant_privkey:{hash}'.format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
private_key = redis_client.get(cache_key)
if not private_key:
try:
private_key = storage.load(filepath)
except FileNotFoundError:
raise PrivkeyNotFoundError("Private key not found, tenant_id: {tenant_id}".format(tenant_id=tenant_id))
redis_client.setex(cache_key, 120, private_key)
rsa_key = RSA.import_key(private_key)
cipher = PKCS1_OAEP.new(rsa_key)
decrypted_text = cipher.decrypt(encrypted_text)
return decrypted_text.decode()
class PrivkeyNotFoundError(Exception):
pass