feat: backend model load balancing support (#4927)

This commit is contained in:
takatost
2024-06-05 00:13:04 +08:00
committed by GitHub
parent 52ec152dd3
commit d1dbbc1e33
47 changed files with 2191 additions and 256 deletions

View File

@@ -1,7 +1,7 @@
import os.path
from core.helper.position_helper import get_position_map, sort_by_position_map
from core.tools.entities.api_entities import UserToolProvider
from core.utils.position_helper import get_position_map, sort_by_position_map
class BuiltinToolProviderSort:

View File

@@ -2,6 +2,7 @@ from abc import abstractmethod
from os import listdir, path
from typing import Any
from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
@@ -14,7 +15,6 @@ from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.module_import_helper import load_single_subclass_from_source
class BuiltinToolProviderController(ToolProviderController):
@@ -82,7 +82,7 @@ class BuiltinToolProviderController(ToolProviderController):
return {}
return self.credentials_schema.copy()
def get_tools(self) -> list[Tool]:
"""
returns a list of tools that the provider can provide
@@ -127,7 +127,7 @@ class BuiltinToolProviderController(ToolProviderController):
:return: type of the provider
"""
return ToolProviderType.BUILT_IN
@property
def tool_labels(self) -> list[str]:
"""
@@ -137,7 +137,7 @@ class BuiltinToolProviderController(ToolProviderController):
"""
label_enums = self._get_tool_labels()
return [default_tool_label_dict[label].name for label in label_enums]
def _get_tool_labels(self) -> list[ToolLabelEnum]:
"""
returns the labels of the provider

View File

@@ -10,6 +10,7 @@ from flask import current_app
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.common_entities import I18nObject
@@ -31,7 +32,6 @@ from core.tools.utils.configuration import (
ToolParameterConfigurationManager,
)
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.utils.module_import_helper import load_single_subclass_from_source
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
@@ -102,10 +102,10 @@ class ToolManager:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@classmethod
def get_tool_runtime(cls, provider_type: str,
def get_tool_runtime(cls, provider_type: str,
provider_id: str,
tool_name: str,
tenant_id: str,
tool_name: str,
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-> Union[BuiltinTool, ApiTool]:
@@ -222,7 +222,7 @@ class ToolManager:
get the agent tool runtime
"""
tool_entity = cls.get_tool_runtime(
provider_type=agent_tool.provider_type,
provider_type=agent_tool.provider_type,
provider_id=agent_tool.provider_id,
tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
@@ -235,7 +235,7 @@ class ToolManager:
# check file types
if parameter.type == ToolParameter.ToolParameterType.FILE:
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory
value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
@@ -403,7 +403,7 @@ class ToolManager:
# get builtin providers
builtin_providers = cls.list_builtin_providers()
# get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
@@ -428,7 +428,7 @@ class ToolManager:
if 'api' in filters:
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
filter(ApiToolProvider.tenant_id == tenant_id).all()
api_provider_controllers = [{
'provider': provider,
'controller': ToolTransformService.api_provider_to_controller(provider)
@@ -450,7 +450,7 @@ class ToolManager:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \
filter(WorkflowToolProvider.tenant_id == tenant_id).all()
workflow_provider_controllers = []
for provider in workflow_providers:
try:
@@ -460,7 +460,7 @@ class ToolManager:
except Exception as e:
# app has been deleted
pass
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
for provider_controller in workflow_provider_controllers:

View File

@@ -73,10 +73,8 @@ class ModelInvocationUtils:
if not model_instance:
raise InvokeModelError('Model not found')
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
# get tokens
tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages)
tokens = model_instance.get_llm_num_tokens(prompt_messages)
return tokens
@@ -108,13 +106,8 @@ class ModelInvocationUtils:
tenant_id=tenant_id, model_type=ModelType.LLM,
)
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
# get model credentials
model_credentials = model_instance.credentials
# get prompt tokens
prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages)
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
model_parameters = {
'temperature': 0.8,
@@ -144,9 +137,7 @@ class ModelInvocationUtils:
db.session.commit()
try:
response: LLMResult = llm_model.invoke(
model=model_instance.model,
credentials=model_credentials,
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[], stop=[], stream=False, user=user_id, callbacks=[]
@@ -176,4 +167,4 @@ class ModelInvocationUtils:
db.session.commit()
return response
return response