mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
Initial commit
This commit is contained in:
1
api/libs/__init__.py
Normal file
1
api/libs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
82
api/libs/ecc_aes.py
Normal file
82
api/libs/ecc_aes.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Hash import SHA256
|
||||
from Crypto.PublicKey import ECC
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
|
||||
|
||||
class ECC_AES:
|
||||
def __init__(self, curve='P-256'):
|
||||
self.curve = curve
|
||||
self._aes_key = None
|
||||
self._private_key = None
|
||||
|
||||
def _derive_aes_key(self, ecc_key, nonce):
|
||||
if not self._aes_key:
|
||||
hasher = SHA256.new()
|
||||
hasher.update(ecc_key.export_key(format='DER') + nonce.encode())
|
||||
self._aes_key = hasher.digest()[:32]
|
||||
return self._aes_key
|
||||
|
||||
def generate_key_pair(self):
|
||||
private_key = ECC.generate(curve=self.curve)
|
||||
public_key = private_key.public_key()
|
||||
|
||||
pem_private = private_key.export_key(format='PEM')
|
||||
pem_public = public_key.export_key(format='PEM')
|
||||
|
||||
return pem_private, pem_public
|
||||
|
||||
def load_private_key(self, private_key_pem):
|
||||
self._private_key = ECC.import_key(private_key_pem)
|
||||
self._aes_key = None
|
||||
|
||||
def encrypt(self, text, nonce):
|
||||
if not self._private_key:
|
||||
raise ValueError("Private key not loaded")
|
||||
|
||||
# Generate AES key using ECC private key and nonce
|
||||
aes_key = self._derive_aes_key(self._private_key, nonce)
|
||||
|
||||
# Encrypt data using AES key
|
||||
cipher = AES.new(aes_key, AES.MODE_ECB)
|
||||
padded_text = pad(text.encode(), AES.block_size)
|
||||
ciphertext = cipher.encrypt(padded_text)
|
||||
|
||||
return ciphertext
|
||||
|
||||
def decrypt(self, ciphertext, nonce):
|
||||
if not self._private_key:
|
||||
raise ValueError("Private key not loaded")
|
||||
|
||||
# Generate AES key using ECC private key and nonce
|
||||
aes_key = self._derive_aes_key(self._private_key, nonce)
|
||||
|
||||
# Decrypt data using AES key
|
||||
cipher = AES.new(aes_key, AES.MODE_ECB)
|
||||
padded_plaintext = cipher.decrypt(ciphertext)
|
||||
plaintext = unpad(padded_plaintext, AES.block_size)
|
||||
|
||||
return plaintext.decode()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ecc_aes = ECC_AES()
|
||||
|
||||
# Generate key pairs for the user
|
||||
private_key, public_key = ecc_aes.generate_key_pair()
|
||||
ecc_aes.load_private_key(private_key)
|
||||
nonce = "THIS-IS-USER-ID"
|
||||
|
||||
print(private_key)
|
||||
|
||||
# Encrypt a message
|
||||
message = "Hello, this is a secret message!"
|
||||
encrypted_message = ecc_aes.encrypt(message, nonce)
|
||||
print(f"Encrypted message: {encrypted_message.hex()}")
|
||||
|
||||
# Decrypt the message
|
||||
decrypted_message = ecc_aes.decrypt(encrypted_message, nonce)
|
||||
print(f"Decrypted message: {decrypted_message}")
|
||||
|
||||
# Check if the original message and decrypted message are the same
|
||||
assert message == decrypted_message, "Original message and decrypted message do not match"
|
||||
17
api/libs/exception.py
Normal file
17
api/libs/exception.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
|
||||
class BaseHTTPException(HTTPException):
|
||||
error_code: str = 'unknown'
|
||||
data: Optional[dict] = None
|
||||
|
||||
def __init__(self, description=None, response=None):
|
||||
super().__init__(description, response)
|
||||
|
||||
self.data = {
|
||||
"code": self.error_code,
|
||||
"message": self.description,
|
||||
"status": self.code,
|
||||
}
|
||||
115
api/libs/external_api.py
Normal file
115
api/libs/external_api.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import re
|
||||
import sys
|
||||
|
||||
from flask import got_request_exception, current_app
|
||||
from flask_restful import Api, http_status_message
|
||||
from werkzeug.datastructures import Headers
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
|
||||
class ExternalApi(Api):
|
||||
|
||||
def handle_error(self, e):
|
||||
"""Error handler for the API transforms a raised exception into a Flask
|
||||
response, with the appropriate HTTP status code and body.
|
||||
|
||||
:param e: the raised Exception object
|
||||
:type e: Exception
|
||||
|
||||
"""
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
|
||||
headers = Headers()
|
||||
if isinstance(e, HTTPException):
|
||||
if e.response is not None:
|
||||
resp = e.get_response()
|
||||
return resp
|
||||
|
||||
status_code = e.code
|
||||
default_data = {
|
||||
'code': re.sub(r'(?<!^)(?=[A-Z])', '_', type(e).__name__).lower(),
|
||||
'message': getattr(e, 'description', http_status_message(status_code)),
|
||||
'status': status_code
|
||||
}
|
||||
headers = e.get_response().headers
|
||||
elif isinstance(e, ValueError):
|
||||
status_code = 400
|
||||
default_data = {
|
||||
'code': 'invalid_param',
|
||||
'message': str(e),
|
||||
'status': status_code
|
||||
}
|
||||
else:
|
||||
status_code = 500
|
||||
default_data = {
|
||||
'message': http_status_message(status_code),
|
||||
}
|
||||
|
||||
# Werkzeug exceptions generate a content-length header which is added
|
||||
# to the response in addition to the actual content-length header
|
||||
# https://github.com/flask-restful/flask-restful/issues/534
|
||||
remove_headers = ('Content-Length',)
|
||||
|
||||
for header in remove_headers:
|
||||
headers.pop(header, None)
|
||||
|
||||
data = getattr(e, 'data', default_data)
|
||||
|
||||
error_cls_name = type(e).__name__
|
||||
if error_cls_name in self.errors:
|
||||
custom_data = self.errors.get(error_cls_name, {})
|
||||
custom_data = custom_data.copy()
|
||||
status_code = custom_data.get('status', 500)
|
||||
|
||||
if 'message' in custom_data:
|
||||
custom_data['message'] = custom_data['message'].format(
|
||||
message=str(e.description if hasattr(e, 'description') else e)
|
||||
)
|
||||
data.update(custom_data)
|
||||
|
||||
# record the exception in the logs when we have a server error of status code: 500
|
||||
if status_code and status_code >= 500:
|
||||
exc_info = sys.exc_info()
|
||||
if exc_info[1] is None:
|
||||
exc_info = None
|
||||
current_app.log_exception(exc_info)
|
||||
|
||||
if status_code == 406 and self.default_mediatype is None:
|
||||
# if we are handling NotAcceptable (406), make sure that
|
||||
# make_response uses a representation we support as the
|
||||
# default mediatype (so that make_response doesn't throw
|
||||
# another NotAcceptable error).
|
||||
supported_mediatypes = list(self.representations.keys()) # only supported application/json
|
||||
fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain"
|
||||
data = {
|
||||
'code': 'not_acceptable',
|
||||
'message': data.get('message')
|
||||
}
|
||||
resp = self.make_response(
|
||||
data,
|
||||
status_code,
|
||||
headers,
|
||||
fallback_mediatype = fallback_mediatype
|
||||
)
|
||||
elif status_code == 400:
|
||||
if isinstance(data.get('message'), dict):
|
||||
param_key, param_value = list(data.get('message').items())[0]
|
||||
data = {
|
||||
'code': 'invalid_param',
|
||||
'message': param_value,
|
||||
'params': param_key
|
||||
}
|
||||
else:
|
||||
if 'code' not in data:
|
||||
data['code'] = 'unknown'
|
||||
|
||||
resp = self.make_response(data, status_code, headers)
|
||||
else:
|
||||
if 'code' not in data:
|
||||
data['code'] = 'unknown'
|
||||
|
||||
resp = self.make_response(data, status_code, headers)
|
||||
|
||||
if status_code == 401:
|
||||
resp = self.unauthorized(resp)
|
||||
return resp
|
||||
149
api/libs/helper.py
Normal file
149
api/libs/helper.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from zoneinfo import available_timezones
|
||||
import random
|
||||
import string
|
||||
|
||||
from flask_restful import fields
|
||||
|
||||
|
||||
def run(script):
|
||||
return subprocess.getstatusoutput('source /root/.bashrc && ' + script)
|
||||
|
||||
|
||||
class TimestampField(fields.Raw):
|
||||
def format(self, value):
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def email(email):
|
||||
# Define a regex pattern for email addresses
|
||||
pattern = r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$"
|
||||
# Check if the email matches the pattern
|
||||
if re.match(pattern, email) is not None:
|
||||
return email
|
||||
|
||||
error = ('{email} is not a valid email.'
|
||||
.format(email=email))
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def uuid_value(value):
|
||||
if value == '':
|
||||
return str(value)
|
||||
|
||||
try:
|
||||
uuid_obj = uuid.UUID(value)
|
||||
return str(uuid_obj)
|
||||
except ValueError:
|
||||
error = ('{value} is not a valid uuid.'
|
||||
.format(value=value))
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def timestamp_value(timestamp):
|
||||
try:
|
||||
int_timestamp = int(timestamp)
|
||||
if int_timestamp < 0:
|
||||
raise ValueError
|
||||
return int_timestamp
|
||||
except ValueError:
|
||||
error = ('{timestamp} is not a valid timestamp.'
|
||||
.format(timestamp=timestamp))
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
class str_len(object):
|
||||
""" Restrict input to an integer in a range (inclusive) """
|
||||
|
||||
def __init__(self, max_length, argument='argument'):
|
||||
self.max_length = max_length
|
||||
self.argument = argument
|
||||
|
||||
def __call__(self, value):
|
||||
length = len(value)
|
||||
if length > self.max_length:
|
||||
error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}'
|
||||
.format(arg=self.argument, val=value, length=self.max_length))
|
||||
raise ValueError(error)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class float_range(object):
|
||||
""" Restrict input to an float in a range (inclusive) """
|
||||
def __init__(self, low, high, argument='argument'):
|
||||
self.low = low
|
||||
self.high = high
|
||||
self.argument = argument
|
||||
|
||||
def __call__(self, value):
|
||||
value = _get_float(value)
|
||||
if value < self.low or value > self.high:
|
||||
error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'
|
||||
.format(arg=self.argument, val=value, lo=self.low, hi=self.high))
|
||||
raise ValueError(error)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class datetime_string(object):
|
||||
def __init__(self, format, argument='argument'):
|
||||
self.format = format
|
||||
self.argument = argument
|
||||
|
||||
def __call__(self, value):
|
||||
try:
|
||||
datetime.strptime(value, self.format)
|
||||
except ValueError:
|
||||
error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}'
|
||||
.format(arg=self.argument, val=value, lo=self.format))
|
||||
raise ValueError(error)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _get_float(value):
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError('{0} is not a valid float'.format(value))
|
||||
|
||||
|
||||
def supported_language(lang):
|
||||
if lang in ['en-US', 'zh-Hans']:
|
||||
return lang
|
||||
|
||||
error = ('{lang} is not a valid language.'
|
||||
.format(lang=lang))
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def timezone(timezone_string):
|
||||
if timezone_string and timezone_string in available_timezones():
|
||||
return timezone_string
|
||||
|
||||
error = ('{timezone_string} is not a valid timezone.'
|
||||
.format(timezone_string=timezone_string))
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def generate_string(n):
|
||||
letters_digits = string.ascii_letters + string.digits
|
||||
result = ""
|
||||
for i in range(n):
|
||||
result += random.choice(letters_digits)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_remote_ip(request):
|
||||
if request.headers.get('CF-Connecting-IP'):
|
||||
return request.headers.get('Cf-Connecting-Ip')
|
||||
elif request.headers.getlist("X-Forwarded-For"):
|
||||
return request.headers.getlist("X-Forwarded-For")[0]
|
||||
else:
|
||||
return request.remote_addr
|
||||
7
api/libs/infinite_scroll_pagination.py
Normal file
7
api/libs/infinite_scroll_pagination.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
class InfiniteScrollPagination:
|
||||
def __init__(self, data, limit, has_more):
|
||||
self.data = data
|
||||
self.limit = limit
|
||||
self.has_more = has_more
|
||||
136
api/libs/oauth.py
Normal file
136
api/libs/oauth.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthUserInfo:
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class OAuth:
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_user_info(self, token: str) -> OAuthUserInfo:
|
||||
raw_info = self.get_raw_user_info(token)
|
||||
return self._transform_user_info(raw_info)
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class GitHubOAuth(OAuth):
|
||||
_AUTH_URL = 'https://github.com/login/oauth/authorize'
|
||||
_TOKEN_URL = 'https://github.com/login/oauth/access_token'
|
||||
_USER_INFO_URL = 'https://api.github.com/user'
|
||||
_EMAIL_INFO_URL = 'https://api.github.com/user/emails'
|
||||
|
||||
def get_authorization_url(self):
|
||||
params = {
|
||||
'client_id': self.client_id,
|
||||
'redirect_uri': self.redirect_uri,
|
||||
'scope': 'user:email' # Request only basic user information
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
data = {
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'code': code,
|
||||
'redirect_uri': self.redirect_uri
|
||||
}
|
||||
headers = {'Accept': 'application/json'}
|
||||
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get('access_token')
|
||||
|
||||
if not access_token:
|
||||
raise ValueError(f"Error in GitHub OAuth: {response_json}")
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
headers = {'Authorization': f"token {token}"}
|
||||
response = requests.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
|
||||
email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = email_response.json()
|
||||
primary_email = next((email for email in email_info if email['primary'] == True), None)
|
||||
|
||||
return {**user_info, 'email': primary_email['email']}
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
email = raw_info.get('email')
|
||||
if not email:
|
||||
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(
|
||||
id=str(raw_info['id']),
|
||||
name=raw_info['name'],
|
||||
email=email
|
||||
)
|
||||
|
||||
|
||||
class GoogleOAuth(OAuth):
|
||||
_AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth'
|
||||
_TOKEN_URL = 'https://oauth2.googleapis.com/token'
|
||||
_USER_INFO_URL = 'https://www.googleapis.com/oauth2/v3/userinfo'
|
||||
|
||||
def get_authorization_url(self):
|
||||
params = {
|
||||
'client_id': self.client_id,
|
||||
'response_type': 'code',
|
||||
'redirect_uri': self.redirect_uri,
|
||||
'scope': 'openid email'
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
data = {
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'code': code,
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': self.redirect_uri
|
||||
}
|
||||
headers = {'Accept': 'application/json'}
|
||||
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get('access_token')
|
||||
|
||||
if not access_token:
|
||||
raise ValueError(f"Error in Google OAuth: {response_json}")
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
headers = {'Authorization': f"Bearer {token}"}
|
||||
response = requests.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
return OAuthUserInfo(
|
||||
id=str(raw_info['sub']),
|
||||
name=None,
|
||||
email=raw_info['email']
|
||||
)
|
||||
26
api/libs/password.py
Normal file
26
api/libs/password.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import base64
|
||||
import binascii
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$"
|
||||
|
||||
def valid_password(password):
|
||||
# Define a regex pattern for password rules
|
||||
pattern = password_pattern
|
||||
# Check if the password matches the pattern
|
||||
if re.match(pattern, password) is not None:
|
||||
return password
|
||||
|
||||
raise ValueError('Not a valid password.')
|
||||
|
||||
|
||||
def hash_password(password_str, salt_byte):
|
||||
dk = hashlib.pbkdf2_hmac('sha256', password_str.encode('utf-8'), salt_byte, 10000)
|
||||
return binascii.hexlify(dk)
|
||||
|
||||
|
||||
def compare_password(password_str, password_hashed_base64, salt_base64):
|
||||
# compare password for login
|
||||
return hash_password(password_str, base64.b64decode(salt_base64)) == base64.b64decode(password_hashed_base64)
|
||||
58
api/libs/rsa.py
Normal file
58
api/libs/rsa.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import hashlib
|
||||
|
||||
from Crypto.Cipher import PKCS1_OAEP
|
||||
from Crypto.PublicKey import RSA
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
# TODO: PKCS1_OAEP is no longer recommended for new systems and protocols. It is recommended to migrate to PKCS1_PSS.
|
||||
|
||||
|
||||
def generate_key_pair(tenant_id):
|
||||
private_key = RSA.generate(2048)
|
||||
public_key = private_key.publickey()
|
||||
|
||||
pem_private = private_key.export_key()
|
||||
pem_public = public_key.export_key()
|
||||
|
||||
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
|
||||
|
||||
storage.save(filepath, pem_private)
|
||||
|
||||
return pem_public.decode()
|
||||
|
||||
|
||||
def encrypt(text, public_key):
|
||||
if isinstance(public_key, str):
|
||||
public_key = public_key.encode()
|
||||
|
||||
rsa_key = RSA.import_key(public_key)
|
||||
cipher = PKCS1_OAEP.new(rsa_key)
|
||||
encrypted_text = cipher.encrypt(text.encode())
|
||||
return encrypted_text
|
||||
|
||||
|
||||
def decrypt(encrypted_text, tenant_id):
|
||||
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
|
||||
|
||||
cache_key = 'tenant_privkey:{hash}'.format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
|
||||
private_key = redis_client.get(cache_key)
|
||||
if not private_key:
|
||||
try:
|
||||
private_key = storage.load(filepath)
|
||||
except FileNotFoundError:
|
||||
raise PrivkeyNotFoundError("Private key not found, tenant_id: {tenant_id}".format(tenant_id=tenant_id))
|
||||
|
||||
redis_client.setex(cache_key, 120, private_key)
|
||||
|
||||
rsa_key = RSA.import_key(private_key)
|
||||
cipher = PKCS1_OAEP.new(rsa_key)
|
||||
decrypted_text = cipher.decrypt(encrypted_text)
|
||||
return decrypted_text.decode()
|
||||
|
||||
|
||||
class PrivkeyNotFoundError(Exception):
|
||||
pass
|
||||
Reference in New Issue
Block a user