mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-19 07:46:53 +08:00
feat: backend model load balancing support (#4927)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import os.path
|
||||
|
||||
from core.helper.position_helper import get_position_map, sort_by_position_map
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
from core.utils.position_helper import get_position_map, sort_by_position_map
|
||||
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
|
||||
@@ -2,6 +2,7 @@ from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
@@ -14,7 +15,6 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
@@ -82,7 +82,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
return {}
|
||||
|
||||
return self.credentials_schema.copy()
|
||||
|
||||
|
||||
def get_tools(self) -> list[Tool]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
@@ -127,7 +127,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
|
||||
@property
|
||||
def tool_labels(self) -> list[str]:
|
||||
"""
|
||||
@@ -137,7 +137,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
label_enums = self._get_tool_labels()
|
||||
return [default_tool_label_dict[label].name for label in label_enums]
|
||||
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
"""
|
||||
returns the labels of the provider
|
||||
|
||||
@@ -10,6 +10,7 @@ from flask import current_app
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@@ -31,7 +32,6 @@ from core.tools.utils.configuration import (
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
@@ -102,10 +102,10 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime(cls, provider_type: str,
|
||||
def get_tool_runtime(cls, provider_type: str,
|
||||
provider_id: str,
|
||||
tool_name: str,
|
||||
tenant_id: str,
|
||||
tool_name: str,
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
@@ -222,7 +222,7 @@ class ToolManager:
|
||||
get the agent tool runtime
|
||||
"""
|
||||
tool_entity = cls.get_tool_runtime(
|
||||
provider_type=agent_tool.provider_type,
|
||||
provider_type=agent_tool.provider_type,
|
||||
provider_id=agent_tool.provider_id,
|
||||
tool_name=agent_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
@@ -235,7 +235,7 @@ class ToolManager:
|
||||
# check file types
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
||||
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# save tool parameter to tool entity memory
|
||||
value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
|
||||
@@ -403,7 +403,7 @@ class ToolManager:
|
||||
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers()
|
||||
|
||||
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
|
||||
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
@@ -428,7 +428,7 @@ class ToolManager:
|
||||
if 'api' in filters:
|
||||
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
|
||||
filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
|
||||
api_provider_controllers = [{
|
||||
'provider': provider,
|
||||
'controller': ToolTransformService.api_provider_to_controller(provider)
|
||||
@@ -450,7 +450,7 @@ class ToolManager:
|
||||
# get workflow providers
|
||||
workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \
|
||||
filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
|
||||
workflow_provider_controllers = []
|
||||
for provider in workflow_providers:
|
||||
try:
|
||||
@@ -460,7 +460,7 @@ class ToolManager:
|
||||
except Exception as e:
|
||||
# app has been deleted
|
||||
pass
|
||||
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
|
||||
@@ -73,10 +73,8 @@ class ModelInvocationUtils:
|
||||
if not model_instance:
|
||||
raise InvokeModelError('Model not found')
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
|
||||
# get tokens
|
||||
tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages)
|
||||
tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return tokens
|
||||
|
||||
@@ -108,13 +106,8 @@ class ModelInvocationUtils:
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
|
||||
# get model credentials
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# get prompt tokens
|
||||
prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages)
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
model_parameters = {
|
||||
'temperature': 0.8,
|
||||
@@ -144,9 +137,7 @@ class ModelInvocationUtils:
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response: LLMResult = llm_model.invoke(
|
||||
model=model_instance.model,
|
||||
credentials=model_credentials,
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[], stop=[], stream=False, user=user_id, callbacks=[]
|
||||
@@ -176,4 +167,4 @@ class ModelInvocationUtils:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return response
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user