mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
Initial commit
This commit is contained in:
2
api/services/__init__.py
Normal file
2
api/services/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import services.errors
|
||||
382
api/services/account_service.py
Normal file
382
api/services/account_service.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import base64
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from flask import session
|
||||
from sqlalchemy import func
|
||||
|
||||
from events.tenant_event import tenant_was_created
|
||||
from services.errors.account import AccountLoginError, CurrentPasswordIncorrectError, LinkAccountIntegrateError, \
|
||||
TenantNotFound, AccountNotLinkTenantError, InvalidActionError, CannotOperateSelfError, MemberNotInTenantError, \
|
||||
RoleAlreadyAssignedError, NoPermissionError, AccountRegisterError, AccountAlreadyInTenantError
|
||||
from libs.helper import get_remote_ip
|
||||
from libs.password import compare_password, hash_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import *
|
||||
|
||||
|
||||
class AccountService:
|
||||
|
||||
@staticmethod
|
||||
def load_user(account_id: int) -> Account:
|
||||
# todo: used by flask_login
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def authenticate(email: str, password: str) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
if not account:
|
||||
raise AccountLoginError('Invalid email or password.')
|
||||
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
raise AccountLoginError('Account is banned or closed.')
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
if account.password is None or not compare_password(password, account.password, account.password_salt):
|
||||
raise AccountLoginError('Invalid email or password.')
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_account_password(account, password, new_password):
|
||||
"""update account password"""
|
||||
# todo: split validation and update
|
||||
if account.password and not compare_password(password, account.password, account.password_salt):
|
||||
raise CurrentPasswordIncorrectError("Current password is incorrect.")
|
||||
password_hashed = hash_password(new_password, account.password_salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
db.session.commit()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_account(email: str, name: str, password: str = None,
|
||||
interface_language: str = 'en-US', interface_theme: str = 'light',
|
||||
timezone: str = 'America/New_York', ) -> Account:
|
||||
"""create account"""
|
||||
account = Account()
|
||||
account.email = email
|
||||
account.name = name
|
||||
|
||||
if password:
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
|
||||
account.interface_language = interface_language
|
||||
account.interface_theme = interface_theme
|
||||
|
||||
if interface_language == 'zh-Hans':
|
||||
account.timezone = 'Asia/Shanghai'
|
||||
else:
|
||||
account.timezone = timezone
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
|
||||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id,
|
||||
provider=provider).first()
|
||||
|
||||
if account_integrate:
|
||||
# If it exists, update the record
|
||||
account_integrate.open_id = open_id
|
||||
account_integrate.encrypted_token = "" # todo
|
||||
account_integrate.updated_at = datetime.utcnow()
|
||||
else:
|
||||
# If it does not exist, create a new record
|
||||
account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id,
|
||||
encrypted_token="")
|
||||
db.session.add(account_integrate)
|
||||
|
||||
db.session.commit()
|
||||
logging.info(f'Account {account.id} linked {provider} account {open_id}.')
|
||||
except Exception as e:
|
||||
logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}')
|
||||
raise LinkAccountIntegrateError('Failed to link account.') from e
|
||||
|
||||
@staticmethod
|
||||
def close_account(account: Account) -> None:
|
||||
"""todo: Close account"""
|
||||
account.status = AccountStatus.CLOSED.value
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_account(account, **kwargs):
|
||||
"""Update account fields"""
|
||||
for field, value in kwargs.items():
|
||||
if hasattr(account, field):
|
||||
setattr(account, field, value)
|
||||
else:
|
||||
raise AttributeError(f"Invalid field: {field}")
|
||||
|
||||
db.session.commit()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_last_login(account: Account, request) -> None:
|
||||
"""Update last login time and ip"""
|
||||
account.last_login_at = datetime.utcnow()
|
||||
account.last_login_ip = get_remote_ip(request)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
logging.info(f'Account {account.id} logged in successfully.')
|
||||
|
||||
|
||||
class TenantService:
|
||||
|
||||
@staticmethod
|
||||
def create_tenant(name: str) -> Tenant:
|
||||
"""Create tenant"""
|
||||
tenant = Tenant(name=name)
|
||||
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
db.session.commit()
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin:
|
||||
"""Create tenant member"""
|
||||
if role == TenantAccountJoinRole.OWNER.value:
|
||||
if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]):
|
||||
logging.error(f'Tenant {tenant.id} has already an owner.')
|
||||
raise Exception('Tenant already has an owner.')
|
||||
|
||||
ta = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role
|
||||
)
|
||||
db.session.add(ta)
|
||||
db.session.commit()
|
||||
return ta
|
||||
|
||||
@staticmethod
|
||||
def get_join_tenants(account: Account) -> List[Tenant]:
|
||||
"""Get account join tenants"""
|
||||
return db.session.query(Tenant).join(
|
||||
TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id
|
||||
).filter(TenantAccountJoin.account_id == account.id).all()
|
||||
|
||||
@staticmethod
|
||||
def get_current_tenant_by_account(account: Account):
|
||||
"""Get tenant by account and add the role"""
|
||||
tenant = account.current_tenant
|
||||
if not tenant:
|
||||
raise TenantNotFound("Tenant not found.")
|
||||
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
if ta:
|
||||
tenant.role = ta.role
|
||||
else:
|
||||
raise TenantNotFound("Tenant not found for the account.")
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
def switch_tenant(account: Account, tenant_id: int = None) -> None:
|
||||
"""Switch the current workspace for the account"""
|
||||
if not tenant_id:
|
||||
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id).first()
|
||||
else:
|
||||
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
|
||||
|
||||
# Check if the tenant exists and the account is a member of the tenant
|
||||
if not tenant_account_join:
|
||||
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
||||
|
||||
# Set the current tenant for the account
|
||||
account.current_tenant_id = tenant_account_join.tenant_id
|
||||
session['workspace_id'] = account.current_tenant.id
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_members(tenant: Tenant) -> List[Account]:
|
||||
"""Get tenant members"""
|
||||
query = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
.select_from(Account)
|
||||
.join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
)
|
||||
.filter(TenantAccountJoin.tenant_id == tenant.id)
|
||||
)
|
||||
|
||||
# Initialize an empty list to store the updated accounts
|
||||
updated_accounts = []
|
||||
|
||||
for account, role in query:
|
||||
account.role = role
|
||||
updated_accounts.append(account)
|
||||
|
||||
return updated_accounts
|
||||
|
||||
@staticmethod
|
||||
def has_roles(tenant: Tenant, roles: List[TenantAccountJoinRole]) -> bool:
|
||||
"""Check if user has any of the given roles for a tenant"""
|
||||
if not all(isinstance(role, TenantAccountJoinRole) for role in roles):
|
||||
raise ValueError('all roles must be TenantAccountJoinRole')
|
||||
|
||||
return db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
TenantAccountJoin.role.in_([role.value for role in roles])
|
||||
).first() is not None
|
||||
|
||||
@staticmethod
|
||||
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]:
|
||||
"""Get the role of the current account for a given tenant"""
|
||||
join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
TenantAccountJoin.account_id == account.id
|
||||
).first()
|
||||
return join.role if join else None
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_count() -> int:
|
||||
"""Get tenant count"""
|
||||
return db.session.query(func.count(Tenant.id)).scalar()
|
||||
|
||||
@staticmethod
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None:
|
||||
"""Check member permission"""
|
||||
perms = {
|
||||
'add': ['owner', 'admin'],
|
||||
'remove': ['owner'],
|
||||
'update': ['owner']
|
||||
}
|
||||
if action not in ['add', 'remove', 'update']:
|
||||
raise InvalidActionError("Invalid action.")
|
||||
|
||||
if operator.id == member.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
ta_operator = TenantAccountJoin.query.filter_by(
|
||||
tenant_id=tenant.id,
|
||||
account_id=operator.id
|
||||
).first()
|
||||
|
||||
if not ta_operator or ta_operator.role not in perms[action]:
|
||||
raise NoPermissionError(f'No permission to {action} member.')
|
||||
|
||||
@staticmethod
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
||||
"""Remove member from tenant"""
|
||||
# todo: check permission
|
||||
|
||||
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'):
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
if not ta:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
|
||||
db.session.delete(ta)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
|
||||
"""Update member role"""
|
||||
TenantService.check_member_permission(tenant, operator, member, 'update')
|
||||
|
||||
target_member_join = TenantAccountJoin.query.filter_by(
|
||||
tenant_id=tenant.id,
|
||||
account_id=member.id
|
||||
).first()
|
||||
|
||||
if target_member_join.role == new_role:
|
||||
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
|
||||
|
||||
if new_role == 'owner':
|
||||
# Find the current owner and change their role to 'admin'
|
||||
current_owner_join = TenantAccountJoin.query.filter_by(
|
||||
tenant_id=tenant.id,
|
||||
role='owner'
|
||||
).first()
|
||||
current_owner_join.role = 'admin'
|
||||
|
||||
# Update the role of the target member
|
||||
target_member_join.role = new_role
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
|
||||
"""Dissolve tenant"""
|
||||
if not TenantService.check_member_permission(tenant, operator, operator, 'remove'):
|
||||
raise NoPermissionError('No permission to dissolve tenant.')
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
|
||||
db.session.delete(tenant)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
class RegisterService:
|
||||
|
||||
@staticmethod
|
||||
def register(email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
|
||||
db.session.begin_nested()
|
||||
"""Register account"""
|
||||
try:
|
||||
account = AccountService.create_account(email, name, password)
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.utcnow()
|
||||
|
||||
if open_id is not None or provider is not None:
|
||||
AccountService.link_account_integrate(provider, open_id, account)
|
||||
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
|
||||
TenantService.create_tenant_member(tenant, account, role='owner')
|
||||
account.current_tenant = tenant
|
||||
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.rollback() # todo: do not work
|
||||
logging.error(f'Register failed: {e}')
|
||||
raise AccountRegisterError(f'Registration failed: {e}') from e
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def invite_new_member(tenant: Tenant, email: str, role: str = 'normal',
|
||||
inviter: Account = None) -> TenantAccountJoin:
|
||||
"""Invite new member"""
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
|
||||
if not account:
|
||||
name = email.split('@')[0]
|
||||
account = AccountService.create_account(email, name)
|
||||
account.status = AccountStatus.PENDING.value
|
||||
db.session.commit()
|
||||
else:
|
||||
TenantService.check_member_permission(tenant, inviter, account, 'add')
|
||||
ta = TenantAccountJoin.query.filter_by(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id
|
||||
).first()
|
||||
if ta:
|
||||
raise AccountAlreadyInTenantError("Account already in tenant.")
|
||||
|
||||
ta = TenantService.create_tenant_member(tenant, account, role)
|
||||
return ta
|
||||
292
api/services/app_model_config_service.py
Normal file
292
api/services/app_model_config_service.py
Normal file
@@ -0,0 +1,292 @@
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from core.constant import llm_constant
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
||||
class AppModelConfigService:
|
||||
@staticmethod
|
||||
def is_dataset_exists(account: Account, dataset_id: str) -> bool:
|
||||
# verify if the dataset ID exists
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
if not dataset:
|
||||
return False
|
||||
|
||||
if dataset.tenant_id != account.current_tenant_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_model_completion_params(cp: dict, model_name: str) -> dict:
|
||||
# 6. model.completion_params
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
# max_tokens
|
||||
if 'max_tokens' not in cp:
|
||||
cp["max_tokens"] = 512
|
||||
|
||||
if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
|
||||
llm_constant.max_context_token_length[model_name]:
|
||||
raise ValueError(
|
||||
"max_tokens must be an integer greater than 0 and not exceeding the maximum value of the corresponding model")
|
||||
|
||||
# temperature
|
||||
if 'temperature' not in cp:
|
||||
cp["temperature"] = 1
|
||||
|
||||
if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
|
||||
raise ValueError("temperature must be a float between 0 and 2")
|
||||
|
||||
# top_p
|
||||
if 'top_p' not in cp:
|
||||
cp["top_p"] = 1
|
||||
|
||||
if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
|
||||
raise ValueError("top_p must be a float between 0 and 2")
|
||||
|
||||
# presence_penalty
|
||||
if 'presence_penalty' not in cp:
|
||||
cp["presence_penalty"] = 0
|
||||
|
||||
if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
|
||||
raise ValueError("presence_penalty must be a float between -2 and 2")
|
||||
|
||||
# presence_penalty
|
||||
if 'frequency_penalty' not in cp:
|
||||
cp["frequency_penalty"] = 0
|
||||
|
||||
if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
|
||||
raise ValueError("frequency_penalty must be a float between -2 and 2")
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_cp = {
|
||||
"max_tokens": cp["max_tokens"],
|
||||
"temperature": cp["temperature"],
|
||||
"top_p": cp["top_p"],
|
||||
"presence_penalty": cp["presence_penalty"],
|
||||
"frequency_penalty": cp["frequency_penalty"]
|
||||
}
|
||||
|
||||
return filtered_cp
|
||||
|
||||
@staticmethod
|
||||
def validate_configuration(account: Account, config: dict, mode: str) -> dict:
|
||||
# opening_statement
|
||||
if 'opening_statement' not in config or not config["opening_statement"]:
|
||||
config["opening_statement"] = ""
|
||||
|
||||
if not isinstance(config["opening_statement"], str):
|
||||
raise ValueError("opening_statement must be of string type")
|
||||
|
||||
# suggested_questions
|
||||
if 'suggested_questions' not in config or not config["suggested_questions"]:
|
||||
config["suggested_questions"] = []
|
||||
|
||||
if not isinstance(config["suggested_questions"], list):
|
||||
raise ValueError("suggested_questions must be of list type")
|
||||
|
||||
for question in config["suggested_questions"]:
|
||||
if not isinstance(question, str):
|
||||
raise ValueError("Elements in suggested_questions list must be of string type")
|
||||
|
||||
# suggested_questions_after_answer
|
||||
if 'suggested_questions_after_answer' not in config or not config["suggested_questions_after_answer"]:
|
||||
config["suggested_questions_after_answer"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"], dict):
|
||||
raise ValueError("suggested_questions_after_answer must be of dict type")
|
||||
|
||||
if "enabled" not in config["suggested_questions_after_answer"] or not config["suggested_questions_after_answer"]["enabled"]:
|
||||
config["suggested_questions_after_answer"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
||||
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
|
||||
|
||||
# more_like_this
|
||||
if 'more_like_this' not in config or not config["more_like_this"]:
|
||||
config["more_like_this"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["more_like_this"], dict):
|
||||
raise ValueError("more_like_this must be of dict type")
|
||||
|
||||
if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]:
|
||||
config["more_like_this"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["more_like_this"]["enabled"], bool):
|
||||
raise ValueError("enabled in more_like_this must be of boolean type")
|
||||
|
||||
# model
|
||||
if 'model' not in config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
if not isinstance(config["model"], dict):
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] != "openai":
|
||||
raise ValueError("model.provider must be 'openai'")
|
||||
|
||||
# model.name
|
||||
if 'name' not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
if config["model"]["name"] not in llm_constant.models_by_mode[mode]:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
# model.completion_params
|
||||
if 'completion_params' not in config["model"]:
|
||||
raise ValueError("model.completion_params is required")
|
||||
|
||||
config["model"]["completion_params"] = AppModelConfigService.validate_model_completion_params(
|
||||
config["model"]["completion_params"],
|
||||
config["model"]["name"]
|
||||
)
|
||||
|
||||
# user_input_form
|
||||
if "user_input_form" not in config or not config["user_input_form"]:
|
||||
config["user_input_form"] = []
|
||||
|
||||
if not isinstance(config["user_input_form"], list):
|
||||
raise ValueError("user_input_form must be a list of objects")
|
||||
|
||||
variables = []
|
||||
for item in config["user_input_form"]:
|
||||
key = list(item.keys())[0]
|
||||
if key not in ["text-input", "select"]:
|
||||
raise ValueError("Keys in user_input_form list can only be 'text-input' or 'select'")
|
||||
|
||||
form_item = item[key]
|
||||
if 'label' not in form_item:
|
||||
raise ValueError("label is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["label"], str):
|
||||
raise ValueError("label in user_input_form must be of string type")
|
||||
|
||||
if 'variable' not in form_item:
|
||||
raise ValueError("variable is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["variable"], str):
|
||||
raise ValueError("variable in user_input_form must be of string type")
|
||||
|
||||
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
|
||||
if pattern.match(form_item["variable"]) is None:
|
||||
raise ValueError("variable in user_input_form must be a string, "
|
||||
"and cannot start with a number")
|
||||
|
||||
variables.append(form_item["variable"])
|
||||
|
||||
if 'required' not in form_item or not form_item["required"]:
|
||||
form_item["required"] = False
|
||||
|
||||
if not isinstance(form_item["required"], bool):
|
||||
raise ValueError("required in user_input_form must be of boolean type")
|
||||
|
||||
if key == "select":
|
||||
if 'options' not in form_item or not form_item["options"]:
|
||||
form_item["options"] = []
|
||||
|
||||
if not isinstance(form_item["options"], list):
|
||||
raise ValueError("options in user_input_form must be a list of strings")
|
||||
|
||||
if "default" in form_item and form_item['default'] \
|
||||
and form_item["default"] not in form_item["options"]:
|
||||
raise ValueError("default value in user_input_form must be in the options list")
|
||||
|
||||
# pre_prompt
|
||||
if "pre_prompt" not in config or not config["pre_prompt"]:
|
||||
config["pre_prompt"] = ""
|
||||
|
||||
if not isinstance(config["pre_prompt"], str):
|
||||
raise ValueError("pre_prompt must be of string type")
|
||||
|
||||
template_vars = re.findall(r"\{\{(\w+)\}\}", config["pre_prompt"])
|
||||
for var in template_vars:
|
||||
if var not in variables:
|
||||
raise ValueError("Template variables in pre_prompt must be defined in user_input_form")
|
||||
|
||||
# agent_mode
|
||||
if "agent_mode" not in config or not config["agent_mode"]:
|
||||
config["agent_mode"] = {
|
||||
"enabled": False,
|
||||
"tools": []
|
||||
}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
|
||||
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
|
||||
config["agent_mode"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["agent_mode"]["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||
|
||||
if "tools" not in config["agent_mode"] or not config["agent_mode"]["tools"]:
|
||||
config["agent_mode"]["tools"] = []
|
||||
|
||||
if not isinstance(config["agent_mode"]["tools"], list):
|
||||
raise ValueError("tools in agent_mode must be a list of objects")
|
||||
|
||||
for tool in config["agent_mode"]["tools"]:
|
||||
key = list(tool.keys())[0]
|
||||
if key not in ["sensitive-word-avoidance", "dataset"]:
|
||||
raise ValueError("Keys in agent_mode.tools list can only be 'sensitive-word-avoidance' or 'dataset'")
|
||||
|
||||
tool_item = tool[key]
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
tool_item["enabled"] = False
|
||||
|
||||
if not isinstance(tool_item["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
||||
|
||||
if key == "sensitive-word-avoidance":
|
||||
if "words" not in tool_item or not tool_item["words"]:
|
||||
tool_item["words"] = ""
|
||||
|
||||
if not isinstance(tool_item["words"], str):
|
||||
raise ValueError("words in sensitive-word-avoidance must be of string type")
|
||||
|
||||
if "canned_response" not in tool_item or not tool_item["canned_response"]:
|
||||
tool_item["canned_response"] = ""
|
||||
|
||||
if not isinstance(tool_item["canned_response"], str):
|
||||
raise ValueError("canned_response in sensitive-word-avoidance must be of string type")
|
||||
elif key == "dataset":
|
||||
if 'id' not in tool_item:
|
||||
raise ValueError("id is required in dataset")
|
||||
|
||||
try:
|
||||
uuid.UUID(tool_item["id"])
|
||||
except ValueError:
|
||||
raise ValueError("id in dataset must be of UUID type")
|
||||
|
||||
if not AppModelConfigService.is_dataset_exists(account, tool_item["id"]):
|
||||
raise ValueError("Dataset ID does not exist, please check your permission.")
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {
|
||||
"opening_statement": config["opening_statement"],
|
||||
"suggested_questions": config["suggested_questions"],
|
||||
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
|
||||
"more_like_this": config["more_like_this"],
|
||||
"model": {
|
||||
"provider": config["model"]["provider"],
|
||||
"name": config["model"]["name"],
|
||||
"completion_params": config["model"]["completion_params"]
|
||||
},
|
||||
"user_input_form": config["user_input_form"],
|
||||
"pre_prompt": config["pre_prompt"],
|
||||
"agent_mode": config["agent_mode"]
|
||||
}
|
||||
|
||||
return filtered_config
|
||||
506
api/services/completion_service.py
Normal file
506
api/services/completion_service.py
Normal file
@@ -0,0 +1,506 @@
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Generator, Union, Any
|
||||
|
||||
from flask import current_app, Flask
|
||||
from redis.client import PubSub
|
||||
from sqlalchemy import and_
|
||||
|
||||
from core.completion import Completion
|
||||
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \
|
||||
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.completion import CompletionStoppedError
|
||||
from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
class CompletionService:
|
||||
|
||||
@classmethod
|
||||
def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any,
|
||||
from_source: str, streaming: bool = True,
|
||||
is_model_config_override: bool = False) -> Union[dict | Generator]:
|
||||
# is streaming mode
|
||||
inputs = args['inputs']
|
||||
query = args['query']
|
||||
conversation_id = args['conversation_id'] if 'conversation_id' in args else None
|
||||
|
||||
conversation = None
|
||||
if conversation_id:
|
||||
conversation_filter = [
|
||||
Conversation.id == args['conversation_id'],
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.status == 'normal'
|
||||
]
|
||||
|
||||
if from_source == 'console':
|
||||
conversation_filter.append(Conversation.from_account_id == user.id)
|
||||
else:
|
||||
conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
|
||||
|
||||
conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
if conversation.status != 'normal':
|
||||
raise ConversationCompletedError()
|
||||
|
||||
if not conversation.override_model_configs:
|
||||
app_model_config = db.session.query(AppModelConfig).get(conversation.app_model_config_id)
|
||||
|
||||
if not app_model_config:
|
||||
raise AppModelConfigBrokenError()
|
||||
else:
|
||||
conversation_override_model_configs = json.loads(conversation.override_model_configs)
|
||||
app_model_config = AppModelConfig(
|
||||
id=conversation.app_model_config_id,
|
||||
app_id=app_model.id,
|
||||
provider="",
|
||||
model_id="",
|
||||
configs="",
|
||||
opening_statement=conversation_override_model_configs['opening_statement'],
|
||||
suggested_questions=json.dumps(conversation_override_model_configs['suggested_questions']),
|
||||
model=json.dumps(conversation_override_model_configs['model']),
|
||||
user_input_form=json.dumps(conversation_override_model_configs['user_input_form']),
|
||||
pre_prompt=conversation_override_model_configs['pre_prompt'],
|
||||
agent_mode=json.dumps(conversation_override_model_configs['agent_mode']),
|
||||
)
|
||||
|
||||
if is_model_config_override:
|
||||
# build new app model config
|
||||
if 'model' not in args['model_config']:
|
||||
raise ValueError('model_config.model is required')
|
||||
|
||||
if 'completion_params' not in args['model_config']['model']:
|
||||
raise ValueError('model_config.model.completion_params is required')
|
||||
|
||||
completion_params = AppModelConfigService.validate_model_completion_params(
|
||||
cp=args['model_config']['model']['completion_params'],
|
||||
model_name=app_model_config.model_dict["name"]
|
||||
)
|
||||
|
||||
app_model_config_model = app_model_config.model_dict
|
||||
app_model_config_model['completion_params'] = completion_params
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
id=app_model_config.id,
|
||||
app_id=app_model.id,
|
||||
provider="",
|
||||
model_id="",
|
||||
configs="",
|
||||
opening_statement=app_model_config.opening_statement,
|
||||
suggested_questions=app_model_config.suggested_questions,
|
||||
model=json.dumps(app_model_config_model),
|
||||
user_input_form=app_model_config.user_input_form,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
agent_mode=app_model_config.agent_mode,
|
||||
)
|
||||
else:
|
||||
if app_model.app_model_config_id is None:
|
||||
raise AppModelConfigBrokenError()
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
if not app_model_config:
|
||||
raise AppModelConfigBrokenError()
|
||||
|
||||
if is_model_config_override:
|
||||
if not isinstance(user, Account):
|
||||
raise Exception("Only account can override model config")
|
||||
|
||||
# validate config
|
||||
model_config = AppModelConfigService.validate_configuration(
|
||||
account=user,
|
||||
config=args['model_config'],
|
||||
mode=app_model.mode
|
||||
)
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
id=app_model_config.id,
|
||||
app_id=app_model.id,
|
||||
provider="",
|
||||
model_id="",
|
||||
configs="",
|
||||
opening_statement=model_config['opening_statement'],
|
||||
suggested_questions=json.dumps(model_config['suggested_questions']),
|
||||
suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']),
|
||||
more_like_this=json.dumps(model_config['more_like_this']),
|
||||
model=json.dumps(model_config['model']),
|
||||
user_input_form=json.dumps(model_config['user_input_form']),
|
||||
pre_prompt=model_config['pre_prompt'],
|
||||
agent_mode=json.dumps(model_config['agent_mode']),
|
||||
)
|
||||
|
||||
# clean input by app_model_config form rules
|
||||
inputs = cls.get_cleaned_inputs(inputs, app_model_config)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
|
||||
|
||||
user = cls.get_real_user_instead_of_proxy_obj(user)
|
||||
|
||||
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'generate_task_id': generate_task_id,
|
||||
'app_model': app_model,
|
||||
'app_model_config': app_model_config,
|
||||
'query': query,
|
||||
'inputs': inputs,
|
||||
'user': user,
|
||||
'conversation': conversation,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': is_model_config_override
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
||||
# wait for 5 minutes to close the thread
|
||||
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
|
||||
@classmethod
|
||||
def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
|
||||
if isinstance(user, Account):
|
||||
user = db.session.query(Account).get(user.id)
|
||||
elif isinstance(user, EndUser):
|
||||
user = db.session.query(EndUser).get(user.id)
|
||||
else:
|
||||
raise Exception("Unknown user type")
|
||||
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, user: Union[Account, EndUser],
|
||||
conversation: Conversation, streaming: bool, is_model_config_override: bool):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
if conversation:
|
||||
# fixed the state of the conversation object when it detached from the original session
|
||||
conversation = db.session.query(Conversation).filter_by(id=conversation.id).first()
|
||||
|
||||
# run
|
||||
Completion.generate(
|
||||
task_id=generate_task_id,
|
||||
app=app_model,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
streaming=streaming,
|
||||
is_override=is_model_config_override,
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
pass
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
|
||||
ModelCurrentlyNotSupportError) as e:
|
||||
db.session.rollback()
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
except LLMAuthorizationError:
|
||||
db.session.rollback()
|
||||
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logging.exception("Unknown Error in completion")
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
|
||||
@classmethod
|
||||
def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
|
||||
# wait for 5 minutes to close the thread
|
||||
timeout = 300
|
||||
|
||||
def close_pubsub():
|
||||
sleep_iterations = 0
|
||||
while sleep_iterations < timeout and worker_thread.is_alive():
|
||||
time.sleep(1)
|
||||
sleep_iterations += 1
|
||||
|
||||
if worker_thread.is_alive():
|
||||
PubHandler.stop(user, generate_task_id)
|
||||
try:
|
||||
pubsub.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
countdown_thread = threading.Thread(target=close_pubsub)
|
||||
countdown_thread.start()
|
||||
|
||||
return countdown_thread
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
|
||||
message_id: str, streaming: bool = True) -> Union[dict | Generator]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
current_app_model_config = app_model.app_model_config
|
||||
more_like_this = current_app_model_config.more_like_this_dict
|
||||
|
||||
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
|
||||
raise MoreLikeThisDisabledError()
|
||||
|
||||
app_model_config = message.app_model_config
|
||||
|
||||
if message.override_model_configs:
|
||||
override_model_configs = json.loads(message.override_model_configs)
|
||||
pre_prompt = override_model_configs.get("pre_prompt", '')
|
||||
elif app_model_config:
|
||||
pre_prompt = app_model_config.pre_prompt
|
||||
else:
|
||||
raise AppModelConfigBrokenError()
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
|
||||
|
||||
user = cls.get_real_user_instead_of_proxy_obj(user)
|
||||
|
||||
generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'generate_task_id': generate_task_id,
|
||||
'app_model': app_model,
|
||||
'app_model_config': app_model_config,
|
||||
'message': message,
|
||||
'pre_prompt': pre_prompt,
|
||||
'user': user,
|
||||
'streaming': streaming
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
||||
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App,
|
||||
app_model_config: AppModelConfig, message: Message, pre_prompt: str,
|
||||
user: Union[Account, EndUser], streaming: bool):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# run
|
||||
Completion.generate_more_like_this(
|
||||
task_id=generate_task_id,
|
||||
app=app_model,
|
||||
user=user,
|
||||
message=message,
|
||||
pre_prompt=pre_prompt,
|
||||
app_model_config=app_model_config,
|
||||
streaming=streaming
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
pass
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
|
||||
ModelCurrentlyNotSupportError) as e:
|
||||
db.session.rollback()
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
except LLMAuthorizationError:
|
||||
db.session.rollback()
|
||||
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logging.exception("Unknown Error in completion")
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
|
||||
@classmethod
|
||||
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
||||
if user_inputs is None:
|
||||
user_inputs = {}
|
||||
|
||||
filtered_inputs = {}
|
||||
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
input_form_config = app_model_config.user_input_form_list
|
||||
for config in input_form_config:
|
||||
input_config = list(config.values())[0]
|
||||
variable = input_config["variable"]
|
||||
|
||||
input_type = list(config.keys())[0]
|
||||
|
||||
if variable not in user_inputs or not user_inputs[variable]:
|
||||
if "required" in input_config and input_config["required"]:
|
||||
raise ValueError(f"{variable} is required in input form")
|
||||
else:
|
||||
filtered_inputs[variable] = input_config["default"] if "default" in input_config else ""
|
||||
continue
|
||||
|
||||
value = user_inputs[variable]
|
||||
|
||||
if input_type == "select":
|
||||
options = input_config["options"] if "options" in input_config else []
|
||||
if value not in options:
|
||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
||||
else:
|
||||
if 'max_length' in variable:
|
||||
max_length = variable['max_length']
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
||||
|
||||
filtered_inputs[variable] = value
|
||||
|
||||
return filtered_inputs
|
||||
|
||||
@classmethod
|
||||
def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict | Generator]:
|
||||
generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
|
||||
if not streaming:
|
||||
try:
|
||||
for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
result = message["data"].decode('utf-8')
|
||||
result = json.loads(result)
|
||||
if result.get('error'):
|
||||
cls.handle_error(result)
|
||||
|
||||
return cls.get_message_response_data(result.get('data'))
|
||||
except ValueError as e:
|
||||
if e.args[0] != "I/O operation on closed file.": # ignore this error
|
||||
raise CompletionStoppedError()
|
||||
else:
|
||||
logging.exception(e)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
pubsub.unsubscribe(generate_channel)
|
||||
except ConnectionError:
|
||||
pass
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
result = message["data"].decode('utf-8')
|
||||
result = json.loads(result)
|
||||
if result.get('error'):
|
||||
cls.handle_error(result)
|
||||
|
||||
event = result.get('event')
|
||||
if event == "end":
|
||||
logging.debug("{} finished".format(generate_channel))
|
||||
break
|
||||
|
||||
if event == 'message':
|
||||
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'chain':
|
||||
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'agent_thought':
|
||||
yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
|
||||
except ValueError as e:
|
||||
if e.args[0] != "I/O operation on closed file.": # ignore this error
|
||||
logging.exception(e)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
pubsub.unsubscribe(generate_channel)
|
||||
except ConnectionError:
|
||||
pass
|
||||
|
||||
return generate()
|
||||
|
||||
@classmethod
|
||||
def get_message_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'message',
|
||||
'task_id': data.get('task_id'),
|
||||
'id': data.get('message_id'),
|
||||
'answer': data.get('text'),
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_chain_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'chain',
|
||||
'id': data.get('chain_id'),
|
||||
'task_id': data.get('task_id'),
|
||||
'message_id': data.get('message_id'),
|
||||
'type': data.get('type'),
|
||||
'input': data.get('input'),
|
||||
'output': data.get('output'),
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_agent_thought_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'agent_thought',
|
||||
'id': data.get('agent_thought_id'),
|
||||
'chain_id': data.get('chain_id'),
|
||||
'task_id': data.get('task_id'),
|
||||
'message_id': data.get('message_id'),
|
||||
'position': data.get('position'),
|
||||
'thought': data.get('thought'),
|
||||
'tool': data.get('tool'), # todo use real dataset obj replace it
|
||||
'tool_input': data.get('tool_input'),
|
||||
'observation': data.get('observation'),
|
||||
'answer': data.get('answer') if not data.get('thought') else '',
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def handle_error(cls, result: dict):
|
||||
logging.debug("error: %s", result)
|
||||
error = result.get('error')
|
||||
description = result.get('description')
|
||||
|
||||
# handle errors
|
||||
llm_errors = {
|
||||
'LLMBadRequestError': LLMBadRequestError,
|
||||
'LLMAPIConnectionError': LLMAPIConnectionError,
|
||||
'LLMAPIUnavailableError': LLMAPIUnavailableError,
|
||||
'LLMRateLimitError': LLMRateLimitError,
|
||||
'ProviderTokenNotInitError': ProviderTokenNotInitError,
|
||||
'QuotaExceededError': QuotaExceededError,
|
||||
'ModelCurrentlyNotSupportError': ModelCurrentlyNotSupportError
|
||||
}
|
||||
|
||||
if error in llm_errors:
|
||||
raise llm_errors[error](description)
|
||||
elif error == 'LLMAuthorizationError':
|
||||
raise LLMAuthorizationError('Incorrect API key provided')
|
||||
else:
|
||||
raise Exception(description)
|
||||
94
api/services/conversation_service.py
Normal file
94
api/services/conversation_service.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from typing import Union, Optional
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import Conversation, App, EndUser
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
|
||||
|
||||
class ConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
|
||||
last_id: Optional[str], limit: int,
|
||||
include_ids: Optional[list] = None, exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
base_query = db.session.query(Conversation).filter(
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
|
||||
if include_ids is not None:
|
||||
base_query = base_query.filter(Conversation.id.in_(include_ids))
|
||||
|
||||
if exclude_ids is not None:
|
||||
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
|
||||
|
||||
if last_id:
|
||||
last_conversation = base_query.filter(
|
||||
Conversation.id == last_id,
|
||||
).first()
|
||||
|
||||
if not last_conversation:
|
||||
raise LastConversationNotExistsError()
|
||||
|
||||
conversations = base_query.filter(
|
||||
Conversation.created_at < last_conversation.created_at,
|
||||
Conversation.id != last_conversation.id
|
||||
).order_by(Conversation.created_at.desc()).limit(limit).all()
|
||||
else:
|
||||
conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all()
|
||||
|
||||
has_more = False
|
||||
if len(conversations) == limit:
|
||||
current_page_first_conversation = conversations[-1]
|
||||
rest_count = base_query.filter(
|
||||
Conversation.created_at < current_page_first_conversation.created_at,
|
||||
Conversation.id != current_page_first_conversation.id
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=conversations,
|
||||
limit=limit,
|
||||
has_more=has_more
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def rename(cls, app_model: App, conversation_id: str,
|
||||
user: Optional[Union[Account | EndUser]], name: str):
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
conversation.name = name
|
||||
db.session.commit()
|
||||
|
||||
return conversation
|
||||
|
||||
@classmethod
|
||||
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account | EndUser]]):
|
||||
conversation = db.session.query(Conversation) \
|
||||
.filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
return conversation
|
||||
|
||||
@classmethod
|
||||
def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account | EndUser]]):
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
db.session.delete(conversation)
|
||||
db.session.commit()
|
||||
521
api/services/dataset_service.py
Normal file
521
api/services/dataset_service.py
Normal file
@@ -0,0 +1,521 @@
|
||||
import json
|
||||
import logging
|
||||
import datetime
|
||||
import time
|
||||
import random
|
||||
from typing import Optional
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask_login import current_user
|
||||
|
||||
from core.index.index_builder import IndexBuilder
|
||||
from events.dataset_event import dataset_was_deleted
|
||||
from events.document_event import document_was_deleted
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin
|
||||
from models.model import UploadFile
|
||||
from services.errors.account import NoPermissionError
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.errors.document import DocumentIndexingError
|
||||
from services.errors.file import FileNotExistsError
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
|
||||
|
||||
class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None):
|
||||
if user:
|
||||
permission_filter = db.or_(Dataset.created_by == user.id,
|
||||
Dataset.permission == 'all_team_members')
|
||||
else:
|
||||
permission_filter = Dataset.permission == 'all_team_members'
|
||||
datasets = Dataset.query.filter(
|
||||
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
|
||||
.paginate(
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
max_per_page=100,
|
||||
error_out=False
|
||||
)
|
||||
|
||||
return datasets.items, datasets.total
|
||||
|
||||
@staticmethod
|
||||
def get_process_rules(dataset_id):
|
||||
# get the latest process rule
|
||||
dataset_process_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.dataset_id == dataset_id). \
|
||||
order_by(DatasetProcessRule.created_at.desc()). \
|
||||
limit(1). \
|
||||
one_or_none()
|
||||
if dataset_process_rule:
|
||||
mode = dataset_process_rule.mode
|
||||
rules = dataset_process_rule.rules_dict
|
||||
else:
|
||||
mode = DocumentService.DEFAULT_RULES['mode']
|
||||
rules = DocumentService.DEFAULT_RULES['rules']
|
||||
return {
|
||||
'mode': mode,
|
||||
'rules': rules
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_datasets_by_ids(ids, tenant_id):
|
||||
datasets = Dataset.query.filter(Dataset.id.in_(ids),
|
||||
Dataset.tenant_id == tenant_id).paginate(
|
||||
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
|
||||
return datasets.items, datasets.total
|
||||
|
||||
@staticmethod
|
||||
def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional[str], account: Account):
|
||||
# check if dataset name already exists
|
||||
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(
|
||||
f'Dataset with name {name} already exists.')
|
||||
|
||||
dataset = Dataset(name=name, indexing_technique=indexing_technique, data_source_type='upload_file')
|
||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||
dataset.created_by = account.id
|
||||
dataset.updated_by = account.id
|
||||
dataset.tenant_id = tenant_id
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id):
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if dataset is None:
|
||||
return None
|
||||
else:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def update_dataset(dataset_id, data, user):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
|
||||
|
||||
filtered_data['updated_by'] = user.id
|
||||
filtered_data['updated_at'] = datetime.datetime.now()
|
||||
|
||||
dataset.query.filter_by(id=dataset_id).update(filtered_data)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def delete_dataset(dataset_id, user):
|
||||
# todo: cannot delete dataset if it is being processed
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
if dataset is None:
|
||||
return False
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
dataset_was_deleted.send(dataset)
|
||||
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_permission(dataset, user):
|
||||
if dataset.tenant_id != user.current_tenant_id:
|
||||
logging.debug(
|
||||
f'User {user.id} does not have permission to access dataset {dataset.id}')
|
||||
raise NoPermissionError(
|
||||
'You do not have permission to access this dataset.')
|
||||
if dataset.permission == 'only_me' and dataset.created_by != user.id:
|
||||
logging.debug(
|
||||
f'User {user.id} does not have permission to access dataset {dataset.id}')
|
||||
raise NoPermissionError(
|
||||
'You do not have permission to access this dataset.')
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
|
||||
dataset_queries = DatasetQuery.query.filter_by(dataset_id=dataset_id) \
|
||||
.order_by(db.desc(DatasetQuery.created_at)) \
|
||||
.paginate(
|
||||
page=page, per_page=per_page, max_per_page=100, error_out=False
|
||||
)
|
||||
return dataset_queries.items, dataset_queries.total
|
||||
|
||||
@staticmethod
|
||||
def get_related_apps(dataset_id: str):
|
||||
return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
|
||||
.order_by(db.desc(AppDatasetJoin.created_at)).all()
|
||||
|
||||
|
||||
class DocumentService:
|
||||
DEFAULT_RULES = {
|
||||
'mode': 'custom',
|
||||
'rules': {
|
||||
'pre_processing_rules': [
|
||||
{'id': 'remove_extra_spaces', 'enabled': True},
|
||||
{'id': 'remove_urls_emails', 'enabled': False}
|
||||
],
|
||||
'segmentation': {
|
||||
'delimiter': '\n',
|
||||
'max_tokens': 500
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DOCUMENT_METADATA_SCHEMA = {
|
||||
"book": {
|
||||
"title": str,
|
||||
"language": str,
|
||||
"author": str,
|
||||
"publisher": str,
|
||||
"publication_date": str,
|
||||
"isbn": str,
|
||||
"category": str,
|
||||
},
|
||||
"web_page": {
|
||||
"title": str,
|
||||
"url": str,
|
||||
"language": str,
|
||||
"publish_date": str,
|
||||
"author/publisher": str,
|
||||
"topic/keywords": str,
|
||||
"description": str,
|
||||
},
|
||||
"paper": {
|
||||
"title": str,
|
||||
"language": str,
|
||||
"author": str,
|
||||
"publish_date": str,
|
||||
"journal/conference_name": str,
|
||||
"volume/issue/page_numbers": str,
|
||||
"doi": str,
|
||||
"topic/keywords": str,
|
||||
"abstract": str,
|
||||
},
|
||||
"social_media_post": {
|
||||
"platform": str,
|
||||
"author/username": str,
|
||||
"publish_date": str,
|
||||
"post_url": str,
|
||||
"topic/tags": str,
|
||||
},
|
||||
"wikipedia_entry": {
|
||||
"title": str,
|
||||
"language": str,
|
||||
"web_page_url": str,
|
||||
"last_edit_date": str,
|
||||
"editor/contributor": str,
|
||||
"summary/introduction": str,
|
||||
},
|
||||
"personal_document": {
|
||||
"title": str,
|
||||
"author": str,
|
||||
"creation_date": str,
|
||||
"last_modified_date": str,
|
||||
"document_type": str,
|
||||
"tags/category": str,
|
||||
},
|
||||
"business_document": {
|
||||
"title": str,
|
||||
"author": str,
|
||||
"creation_date": str,
|
||||
"last_modified_date": str,
|
||||
"document_type": str,
|
||||
"department/team": str,
|
||||
},
|
||||
"im_chat_log": {
|
||||
"chat_platform": str,
|
||||
"chat_participants/group_name": str,
|
||||
"start_date": str,
|
||||
"end_date": str,
|
||||
"summary": str,
|
||||
},
|
||||
"synced_from_notion": {
|
||||
"title": str,
|
||||
"language": str,
|
||||
"author/creator": str,
|
||||
"creation_date": str,
|
||||
"last_modified_date": str,
|
||||
"notion_page_link": str,
|
||||
"category/tags": str,
|
||||
"description": str,
|
||||
},
|
||||
"synced_from_github": {
|
||||
"repository_name": str,
|
||||
"repository_description": str,
|
||||
"repository_owner/organization": str,
|
||||
"code_filename": str,
|
||||
"code_file_path": str,
|
||||
"programming_language": str,
|
||||
"github_link": str,
|
||||
"open_source_license": str,
|
||||
"commit_date": str,
|
||||
"commit_author": str
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_document(dataset_id: str, document_id: str) -> Optional[Document]:
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def get_document_file_detail(file_id: str):
|
||||
file_detail = db.session.query(UploadFile). \
|
||||
filter(UploadFile.id == file_id). \
|
||||
one_or_none()
|
||||
return file_detail
|
||||
|
||||
@staticmethod
|
||||
def check_archived(document):
|
||||
if document.archived:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def delete_document(document):
|
||||
if document.indexing_status in ["parsing", "cleaning", "splitting", "indexing"]:
|
||||
raise DocumentIndexingError()
|
||||
|
||||
# trigger document_was_deleted signal
|
||||
document_was_deleted.send(document.id, dataset_id=document.dataset_id)
|
||||
|
||||
db.session.delete(document)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def pause_document(document):
|
||||
if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]:
|
||||
raise DocumentIndexingError()
|
||||
# update document to be paused
|
||||
document.is_paused = True
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = datetime.datetime.utcnow()
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# set document paused flag
|
||||
indexing_cache_key = 'document_{}_is_paused'.format(document.id)
|
||||
redis_client.setnx(indexing_cache_key, "True")
|
||||
|
||||
@staticmethod
|
||||
def recover_document(document):
|
||||
if not document.is_paused:
|
||||
raise DocumentIndexingError()
|
||||
# update document to be recover
|
||||
document.is_paused = False
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = time.time()
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# delete paused flag
|
||||
indexing_cache_key = 'document_{}_is_paused'.format(document.id)
|
||||
redis_client.delete(indexing_cache_key)
|
||||
# trigger async task
|
||||
document_indexing_task.delay(document.dataset_id, document.id)
|
||||
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
documents = Document.query.filter_by(dataset_id=dataset_id).all()
|
||||
if documents:
|
||||
return len(documents) + 1
|
||||
else:
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
||||
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = 'web'):
|
||||
if not dataset.indexing_technique:
|
||||
if 'indexing_technique' not in document_data \
|
||||
or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
raise ValueError("Indexing technique is required")
|
||||
|
||||
dataset.indexing_technique = document_data["indexing_technique"]
|
||||
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
IndexBuilder.get_default_service_context(dataset.tenant_id)
|
||||
|
||||
# save process rule
|
||||
if not dataset_process_rule:
|
||||
process_rule = document_data["process_rule"]
|
||||
if process_rule["mode"] == "custom":
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule["mode"],
|
||||
rules=json.dumps(process_rule["rules"]),
|
||||
created_by=account.id
|
||||
)
|
||||
elif process_rule["mode"] == "automatic":
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule["mode"],
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id
|
||||
)
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
|
||||
file_name = ''
|
||||
data_source_info = {}
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
file_id = document_data["data_source"]["info"]
|
||||
file = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
|
||||
# save document
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=position,
|
||||
data_source_type=document_data["data_source"]["type"],
|
||||
data_source_info=json.dumps(data_source_info),
|
||||
dataset_process_rule_id=dataset_process_rule.id,
|
||||
batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)),
|
||||
name=file_name,
|
||||
created_from=created_from,
|
||||
created_by=account.id,
|
||||
# created_api_request_id = db.Column(UUID, nullable=True)
|
||||
)
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
document_indexing_task.delay(document.dataset_id, document.id)
|
||||
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name='',
|
||||
data_source_type=document_data["data_source"]["type"],
|
||||
indexing_technique=document_data["indexing_technique"],
|
||||
created_by=account.id
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
|
||||
document = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
|
||||
|
||||
cut_length = 18
|
||||
cut_name = document.name[:cut_length]
|
||||
dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name
|
||||
dataset.description = 'useful for when you want to answer queries about the ' + document.name
|
||||
db.session.commit()
|
||||
|
||||
return dataset, document
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if 'data_source' not in args or not args['data_source']:
|
||||
raise ValueError("Data source is required")
|
||||
|
||||
if not isinstance(args['data_source'], dict):
|
||||
raise ValueError("Data source is invalid")
|
||||
|
||||
if 'type' not in args['data_source'] or not args['data_source']['type']:
|
||||
raise ValueError("Data source type is required")
|
||||
|
||||
if args['data_source']['type'] not in Document.DATA_SOURCES:
|
||||
raise ValueError("Data source type is invalid")
|
||||
|
||||
if args['data_source']['type'] == 'upload_file':
|
||||
if 'info' not in args['data_source'] or not args['data_source']['info']:
|
||||
raise ValueError("Data source info is required")
|
||||
|
||||
if 'process_rule' not in args or not args['process_rule']:
|
||||
raise ValueError("Process rule is required")
|
||||
|
||||
if not isinstance(args['process_rule'], dict):
|
||||
raise ValueError("Process rule is invalid")
|
||||
|
||||
if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:
|
||||
raise ValueError("Process rule mode is required")
|
||||
|
||||
if args['process_rule']['mode'] not in DatasetProcessRule.MODES:
|
||||
raise ValueError("Process rule mode is invalid")
|
||||
|
||||
if args['process_rule']['mode'] == 'automatic':
|
||||
args['process_rule']['rules'] = {}
|
||||
else:
|
||||
if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:
|
||||
raise ValueError("Process rule rules is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules'], dict):
|
||||
raise ValueError("Process rule rules is invalid")
|
||||
|
||||
if 'pre_processing_rules' not in args['process_rule']['rules'] \
|
||||
or args['process_rule']['rules']['pre_processing_rules'] is None:
|
||||
raise ValueError("Process rule pre_processing_rules is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
|
||||
raise ValueError("Process rule pre_processing_rules is invalid")
|
||||
|
||||
unique_pre_processing_rule_dicts = {}
|
||||
for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:
|
||||
if 'id' not in pre_processing_rule or not pre_processing_rule['id']:
|
||||
raise ValueError("Process rule pre_processing_rules id is required")
|
||||
|
||||
if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:
|
||||
raise ValueError("Process rule pre_processing_rules id is invalid")
|
||||
|
||||
if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:
|
||||
raise ValueError("Process rule pre_processing_rules enabled is required")
|
||||
|
||||
if not isinstance(pre_processing_rule['enabled'], bool):
|
||||
raise ValueError("Process rule pre_processing_rules enabled is invalid")
|
||||
|
||||
unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule
|
||||
|
||||
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
|
||||
|
||||
if 'segmentation' not in args['process_rule']['rules'] \
|
||||
or args['process_rule']['rules']['segmentation'] is None:
|
||||
raise ValueError("Process rule segmentation is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
|
||||
raise ValueError("Process rule segmentation is invalid")
|
||||
|
||||
if 'separator' not in args['process_rule']['rules']['segmentation'] \
|
||||
or not args['process_rule']['rules']['segmentation']['separator']:
|
||||
raise ValueError("Process rule segmentation separator is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
|
||||
raise ValueError("Process rule segmentation separator is invalid")
|
||||
|
||||
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
|
||||
or not args['process_rule']['rules']['segmentation']['max_tokens']:
|
||||
raise ValueError("Process rule segmentation max_tokens is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
7
api/services/errors/__init__.py
Normal file
7
api/services/errors/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
__all__ = [
|
||||
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
|
||||
'app', 'completion'
|
||||
]
|
||||
|
||||
from . import *
|
||||
53
api/services/errors/account.py
Normal file
53
api/services/errors/account.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class AccountNotFound(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class AccountRegisterError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class AccountLoginError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class AccountNotLinkTenantError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class CurrentPasswordIncorrectError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class LinkAccountIntegrateError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class TenantNotFound(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class AccountAlreadyInTenantError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidActionError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class CannotOperateSelfError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class NoPermissionError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class MemberNotInTenantError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class RoleAlreadyAssignedError(BaseServiceError):
|
||||
pass
|
||||
2
api/services/errors/app.py
Normal file
2
api/services/errors/app.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class MoreLikeThisDisabledError(Exception):
|
||||
pass
|
||||
5
api/services/errors/app_model_config.py
Normal file
5
api/services/errors/app_model_config.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class AppModelConfigBrokenError(BaseServiceError):
|
||||
pass
|
||||
3
api/services/errors/base.py
Normal file
3
api/services/errors/base.py
Normal file
@@ -0,0 +1,3 @@
|
||||
class BaseServiceError(Exception):
|
||||
def __init__(self, description: str = None):
|
||||
self.description = description
|
||||
5
api/services/errors/completion.py
Normal file
5
api/services/errors/completion.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class CompletionStoppedError(BaseServiceError):
|
||||
pass
|
||||
13
api/services/errors/conversation.py
Normal file
13
api/services/errors/conversation.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class LastConversationNotExistsError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class ConversationNotExistsError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class ConversationCompletedError(Exception):
|
||||
pass
|
||||
5
api/services/errors/dataset.py
Normal file
5
api/services/errors/dataset.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class DatasetNameDuplicateError(BaseServiceError):
|
||||
pass
|
||||
5
api/services/errors/document.py
Normal file
5
api/services/errors/document.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class DocumentIndexingError(BaseServiceError):
|
||||
pass
|
||||
5
api/services/errors/file.py
Normal file
5
api/services/errors/file.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class FileNotExistsError(BaseServiceError):
|
||||
pass
|
||||
5
api/services/errors/index.py
Normal file
5
api/services/errors/index.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class IndexNotInitializedError(BaseServiceError):
|
||||
pass
|
||||
17
api/services/errors/message.py
Normal file
17
api/services/errors/message.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class FirstMessageNotExistsError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class LastMessageNotExistsError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class MessageNotExistsError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerDisabledError(BaseServiceError):
|
||||
pass
|
||||
130
api/services/hit_testing_service.py
Normal file
130
api/services/hit_testing_service.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from llama_index.data_structs.node_v2 import NodeWithScore
|
||||
from llama_index.indices.query.schema import QueryBundle
|
||||
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
from core.docstore.empty_docstore import EmptyDocumentStore
|
||||
from core.index.vector_index import VectorIndex
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DocumentSegment, DatasetQuery
|
||||
from services.errors.index import IndexNotInitializedError
|
||||
|
||||
|
||||
class HitTestingService:
|
||||
@classmethod
|
||||
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
|
||||
index = VectorIndex(dataset=dataset).query_index
|
||||
|
||||
if not index:
|
||||
raise IndexNotInitializedError()
|
||||
|
||||
index_query = GPTVectorStoreIndexQuery(
|
||||
index_struct=index.index_struct,
|
||||
service_context=index.service_context,
|
||||
vector_store=index.query_context.get('vector_store'),
|
||||
docstore=EmptyDocumentStore(),
|
||||
response_synthesizer=None,
|
||||
similarity_top_k=limit
|
||||
)
|
||||
|
||||
query_bundle = QueryBundle(
|
||||
query_str=query,
|
||||
custom_embedding_strs=[query],
|
||||
)
|
||||
|
||||
query_bundle.embedding = index.service_context.embed_model.get_agg_embedding_from_queries(
|
||||
query_bundle.embedding_strs
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
nodes = index_query.retrieve(query_bundle=query_bundle)
|
||||
end = time.perf_counter()
|
||||
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
|
||||
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=query,
|
||||
source='hit_testing',
|
||||
created_by_role='account',
|
||||
created_by=account.id
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
return cls.compact_retrieve_response(dataset, query_bundle, nodes)
|
||||
|
||||
@classmethod
|
||||
def compact_retrieve_response(cls, dataset: Dataset, query_bundle: QueryBundle, nodes: List[NodeWithScore]):
|
||||
embeddings = [
|
||||
query_bundle.embedding
|
||||
]
|
||||
|
||||
for node in nodes:
|
||||
embeddings.append(node.node.embedding)
|
||||
|
||||
tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings)
|
||||
|
||||
query_position = tsne_position_data.pop(0)
|
||||
|
||||
i = 0
|
||||
records = []
|
||||
for node in nodes:
|
||||
index_node_id = node.node.doc_id
|
||||
|
||||
segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.index_node_id == index_node_id
|
||||
).first()
|
||||
|
||||
if not segment:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": node.score,
|
||||
"tsne_position": tsne_position_data[i]
|
||||
}
|
||||
|
||||
records.append(record)
|
||||
|
||||
i += 1
|
||||
|
||||
return {
|
||||
"query": {
|
||||
"content": query_bundle.query_str,
|
||||
"tsne_position": query_position,
|
||||
},
|
||||
"records": records
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_tsne_positions_from_embeddings(cls, embeddings: list):
|
||||
embedding_length = len(embeddings)
|
||||
if embedding_length <= 1:
|
||||
return [{'x': 0, 'y': 0}]
|
||||
|
||||
concatenate_data = np.array(embeddings).reshape(embedding_length, -1)
|
||||
# concatenate_data = np.concatenate(embeddings)
|
||||
|
||||
perplexity = embedding_length / 2 + 1
|
||||
if perplexity >= embedding_length:
|
||||
perplexity = max(embedding_length - 1, 1)
|
||||
|
||||
tsne = TSNE(n_components=2, perplexity=perplexity, early_exaggeration=12.0)
|
||||
data_tsne = tsne.fit_transform(concatenate_data)
|
||||
|
||||
tsne_position_data = []
|
||||
for i in range(len(data_tsne)):
|
||||
tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])})
|
||||
|
||||
return tsne_position_data
|
||||
212
api/services/message_service.py
Normal file
212
api/services/message_service.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from typing import Optional, Union, List
|
||||
|
||||
from core.completion import Completion
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser, Message, MessageFeedback
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import FirstMessageNotExistsError, MessageNotExistsError, LastMessageNotExistsError, \
|
||||
SuggestedQuestionsAfterAnswerDisabledError
|
||||
|
||||
|
||||
class MessageService:
|
||||
@classmethod
|
||||
def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
|
||||
conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
if not conversation_id:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
if first_id:
|
||||
first_message = db.session.query(Message) \
|
||||
.filter(Message.conversation_id == conversation.id, Message.id == first_id).first()
|
||||
|
||||
if not first_message:
|
||||
raise FirstMessageNotExistsError()
|
||||
|
||||
history_messages = db.session.query(Message).filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
Message.id != first_message.id
|
||||
) \
|
||||
.order_by(Message.created_at.desc()).limit(limit).all()
|
||||
else:
|
||||
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
|
||||
.order_by(Message.created_at.desc()).limit(limit).all()
|
||||
|
||||
has_more = False
|
||||
if len(history_messages) == limit:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = db.session.query(Message).filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=history_messages,
|
||||
limit=limit,
|
||||
has_more=has_more
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
|
||||
last_id: Optional[str], limit: int, conversation_id: Optional[str] = None,
|
||||
include_ids: Optional[list] = None) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
base_query = db.session.query(Message)
|
||||
|
||||
if conversation_id is not None:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
base_query = base_query.filter(Message.conversation_id == conversation.id)
|
||||
|
||||
if include_ids is not None:
|
||||
base_query = base_query.filter(Message.id.in_(include_ids))
|
||||
|
||||
if last_id:
|
||||
last_message = base_query.filter(Message.id == last_id).first()
|
||||
|
||||
if not last_message:
|
||||
raise LastMessageNotExistsError()
|
||||
|
||||
history_messages = base_query.filter(
|
||||
Message.created_at < last_message.created_at,
|
||||
Message.id != last_message.id
|
||||
).order_by(Message.created_at.desc()).limit(limit).all()
|
||||
else:
|
||||
history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all()
|
||||
|
||||
has_more = False
|
||||
if len(history_messages) == limit:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = base_query.filter(
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=history_messages,
|
||||
limit=limit,
|
||||
has_more=has_more
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account | EndUser]],
|
||||
rating: Optional[str]) -> MessageFeedback:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
|
||||
message = cls.get_message(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
message_id=message_id
|
||||
)
|
||||
|
||||
feedback = message.user_feedback
|
||||
|
||||
if not rating and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif rating and feedback:
|
||||
feedback.rating = rating
|
||||
elif not rating and not feedback:
|
||||
raise ValueError('rating cannot be None when feedback not exists')
|
||||
else:
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating,
|
||||
from_source=('user' if isinstance(user, EndUser) else 'admin'),
|
||||
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
|
||||
from_account_id=(user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
db.session.add(feedback)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return feedback
|
||||
|
||||
@classmethod
|
||||
def get_message(cls, app_model: App, user: Optional[Union[Account | EndUser]], message_id: str):
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account | EndUser]],
|
||||
message_id: str, check_enabled: bool = True) -> List[Message]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
|
||||
|
||||
if check_enabled and suggested_questions_after_answer.get("enabled", False) is False:
|
||||
raise SuggestedQuestionsAfterAnswerDisabledError()
|
||||
|
||||
message = cls.get_message(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
message_id=message_id
|
||||
)
|
||||
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
conversation_id=message.conversation_id,
|
||||
user=user
|
||||
)
|
||||
|
||||
# get memory of conversation (read-only)
|
||||
memory = Completion.get_memory_from_conversation(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_model_config=app_model.app_model_config,
|
||||
conversation=conversation,
|
||||
max_token_limit=3000,
|
||||
message_limit=3,
|
||||
return_messages=False,
|
||||
memory_key="histories"
|
||||
)
|
||||
|
||||
external_context = memory.load_memory_variables({})
|
||||
|
||||
questions = LLMGenerator.generate_suggested_questions_after_answer(
|
||||
tenant_id=app_model.tenant_id,
|
||||
**external_context
|
||||
)
|
||||
|
||||
return questions
|
||||
|
||||
96
api/services/provider_service.py
Normal file
96
api/services/provider_service.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Union
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from models.account import Tenant
|
||||
from models.provider import *
|
||||
|
||||
|
||||
class ProviderService:
|
||||
|
||||
@staticmethod
|
||||
def init_supported_provider(tenant, edition):
|
||||
"""Initialize the model provider, check whether the supported provider has a record"""
|
||||
|
||||
providers = Provider.query.filter_by(tenant_id=tenant.id).all()
|
||||
|
||||
openai_provider_exists = False
|
||||
azure_openai_provider_exists = False
|
||||
|
||||
# TODO: The cloud version needs to construct the data of the SYSTEM type
|
||||
|
||||
for provider in providers:
|
||||
if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
|
||||
openai_provider_exists = True
|
||||
if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
|
||||
azure_openai_provider_exists = True
|
||||
|
||||
# Initialize the model provider, check whether the supported provider has a record
|
||||
|
||||
# Create default providers if they don't exist
|
||||
if not openai_provider_exists:
|
||||
openai_provider = Provider(
|
||||
tenant_id=tenant.id,
|
||||
provider_name=ProviderName.OPENAI.value,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(openai_provider)
|
||||
|
||||
if not azure_openai_provider_exists:
|
||||
azure_openai_provider = Provider(
|
||||
tenant_id=tenant.id,
|
||||
provider_name=ProviderName.AZURE_OPENAI.value,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(azure_openai_provider)
|
||||
|
||||
if not openai_provider_exists or not azure_openai_provider_exists:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_obfuscated_api_key(tenant, provider_name: ProviderName):
|
||||
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
|
||||
return llm_provider_service.get_provider_configs(obfuscated=True)
|
||||
|
||||
@staticmethod
|
||||
def get_token_type(tenant, provider_name: ProviderName):
|
||||
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
|
||||
return llm_provider_service.get_token_type()
|
||||
|
||||
@staticmethod
|
||||
def validate_provider_configs(tenant, provider_name: ProviderName, configs: Union[dict | str]):
|
||||
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
|
||||
return llm_provider_service.config_validate(configs)
|
||||
|
||||
@staticmethod
|
||||
def get_encrypted_token(tenant, provider_name: ProviderName, configs: Union[dict | str]):
|
||||
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
|
||||
return llm_provider_service.get_encrypted_token(configs)
|
||||
|
||||
@staticmethod
|
||||
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value,
|
||||
is_valid: bool = True):
|
||||
if current_app.config['EDITION'] != 'CLOUD':
|
||||
return
|
||||
|
||||
provider = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value
|
||||
).one_or_none()
|
||||
|
||||
if not provider:
|
||||
provider = Provider(
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=200,
|
||||
encrypted_config='',
|
||||
is_valid=is_valid,
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
66
api/services/saved_message_service.py
Normal file
66
api/services/saved_message_service.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import Optional
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from models.web import SavedMessage
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
class SavedMessageService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, end_user: Optional[EndUser],
|
||||
last_id: Optional[str], limit: int) -> InfiniteScrollPagination:
|
||||
saved_messages = db.session.query(SavedMessage).filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.created_by == end_user.id
|
||||
).order_by(SavedMessage.created_at.desc()).all()
|
||||
message_ids = [sm.message_id for sm in saved_messages]
|
||||
|
||||
return MessageService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=last_id,
|
||||
limit=limit,
|
||||
include_ids=message_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save(cls, app_model: App, user: Optional[EndUser], message_id: str):
|
||||
saved_message = db.session.query(SavedMessage).filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.message_id == message_id,
|
||||
SavedMessage.created_by == user.id
|
||||
).first()
|
||||
|
||||
if saved_message:
|
||||
return
|
||||
|
||||
message = MessageService.get_message(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
message_id=message_id
|
||||
)
|
||||
|
||||
saved_message = SavedMessage(
|
||||
app_id=app_model.id,
|
||||
message_id=message.id,
|
||||
created_by=user.id
|
||||
)
|
||||
|
||||
db.session.add(saved_message)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def delete(cls, app_model: App, user: Optional[EndUser], message_id: str):
|
||||
saved_message = db.session.query(SavedMessage).filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.message_id == message_id,
|
||||
SavedMessage.created_by == user.id
|
||||
).first()
|
||||
|
||||
if not saved_message:
|
||||
return
|
||||
|
||||
db.session.delete(saved_message)
|
||||
db.session.commit()
|
||||
74
api/services/web_conversation_service.py
Normal file
74
api/services/web_conversation_service.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from models.web import PinnedConversation
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class WebConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, end_user: Optional[EndUser],
|
||||
last_id: Optional[str], limit: int, pinned: Optional[bool] = None) -> InfiniteScrollPagination:
|
||||
include_ids = None
|
||||
exclude_ids = None
|
||||
if pinned is not None:
|
||||
pinned_conversations = db.session.query(PinnedConversation).filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.created_by == end_user.id
|
||||
).order_by(PinnedConversation.created_at.desc()).all()
|
||||
pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations]
|
||||
if pinned:
|
||||
include_ids = pinned_conversation_ids
|
||||
else:
|
||||
exclude_ids = pinned_conversation_ids
|
||||
|
||||
return ConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=last_id,
|
||||
limit=limit,
|
||||
include_ids=include_ids,
|
||||
exclude_ids=exclude_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pin(cls, app_model: App, conversation_id: str, user: Optional[EndUser]):
|
||||
pinned_conversation = db.session.query(PinnedConversation).filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by == user.id
|
||||
).first()
|
||||
|
||||
if pinned_conversation:
|
||||
return
|
||||
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
conversation_id=conversation_id,
|
||||
user=user
|
||||
)
|
||||
|
||||
pinned_conversation = PinnedConversation(
|
||||
app_id=app_model.id,
|
||||
conversation_id=conversation.id,
|
||||
created_by=user.id
|
||||
)
|
||||
|
||||
db.session.add(pinned_conversation)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def unpin(cls, app_model: App, conversation_id: str, user: Optional[EndUser]):
|
||||
pinned_conversation = db.session.query(PinnedConversation).filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by == user.id
|
||||
).first()
|
||||
|
||||
if not pinned_conversation:
|
||||
return
|
||||
|
||||
db.session.delete(pinned_conversation)
|
||||
db.session.commit()
|
||||
49
api/services/workspace_service.py
Normal file
49
api/services/workspace_service.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
class WorkspaceService:
|
||||
@classmethod
|
||||
def get_tenant_info(cls, tenant: Tenant):
|
||||
tenant_info = {
|
||||
'id': tenant.id,
|
||||
'name': tenant.name,
|
||||
'plan': tenant.plan,
|
||||
'status': tenant.status,
|
||||
'created_at': tenant.created_at,
|
||||
'providers': [],
|
||||
'in_trail': False,
|
||||
'trial_end_reason': 'using_custom'
|
||||
}
|
||||
|
||||
# Get providers
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id
|
||||
).all()
|
||||
|
||||
# Add providers to the tenant info
|
||||
tenant_info['providers'] = providers
|
||||
|
||||
custom_provider = None
|
||||
system_provider = None
|
||||
|
||||
for provider in providers:
|
||||
if provider.provider_type == ProviderType.CUSTOM.value:
|
||||
if provider.is_valid and provider.encrypted_config:
|
||||
custom_provider = provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value:
|
||||
if provider.is_valid:
|
||||
system_provider = provider
|
||||
|
||||
if system_provider and not custom_provider:
|
||||
quota_used = system_provider.quota_used if system_provider.quota_used is not None else 0
|
||||
quota_limit = system_provider.quota_limit if system_provider.quota_limit is not None else 0
|
||||
|
||||
if quota_used >= quota_limit:
|
||||
tenant_info['trial_end_reason'] = 'trial_exceeded'
|
||||
else:
|
||||
tenant_info['in_trail'] = True
|
||||
tenant_info['trial_end_reason'] = None
|
||||
|
||||
return tenant_info
|
||||
Reference in New Issue
Block a user