mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-11 03:46:52 +08:00
feat: backend model load balancing support (#4927)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
from core.utils.module_import_helper import import_module_from_source, load_single_subclass_from_source
|
||||
from core.helper.module_import_helper import import_module_from_source, load_single_subclass_from_source
|
||||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
|
||||
@@ -92,7 +92,8 @@ def test_execute_llm(setup_openai_mock):
|
||||
provider=CustomProviderConfiguration(
|
||||
credentials=credentials
|
||||
)
|
||||
)
|
||||
),
|
||||
model_settings=[]
|
||||
),
|
||||
provider_instance=provider_instance,
|
||||
model_type_instance=model_type_instance
|
||||
@@ -206,10 +207,11 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
||||
provider=CustomProviderConfiguration(
|
||||
credentials=credentials
|
||||
)
|
||||
)
|
||||
),
|
||||
model_settings=[]
|
||||
),
|
||||
provider_instance=provider_instance,
|
||||
model_type_instance=model_type_instance
|
||||
model_type_instance=model_type_instance,
|
||||
)
|
||||
|
||||
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
|
||||
|
||||
@@ -42,7 +42,8 @@ def get_mocked_fetch_model_config(
|
||||
provider=CustomProviderConfiguration(
|
||||
credentials=credentials
|
||||
)
|
||||
)
|
||||
),
|
||||
model_settings=[]
|
||||
),
|
||||
provider_instance=provider_instance,
|
||||
model_type_instance=model_type_instance
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
|
||||
@@ -22,8 +23,16 @@ def test__calculate_rest_token():
|
||||
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
|
||||
large_language_model_mock.get_num_tokens.return_value = 6
|
||||
|
||||
provider_mock = MagicMock(spec=ProviderEntity)
|
||||
provider_mock.provider = 'openai'
|
||||
|
||||
provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
|
||||
provider_configuration_mock.provider = provider_mock
|
||||
provider_configuration_mock.model_settings = None
|
||||
|
||||
provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
|
||||
provider_model_bundle_mock.model_type_instance = large_language_model_mock
|
||||
provider_model_bundle_mock.configuration = provider_configuration_mock
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.model = 'gpt-4'
|
||||
|
||||
77
api/tests/unit_tests/core/test_model_manager.py
Normal file
77
api/tests/unit_tests/core/test_model_manager.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.model_manager import LBModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lb_model_manager():
|
||||
load_balancing_configs = [
|
||||
ModelLoadBalancingConfiguration(
|
||||
id='id1',
|
||||
name='__inherit__',
|
||||
credentials={}
|
||||
),
|
||||
ModelLoadBalancingConfiguration(
|
||||
id='id2',
|
||||
name='first',
|
||||
credentials={"openai_api_key": "fake_key"}
|
||||
),
|
||||
ModelLoadBalancingConfiguration(
|
||||
id='id3',
|
||||
name='second',
|
||||
credentials={"openai_api_key": "fake_key"}
|
||||
)
|
||||
]
|
||||
|
||||
lb_model_manager = LBModelManager(
|
||||
tenant_id='tenant_id',
|
||||
provider='openai',
|
||||
model_type=ModelType.LLM,
|
||||
model='gpt-4',
|
||||
load_balancing_configs=load_balancing_configs,
|
||||
managed_credentials={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
lb_model_manager.cooldown = MagicMock(return_value=None)
|
||||
|
||||
def is_cooldown(config: ModelLoadBalancingConfiguration):
|
||||
if config.id == 'id1':
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
lb_model_manager.in_cooldown = MagicMock(side_effect=is_cooldown)
|
||||
|
||||
return lb_model_manager
|
||||
|
||||
|
||||
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
|
||||
assert len(lb_model_manager._load_balancing_configs) == 3
|
||||
|
||||
config1 = lb_model_manager._load_balancing_configs[0]
|
||||
config2 = lb_model_manager._load_balancing_configs[1]
|
||||
config3 = lb_model_manager._load_balancing_configs[2]
|
||||
|
||||
assert lb_model_manager.in_cooldown(config1) is True
|
||||
assert lb_model_manager.in_cooldown(config2) is False
|
||||
assert lb_model_manager.in_cooldown(config3) is False
|
||||
|
||||
start_index = 0
|
||||
def incr(key):
|
||||
nonlocal start_index
|
||||
start_index += 1
|
||||
return start_index
|
||||
|
||||
mocker.patch('redis.Redis.incr', side_effect=incr)
|
||||
mocker.patch('redis.Redis.set', return_value=None)
|
||||
mocker.patch('redis.Redis.expire', return_value=None)
|
||||
|
||||
config = lb_model_manager.fetch_next()
|
||||
assert config == config2
|
||||
|
||||
config = lb_model_manager.fetch_next()
|
||||
assert config == config3
|
||||
183
api/tests/unit_tests/core/test_provider_manager.py
Normal file
183
api/tests/unit_tests/core/test_provider_manager.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from core.entities.provider_entities import ModelSettings
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||
|
||||
|
||||
def test__to_model_settings(mocker):
|
||||
# Get all provider entities
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
|
||||
provider_entity = None
|
||||
for provider in provider_entities:
|
||||
if provider.provider == 'openai':
|
||||
provider_entity = provider
|
||||
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [ProviderModelSetting(
|
||||
id='id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
enabled=True,
|
||||
load_balancing_enabled=True
|
||||
)]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id='id1',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
name='__inherit__',
|
||||
encrypted_config=None,
|
||||
enabled=True
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id='id2',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
name='first',
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True
|
||||
)
|
||||
]
|
||||
|
||||
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity,
|
||||
provider_model_settings,
|
||||
load_balancing_model_configs
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == 'gpt-4'
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 2
|
||||
assert result[0].load_balancing_configs[0].name == '__inherit__'
|
||||
assert result[0].load_balancing_configs[1].name == 'first'
|
||||
|
||||
|
||||
def test__to_model_settings_only_one_lb(mocker):
|
||||
# Get all provider entities
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
|
||||
provider_entity = None
|
||||
for provider in provider_entities:
|
||||
if provider.provider == 'openai':
|
||||
provider_entity = provider
|
||||
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [ProviderModelSetting(
|
||||
id='id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
enabled=True,
|
||||
load_balancing_enabled=True
|
||||
)]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id='id1',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
name='__inherit__',
|
||||
encrypted_config=None,
|
||||
enabled=True
|
||||
)
|
||||
]
|
||||
|
||||
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity,
|
||||
provider_model_settings,
|
||||
load_balancing_model_configs
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == 'gpt-4'
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
|
||||
def test__to_model_settings_lb_disabled(mocker):
|
||||
# Get all provider entities
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
|
||||
provider_entity = None
|
||||
for provider in provider_entities:
|
||||
if provider.provider == 'openai':
|
||||
provider_entity = provider
|
||||
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [ProviderModelSetting(
|
||||
id='id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
enabled=True,
|
||||
load_balancing_enabled=False
|
||||
)]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id='id1',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
name='__inherit__',
|
||||
encrypted_config=None,
|
||||
enabled=True
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id='id2',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
name='first',
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True
|
||||
)
|
||||
]
|
||||
|
||||
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity,
|
||||
provider_model_settings,
|
||||
load_balancing_model_configs
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == 'gpt-4'
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
@@ -2,7 +2,7 @@ from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
|
||||
from core.utils.position_helper import get_position_map
|
||||
from core.helper.position_helper import get_position_map
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Reference in New Issue
Block a user