mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 10:56:52 +08:00
feat: member invitation and activation (#535)
Co-authored-by: John Wang <takatost@gmail.com>
This commit is contained in:
@@ -8,13 +8,19 @@ EDITION=SELF_HOSTED
|
||||
SECRET_KEY=
|
||||
|
||||
# Console API base URL
|
||||
CONSOLE_URL=http://127.0.0.1:5001
|
||||
CONSOLE_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Console frontend web base URL
|
||||
CONSOLE_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# Service API base URL
|
||||
API_URL=http://127.0.0.1:5001
|
||||
SERVICE_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Web APP base URL
|
||||
APP_URL=http://127.0.0.1:3000
|
||||
# Web APP API base URL
|
||||
APP_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Web APP frontend web base URL
|
||||
APP_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
@@ -79,6 +85,11 @@ WEAVIATE_BATCH_SIZE=100
|
||||
QDRANT_URL=path:storage/qdrant
|
||||
QDRANT_API_KEY=your-qdrant-api-key
|
||||
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE=
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
RESEND_API_KEY=
|
||||
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
|
||||
|
||||
@@ -5,9 +5,11 @@ LABEL maintainer="takatost@gmail.com"
|
||||
ENV FLASK_APP app.py
|
||||
ENV EDITION SELF_HOSTED
|
||||
ENV DEPLOY_ENV PRODUCTION
|
||||
ENV CONSOLE_URL http://127.0.0.1:5001
|
||||
ENV API_URL http://127.0.0.1:5001
|
||||
ENV APP_URL http://127.0.0.1:5001
|
||||
ENV CONSOLE_API_URL http://127.0.0.1:5001
|
||||
ENV CONSOLE_WEB_URL http://127.0.0.1:3000
|
||||
ENV SERVICE_API_URL http://127.0.0.1:5001
|
||||
ENV APP_API_URL http://127.0.0.1:5001
|
||||
ENV APP_WEB_URL http://127.0.0.1:3000
|
||||
|
||||
EXPOSE 5001
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import flask_login
|
||||
from flask_cors import CORS
|
||||
|
||||
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
ext_database, ext_storage
|
||||
ext_database, ext_storage, ext_mail
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
|
||||
@@ -83,6 +83,7 @@ def initialize_extensions(app):
|
||||
ext_celery.init_app(app)
|
||||
ext_session.init_app(app)
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
|
||||
|
||||
|
||||
@@ -28,9 +28,11 @@ DEFAULTS = {
|
||||
'SESSION_REDIS_USE_SSL': 'False',
|
||||
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
|
||||
'OAUTH_REDIRECT_INDEX_PATH': '/',
|
||||
'CONSOLE_URL': 'https://cloud.dify.ai',
|
||||
'API_URL': 'https://api.dify.ai',
|
||||
'APP_URL': 'https://udify.app',
|
||||
'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
|
||||
'CONSOLE_API_URL': 'https://cloud.dify.ai',
|
||||
'SERVICE_API_URL': 'https://api.dify.ai',
|
||||
'APP_WEB_URL': 'https://udify.app',
|
||||
'APP_API_URL': 'https://udify.app',
|
||||
'STORAGE_TYPE': 'local',
|
||||
'STORAGE_LOCAL_PATH': 'storage',
|
||||
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
|
||||
@@ -76,6 +78,11 @@ class Config:
|
||||
|
||||
def __init__(self):
|
||||
# app settings
|
||||
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
|
||||
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
|
||||
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
|
||||
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
|
||||
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
@@ -147,10 +154,15 @@ class Config:
|
||||
|
||||
# cors settings
|
||||
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_URL)
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
||||
self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'WEB_API_CORS_ALLOW_ORIGINS', '*')
|
||||
|
||||
# mail settings
|
||||
self.MAIL_TYPE = get_env('MAIL_TYPE')
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
|
||||
# sentry settings
|
||||
self.SENTRY_DSN = get_env('SENTRY_DSN')
|
||||
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
|
||||
|
||||
@@ -12,7 +12,7 @@ from . import setup, version, apikey, admin
|
||||
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import login, oauth, data_source_oauth
|
||||
from .auth import login, oauth, data_source_oauth, activate
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
|
||||
|
||||
75
api/controllers/console/auth/activate.py
Normal file
75
api/controllers/console/auth/activate.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import base64
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, str_len, supported_language, timezone
|
||||
from libs.password import valid_password, hash_password
|
||||
from models.account import AccountStatus, Tenant
|
||||
from services.account_service import RegisterService
|
||||
|
||||
|
||||
class ActivateCheckApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument('email', type=email, required=True, nullable=False, location='args')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
|
||||
tenant = db.session.query(Tenant).filter(
|
||||
Tenant.id == args['workspace_id'],
|
||||
Tenant.status == 'normal'
|
||||
).first()
|
||||
|
||||
return {'is_valid': account is not None, 'workspace_name': tenant.name}
|
||||
|
||||
|
||||
class ActivateApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('email', type=email, required=True, nullable=False, location='json')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
|
||||
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
|
||||
parser.add_argument('interface_language', type=supported_language, required=True, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
if account is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
|
||||
|
||||
account.name = args['name']
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(args['password'], salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
account.interface_language = args['interface_language']
|
||||
account.timezone = args['timezone']
|
||||
account.interface_theme = 'light'
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(ActivateCheckApi, '/activate/check')
|
||||
api.add_resource(ActivateApi, '/activate')
|
||||
@@ -20,7 +20,7 @@ def get_oauth_providers():
|
||||
client_secret=current_app.config.get(
|
||||
'NOTION_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/data-source/callback/notion')
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/data-source/callback/notion')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'notion': notion_oauth
|
||||
@@ -42,7 +42,7 @@ class OAuthDataSource(Resource):
|
||||
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
|
||||
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
|
||||
oauth_provider.save_internal_access_token(internal_secret)
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
|
||||
else:
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return redirect(auth_url)
|
||||
@@ -66,12 +66,12 @@ class OAuthDataSourceCallback(Resource):
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source={error}')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source={error}')
|
||||
else:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=access_denied')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=access_denied')
|
||||
|
||||
|
||||
class OAuthDataSourceSync(Resource):
|
||||
|
||||
@@ -20,13 +20,13 @@ def get_oauth_providers():
|
||||
client_secret=current_app.config.get(
|
||||
'GITHUB_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/authorize/github')
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/authorize/github')
|
||||
|
||||
google_oauth = GoogleOAuth(client_id=current_app.config.get('GOOGLE_CLIENT_ID'),
|
||||
client_secret=current_app.config.get(
|
||||
'GOOGLE_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/authorize/google')
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/authorize/google')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'github': github_oauth,
|
||||
@@ -80,7 +80,7 @@ class OAuthCallback(Resource):
|
||||
flask_login.login_user(account, remember=True)
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_login=success')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success')
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
|
||||
@@ -18,3 +18,9 @@ class AccountNotLinkTenantError(BaseHTTPException):
|
||||
error_code = 'account_not_link_tenant'
|
||||
description = "Account not link tenant."
|
||||
code = 403
|
||||
|
||||
|
||||
class AlreadyActivateError(BaseHTTPException):
|
||||
error_code = 'already_activate'
|
||||
description = "Auth Token is invalid or account already activated, please check again."
|
||||
code = 403
|
||||
|
||||
@@ -6,22 +6,23 @@ from flask import current_app, request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.workspace.error import AccountAlreadyInitedError, InvalidInvitationCodeError, \
|
||||
RepeatPasswordNotMatchError
|
||||
RepeatPasswordNotMatchError, CurrentPasswordIncorrectError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import TimestampField, supported_language, timezone
|
||||
from extensions.ext_database import db
|
||||
from models.account import InvitationCode, AccountIntegrate
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'avatar': fields.String,
|
||||
'email': fields.String,
|
||||
'is_password_set': fields.Boolean,
|
||||
'interface_language': fields.String,
|
||||
'interface_theme': fields.String,
|
||||
'timezone': fields.String,
|
||||
@@ -194,8 +195,11 @@ class AccountPasswordApi(Resource):
|
||||
if args['new_password'] != args['repeat_new_password']:
|
||||
raise RepeatPasswordNotMatchError()
|
||||
|
||||
AccountService.update_account_password(
|
||||
current_user, args['password'], args['new_password'])
|
||||
try:
|
||||
AccountService.update_account_password(
|
||||
current_user, args['password'], args['new_password'])
|
||||
except ServiceCurrentPasswordIncorrectError:
|
||||
raise CurrentPasswordIncorrectError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -7,6 +7,12 @@ class RepeatPasswordNotMatchError(BaseHTTPException):
|
||||
code = 400
|
||||
|
||||
|
||||
class CurrentPasswordIncorrectError(BaseHTTPException):
|
||||
error_code = 'current_password_incorrect'
|
||||
description = "Current password is incorrect."
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderRequestFailedError(BaseHTTPException):
|
||||
error_code = 'provider_request_failed'
|
||||
description = None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
|
||||
|
||||
@@ -60,7 +60,8 @@ class MemberInviteEmailApi(Resource):
|
||||
inviter = current_user
|
||||
|
||||
try:
|
||||
RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role, inviter=inviter)
|
||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
|
||||
inviter=inviter)
|
||||
account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == args['email']).first()
|
||||
@@ -78,7 +79,16 @@ class MemberInviteEmailApi(Resource):
|
||||
|
||||
# todo:413
|
||||
|
||||
return {'result': 'success', 'account': account}, 201
|
||||
return {
|
||||
'result': 'success',
|
||||
'account': account,
|
||||
'invite_url': '{}/activate?workspace_id={}&email={}&token={}'.format(
|
||||
current_app.config.get("CONSOLE_WEB_URL"),
|
||||
str(current_user.current_tenant_id),
|
||||
invitee_email,
|
||||
token
|
||||
)
|
||||
}, 201
|
||||
|
||||
|
||||
class MemberCancelInviteApi(Resource):
|
||||
@@ -88,7 +98,7 @@ class MemberCancelInviteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, member_id):
|
||||
member = Account.query.get(str(member_id))
|
||||
member = db.session.query(Account).filter(Account.id == str(member_id)).first()
|
||||
if not member:
|
||||
abort(404)
|
||||
|
||||
|
||||
61
api/extensions/ext_mail.py
Normal file
61
api/extensions/ext_mail.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import Optional
|
||||
|
||||
import resend
|
||||
from flask import Flask
|
||||
|
||||
|
||||
class Mail:
|
||||
def __init__(self):
|
||||
self._client = None
|
||||
self._default_send_from = None
|
||||
|
||||
def is_inited(self) -> bool:
|
||||
return self._client is not None
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
if app.config.get('MAIL_TYPE'):
|
||||
if app.config.get('MAIL_DEFAULT_SEND_FROM'):
|
||||
self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM')
|
||||
|
||||
if app.config.get('MAIL_TYPE') == 'resend':
|
||||
api_key = app.config.get('RESEND_API_KEY')
|
||||
if not api_key:
|
||||
raise ValueError('RESEND_API_KEY is not set')
|
||||
|
||||
resend.api_key = api_key
|
||||
self._client = resend.Emails
|
||||
else:
|
||||
raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE')))
|
||||
|
||||
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
|
||||
if not self._client:
|
||||
raise ValueError('Mail client is not initialized')
|
||||
|
||||
if not from_ and self._default_send_from:
|
||||
from_ = self._default_send_from
|
||||
|
||||
if not from_:
|
||||
raise ValueError('mail from is not set')
|
||||
|
||||
if not to:
|
||||
raise ValueError('mail to is not set')
|
||||
|
||||
if not subject:
|
||||
raise ValueError('mail subject is not set')
|
||||
|
||||
if not html:
|
||||
raise ValueError('mail html is not set')
|
||||
|
||||
self._client.send({
|
||||
"from": from_,
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"html": html
|
||||
})
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
mail.init_app(app)
|
||||
|
||||
|
||||
mail = Mail()
|
||||
@@ -38,6 +38,10 @@ class Account(UserMixin, db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@property
|
||||
def is_password_set(self):
|
||||
return self.password is not None
|
||||
|
||||
@property
|
||||
def current_tenant(self):
|
||||
return self._current_tenant
|
||||
|
||||
@@ -56,7 +56,8 @@ class App(db.Model):
|
||||
|
||||
@property
|
||||
def api_base_url(self):
|
||||
return (current_app.config['API_URL'] if current_app.config['API_URL'] else request.host_url.rstrip('/')) + '/v1'
|
||||
return (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
|
||||
else request.host_url.rstrip('/')) + '/v1'
|
||||
|
||||
@property
|
||||
def tenant(self):
|
||||
@@ -515,7 +516,7 @@ class Site(db.Model):
|
||||
|
||||
@property
|
||||
def app_base_url(self):
|
||||
return (current_app.config['APP_URL'] if current_app.config['APP_URL'] else request.host_url.rstrip('/'))
|
||||
return (current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/'))
|
||||
|
||||
|
||||
class ApiToken(db.Model):
|
||||
|
||||
@@ -21,7 +21,7 @@ Authlib==1.2.0
|
||||
boto3~=1.26.123
|
||||
tenacity==8.2.2
|
||||
cachetools~=5.3.0
|
||||
weaviate-client~=3.16.2
|
||||
weaviate-client~=3.21.0
|
||||
qdrant_client~=1.1.6
|
||||
mailchimp-transactional~=1.0.50
|
||||
scikit-learn==1.2.2
|
||||
@@ -33,4 +33,5 @@ openpyxl==3.1.2
|
||||
chardet~=5.1.0
|
||||
docx2txt==0.8
|
||||
pypdfium2==4.16.0
|
||||
pyjwt~=2.6.0
|
||||
resend~=0.5.1
|
||||
pyjwt~=2.6.0
|
||||
|
||||
@@ -2,13 +2,16 @@
|
||||
import base64
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import Optional
|
||||
|
||||
from flask import session
|
||||
from sqlalchemy import func
|
||||
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.errors.account import AccountLoginError, CurrentPasswordIncorrectError, LinkAccountIntegrateError, \
|
||||
TenantNotFound, AccountNotLinkTenantError, InvalidActionError, CannotOperateSelfError, MemberNotInTenantError, \
|
||||
RoleAlreadyAssignedError, NoPermissionError, AccountRegisterError, AccountAlreadyInTenantError
|
||||
@@ -16,6 +19,7 @@ from libs.helper import get_remote_ip
|
||||
from libs.password import compare_password, hash_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import *
|
||||
from tasks.mail_invite_member_task import send_invite_member_mail_task
|
||||
|
||||
|
||||
class AccountService:
|
||||
@@ -48,12 +52,18 @@ class AccountService:
|
||||
@staticmethod
|
||||
def update_account_password(account, password, new_password):
|
||||
"""update account password"""
|
||||
# todo: split validation and update
|
||||
if account.password and not compare_password(password, account.password, account.password_salt):
|
||||
raise CurrentPasswordIncorrectError("Current password is incorrect.")
|
||||
password_hashed = hash_password(new_password, account.password_salt)
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
return account
|
||||
|
||||
@@ -283,8 +293,6 @@ class TenantService:
|
||||
@staticmethod
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
||||
"""Remove member from tenant"""
|
||||
# todo: check permission
|
||||
|
||||
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'):
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
@@ -293,6 +301,12 @@ class TenantService:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
|
||||
db.session.delete(ta)
|
||||
|
||||
account.initialized_at = None
|
||||
account.status = AccountStatus.PENDING.value
|
||||
account.password = None
|
||||
account.password_salt = None
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
@@ -332,8 +346,8 @@ class TenantService:
|
||||
|
||||
class RegisterService:
|
||||
|
||||
@staticmethod
|
||||
def register(email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
|
||||
@classmethod
|
||||
def register(cls, email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
|
||||
db.session.begin_nested()
|
||||
"""Register account"""
|
||||
try:
|
||||
@@ -359,9 +373,9 @@ class RegisterService:
|
||||
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def invite_new_member(tenant: Tenant, email: str, role: str = 'normal',
|
||||
inviter: Account = None) -> TenantAccountJoin:
|
||||
@classmethod
|
||||
def invite_new_member(cls, tenant: Tenant, email: str, role: str = 'normal',
|
||||
inviter: Account = None) -> str:
|
||||
"""Invite new member"""
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
|
||||
@@ -380,5 +394,71 @@ class RegisterService:
|
||||
if ta:
|
||||
raise AccountAlreadyInTenantError("Account already in tenant.")
|
||||
|
||||
ta = TenantService.create_tenant_member(tenant, account, role)
|
||||
return ta
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
|
||||
token = cls.generate_invite_token(tenant, account)
|
||||
|
||||
# send email
|
||||
send_invite_member_mail_task.delay(
|
||||
to=email,
|
||||
token=cls.generate_invite_token(tenant, account),
|
||||
inviter_name=inviter.name if inviter else 'Dify',
|
||||
workspace_id=tenant.id,
|
||||
workspace_name=tenant.name,
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def generate_invite_token(cls, tenant: Tenant, account: Account) -> str:
|
||||
token = str(uuid.uuid4())
|
||||
email_hash = sha256(account.email.encode()).hexdigest()
|
||||
cache_key = 'member_invite_token:{}, {}:{}'.format(str(tenant.id), email_hash, token)
|
||||
redis_client.setex(cache_key, 3600, str(account.id))
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def revoke_token(cls, workspace_id: str, email: str, token: str):
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@classmethod
|
||||
def get_account_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[Account]:
|
||||
tenant = db.session.query(Tenant).filter(
|
||||
Tenant.id == workspace_id,
|
||||
Tenant.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
tenant_account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == email, TenantAccountJoin.tenant_id == tenant.id).first()
|
||||
|
||||
if not tenant_account:
|
||||
return None
|
||||
|
||||
account_id = cls._get_account_id_by_invite_token(workspace_id, email, token)
|
||||
if not account_id:
|
||||
return None
|
||||
|
||||
account = tenant_account[0]
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if account_id != str(account.id):
|
||||
return None
|
||||
|
||||
return account
|
||||
|
||||
@classmethod
|
||||
def _get_account_id_by_invite_token(cls, workspace_id: str, email: str, token: str) -> Optional[str]:
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
|
||||
account_id = redis_client.get(cache_key)
|
||||
if not account_id:
|
||||
return None
|
||||
|
||||
return account_id.decode('utf-8')
|
||||
|
||||
52
api/tasks/mail_invite_member_task.py
Normal file
52
api/tasks/mail_invite_member_task.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from flask import current_app
|
||||
|
||||
from extensions.ext_mail import mail
|
||||
|
||||
|
||||
@shared_task
|
||||
def send_invite_member_mail_task(to: str, token: str, inviter_name: str, workspace_id: str, workspace_name: str):
|
||||
"""
|
||||
Async Send invite member mail
|
||||
:param to
|
||||
:param token
|
||||
:param inviter_name
|
||||
:param workspace_id
|
||||
:param workspace_name
|
||||
|
||||
Usage: send_invite_member_mail_task.delay(to, token, inviter_name, workspace_id, workspace_name)
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logging.info(click.style('Start send invite member mail to {} in workspace {}'.format(to, workspace_name),
|
||||
fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
mail.send(
|
||||
to=to,
|
||||
subject="{} invited you to join {}".format(inviter_name, workspace_name),
|
||||
html="""<p>Hi there,</p>
|
||||
<p>{inviter_name} invited you to join {workspace_name}.</p>
|
||||
<p>Click <a href="{url}">here</a> to join.</p>
|
||||
<p>Thanks,</p>
|
||||
<p>Dify Team</p>""".format(inviter_name=inviter_name, workspace_name=workspace_name,
|
||||
url='{}/activate?workspace_id={}&email={}&token={}'.format(
|
||||
current_app.config.get("CONSOLE_WEB_URL"),
|
||||
workspace_id,
|
||||
to,
|
||||
token)
|
||||
)
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('Send invite member mail to {} succeeded: latency: {}'.format(to, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception:
|
||||
logging.exception("Send invite member mail to {} failed".format(to))
|
||||
Reference in New Issue
Block a user