Initial commit

This commit is contained in:
John Wang
2023-05-15 08:51:32 +08:00
commit db896255d6
744 changed files with 56028 additions and 0 deletions

2
api/services/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
# -*- coding:utf-8 -*-
import services.errors

View File

@@ -0,0 +1,382 @@
# -*- coding:utf-8 -*-
import base64
import logging
import secrets
from datetime import datetime
from typing import Optional
from flask import session
from sqlalchemy import func
from events.tenant_event import tenant_was_created
from services.errors.account import AccountLoginError, CurrentPasswordIncorrectError, LinkAccountIntegrateError, \
TenantNotFound, AccountNotLinkTenantError, InvalidActionError, CannotOperateSelfError, MemberNotInTenantError, \
RoleAlreadyAssignedError, NoPermissionError, AccountRegisterError, AccountAlreadyInTenantError
from libs.helper import get_remote_ip
from libs.password import compare_password, hash_password
from libs.rsa import generate_key_pair
from models.account import *
class AccountService:
@staticmethod
def load_user(account_id: int) -> Account:
# todo: used by flask_login
pass
@staticmethod
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
account = Account.query.filter_by(email=email).first()
if not account:
raise AccountLoginError('Invalid email or password.')
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise AccountLoginError('Account is banned or closed.')
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.utcnow()
db.session.commit()
if account.password is None or not compare_password(password, account.password, account.password_salt):
raise AccountLoginError('Invalid email or password.')
return account
@staticmethod
def update_account_password(account, password, new_password):
"""update account password"""
# todo: split validation and update
if account.password and not compare_password(password, account.password, account.password_salt):
raise CurrentPasswordIncorrectError("Current password is incorrect.")
password_hashed = hash_password(new_password, account.password_salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
db.session.commit()
return account
@staticmethod
def create_account(email: str, name: str, password: str = None,
interface_language: str = 'en-US', interface_theme: str = 'light',
timezone: str = 'America/New_York', ) -> Account:
"""create account"""
account = Account()
account.email = email
account.name = name
if password:
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
account.interface_language = interface_language
account.interface_theme = interface_theme
if interface_language == 'zh-Hans':
account.timezone = 'Asia/Shanghai'
else:
account.timezone = timezone
db.session.add(account)
db.session.commit()
return account
@staticmethod
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id,
provider=provider).first()
if account_integrate:
# If it exists, update the record
account_integrate.open_id = open_id
account_integrate.encrypted_token = "" # todo
account_integrate.updated_at = datetime.utcnow()
else:
# If it does not exist, create a new record
account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id,
encrypted_token="")
db.session.add(account_integrate)
db.session.commit()
logging.info(f'Account {account.id} linked {provider} account {open_id}.')
except Exception as e:
logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}')
raise LinkAccountIntegrateError('Failed to link account.') from e
@staticmethod
def close_account(account: Account) -> None:
"""todo: Close account"""
account.status = AccountStatus.CLOSED.value
db.session.commit()
@staticmethod
def update_account(account, **kwargs):
"""Update account fields"""
for field, value in kwargs.items():
if hasattr(account, field):
setattr(account, field, value)
else:
raise AttributeError(f"Invalid field: {field}")
db.session.commit()
return account
@staticmethod
def update_last_login(account: Account, request) -> None:
"""Update last login time and ip"""
account.last_login_at = datetime.utcnow()
account.last_login_ip = get_remote_ip(request)
db.session.add(account)
db.session.commit()
logging.info(f'Account {account.id} logged in successfully.')
class TenantService:
@staticmethod
def create_tenant(name: str) -> Tenant:
"""Create tenant"""
tenant = Tenant(name=name)
db.session.add(tenant)
db.session.commit()
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
return tenant
@staticmethod
def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin:
"""Create tenant member"""
if role == TenantAccountJoinRole.OWNER.value:
if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]):
logging.error(f'Tenant {tenant.id} has already an owner.')
raise Exception('Tenant already has an owner.')
ta = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role
)
db.session.add(ta)
db.session.commit()
return ta
@staticmethod
def get_join_tenants(account: Account) -> List[Tenant]:
"""Get account join tenants"""
return db.session.query(Tenant).join(
TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id
).filter(TenantAccountJoin.account_id == account.id).all()
@staticmethod
def get_current_tenant_by_account(account: Account):
"""Get tenant by account and add the role"""
tenant = account.current_tenant
if not tenant:
raise TenantNotFound("Tenant not found.")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
if ta:
tenant.role = ta.role
else:
raise TenantNotFound("Tenant not found for the account.")
return tenant
@staticmethod
def switch_tenant(account: Account, tenant_id: int = None) -> None:
"""Switch the current workspace for the account"""
if not tenant_id:
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id).first()
else:
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
# Check if the tenant exists and the account is a member of the tenant
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
# Set the current tenant for the account
account.current_tenant_id = tenant_account_join.tenant_id
session['workspace_id'] = account.current_tenant.id
@staticmethod
def get_tenant_members(tenant: Tenant) -> List[Account]:
"""Get tenant members"""
query = (
db.session.query(Account, TenantAccountJoin.role)
.select_from(Account)
.join(
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
)
.filter(TenantAccountJoin.tenant_id == tenant.id)
)
# Initialize an empty list to store the updated accounts
updated_accounts = []
for account, role in query:
account.role = role
updated_accounts.append(account)
return updated_accounts
@staticmethod
def has_roles(tenant: Tenant, roles: List[TenantAccountJoinRole]) -> bool:
"""Check if user has any of the given roles for a tenant"""
if not all(isinstance(role, TenantAccountJoinRole) for role in roles):
raise ValueError('all roles must be TenantAccountJoinRole')
return db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.role.in_([role.value for role in roles])
).first() is not None
@staticmethod
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]:
"""Get the role of the current account for a given tenant"""
join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.account_id == account.id
).first()
return join.role if join else None
@staticmethod
def get_tenant_count() -> int:
"""Get tenant count"""
return db.session.query(func.count(Tenant.id)).scalar()
@staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None:
"""Check member permission"""
perms = {
'add': ['owner', 'admin'],
'remove': ['owner'],
'update': ['owner']
}
if action not in ['add', 'remove', 'update']:
raise InvalidActionError("Invalid action.")
if operator.id == member.id:
raise CannotOperateSelfError("Cannot operate self.")
ta_operator = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id,
account_id=operator.id
).first()
if not ta_operator or ta_operator.role not in perms[action]:
raise NoPermissionError(f'No permission to {action} member.')
@staticmethod
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
"""Remove member from tenant"""
# todo: check permission
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'):
raise CannotOperateSelfError("Cannot operate self.")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
if not ta:
raise MemberNotInTenantError("Member not in tenant.")
db.session.delete(ta)
db.session.commit()
@staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, 'update')
target_member_join = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id,
account_id=member.id
).first()
if target_member_join.role == new_role:
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
if new_role == 'owner':
# Find the current owner and change their role to 'admin'
current_owner_join = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id,
role='owner'
).first()
current_owner_join.role = 'admin'
# Update the role of the target member
target_member_join.role = new_role
db.session.commit()
@staticmethod
def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
"""Dissolve tenant"""
if not TenantService.check_member_permission(tenant, operator, operator, 'remove'):
raise NoPermissionError('No permission to dissolve tenant.')
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
db.session.delete(tenant)
db.session.commit()
class RegisterService:
@staticmethod
def register(email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
db.session.begin_nested()
"""Register account"""
try:
account = AccountService.create_account(email, name, password)
account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.utcnow()
if open_id is not None or provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role='owner')
account.current_tenant = tenant
db.session.commit()
except Exception as e:
db.session.rollback() # todo: do not work
logging.error(f'Register failed: {e}')
raise AccountRegisterError(f'Registration failed: {e}') from e
tenant_was_created.send(tenant)
return account
@staticmethod
def invite_new_member(tenant: Tenant, email: str, role: str = 'normal',
inviter: Account = None) -> TenantAccountJoin:
"""Invite new member"""
account = Account.query.filter_by(email=email).first()
if not account:
name = email.split('@')[0]
account = AccountService.create_account(email, name)
account.status = AccountStatus.PENDING.value
db.session.commit()
else:
TenantService.check_member_permission(tenant, inviter, account, 'add')
ta = TenantAccountJoin.query.filter_by(
tenant_id=tenant.id,
account_id=account.id
).first()
if ta:
raise AccountAlreadyInTenantError("Account already in tenant.")
ta = TenantService.create_tenant_member(tenant, account, role)
return ta

View File

@@ -0,0 +1,292 @@
import re
import uuid
from core.constant import llm_constant
from models.account import Account
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
class AppModelConfigService:
@staticmethod
def is_dataset_exists(account: Account, dataset_id: str) -> bool:
# verify if the dataset ID exists
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
return False
if dataset.tenant_id != account.current_tenant_id:
return False
return True
@staticmethod
def validate_model_completion_params(cp: dict, model_name: str) -> dict:
# 6. model.completion_params
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")
# max_tokens
if 'max_tokens' not in cp:
cp["max_tokens"] = 512
if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
llm_constant.max_context_token_length[model_name]:
raise ValueError(
"max_tokens must be an integer greater than 0 and not exceeding the maximum value of the corresponding model")
# temperature
if 'temperature' not in cp:
cp["temperature"] = 1
if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
raise ValueError("temperature must be a float between 0 and 2")
# top_p
if 'top_p' not in cp:
cp["top_p"] = 1
if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
raise ValueError("top_p must be a float between 0 and 2")
# presence_penalty
if 'presence_penalty' not in cp:
cp["presence_penalty"] = 0
if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
raise ValueError("presence_penalty must be a float between -2 and 2")
# presence_penalty
if 'frequency_penalty' not in cp:
cp["frequency_penalty"] = 0
if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
raise ValueError("frequency_penalty must be a float between -2 and 2")
# Filter out extra parameters
filtered_cp = {
"max_tokens": cp["max_tokens"],
"temperature": cp["temperature"],
"top_p": cp["top_p"],
"presence_penalty": cp["presence_penalty"],
"frequency_penalty": cp["frequency_penalty"]
}
return filtered_cp
@staticmethod
def validate_configuration(account: Account, config: dict, mode: str) -> dict:
# opening_statement
if 'opening_statement' not in config or not config["opening_statement"]:
config["opening_statement"] = ""
if not isinstance(config["opening_statement"], str):
raise ValueError("opening_statement must be of string type")
# suggested_questions
if 'suggested_questions' not in config or not config["suggested_questions"]:
config["suggested_questions"] = []
if not isinstance(config["suggested_questions"], list):
raise ValueError("suggested_questions must be of list type")
for question in config["suggested_questions"]:
if not isinstance(question, str):
raise ValueError("Elements in suggested_questions list must be of string type")
# suggested_questions_after_answer
if 'suggested_questions_after_answer' not in config or not config["suggested_questions_after_answer"]:
config["suggested_questions_after_answer"] = {
"enabled": False
}
if not isinstance(config["suggested_questions_after_answer"], dict):
raise ValueError("suggested_questions_after_answer must be of dict type")
if "enabled" not in config["suggested_questions_after_answer"] or not config["suggested_questions_after_answer"]["enabled"]:
config["suggested_questions_after_answer"]["enabled"] = False
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
# more_like_this
if 'more_like_this' not in config or not config["more_like_this"]:
config["more_like_this"] = {
"enabled": False
}
if not isinstance(config["more_like_this"], dict):
raise ValueError("more_like_this must be of dict type")
if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]:
config["more_like_this"]["enabled"] = False
if not isinstance(config["more_like_this"]["enabled"], bool):
raise ValueError("enabled in more_like_this must be of boolean type")
# model
if 'model' not in config:
raise ValueError("model is required")
if not isinstance(config["model"], dict):
raise ValueError("model must be of object type")
# model.provider
if 'provider' not in config["model"] or config["model"]["provider"] != "openai":
raise ValueError("model.provider must be 'openai'")
# model.name
if 'name' not in config["model"]:
raise ValueError("model.name is required")
if config["model"]["name"] not in llm_constant.models_by_mode[mode]:
raise ValueError("model.name must be in the specified model list")
# model.completion_params
if 'completion_params' not in config["model"]:
raise ValueError("model.completion_params is required")
config["model"]["completion_params"] = AppModelConfigService.validate_model_completion_params(
config["model"]["completion_params"],
config["model"]["name"]
)
# user_input_form
if "user_input_form" not in config or not config["user_input_form"]:
config["user_input_form"] = []
if not isinstance(config["user_input_form"], list):
raise ValueError("user_input_form must be a list of objects")
variables = []
for item in config["user_input_form"]:
key = list(item.keys())[0]
if key not in ["text-input", "select"]:
raise ValueError("Keys in user_input_form list can only be 'text-input' or 'select'")
form_item = item[key]
if 'label' not in form_item:
raise ValueError("label is required in user_input_form")
if not isinstance(form_item["label"], str):
raise ValueError("label in user_input_form must be of string type")
if 'variable' not in form_item:
raise ValueError("variable is required in user_input_form")
if not isinstance(form_item["variable"], str):
raise ValueError("variable in user_input_form must be of string type")
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
if pattern.match(form_item["variable"]) is None:
raise ValueError("variable in user_input_form must be a string, "
"and cannot start with a number")
variables.append(form_item["variable"])
if 'required' not in form_item or not form_item["required"]:
form_item["required"] = False
if not isinstance(form_item["required"], bool):
raise ValueError("required in user_input_form must be of boolean type")
if key == "select":
if 'options' not in form_item or not form_item["options"]:
form_item["options"] = []
if not isinstance(form_item["options"], list):
raise ValueError("options in user_input_form must be a list of strings")
if "default" in form_item and form_item['default'] \
and form_item["default"] not in form_item["options"]:
raise ValueError("default value in user_input_form must be in the options list")
# pre_prompt
if "pre_prompt" not in config or not config["pre_prompt"]:
config["pre_prompt"] = ""
if not isinstance(config["pre_prompt"], str):
raise ValueError("pre_prompt must be of string type")
template_vars = re.findall(r"\{\{(\w+)\}\}", config["pre_prompt"])
for var in template_vars:
if var not in variables:
raise ValueError("Template variables in pre_prompt must be defined in user_input_form")
# agent_mode
if "agent_mode" not in config or not config["agent_mode"]:
config["agent_mode"] = {
"enabled": False,
"tools": []
}
if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type")
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
config["agent_mode"]["enabled"] = False
if not isinstance(config["agent_mode"]["enabled"], bool):
raise ValueError("enabled in agent_mode must be of boolean type")
if "tools" not in config["agent_mode"] or not config["agent_mode"]["tools"]:
config["agent_mode"]["tools"] = []
if not isinstance(config["agent_mode"]["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects")
for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0]
if key not in ["sensitive-word-avoidance", "dataset"]:
raise ValueError("Keys in agent_mode.tools list can only be 'sensitive-word-avoidance' or 'dataset'")
tool_item = tool[key]
if "enabled" not in tool_item or not tool_item["enabled"]:
tool_item["enabled"] = False
if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type")
if key == "sensitive-word-avoidance":
if "words" not in tool_item or not tool_item["words"]:
tool_item["words"] = ""
if not isinstance(tool_item["words"], str):
raise ValueError("words in sensitive-word-avoidance must be of string type")
if "canned_response" not in tool_item or not tool_item["canned_response"]:
tool_item["canned_response"] = ""
if not isinstance(tool_item["canned_response"], str):
raise ValueError("canned_response in sensitive-word-avoidance must be of string type")
elif key == "dataset":
if 'id' not in tool_item:
raise ValueError("id is required in dataset")
try:
uuid.UUID(tool_item["id"])
except ValueError:
raise ValueError("id in dataset must be of UUID type")
if not AppModelConfigService.is_dataset_exists(account, tool_item["id"]):
raise ValueError("Dataset ID does not exist, please check your permission.")
# Filter out extra parameters
filtered_config = {
"opening_statement": config["opening_statement"],
"suggested_questions": config["suggested_questions"],
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
"more_like_this": config["more_like_this"],
"model": {
"provider": config["model"]["provider"],
"name": config["model"]["name"],
"completion_params": config["model"]["completion_params"]
},
"user_input_form": config["user_input_form"],
"pre_prompt": config["pre_prompt"],
"agent_mode": config["agent_mode"]
}
return filtered_config

View File

@@ -0,0 +1,506 @@
import json
import logging
import threading
import time
import uuid
from typing import Generator, Union, Any
from flask import current_app, Flask
from redis.client import PubSub
from sqlalchemy import and_
from core.completion import Completion
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
from services.app_model_config_service import AppModelConfigService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.completion import CompletionStoppedError
from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
from services.errors.message import MessageNotExistsError
class CompletionService:
@classmethod
def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any,
from_source: str, streaming: bool = True,
is_model_config_override: bool = False) -> Union[dict | Generator]:
# is streaming mode
inputs = args['inputs']
query = args['query']
conversation_id = args['conversation_id'] if 'conversation_id' in args else None
conversation = None
if conversation_id:
conversation_filter = [
Conversation.id == args['conversation_id'],
Conversation.app_id == app_model.id,
Conversation.status == 'normal'
]
if from_source == 'console':
conversation_filter.append(Conversation.from_account_id == user.id)
else:
conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
if not conversation:
raise ConversationNotExistsError()
if conversation.status != 'normal':
raise ConversationCompletedError()
if not conversation.override_model_configs:
app_model_config = db.session.query(AppModelConfig).get(conversation.app_model_config_id)
if not app_model_config:
raise AppModelConfigBrokenError()
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
app_model_config = AppModelConfig(
id=conversation.app_model_config_id,
app_id=app_model.id,
provider="",
model_id="",
configs="",
opening_statement=conversation_override_model_configs['opening_statement'],
suggested_questions=json.dumps(conversation_override_model_configs['suggested_questions']),
model=json.dumps(conversation_override_model_configs['model']),
user_input_form=json.dumps(conversation_override_model_configs['user_input_form']),
pre_prompt=conversation_override_model_configs['pre_prompt'],
agent_mode=json.dumps(conversation_override_model_configs['agent_mode']),
)
if is_model_config_override:
# build new app model config
if 'model' not in args['model_config']:
raise ValueError('model_config.model is required')
if 'completion_params' not in args['model_config']['model']:
raise ValueError('model_config.model.completion_params is required')
completion_params = AppModelConfigService.validate_model_completion_params(
cp=args['model_config']['model']['completion_params'],
model_name=app_model_config.model_dict["name"]
)
app_model_config_model = app_model_config.model_dict
app_model_config_model['completion_params'] = completion_params
app_model_config = AppModelConfig(
id=app_model_config.id,
app_id=app_model.id,
provider="",
model_id="",
configs="",
opening_statement=app_model_config.opening_statement,
suggested_questions=app_model_config.suggested_questions,
model=json.dumps(app_model_config_model),
user_input_form=app_model_config.user_input_form,
pre_prompt=app_model_config.pre_prompt,
agent_mode=app_model_config.agent_mode,
)
else:
if app_model.app_model_config_id is None:
raise AppModelConfigBrokenError()
app_model_config = app_model.app_model_config
if not app_model_config:
raise AppModelConfigBrokenError()
if is_model_config_override:
if not isinstance(user, Account):
raise Exception("Only account can override model config")
# validate config
model_config = AppModelConfigService.validate_configuration(
account=user,
config=args['model_config'],
mode=app_model.mode
)
app_model_config = AppModelConfig(
id=app_model_config.id,
app_id=app_model.id,
provider="",
model_id="",
configs="",
opening_statement=model_config['opening_statement'],
suggested_questions=json.dumps(model_config['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']),
more_like_this=json.dumps(model_config['more_like_this']),
model=json.dumps(model_config['model']),
user_input_form=json.dumps(model_config['user_input_form']),
pre_prompt=model_config['pre_prompt'],
agent_mode=json.dumps(model_config['agent_mode']),
)
# clean input by app_model_config form rules
inputs = cls.get_cleaned_inputs(inputs, app_model_config)
generate_task_id = str(uuid.uuid4())
pubsub = redis_client.pubsub()
pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
user = cls.get_real_user_instead_of_proxy_obj(user)
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id,
'app_model': app_model,
'app_model_config': app_model_config,
'query': query,
'inputs': inputs,
'user': user,
'conversation': conversation,
'streaming': streaming,
'is_model_config_override': is_model_config_override
})
generate_worker_thread.start()
# wait for 5 minutes to close the thread
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
return cls.compact_response(pubsub, streaming)
@classmethod
def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
if isinstance(user, Account):
user = db.session.query(Account).get(user.id)
elif isinstance(user, EndUser):
user = db.session.query(EndUser).get(user.id)
else:
raise Exception("Unknown user type")
return user
@classmethod
def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
query: str, inputs: dict, user: Union[Account, EndUser],
conversation: Conversation, streaming: bool, is_model_config_override: bool):
with flask_app.app_context():
try:
if conversation:
# fixed the state of the conversation object when it detached from the original session
conversation = db.session.query(Conversation).filter_by(id=conversation.id).first()
# run
Completion.generate(
task_id=generate_task_id,
app=app_model,
app_model_config=app_model_config,
query=query,
inputs=inputs,
user=user,
conversation=conversation,
streaming=streaming,
is_override=is_model_config_override,
)
except ConversationTaskStoppedException:
pass
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
ModelCurrentlyNotSupportError) as e:
db.session.rollback()
PubHandler.pub_error(user, generate_task_id, e)
except LLMAuthorizationError:
db.session.rollback()
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
except Exception as e:
db.session.rollback()
logging.exception("Unknown Error in completion")
PubHandler.pub_error(user, generate_task_id, e)
@classmethod
def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
# wait for 5 minutes to close the thread
timeout = 300
def close_pubsub():
sleep_iterations = 0
while sleep_iterations < timeout and worker_thread.is_alive():
time.sleep(1)
sleep_iterations += 1
if worker_thread.is_alive():
PubHandler.stop(user, generate_task_id)
try:
pubsub.close()
except:
pass
countdown_thread = threading.Thread(target=close_pubsub)
countdown_thread.start()
return countdown_thread
@classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
message_id: str, streaming: bool = True) -> Union[dict | Generator]:
if not user:
raise ValueError('user cannot be None')
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
).first()
if not message:
raise MessageNotExistsError()
current_app_model_config = app_model.app_model_config
more_like_this = current_app_model_config.more_like_this_dict
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
if message.override_model_configs:
override_model_configs = json.loads(message.override_model_configs)
pre_prompt = override_model_configs.get("pre_prompt", '')
elif app_model_config:
pre_prompt = app_model_config.pre_prompt
else:
raise AppModelConfigBrokenError()
generate_task_id = str(uuid.uuid4())
pubsub = redis_client.pubsub()
pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
user = cls.get_real_user_instead_of_proxy_obj(user)
generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id,
'app_model': app_model,
'app_model_config': app_model_config,
'message': message,
'pre_prompt': pre_prompt,
'user': user,
'streaming': streaming
})
generate_worker_thread.start()
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
return cls.compact_response(pubsub, streaming)
@classmethod
def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App,
app_model_config: AppModelConfig, message: Message, pre_prompt: str,
user: Union[Account, EndUser], streaming: bool):
with flask_app.app_context():
try:
# run
Completion.generate_more_like_this(
task_id=generate_task_id,
app=app_model,
user=user,
message=message,
pre_prompt=pre_prompt,
app_model_config=app_model_config,
streaming=streaming
)
except ConversationTaskStoppedException:
pass
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
ModelCurrentlyNotSupportError) as e:
db.session.rollback()
PubHandler.pub_error(user, generate_task_id, e)
except LLMAuthorizationError:
db.session.rollback()
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
except Exception as e:
db.session.rollback()
logging.exception("Unknown Error in completion")
PubHandler.pub_error(user, generate_task_id, e)
@classmethod
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
# Filter input variables from form configuration, handle required fields, default values, and option values
input_form_config = app_model_config.user_input_form_list
for config in input_form_config:
input_config = list(config.values())[0]
variable = input_config["variable"]
input_type = list(config.keys())[0]
if variable not in user_inputs or not user_inputs[variable]:
if "required" in input_config and input_config["required"]:
raise ValueError(f"{variable} is required in input form")
else:
filtered_inputs[variable] = input_config["default"] if "default" in input_config else ""
continue
value = user_inputs[variable]
if input_type == "select":
options = input_config["options"] if "options" in input_config else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if 'max_length' in variable:
max_length = variable['max_length']
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value
return filtered_inputs
@classmethod
def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict | Generator]:
generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
if not streaming:
try:
for message in pubsub.listen():
if message["type"] == "message":
result = message["data"].decode('utf-8')
result = json.loads(result)
if result.get('error'):
cls.handle_error(result)
return cls.get_message_response_data(result.get('data'))
except ValueError as e:
if e.args[0] != "I/O operation on closed file.": # ignore this error
raise CompletionStoppedError()
else:
logging.exception(e)
raise
finally:
try:
pubsub.unsubscribe(generate_channel)
except ConnectionError:
pass
else:
def generate() -> Generator:
try:
for message in pubsub.listen():
if message["type"] == "message":
result = message["data"].decode('utf-8')
result = json.loads(result)
if result.get('error'):
cls.handle_error(result)
event = result.get('event')
if event == "end":
logging.debug("{} finished".format(generate_channel))
break
if event == 'message':
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
elif event == 'chain':
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
elif event == 'agent_thought':
yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
except ValueError as e:
if e.args[0] != "I/O operation on closed file.": # ignore this error
logging.exception(e)
raise
finally:
try:
pubsub.unsubscribe(generate_channel)
except ConnectionError:
pass
return generate()
@classmethod
def get_message_response_data(cls, data: dict):
response_data = {
'event': 'message',
'task_id': data.get('task_id'),
'id': data.get('message_id'),
'answer': data.get('text'),
'created_at': int(time.time())
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_chain_response_data(cls, data: dict):
response_data = {
'event': 'chain',
'id': data.get('chain_id'),
'task_id': data.get('task_id'),
'message_id': data.get('message_id'),
'type': data.get('type'),
'input': data.get('input'),
'output': data.get('output'),
'created_at': int(time.time())
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_agent_thought_response_data(cls, data: dict):
response_data = {
'event': 'agent_thought',
'id': data.get('agent_thought_id'),
'chain_id': data.get('chain_id'),
'task_id': data.get('task_id'),
'message_id': data.get('message_id'),
'position': data.get('position'),
'thought': data.get('thought'),
'tool': data.get('tool'), # todo use real dataset obj replace it
'tool_input': data.get('tool_input'),
'observation': data.get('observation'),
'answer': data.get('answer') if not data.get('thought') else '',
'created_at': int(time.time())
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def handle_error(cls, result: dict):
logging.debug("error: %s", result)
error = result.get('error')
description = result.get('description')
# handle errors
llm_errors = {
'LLMBadRequestError': LLMBadRequestError,
'LLMAPIConnectionError': LLMAPIConnectionError,
'LLMAPIUnavailableError': LLMAPIUnavailableError,
'LLMRateLimitError': LLMRateLimitError,
'ProviderTokenNotInitError': ProviderTokenNotInitError,
'QuotaExceededError': QuotaExceededError,
'ModelCurrentlyNotSupportError': ModelCurrentlyNotSupportError
}
if error in llm_errors:
raise llm_errors[error](description)
elif error == 'LLMAuthorizationError':
raise LLMAuthorizationError('Incorrect API key provided')
else:
raise Exception(description)

View File

@@ -0,0 +1,94 @@
from typing import Union, Optional
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.account import Account
from models.model import Conversation, App, EndUser
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
class ConversationService:
@classmethod
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
last_id: Optional[str], limit: int,
include_ids: Optional[list] = None, exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
base_query = db.session.query(Conversation).filter(
Conversation.app_id == app_model.id,
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
)
if include_ids is not None:
base_query = base_query.filter(Conversation.id.in_(include_ids))
if exclude_ids is not None:
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
if last_id:
last_conversation = base_query.filter(
Conversation.id == last_id,
).first()
if not last_conversation:
raise LastConversationNotExistsError()
conversations = base_query.filter(
Conversation.created_at < last_conversation.created_at,
Conversation.id != last_conversation.id
).order_by(Conversation.created_at.desc()).limit(limit).all()
else:
conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all()
has_more = False
if len(conversations) == limit:
current_page_first_conversation = conversations[-1]
rest_count = base_query.filter(
Conversation.created_at < current_page_first_conversation.created_at,
Conversation.id != current_page_first_conversation.id
).count()
if rest_count > 0:
has_more = True
return InfiniteScrollPagination(
data=conversations,
limit=limit,
has_more=has_more
)
@classmethod
def rename(cls, app_model: App, conversation_id: str,
user: Optional[Union[Account | EndUser]], name: str):
conversation = cls.get_conversation(app_model, conversation_id, user)
conversation.name = name
db.session.commit()
return conversation
@classmethod
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account | EndUser]]):
conversation = db.session.query(Conversation) \
.filter(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
).first()
if not conversation:
raise ConversationNotExistsError()
return conversation
@classmethod
def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account | EndUser]]):
conversation = cls.get_conversation(app_model, conversation_id, user)
db.session.delete(conversation)
db.session.commit()

View File

@@ -0,0 +1,521 @@
import json
import logging
import datetime
import time
import random
from typing import Optional
from extensions.ext_redis import redis_client
from flask_login import current_user
from core.index.index_builder import IndexBuilder
from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted
from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin
from models.model import UploadFile
from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from tasks.document_indexing_task import document_indexing_task
class DatasetService:
@staticmethod
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None):
if user:
permission_filter = db.or_(Dataset.created_by == user.id,
Dataset.permission == 'all_team_members')
else:
permission_filter = Dataset.permission == 'all_team_members'
datasets = Dataset.query.filter(
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
.paginate(
page=page,
per_page=per_page,
max_per_page=100,
error_out=False
)
return datasets.items, datasets.total
@staticmethod
def get_process_rules(dataset_id):
# get the latest process rule
dataset_process_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.dataset_id == dataset_id). \
order_by(DatasetProcessRule.created_at.desc()). \
limit(1). \
one_or_none()
if dataset_process_rule:
mode = dataset_process_rule.mode
rules = dataset_process_rule.rules_dict
else:
mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules']
return {
'mode': mode,
'rules': rules
}
@staticmethod
def get_datasets_by_ids(ids, tenant_id):
datasets = Dataset.query.filter(Dataset.id.in_(ids),
Dataset.tenant_id == tenant_id).paginate(
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
return datasets.items, datasets.total
@staticmethod
def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional[str], account: Account):
# check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(
f'Dataset with name {name} already exists.')
dataset = Dataset(name=name, indexing_technique=indexing_technique, data_source_type='upload_file')
# dataset = Dataset(name=name, provider=provider, config=config)
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.tenant_id = tenant_id
db.session.add(dataset)
db.session.commit()
return dataset
@staticmethod
def get_dataset(dataset_id):
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if dataset is None:
return None
else:
return dataset
@staticmethod
def update_dataset(dataset_id, data, user):
dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_permission(dataset, user)
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
filtered_data['updated_by'] = user.id
filtered_data['updated_at'] = datetime.datetime.now()
dataset.query.filter_by(id=dataset_id).update(filtered_data)
db.session.commit()
return dataset
@staticmethod
def delete_dataset(dataset_id, user):
# todo: cannot delete dataset if it is being processed
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
return False
DatasetService.check_dataset_permission(dataset, user)
dataset_was_deleted.send(dataset)
db.session.delete(dataset)
db.session.commit()
return True
@staticmethod
def check_dataset_permission(dataset, user):
if dataset.tenant_id != user.current_tenant_id:
logging.debug(
f'User {user.id} does not have permission to access dataset {dataset.id}')
raise NoPermissionError(
'You do not have permission to access this dataset.')
if dataset.permission == 'only_me' and dataset.created_by != user.id:
logging.debug(
f'User {user.id} does not have permission to access dataset {dataset.id}')
raise NoPermissionError(
'You do not have permission to access this dataset.')
@staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
dataset_queries = DatasetQuery.query.filter_by(dataset_id=dataset_id) \
.order_by(db.desc(DatasetQuery.created_at)) \
.paginate(
page=page, per_page=per_page, max_per_page=100, error_out=False
)
return dataset_queries.items, dataset_queries.total
@staticmethod
def get_related_apps(dataset_id: str):
return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
.order_by(db.desc(AppDatasetJoin.created_at)).all()
class DocumentService:
DEFAULT_RULES = {
'mode': 'custom',
'rules': {
'pre_processing_rules': [
{'id': 'remove_extra_spaces', 'enabled': True},
{'id': 'remove_urls_emails', 'enabled': False}
],
'segmentation': {
'delimiter': '\n',
'max_tokens': 500
}
}
}
DOCUMENT_METADATA_SCHEMA = {
"book": {
"title": str,
"language": str,
"author": str,
"publisher": str,
"publication_date": str,
"isbn": str,
"category": str,
},
"web_page": {
"title": str,
"url": str,
"language": str,
"publish_date": str,
"author/publisher": str,
"topic/keywords": str,
"description": str,
},
"paper": {
"title": str,
"language": str,
"author": str,
"publish_date": str,
"journal/conference_name": str,
"volume/issue/page_numbers": str,
"doi": str,
"topic/keywords": str,
"abstract": str,
},
"social_media_post": {
"platform": str,
"author/username": str,
"publish_date": str,
"post_url": str,
"topic/tags": str,
},
"wikipedia_entry": {
"title": str,
"language": str,
"web_page_url": str,
"last_edit_date": str,
"editor/contributor": str,
"summary/introduction": str,
},
"personal_document": {
"title": str,
"author": str,
"creation_date": str,
"last_modified_date": str,
"document_type": str,
"tags/category": str,
},
"business_document": {
"title": str,
"author": str,
"creation_date": str,
"last_modified_date": str,
"document_type": str,
"department/team": str,
},
"im_chat_log": {
"chat_platform": str,
"chat_participants/group_name": str,
"start_date": str,
"end_date": str,
"summary": str,
},
"synced_from_notion": {
"title": str,
"language": str,
"author/creator": str,
"creation_date": str,
"last_modified_date": str,
"notion_page_link": str,
"category/tags": str,
"description": str,
},
"synced_from_github": {
"repository_name": str,
"repository_description": str,
"repository_owner/organization": str,
"code_filename": str,
"code_file_path": str,
"programming_language": str,
"github_link": str,
"open_source_license": str,
"commit_date": str,
"commit_author": str
}
}
@staticmethod
def get_document(dataset_id: str, document_id: str) -> Optional[Document]:
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
return document
@staticmethod
def get_document_file_detail(file_id: str):
file_detail = db.session.query(UploadFile). \
filter(UploadFile.id == file_id). \
one_or_none()
return file_detail
@staticmethod
def check_archived(document):
if document.archived:
return True
else:
return False
@staticmethod
def delete_document(document):
if document.indexing_status in ["parsing", "cleaning", "splitting", "indexing"]:
raise DocumentIndexingError()
# trigger document_was_deleted signal
document_was_deleted.send(document.id, dataset_id=document.dataset_id)
db.session.delete(document)
db.session.commit()
@staticmethod
def pause_document(document):
if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]:
raise DocumentIndexingError()
# update document to be paused
document.is_paused = True
document.paused_by = current_user.id
document.paused_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
# set document paused flag
indexing_cache_key = 'document_{}_is_paused'.format(document.id)
redis_client.setnx(indexing_cache_key, "True")
@staticmethod
def recover_document(document):
if not document.is_paused:
raise DocumentIndexingError()
# update document to be recover
document.is_paused = False
document.paused_by = current_user.id
document.paused_at = time.time()
db.session.add(document)
db.session.commit()
# delete paused flag
indexing_cache_key = 'document_{}_is_paused'.format(document.id)
redis_client.delete(indexing_cache_key)
# trigger async task
document_indexing_task.delay(document.dataset_id, document.id)
@staticmethod
def get_documents_position(dataset_id):
documents = Document.query.filter_by(dataset_id=dataset_id).all()
if documents:
return len(documents) + 1
else:
return 1
@staticmethod
def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'):
if not dataset.indexing_technique:
if 'indexing_technique' not in document_data \
or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is required")
dataset.indexing_technique = document_data["indexing_technique"]
if dataset.indexing_technique == 'high_quality':
IndexBuilder.get_default_service_context(dataset.tenant_id)
# save process rule
if not dataset_process_rule:
process_rule = document_data["process_rule"]
if process_rule["mode"] == "custom":
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule["mode"],
rules=json.dumps(process_rule["rules"]),
created_by=account.id
)
elif process_rule["mode"] == "automatic":
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule["mode"],
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id
)
db.session.add(dataset_process_rule)
db.session.commit()
file_name = ''
data_source_info = {}
if document_data["data_source"]["type"] == "upload_file":
file_id = document_data["data_source"]["info"]
file = db.session.query(UploadFile).filter(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id == file_id
).first()
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
# save document
position = DocumentService.get_documents_position(dataset.id)
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=position,
data_source_type=document_data["data_source"]["type"],
data_source_info=json.dumps(data_source_info),
dataset_process_rule_id=dataset_process_rule.id,
batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)),
name=file_name,
created_from=created_from,
created_by=account.id,
# created_api_request_id = db.Column(UUID, nullable=True)
)
db.session.add(document)
db.session.commit()
# trigger async task
document_indexing_task.delay(document.dataset_id, document.id)
return document
@staticmethod
def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
# save dataset
dataset = Dataset(
tenant_id=tenant_id,
name='',
data_source_type=document_data["data_source"]["type"],
indexing_technique=document_data["indexing_technique"],
created_by=account.id
)
db.session.add(dataset)
db.session.flush()
document = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
cut_length = 18
cut_name = document.name[:cut_length]
dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name
dataset.description = 'useful for when you want to answer queries about the ' + document.name
db.session.commit()
return dataset, document
@classmethod
def document_create_args_validate(cls, args: dict):
if 'data_source' not in args or not args['data_source']:
raise ValueError("Data source is required")
if not isinstance(args['data_source'], dict):
raise ValueError("Data source is invalid")
if 'type' not in args['data_source'] or not args['data_source']['type']:
raise ValueError("Data source type is required")
if args['data_source']['type'] not in Document.DATA_SOURCES:
raise ValueError("Data source type is invalid")
if args['data_source']['type'] == 'upload_file':
if 'info' not in args['data_source'] or not args['data_source']['info']:
raise ValueError("Data source info is required")
if 'process_rule' not in args or not args['process_rule']:
raise ValueError("Process rule is required")
if not isinstance(args['process_rule'], dict):
raise ValueError("Process rule is invalid")
if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:
raise ValueError("Process rule mode is required")
if args['process_rule']['mode'] not in DatasetProcessRule.MODES:
raise ValueError("Process rule mode is invalid")
if args['process_rule']['mode'] == 'automatic':
args['process_rule']['rules'] = {}
else:
if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:
raise ValueError("Process rule rules is required")
if not isinstance(args['process_rule']['rules'], dict):
raise ValueError("Process rule rules is invalid")
if 'pre_processing_rules' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['pre_processing_rules'] is None:
raise ValueError("Process rule pre_processing_rules is required")
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
raise ValueError("Process rule pre_processing_rules is invalid")
unique_pre_processing_rule_dicts = {}
for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:
if 'id' not in pre_processing_rule or not pre_processing_rule['id']:
raise ValueError("Process rule pre_processing_rules id is required")
if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:
raise ValueError("Process rule pre_processing_rules id is invalid")
if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:
raise ValueError("Process rule pre_processing_rules enabled is required")
if not isinstance(pre_processing_rule['enabled'], bool):
raise ValueError("Process rule pre_processing_rules enabled is invalid")
unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
if 'segmentation' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['segmentation'] is None:
raise ValueError("Process rule segmentation is required")
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
raise ValueError("Process rule segmentation is invalid")
if 'separator' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['separator']:
raise ValueError("Process rule segmentation separator is required")
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
raise ValueError("Process rule segmentation separator is invalid")
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['max_tokens']:
raise ValueError("Process rule segmentation max_tokens is required")
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
raise ValueError("Process rule segmentation max_tokens is invalid")

View File

@@ -0,0 +1,7 @@
# -*- coding:utf-8 -*-
__all__ = [
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
'app', 'completion'
]
from . import *

View File

@@ -0,0 +1,53 @@
from services.errors.base import BaseServiceError
class AccountNotFound(BaseServiceError):
pass
class AccountRegisterError(BaseServiceError):
pass
class AccountLoginError(BaseServiceError):
pass
class AccountNotLinkTenantError(BaseServiceError):
pass
class CurrentPasswordIncorrectError(BaseServiceError):
pass
class LinkAccountIntegrateError(BaseServiceError):
pass
class TenantNotFound(BaseServiceError):
pass
class AccountAlreadyInTenantError(BaseServiceError):
pass
class InvalidActionError(BaseServiceError):
pass
class CannotOperateSelfError(BaseServiceError):
pass
class NoPermissionError(BaseServiceError):
pass
class MemberNotInTenantError(BaseServiceError):
pass
class RoleAlreadyAssignedError(BaseServiceError):
pass

View File

@@ -0,0 +1,2 @@
class MoreLikeThisDisabledError(Exception):
pass

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class AppModelConfigBrokenError(BaseServiceError):
pass

View File

@@ -0,0 +1,3 @@
class BaseServiceError(Exception):
def __init__(self, description: str = None):
self.description = description

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class CompletionStoppedError(BaseServiceError):
pass

View File

@@ -0,0 +1,13 @@
from services.errors.base import BaseServiceError
class LastConversationNotExistsError(BaseServiceError):
pass
class ConversationNotExistsError(BaseServiceError):
pass
class ConversationCompletedError(Exception):
pass

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class DatasetNameDuplicateError(BaseServiceError):
pass

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class DocumentIndexingError(BaseServiceError):
pass

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class FileNotExistsError(BaseServiceError):
pass

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class IndexNotInitializedError(BaseServiceError):
pass

View File

@@ -0,0 +1,17 @@
from services.errors.base import BaseServiceError
class FirstMessageNotExistsError(BaseServiceError):
pass
class LastMessageNotExistsError(BaseServiceError):
pass
class MessageNotExistsError(BaseServiceError):
pass
class SuggestedQuestionsAfterAnswerDisabledError(BaseServiceError):
pass

View File

@@ -0,0 +1,130 @@
import logging
import time
from typing import List
import numpy as np
from llama_index.data_structs.node_v2 import NodeWithScore
from llama_index.indices.query.schema import QueryBundle
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
from sklearn.manifold import TSNE
from core.docstore.empty_docstore import EmptyDocumentStore
from core.index.vector_index import VectorIndex
from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery
from services.errors.index import IndexNotInitializedError
class HitTestingService:
@classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
index = VectorIndex(dataset=dataset).query_index
if not index:
raise IndexNotInitializedError()
index_query = GPTVectorStoreIndexQuery(
index_struct=index.index_struct,
service_context=index.service_context,
vector_store=index.query_context.get('vector_store'),
docstore=EmptyDocumentStore(),
response_synthesizer=None,
similarity_top_k=limit
)
query_bundle = QueryBundle(
query_str=query,
custom_embedding_strs=[query],
)
query_bundle.embedding = index.service_context.embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
start = time.perf_counter()
nodes = index_query.retrieve(query_bundle=query_bundle)
end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=query,
source='hit_testing',
created_by_role='account',
created_by=account.id
)
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(dataset, query_bundle, nodes)
@classmethod
def compact_retrieve_response(cls, dataset: Dataset, query_bundle: QueryBundle, nodes: List[NodeWithScore]):
embeddings = [
query_bundle.embedding
]
for node in nodes:
embeddings.append(node.node.embedding)
tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings)
query_position = tsne_position_data.pop(0)
i = 0
records = []
for node in nodes:
index_node_id = node.node.doc_id
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.enabled == True,
DocumentSegment.status == 'completed',
DocumentSegment.index_node_id == index_node_id
).first()
if not segment:
i += 1
continue
record = {
"segment": segment,
"score": node.score,
"tsne_position": tsne_position_data[i]
}
records.append(record)
i += 1
return {
"query": {
"content": query_bundle.query_str,
"tsne_position": query_position,
},
"records": records
}
@classmethod
def get_tsne_positions_from_embeddings(cls, embeddings: list):
embedding_length = len(embeddings)
if embedding_length <= 1:
return [{'x': 0, 'y': 0}]
concatenate_data = np.array(embeddings).reshape(embedding_length, -1)
# concatenate_data = np.concatenate(embeddings)
perplexity = embedding_length / 2 + 1
if perplexity >= embedding_length:
perplexity = max(embedding_length - 1, 1)
tsne = TSNE(n_components=2, perplexity=perplexity, early_exaggeration=12.0)
data_tsne = tsne.fit_transform(concatenate_data)
tsne_position_data = []
for i in range(len(data_tsne)):
tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])})
return tsne_position_data

View File

@@ -0,0 +1,212 @@
from typing import Optional, Union, List
from core.completion import Completion
from core.generator.llm_generator import LLMGenerator
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser, Message, MessageFeedback
from services.conversation_service import ConversationService
from services.errors.message import FirstMessageNotExistsError, MessageNotExistsError, LastMessageNotExistsError, \
SuggestedQuestionsAfterAnswerDisabledError
class MessageService:
@classmethod
def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
if not conversation_id:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
conversation = ConversationService.get_conversation(
app_model=app_model,
user=user,
conversation_id=conversation_id
)
if first_id:
first_message = db.session.query(Message) \
.filter(Message.conversation_id == conversation.id, Message.id == first_id).first()
if not first_message:
raise FirstMessageNotExistsError()
history_messages = db.session.query(Message).filter(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id
) \
.order_by(Message.created_at.desc()).limit(limit).all()
else:
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
.order_by(Message.created_at.desc()).limit(limit).all()
has_more = False
if len(history_messages) == limit:
current_page_first_message = history_messages[-1]
rest_count = db.session.query(Message).filter(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id
).count()
if rest_count > 0:
has_more = True
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(
data=history_messages,
limit=limit,
has_more=has_more
)
@classmethod
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
last_id: Optional[str], limit: int, conversation_id: Optional[str] = None,
include_ids: Optional[list] = None) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
base_query = db.session.query(Message)
if conversation_id is not None:
conversation = ConversationService.get_conversation(
app_model=app_model,
user=user,
conversation_id=conversation_id
)
base_query = base_query.filter(Message.conversation_id == conversation.id)
if include_ids is not None:
base_query = base_query.filter(Message.id.in_(include_ids))
if last_id:
last_message = base_query.filter(Message.id == last_id).first()
if not last_message:
raise LastMessageNotExistsError()
history_messages = base_query.filter(
Message.created_at < last_message.created_at,
Message.id != last_message.id
).order_by(Message.created_at.desc()).limit(limit).all()
else:
history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all()
has_more = False
if len(history_messages) == limit:
current_page_first_message = history_messages[-1]
rest_count = base_query.filter(
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id
).count()
if rest_count > 0:
has_more = True
return InfiniteScrollPagination(
data=history_messages,
limit=limit,
has_more=has_more
)
@classmethod
def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account | EndUser]],
rating: Optional[str]) -> MessageFeedback:
if not user:
raise ValueError('user cannot be None')
message = cls.get_message(
app_model=app_model,
user=user,
message_id=message_id
)
feedback = message.user_feedback
if not rating and feedback:
db.session.delete(feedback)
elif rating and feedback:
feedback.rating = rating
elif not rating and not feedback:
raise ValueError('rating cannot be None when feedback not exists')
else:
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating,
from_source=('user' if isinstance(user, EndUser) else 'admin'),
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
from_account_id=(user.id if isinstance(user, Account) else None),
)
db.session.add(feedback)
db.session.commit()
return feedback
@classmethod
def get_message(cls, app_model: App, user: Optional[Union[Account | EndUser]], message_id: str):
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
).first()
if not message:
raise MessageNotExistsError()
return message
@classmethod
def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account | EndUser]],
message_id: str, check_enabled: bool = True) -> List[Message]:
if not user:
raise ValueError('user cannot be None')
app_model_config = app_model.app_model_config
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
if check_enabled and suggested_questions_after_answer.get("enabled", False) is False:
raise SuggestedQuestionsAfterAnswerDisabledError()
message = cls.get_message(
app_model=app_model,
user=user,
message_id=message_id
)
conversation = ConversationService.get_conversation(
app_model=app_model,
conversation_id=message.conversation_id,
user=user
)
# get memory of conversation (read-only)
memory = Completion.get_memory_from_conversation(
tenant_id=app_model.tenant_id,
app_model_config=app_model.app_model_config,
conversation=conversation,
max_token_limit=3000,
message_limit=3,
return_messages=False,
memory_key="histories"
)
external_context = memory.load_memory_variables({})
questions = LLMGenerator.generate_suggested_questions_after_answer(
tenant_id=app_model.tenant_id,
**external_context
)
return questions

View File

@@ -0,0 +1,96 @@
from typing import Union
from flask import current_app
from core.llm.provider.llm_provider_service import LLMProviderService
from models.account import Tenant
from models.provider import *
class ProviderService:
@staticmethod
def init_supported_provider(tenant, edition):
"""Initialize the model provider, check whether the supported provider has a record"""
providers = Provider.query.filter_by(tenant_id=tenant.id).all()
openai_provider_exists = False
azure_openai_provider_exists = False
# TODO: The cloud version needs to construct the data of the SYSTEM type
for provider in providers:
if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
openai_provider_exists = True
if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
azure_openai_provider_exists = True
# Initialize the model provider, check whether the supported provider has a record
# Create default providers if they don't exist
if not openai_provider_exists:
openai_provider = Provider(
tenant_id=tenant.id,
provider_name=ProviderName.OPENAI.value,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(openai_provider)
if not azure_openai_provider_exists:
azure_openai_provider = Provider(
tenant_id=tenant.id,
provider_name=ProviderName.AZURE_OPENAI.value,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(azure_openai_provider)
if not openai_provider_exists or not azure_openai_provider_exists:
db.session.commit()
@staticmethod
def get_obfuscated_api_key(tenant, provider_name: ProviderName):
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
return llm_provider_service.get_provider_configs(obfuscated=True)
@staticmethod
def get_token_type(tenant, provider_name: ProviderName):
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
return llm_provider_service.get_token_type()
@staticmethod
def validate_provider_configs(tenant, provider_name: ProviderName, configs: Union[dict | str]):
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
return llm_provider_service.config_validate(configs)
@staticmethod
def get_encrypted_token(tenant, provider_name: ProviderName, configs: Union[dict | str]):
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
return llm_provider_service.get_encrypted_token(configs)
@staticmethod
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value,
is_valid: bool = True):
if current_app.config['EDITION'] != 'CLOUD':
return
provider = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name == provider_name,
Provider.provider_type == ProviderType.SYSTEM.value
).one_or_none()
if not provider:
provider = Provider(
tenant_id=tenant.id,
provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=200,
encrypted_config='',
is_valid=is_valid,
)
db.session.add(provider)
db.session.commit()

View File

@@ -0,0 +1,66 @@
from typing import Optional
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.model import App, EndUser
from models.web import SavedMessage
from services.message_service import MessageService
class SavedMessageService:
@classmethod
def pagination_by_last_id(cls, app_model: App, end_user: Optional[EndUser],
last_id: Optional[str], limit: int) -> InfiniteScrollPagination:
saved_messages = db.session.query(SavedMessage).filter(
SavedMessage.app_id == app_model.id,
SavedMessage.created_by == end_user.id
).order_by(SavedMessage.created_at.desc()).all()
message_ids = [sm.message_id for sm in saved_messages]
return MessageService.pagination_by_last_id(
app_model=app_model,
user=end_user,
last_id=last_id,
limit=limit,
include_ids=message_ids
)
@classmethod
def save(cls, app_model: App, user: Optional[EndUser], message_id: str):
saved_message = db.session.query(SavedMessage).filter(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by == user.id
).first()
if saved_message:
return
message = MessageService.get_message(
app_model=app_model,
user=user,
message_id=message_id
)
saved_message = SavedMessage(
app_id=app_model.id,
message_id=message.id,
created_by=user.id
)
db.session.add(saved_message)
db.session.commit()
@classmethod
def delete(cls, app_model: App, user: Optional[EndUser], message_id: str):
saved_message = db.session.query(SavedMessage).filter(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by == user.id
).first()
if not saved_message:
return
db.session.delete(saved_message)
db.session.commit()

View File

@@ -0,0 +1,74 @@
from typing import Optional, Union
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.model import App, EndUser
from models.web import PinnedConversation
from services.conversation_service import ConversationService
class WebConversationService:
@classmethod
def pagination_by_last_id(cls, app_model: App, end_user: Optional[EndUser],
last_id: Optional[str], limit: int, pinned: Optional[bool] = None) -> InfiniteScrollPagination:
include_ids = None
exclude_ids = None
if pinned is not None:
pinned_conversations = db.session.query(PinnedConversation).filter(
PinnedConversation.app_id == app_model.id,
PinnedConversation.created_by == end_user.id
).order_by(PinnedConversation.created_at.desc()).all()
pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations]
if pinned:
include_ids = pinned_conversation_ids
else:
exclude_ids = pinned_conversation_ids
return ConversationService.pagination_by_last_id(
app_model=app_model,
user=end_user,
last_id=last_id,
limit=limit,
include_ids=include_ids,
exclude_ids=exclude_ids
)
@classmethod
def pin(cls, app_model: App, conversation_id: str, user: Optional[EndUser]):
pinned_conversation = db.session.query(PinnedConversation).filter(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by == user.id
).first()
if pinned_conversation:
return
conversation = ConversationService.get_conversation(
app_model=app_model,
conversation_id=conversation_id,
user=user
)
pinned_conversation = PinnedConversation(
app_id=app_model.id,
conversation_id=conversation.id,
created_by=user.id
)
db.session.add(pinned_conversation)
db.session.commit()
@classmethod
def unpin(cls, app_model: App, conversation_id: str, user: Optional[EndUser]):
pinned_conversation = db.session.query(PinnedConversation).filter(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by == user.id
).first()
if not pinned_conversation:
return
db.session.delete(pinned_conversation)
db.session.commit()

View File

@@ -0,0 +1,49 @@
from extensions.ext_database import db
from models.account import Tenant
from models.provider import Provider, ProviderType
class WorkspaceService:
@classmethod
def get_tenant_info(cls, tenant: Tenant):
tenant_info = {
'id': tenant.id,
'name': tenant.name,
'plan': tenant.plan,
'status': tenant.status,
'created_at': tenant.created_at,
'providers': [],
'in_trail': False,
'trial_end_reason': 'using_custom'
}
# Get providers
providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id
).all()
# Add providers to the tenant info
tenant_info['providers'] = providers
custom_provider = None
system_provider = None
for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value:
if provider.is_valid and provider.encrypted_config:
custom_provider = provider
elif provider.provider_type == ProviderType.SYSTEM.value:
if provider.is_valid:
system_provider = provider
if system_provider and not custom_provider:
quota_used = system_provider.quota_used if system_provider.quota_used is not None else 0
quota_limit = system_provider.quota_limit if system_provider.quota_limit is not None else 0
if quota_used >= quota_limit:
tenant_info['trial_end_reason'] = 'trial_exceeded'
else:
tenant_info['in_trail'] = True
tenant_info['trial_end_reason'] = None
return tenant_info