mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
feat: backend model load balancing support (#4927)
This commit is contained in:
@@ -16,6 +16,7 @@ class ModelStatus(Enum):
|
||||
NO_CONFIGURE = "no-configure"
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
NO_PERMISSION = "no-permission"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class SimpleModelProviderEntity(BaseModel):
|
||||
@@ -43,12 +44,19 @@ class SimpleModelProviderEntity(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ModelWithProviderEntity(ProviderModel):
|
||||
class ProviderModelWithStatusEntity(ProviderModel):
|
||||
"""
|
||||
Model class for model response.
|
||||
"""
|
||||
status: ModelStatus
|
||||
load_balancing_enabled: bool = False
|
||||
|
||||
|
||||
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
|
||||
"""
|
||||
Model with provider entity.
|
||||
"""
|
||||
provider: SimpleModelProviderEntity
|
||||
status: ModelStatus
|
||||
|
||||
|
||||
class DefaultModelProviderEntity(BaseModel):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
@@ -8,7 +9,12 @@ from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
ModelSettings,
|
||||
SystemConfiguration,
|
||||
SystemConfigurationStatus,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
|
||||
@@ -22,7 +28,14 @@ from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
from extensions.ext_database import db
|
||||
from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
ProviderModel,
|
||||
ProviderModelSetting,
|
||||
ProviderType,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,6 +52,7 @@ class ProviderConfiguration(BaseModel):
|
||||
using_provider_type: ProviderType
|
||||
system_configuration: SystemConfiguration
|
||||
custom_configuration: CustomConfiguration
|
||||
model_settings: list[ModelSettings]
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -62,6 +76,14 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
if self.model_settings:
|
||||
# check if model is disabled by admin
|
||||
for model_setting in self.model_settings:
|
||||
if (model_setting.model_type == model_type
|
||||
and model_setting.model == model):
|
||||
if not model_setting.enabled:
|
||||
raise ValueError(f'Model {model} is disabled.')
|
||||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
restrict_models = []
|
||||
for quota_configuration in self.system_configuration.quota_configurations:
|
||||
@@ -80,15 +102,17 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return copy_credentials
|
||||
else:
|
||||
credentials = None
|
||||
if self.custom_configuration.models:
|
||||
for model_configuration in self.custom_configuration.models:
|
||||
if model_configuration.model_type == model_type and model_configuration.model == model:
|
||||
return model_configuration.credentials
|
||||
credentials = model_configuration.credentials
|
||||
break
|
||||
|
||||
if self.custom_configuration.provider:
|
||||
return self.custom_configuration.provider.credentials
|
||||
else:
|
||||
return None
|
||||
credentials = self.custom_configuration.provider.credentials
|
||||
|
||||
return credentials
|
||||
|
||||
def get_system_configuration_status(self) -> SystemConfigurationStatus:
|
||||
"""
|
||||
@@ -130,7 +154,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return credentials
|
||||
|
||||
# Obfuscate credentials
|
||||
return self._obfuscated_credentials(
|
||||
return self.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema else []
|
||||
@@ -151,7 +175,7 @@ class ProviderConfiguration(BaseModel):
|
||||
).first()
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema else []
|
||||
)
|
||||
@@ -274,7 +298,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return credentials
|
||||
|
||||
# Obfuscate credentials
|
||||
return self._obfuscated_credentials(
|
||||
return self.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema else []
|
||||
@@ -302,7 +326,7 @@ class ProviderConfiguration(BaseModel):
|
||||
).first()
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema else []
|
||||
)
|
||||
@@ -402,6 +426,160 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||
"""
|
||||
Enable model.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = True
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
enabled=True
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
return model_setting
|
||||
|
||||
def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||
"""
|
||||
Disable model.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = False
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
enabled=False
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
return model_setting
|
||||
|
||||
def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
|
||||
"""
|
||||
Get provider model setting.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(ProviderModelSetting) \
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
|
||||
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||
"""
|
||||
Enable model load balancing.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model
|
||||
).count()
|
||||
|
||||
if load_balancing_config_count <= 1:
|
||||
raise ValueError('Model load balancing configuration must be more than 1.')
|
||||
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = True
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
load_balancing_enabled=True
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
return model_setting
|
||||
|
||||
def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||
"""
|
||||
Disable model load balancing.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = False
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
load_balancing_enabled=False
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
return model_setting
|
||||
|
||||
def get_provider_instance(self) -> ModelProvider:
|
||||
"""
|
||||
Get provider instance.
|
||||
@@ -453,7 +631,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
||||
def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
||||
"""
|
||||
Extract secret input form variables.
|
||||
|
||||
@@ -467,7 +645,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return secret_input_form_variables
|
||||
|
||||
def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
|
||||
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
|
||||
"""
|
||||
Obfuscated credentials.
|
||||
|
||||
@@ -476,7 +654,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self._extract_secret_variables(
|
||||
credential_secret_variables = self.extract_secret_variables(
|
||||
credential_form_schemas
|
||||
)
|
||||
|
||||
@@ -522,15 +700,22 @@ class ProviderConfiguration(BaseModel):
|
||||
else:
|
||||
model_types = provider_instance.get_provider_schema().supported_model_types
|
||||
|
||||
# Group model settings by model type and model
|
||||
model_setting_map = defaultdict(dict)
|
||||
for model_setting in self.model_settings:
|
||||
model_setting_map[model_setting.model_type][model_setting.model] = model_setting
|
||||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
provider_models = self._get_system_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance
|
||||
provider_instance=provider_instance,
|
||||
model_setting_map=model_setting_map
|
||||
)
|
||||
else:
|
||||
provider_models = self._get_custom_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance
|
||||
provider_instance=provider_instance,
|
||||
model_setting_map=model_setting_map
|
||||
)
|
||||
|
||||
if only_active:
|
||||
@@ -541,18 +726,27 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def _get_system_provider_models(self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
|
||||
-> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get system provider models.
|
||||
|
||||
:param model_types: model types
|
||||
:param provider_instance: provider instance
|
||||
:param model_setting_map: model setting map
|
||||
:return:
|
||||
"""
|
||||
provider_models = []
|
||||
for model_type in model_types:
|
||||
provider_models.extend(
|
||||
[
|
||||
for m in provider_instance.models(model_type):
|
||||
status = ModelStatus.ACTIVE
|
||||
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
||||
model_setting = model_setting_map[m.model_type][m.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=m.model,
|
||||
label=m.label,
|
||||
@@ -562,11 +756,9 @@ class ProviderConfiguration(BaseModel):
|
||||
model_properties=m.model_properties,
|
||||
deprecated=m.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
status=status
|
||||
)
|
||||
for m in provider_instance.models(model_type)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
@@ -586,7 +778,8 @@ class ProviderConfiguration(BaseModel):
|
||||
break
|
||||
|
||||
if should_use_custom_model:
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
if original_provider_configurate_methods[self.provider.provider] == [
|
||||
ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
# only customizable model
|
||||
for restrict_model in restrict_models:
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
@@ -611,6 +804,13 @@ class ProviderConfiguration(BaseModel):
|
||||
if custom_model_schema.model_type not in model_types:
|
||||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
if (custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
|
||||
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=custom_model_schema.model,
|
||||
@@ -621,7 +821,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
status=status
|
||||
)
|
||||
)
|
||||
|
||||
@@ -632,16 +832,20 @@ class ProviderConfiguration(BaseModel):
|
||||
m.status = ModelStatus.NO_PERMISSION
|
||||
elif not quota_configuration.is_valid:
|
||||
m.status = ModelStatus.QUOTA_EXCEEDED
|
||||
|
||||
return provider_models
|
||||
|
||||
def _get_custom_provider_models(self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
|
||||
-> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get custom provider models.
|
||||
|
||||
:param model_types: model types
|
||||
:param provider_instance: provider instance
|
||||
:param model_setting_map: model setting map
|
||||
:return:
|
||||
"""
|
||||
provider_models = []
|
||||
@@ -656,6 +860,16 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
models = provider_instance.models(model_type)
|
||||
for m in models:
|
||||
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
||||
load_balancing_enabled = False
|
||||
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
||||
model_setting = model_setting_map[m.model_type][m.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
||||
if len(model_setting.load_balancing_configs) > 1:
|
||||
load_balancing_enabled = True
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=m.model,
|
||||
@@ -666,7 +880,8 @@ class ProviderConfiguration(BaseModel):
|
||||
model_properties=m.model_properties,
|
||||
deprecated=m.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
||||
status=status,
|
||||
load_balancing_enabled=load_balancing_enabled
|
||||
)
|
||||
)
|
||||
|
||||
@@ -690,6 +905,17 @@ class ProviderConfiguration(BaseModel):
|
||||
if not custom_model_schema:
|
||||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
load_balancing_enabled = False
|
||||
if (custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
|
||||
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
||||
if len(model_setting.load_balancing_configs) > 1:
|
||||
load_balancing_enabled = True
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=custom_model_schema.model,
|
||||
@@ -700,7 +926,8 @@ class ProviderConfiguration(BaseModel):
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
status=status,
|
||||
load_balancing_enabled=load_balancing_enabled
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -72,3 +72,22 @@ class CustomConfiguration(BaseModel):
|
||||
"""
|
||||
provider: Optional[CustomProviderConfiguration] = None
|
||||
models: list[CustomModelConfiguration] = []
|
||||
|
||||
|
||||
class ModelLoadBalancingConfiguration(BaseModel):
|
||||
"""
|
||||
Class for model load balancing configuration.
|
||||
"""
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict
|
||||
|
||||
|
||||
class ModelSettings(BaseModel):
|
||||
"""
|
||||
Model class for model settings.
|
||||
"""
|
||||
model: str
|
||||
model_type: ModelType
|
||||
enabled: bool = True
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
|
||||
|
||||
Reference in New Issue
Block a user