Fix/plugin race condition (#14253)

This commit is contained in:
Yeuoly
2025-02-25 12:20:47 +08:00
committed by GitHub
parent 42b13bd312
commit 490b6d092e
11 changed files with 116 additions and 41 deletions

View File

@@ -1,7 +1,7 @@
from enum import StrEnum
from typing import Any, Optional, Union
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
@@ -14,7 +14,7 @@ class AgentToolEntity(BaseModel):
provider_type: ToolProviderType
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}
tool_parameters: dict[str, Any] = Field(default_factory=dict)
plugin_unique_identifier: str | None = None

View File

@@ -2,9 +2,9 @@ from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import ModelConfigEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
@@ -61,9 +61,7 @@ class ModelConfigManager:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
if "/" not in config["model"]["provider"]:
config["model"]["provider"] = (
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
)
config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"]))
if config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")

View File

@@ -17,8 +17,8 @@ class ModelConfigEntity(BaseModel):
provider: str
model: str
mode: Optional[str] = None
parameters: dict[str, Any] = {}
stop: list[str] = []
parameters: dict[str, Any] = Field(default_factory=dict)
stop: list[str] = Field(default_factory=list)
class AdvancedChatMessageEntity(BaseModel):
@@ -132,7 +132,7 @@ class ExternalDataVariableEntity(BaseModel):
variable: str
type: str
config: dict[str, Any] = {}
config: dict[str, Any] = Field(default_factory=dict)
class DatasetRetrieveConfigEntity(BaseModel):
@@ -188,7 +188,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
"""
type: str
config: dict[str, Any] = {}
config: dict[str, Any] = Field(default_factory=dict)
class TextToSpeechEntity(BaseModel):

View File

@@ -63,9 +63,9 @@ class ModelConfigWithCredentialsEntity(BaseModel):
model_schema: AIModelEntity
mode: str
provider_model_bundle: ProviderModelBundle
credentials: dict[str, Any] = {}
parameters: dict[str, Any] = {}
stop: list[str] = []
credentials: dict[str, Any] = Field(default_factory=dict)
parameters: dict[str, Any] = Field(default_factory=dict)
stop: list[str] = Field(default_factory=list)
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@@ -94,7 +94,7 @@ class AppGenerateEntity(BaseModel):
call_depth: int = 0
# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {}
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
trace_manager: Optional[TraceQueueManager] = None

View File

@@ -6,11 +6,10 @@ from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Optional
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import or_
from constants import HIDDEN_VALUE
from core.entities import DEFAULT_PLUGIN_ID
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import (
CustomConfiguration,
@@ -1004,7 +1003,7 @@ class ProviderConfigurations(BaseModel):
"""
tenant_id: str
configurations: dict[str, ProviderConfiguration] = {}
configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id)
@@ -1060,7 +1059,7 @@ class ProviderConfigurations(BaseModel):
def __getitem__(self, key):
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
key = str(ModelProviderID(key))
return self.configurations[key]
@@ -1075,7 +1074,7 @@ class ProviderConfigurations(BaseModel):
def get(self, key, default=None) -> ProviderConfiguration | None:
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
key = str(ModelProviderID(key))
return self.configurations.get(key, default) # type: ignore

View File

@@ -41,9 +41,13 @@ class HostedModerationConfig(BaseModel):
class HostingConfiguration:
provider_map: dict[str, HostingProvider] = {}
provider_map: dict[str, HostingProvider]
moderation_config: Optional[HostedModerationConfig] = None
def __init__(self) -> None:
self.provider_map = {}
self.moderation_config = None
def init_app(self, app: Flask) -> None:
if dify_config.EDITION != "CLOUD":
return

View File

@@ -7,7 +7,6 @@ from typing import Optional
from pydantic import BaseModel
import contexts
from core.entities import DEFAULT_PLUGIN_ID
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
@@ -34,9 +33,11 @@ class ModelProviderExtension(BaseModel):
class ModelProviderFactory:
provider_position_map: dict[str, int] = {}
provider_position_map: dict[str, int]
def __init__(self, tenant_id: str) -> None:
self.provider_position_map = {}
self.tenant_id = tenant_id
self.plugin_model_manager = PluginModelManager()
@@ -360,11 +361,5 @@ class ModelProviderFactory:
:param provider: provider name
:return: plugin id and provider name
"""
plugin_id = DEFAULT_PLUGIN_ID
provider_name = provider
if "/" in provider:
# get the plugin_id before provider
plugin_id = "/".join(provider.split("/")[:-1])
provider_name = provider.split("/")[-1]
return str(plugin_id), provider_name
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name