feat: Add caching mechanism for plugin model schemas (#14898)

This commit is contained in:
Yeuoly
2025-03-04 18:02:06 +08:00
committed by GitHub
parent 330dc2fd44
commit 4668c4996a
3 changed files with 70 additions and 19 deletions

View File

@@ -1,8 +1,11 @@
import decimal
import hashlib
from threading import Lock
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
import contexts
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import (
@@ -139,15 +142,35 @@ class AIModel(BaseModel):
:return: model schema
"""
plugin_model_manager = PluginModelManager()
return plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
)
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
# sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
try:
contexts.plugin_model_schemas.get()
except LookupError:
contexts.plugin_model_schemas.set({})
contexts.plugin_model_schema_lock.set(Lock())
with contexts.plugin_model_schema_lock.get():
if cache_key in contexts.plugin_model_schemas.get():
return contexts.plugin_model_schemas.get()[cache_key]
schema = plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
)
if schema:
contexts.plugin_model_schemas.get()[cache_key] = schema
return schema
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""

View File

@@ -1,3 +1,4 @@
import hashlib
import logging
import os
from collections.abc import Sequence
@@ -206,17 +207,35 @@ class ModelProviderFactory:
Get model schema
"""
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
model_schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
# sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
return model_schema
try:
contexts.plugin_model_schemas.get()
except LookupError:
contexts.plugin_model_schemas.set({})
contexts.plugin_model_schema_lock.set(Lock())
with contexts.plugin_model_schema_lock.get():
if cache_key in contexts.plugin_model_schemas.get():
return contexts.plugin_model_schemas.get()[cache_key]
schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials or {},
)
if schema:
contexts.plugin_model_schemas.get()[cache_key] = schema
return schema
def get_models(
self,