chore(api/services): apply ruff reformatting (#7599)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Bowen Liang
2024-08-26 13:43:57 +08:00
committed by GitHub
parent 979422cdc6
commit 17fd773a30
49 changed files with 2630 additions and 2655 deletions

View File

@@ -23,7 +23,6 @@ logger = logging.getLogger(__name__)
class ModelLoadBalancingService:
def __init__(self) -> None:
self.provider_manager = ProviderManager()
@@ -46,10 +45,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model load balancing
provider_configuration.enable_model_load_balancing(
model=model,
model_type=ModelType.value_of(model_type)
)
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
"""
@@ -70,13 +66,11 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# disable model load balancing
provider_configuration.disable_model_load_balancing(
model=model,
model_type=ModelType.value_of(model_type)
)
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \
-> tuple[bool, list[dict]]:
def get_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str
) -> tuple[bool, list[dict]]:
"""
Get load balancing configurations.
:param tenant_id: workspace id
@@ -107,20 +101,24 @@ class ModelLoadBalancingService:
is_load_balancing_enabled = True
# Get load balancing configurations
load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model
).order_by(LoadBalancingModelConfig.created_at).all()
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.order_by(LoadBalancingModelConfig.created_at)
.all()
)
if provider_configuration.custom_configuration.provider:
# check if the inherit configuration exists,
# inherit is represented for the provider or model custom credentials
inherit_config_exists = False
for load_balancing_config in load_balancing_configs:
if load_balancing_config.name == '__inherit__':
if load_balancing_config.name == "__inherit__":
inherit_config_exists = True
break
@@ -133,7 +131,7 @@ class ModelLoadBalancingService:
else:
# move the inherit configuration to the first
for i, load_balancing_config in enumerate(load_balancing_configs[:]):
if load_balancing_config.name == '__inherit__':
if load_balancing_config.name == "__inherit__":
inherit_config = load_balancing_configs.pop(i)
load_balancing_configs.insert(0, inherit_config)
@@ -151,7 +149,7 @@ class ModelLoadBalancingService:
provider=provider,
model=model,
model_type=model_type,
config_id=load_balancing_config.id
config_id=load_balancing_config.id,
)
try:
@@ -172,32 +170,32 @@ class ModelLoadBalancingService:
if variable in credentials:
try:
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa
)
except ValueError:
pass
# Obfuscate credentials
credentials = provider_configuration.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=credential_schemas.credential_form_schemas
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
)
datas.append({
'id': load_balancing_config.id,
'name': load_balancing_config.name,
'credentials': credentials,
'enabled': load_balancing_config.enabled,
'in_cooldown': in_cooldown,
'ttl': ttl
})
datas.append(
{
"id": load_balancing_config.id,
"name": load_balancing_config.name,
"credentials": credentials,
"enabled": load_balancing_config.enabled,
"in_cooldown": in_cooldown,
"ttl": ttl,
}
)
return is_load_balancing_enabled, datas
def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \
-> Optional[dict]:
def get_load_balancing_config(
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
) -> Optional[dict]:
"""
Get load balancing configuration.
:param tenant_id: workspace id
@@ -219,14 +217,17 @@ class ModelLoadBalancingService:
model_type = ModelType.value_of(model_type)
# Get load balancing configurations
load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id
).first()
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
)
if not load_balancing_model_config:
return None
@@ -244,19 +245,19 @@ class ModelLoadBalancingService:
# Obfuscate credentials
credentials = provider_configuration.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=credential_schemas.credential_form_schemas
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
)
return {
'id': load_balancing_model_config.id,
'name': load_balancing_model_config.name,
'credentials': credentials,
'enabled': load_balancing_model_config.enabled
"id": load_balancing_model_config.id,
"name": load_balancing_model_config.name,
"credentials": credentials,
"enabled": load_balancing_model_config.enabled,
}
def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \
-> LoadBalancingModelConfig:
def _init_inherit_config(
self, tenant_id: str, provider: str, model: str, model_type: ModelType
) -> LoadBalancingModelConfig:
"""
Initialize the inherit configuration.
:param tenant_id: workspace id
@@ -271,18 +272,16 @@ class ModelLoadBalancingService:
provider_name=provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
name='__inherit__'
name="__inherit__",
)
db.session.add(inherit_config)
db.session.commit()
return inherit_config
def update_load_balancing_configs(self, tenant_id: str,
provider: str,
model: str,
model_type: str,
configs: list[dict]) -> None:
def update_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict]
) -> None:
"""
Update load balancing configurations.
:param tenant_id: workspace id
@@ -304,15 +303,18 @@ class ModelLoadBalancingService:
model_type = ModelType.value_of(model_type)
if not isinstance(configs, list):
raise ValueError('Invalid load balancing configs')
raise ValueError("Invalid load balancing configs")
current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
current_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model
).all()
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.all()
)
# id as key, config as value
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
@@ -320,25 +322,25 @@ class ModelLoadBalancingService:
for config in configs:
if not isinstance(config, dict):
raise ValueError('Invalid load balancing config')
raise ValueError("Invalid load balancing config")
config_id = config.get('id')
name = config.get('name')
credentials = config.get('credentials')
enabled = config.get('enabled')
config_id = config.get("id")
name = config.get("name")
credentials = config.get("credentials")
enabled = config.get("enabled")
if not name:
raise ValueError('Invalid load balancing config name')
raise ValueError("Invalid load balancing config name")
if enabled is None:
raise ValueError('Invalid load balancing config enabled')
raise ValueError("Invalid load balancing config enabled")
# is config exists
if config_id:
config_id = str(config_id)
if config_id not in current_load_balancing_configs_dict:
raise ValueError('Invalid load balancing config id: {}'.format(config_id))
raise ValueError("Invalid load balancing config id: {}".format(config_id))
updated_config_ids.add(config_id)
@@ -347,11 +349,11 @@ class ModelLoadBalancingService:
# check duplicate name
for current_load_balancing_config in current_load_balancing_configs:
if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
raise ValueError('Load balancing config name {} already exists'.format(name))
raise ValueError("Load balancing config name {} already exists".format(name))
if credentials:
if not isinstance(credentials, dict):
raise ValueError('Invalid load balancing config credentials')
raise ValueError("Invalid load balancing config credentials")
# validate custom provider config
credentials = self._custom_credentials_validate(
@@ -361,7 +363,7 @@ class ModelLoadBalancingService:
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_config,
validate=False
validate=False,
)
# update load balancing config
@@ -375,19 +377,19 @@ class ModelLoadBalancingService:
self._clear_credentials_cache(tenant_id, config_id)
else:
# create load balancing config
if name == '__inherit__':
raise ValueError('Invalid load balancing config name')
if name == "__inherit__":
raise ValueError("Invalid load balancing config name")
# check duplicate name
for current_load_balancing_config in current_load_balancing_configs:
if current_load_balancing_config.name == name:
raise ValueError('Load balancing config name {} already exists'.format(name))
raise ValueError("Load balancing config name {} already exists".format(name))
if not credentials:
raise ValueError('Invalid load balancing config credentials')
raise ValueError("Invalid load balancing config credentials")
if not isinstance(credentials, dict):
raise ValueError('Invalid load balancing config credentials')
raise ValueError("Invalid load balancing config credentials")
# validate custom provider config
credentials = self._custom_credentials_validate(
@@ -396,7 +398,7 @@ class ModelLoadBalancingService:
model_type=model_type,
model=model,
credentials=credentials,
validate=False
validate=False,
)
# create load balancing config
@@ -406,7 +408,7 @@ class ModelLoadBalancingService:
model_type=model_type.to_origin_model_type(),
model_name=model,
name=name,
encrypted_config=json.dumps(credentials)
encrypted_config=json.dumps(credentials),
)
db.session.add(load_balancing_model_config)
@@ -420,12 +422,15 @@ class ModelLoadBalancingService:
self._clear_credentials_cache(tenant_id, config_id)
def validate_load_balancing_credentials(self, tenant_id: str,
provider: str,
model: str,
model_type: str,
credentials: dict,
config_id: Optional[str] = None) -> None:
def validate_load_balancing_credentials(
self,
tenant_id: str,
provider: str,
model: str,
model_type: str,
credentials: dict,
config_id: Optional[str] = None,
) -> None:
"""
Validate load balancing credentials.
:param tenant_id: workspace id
@@ -450,14 +455,17 @@ class ModelLoadBalancingService:
load_balancing_model_config = None
if config_id:
# Get load balancing config
load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id
).first()
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
)
if not load_balancing_model_config:
raise ValueError(f"Load balancing config {config_id} does not exist.")
@@ -469,16 +477,19 @@ class ModelLoadBalancingService:
model_type=model_type,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_model_config
load_balancing_model_config=load_balancing_model_config,
)
def _custom_credentials_validate(self, tenant_id: str,
provider_configuration: ProviderConfiguration,
model_type: ModelType,
model: str,
credentials: dict,
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
validate: bool = True) -> dict:
def _custom_credentials_validate(
self,
tenant_id: str,
provider_configuration: ProviderConfiguration,
model_type: ModelType,
model: str,
credentials: dict,
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
validate: bool = True,
) -> dict:
"""
Validate custom credentials.
:param tenant_id: workspace id
@@ -521,12 +532,11 @@ class ModelLoadBalancingService:
provider=provider_configuration.provider.provider,
model_type=model_type,
model=model,
credentials=credentials
credentials=credentials,
)
else:
credentials = model_provider_factory.provider_credentials_validate(
provider=provider_configuration.provider.provider,
credentials=credentials
provider=provider_configuration.provider.provider, credentials=credentials
)
for key, value in credentials.items():
@@ -535,8 +545,9 @@ class ModelLoadBalancingService:
return credentials
def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \
-> ModelCredentialSchema | ProviderCredentialSchema:
def _get_credential_schema(
self, provider_configuration: ProviderConfiguration
) -> ModelCredentialSchema | ProviderCredentialSchema:
"""
Get form schemas.
:param provider_configuration: provider configuration
@@ -558,9 +569,7 @@ class ModelLoadBalancingService:
:return:
"""
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=config_id,
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
)
provider_model_credentials_cache.delete()