feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

@@ -6,9 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.callback_handler.entity.llm_message import LLMMessage
from core.callback_handler.entity.chain_result import ChainResult
from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
from core.llm.provider.llm_provider_service import LLMProviderService
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from events.message_event import message_was_created
@@ -16,12 +16,11 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
from models.provider import ProviderType, Provider
class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False):
self.task_id = task_id
@@ -38,9 +37,12 @@ class ConversationMessageTask:
self.conversation = conversation
self.is_new_conversation = False
self.model_instance = model_instance
self.message = None
self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider')
self.model_name = self.model_dict.get('name')
self.mode = app.mode
@@ -56,9 +58,6 @@ class ConversationMessageTask:
)
def init(self):
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
self.model_dict['provider'] = provider_name
override_model_configs = None
if self.is_override:
override_model_configs = {
@@ -89,15 +88,19 @@ class ConversationMessageTask:
if self.app_model_config.pre_prompt:
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_provider_name=self.provider_name,
model_name=self.model_name
)
system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
if not self.conversation:
self.is_new_conversation = True
self.conversation = Conversation(
app_id=self.app_model_config.app_id,
app_model_config_id=self.app_model_config.id,
model_provider=self.model_dict.get('provider'),
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=self.mode,
@@ -117,7 +120,7 @@ class ConversationMessageTask:
self.message = Message(
app_id=self.app_model_config.app_id,
model_provider=self.model_dict.get('provider'),
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=self.conversation.id,
@@ -131,7 +134,7 @@ class ConversationMessageTask:
answer_unit_price=0,
provider_response_latency=0,
total_price=0,
currency=llm_constant.model_currency,
currency=self.model_instance.get_currency(),
from_source=('console' if isinstance(self.user, Account) else 'api'),
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
@@ -145,12 +148,10 @@ class ConversationMessageTask:
self._pub_handler.pub_text(text)
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
model_name = self.app_model_config.model_dict.get('name')
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = llm_constant.model_prices[model_name]['prompt']
answer_unit_price = llm_constant.model_prices[model_name]['completion']
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
@@ -163,8 +164,6 @@ class ConversationMessageTask:
self.message.provider_response_latency = llm_message.latency
self.message.total_price = total_price
self.update_provider_quota()
db.session.commit()
message_was_created.send(
@@ -176,20 +175,6 @@ class ConversationMessageTask:
if not by_stopped:
self.end()
def update_provider_quota(self):
llm_provider_service = LLMProviderService(
tenant_id=self.app.tenant_id,
provider_name=self.message.model_provider,
)
provider = llm_provider_service.get_provider_db_record()
if provider and provider.provider_type == ProviderType.SYSTEM.value:
db.session.query(Provider).filter(
Provider.tenant_id == self.app.tenant_id,
Provider.provider_name == provider.provider_name,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + 1})
def init_chain(self, chain_result: ChainResult):
message_chain = MessageChain(
message_id=self.message.id,
@@ -229,10 +214,10 @@ class ConversationMessageTask:
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
@@ -253,7 +238,7 @@ class ConversationMessageTask:
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = llm_constant.model_currency
message_agent_thought.currency = agent_model_instant.get_currency()
db.session.flush()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):