Introduce Plugins (#13836)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: -LAN- <laipz8200@outlook.com>
Signed-off-by: xhe <xw897002528@gmail.com>
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: takatost <takatost@gmail.com>
Co-authored-by: kurokobo <kuro664@gmail.com>
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: AkaraChen <akarachen@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com>
Co-authored-by: AkaraChen <85140972+AkaraChen@users.noreply.github.com>
Co-authored-by: NFish <douxc512@gmail.com>
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: Novice <857526207@qq.com>
Co-authored-by: Hiroki Nagai <82458324+nagaihiroki-git@users.noreply.github.com>
Co-authored-by: Gen Sato <52241300+halogen22@users.noreply.github.com>
Co-authored-by: eux <euxuuu@gmail.com>
Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com>
Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com>
Co-authored-by: lotsik <lotsik@mail.ru>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: gakkiyomi <gakkiyomi@aliyun.com>
Co-authored-by: CN-P5 <heibai2006@gmail.com>
Co-authored-by: CN-P5 <heibai2006@qq.com>
Co-authored-by: Chuehnone <1897025+chuehnone@users.noreply.github.com>
Co-authored-by: yihong <zouzou0208@gmail.com>
Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Boris Feld <lothiraldan@gmail.com>
Co-authored-by: mbo <himabo@gmail.com>
Co-authored-by: mabo <mabo@aeyes.ai>
Co-authored-by: Warren Chen <warren.chen830@gmail.com>
Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com>
Co-authored-by: jiandanfeng <chenjh3@wangsu.com>
Co-authored-by: zhu-an <70234959+xhdd123321@users.noreply.github.com>
Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com>
Co-authored-by: 海狸大師 <86974027+yenslife@users.noreply.github.com>
Co-authored-by: Xu Song <xusong.vip@gmail.com>
Co-authored-by: rayshaw001 <396301947@163.com>
Co-authored-by: Ding Jiatong <dingjiatong@gmail.com>
Co-authored-by: Bowen Liang <liangbowen@gf.com.cn>
Co-authored-by: JasonVV <jasonwangiii@outlook.com>
Co-authored-by: le0zh <newlight@qq.com>
Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com>
Co-authored-by: k-zaku <zaku99@outlook.jp>
Co-authored-by: luckylhb90 <luckylhb90@gmail.com>
Co-authored-by: hobo.l <hobo.l@binance.com>
Co-authored-by: jiangbo721 <365065261@qq.com>
Co-authored-by: 刘江波 <jiangbo721@163.com>
Co-authored-by: Shun Miyazawa <34241526+miya@users.noreply.github.com>
Co-authored-by: EricPan <30651140+Egfly@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: sino <sino2322@gmail.com>
Co-authored-by: Jhvcc <37662342+Jhvcc@users.noreply.github.com>
Co-authored-by: lowell <lowell.hu@zkteco.in>
Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com>
Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com>
Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com>
Co-authored-by: IWAI, Masaharu <iwaim.sub@gmail.com>
Co-authored-by: Yueh-Po Peng (Yabi) <94939112+y10ab1@users.noreply.github.com>
Co-authored-by: Jason <ggbbddjm@gmail.com>
Co-authored-by: Xin Zhang <sjhpzx@gmail.com>
Co-authored-by: yjc980121 <3898524+yjc980121@users.noreply.github.com>
Co-authored-by: heyszt <36215648+hieheihei@users.noreply.github.com>
Co-authored-by: Abdullah AlOsaimi <osaimiacc@gmail.com>
Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com>
Co-authored-by: Yingchun Lai <laiyingchun@apache.org>
Co-authored-by: Hash Brown <hi@xzd.me>
Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com>
Co-authored-by: Masashi Tomooka <tmokmss@users.noreply.github.com>
Co-authored-by: aplio <ryo.091219@gmail.com>
Co-authored-by: Obada Khalili <54270856+obadakhalili@users.noreply.github.com>
Co-authored-by: Nam Vu <zuzoovn@gmail.com>
Co-authored-by: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com>
Co-authored-by: TechnoHouse <13776377+deephbz@users.noreply.github.com>
Co-authored-by: Riddhimaan-Senapati <114703025+Riddhimaan-Senapati@users.noreply.github.com>
Co-authored-by: MaFee921 <31881301+2284730142@users.noreply.github.com>
Co-authored-by: te-chan <t-nakanome@sakura-is.co.jp>
Co-authored-by: HQidea <HQidea@users.noreply.github.com>
Co-authored-by: Joshbly <36315710+Joshbly@users.noreply.github.com>
Co-authored-by: xhe <xw897002528@gmail.com>
Co-authored-by: weiwenyan-dev <154779315+weiwenyan-dev@users.noreply.github.com>
Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com>
Co-authored-by: engchina <12236799+engchina@users.noreply.github.com>
Co-authored-by: engchina <atjapan2015@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: 呆萌闷油瓶 <253605712@qq.com>
Co-authored-by: Kemal <kemalmeler@outlook.com>
Co-authored-by: Lazy_Frog <4590648+lazyFrogLOL@users.noreply.github.com>
Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com>
Co-authored-by: Steven sun <98230804+Tuyohai@users.noreply.github.com>
Co-authored-by: steven <sunzwj@digitalchina.com>
Co-authored-by: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com>
Co-authored-by: Katy Tao <34019945+KatyTao@users.noreply.github.com>
Co-authored-by: depy <42985524+h4ckdepy@users.noreply.github.com>
Co-authored-by: 胡春东 <gycm520@gmail.com>
Co-authored-by: Junjie.M <118170653@qq.com>
Co-authored-by: MuYu <mr.muzea@gmail.com>
Co-authored-by: Naoki Takashima <39912547+takatea@users.noreply.github.com>
Co-authored-by: Summer-Gu <37869445+gubinjie@users.noreply.github.com>
Co-authored-by: Fei He <droxer.he@gmail.com>
Co-authored-by: ybalbert001 <120714773+ybalbert001@users.noreply.github.com>
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
Co-authored-by: douxc <7553076+douxc@users.noreply.github.com>
Co-authored-by: liuzhenghua <1090179900@qq.com>
Co-authored-by: Wu Jiayang <62842862+Wu-Jiayang@users.noreply.github.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: kimjion <45935338+kimjion@users.noreply.github.com>
Co-authored-by: AugNSo <song.tiankai@icloud.com>
Co-authored-by: llinvokerl <38915183+llinvokerl@users.noreply.github.com>
Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com>
Co-authored-by: Vasu Negi <vasu-negi@users.noreply.github.com>
Co-authored-by: Hundredwz <1808096180@qq.com>
Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
This commit is contained in:
Yeuoly
2025-02-17 17:05:13 +08:00
committed by GitHub
parent 222df44d21
commit 403e2d58b9
3272 changed files with 66339 additions and 281594 deletions

View File

@@ -0,0 +1 @@
DEFAULT_PLUGIN_ID = "langgenius"

View File

@@ -0,0 +1,42 @@
from enum import StrEnum
class CommonParameterType(StrEnum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
STRING = "string"
NUMBER = "number"
FILE = "file"
FILES = "files"
SYSTEM_FILES = "system-files"
BOOLEAN = "boolean"
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"
# TOOL_SELECTOR = "tool-selector"
class AppSelectorScope(StrEnum):
ALL = "all"
CHAT = "chat"
WORKFLOW = "workflow"
COMPLETION = "completion"
class ModelSelectorScope(StrEnum):
LLM = "llm"
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
TTS = "tts"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
VISION = "vision"
class ToolSelectorScope(StrEnum):
ALL = "all"
CUSTOM = "custom"
BUILTIN = "builtin"
WORKFLOW = "workflow"

View File

@@ -2,13 +2,14 @@ import datetime
import json
import logging
from collections import defaultdict
from collections.abc import Iterator
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Optional
from pydantic import BaseModel, ConfigDict
from constants import HIDDEN_VALUE
from core.entities import DEFAULT_PLUGIN_ID
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import (
CustomConfiguration,
@@ -18,16 +19,15 @@ from core.entities.provider_entities import (
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
FormType,
ProviderEntity,
)
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db
from models.provider import (
LoadBalancingModelConfig,
@@ -99,9 +99,10 @@ class ProviderConfiguration(BaseModel):
continue
restrict_models = quota_configuration.restrict_models
if self.system_configuration.credentials is None:
return None
copy_credentials = self.system_configuration.credentials.copy()
copy_credentials = (
self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
)
if restrict_models:
for restrict_model in restrict_models:
if (
@@ -140,6 +141,9 @@ class ProviderConfiguration(BaseModel):
if current_quota_configuration is None:
return None
if not current_quota_configuration:
return SystemConfigurationStatus.UNSUPPORTED
return (
SystemConfigurationStatus.ACTIVE
if current_quota_configuration.is_valid
@@ -153,7 +157,7 @@ class ProviderConfiguration(BaseModel):
"""
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
def get_custom_credentials(self, obfuscated: bool = False):
def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
"""
Get custom credentials.
@@ -175,7 +179,7 @@ class ProviderConfiguration(BaseModel):
else [],
)
def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]:
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
@@ -219,6 +223,7 @@ class ProviderConfiguration(BaseModel):
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
model_provider_factory = ModelProviderFactory(self.tenant_id)
credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
@@ -246,13 +251,13 @@ class ProviderConfiguration(BaseModel):
provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
else:
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(credentials),
is_valid=True,
)
provider_record = Provider()
provider_record.tenant_id = self.tenant_id
provider_record.provider_name = self.provider.provider
provider_record.provider_type = ProviderType.CUSTOM.value
provider_record.encrypted_config = json.dumps(credentials)
provider_record.is_valid = True
db.session.add(provider_record)
db.session.commit()
@@ -327,7 +332,7 @@ class ProviderConfiguration(BaseModel):
def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict
) -> tuple[Optional[ProviderModel], dict]:
) -> tuple[ProviderModel | None, dict]:
"""
Validate custom model credentials.
@@ -370,6 +375,7 @@ class ProviderConfiguration(BaseModel):
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
model_provider_factory = ModelProviderFactory(self.tenant_id)
credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
@@ -400,14 +406,13 @@ class ProviderConfiguration(BaseModel):
provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
else:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
encrypted_config=json.dumps(credentials),
is_valid=True,
)
provider_model_record = ProviderModel()
provider_model_record.tenant_id = self.tenant_id
provider_model_record.provider_name = self.provider.provider
provider_model_record.model_name = model
provider_model_record.model_type = model_type.to_origin_model_type()
provider_model_record.encrypted_config = json.dumps(credentials)
provider_model_record.is_valid = True
db.session.add(provider_model_record)
db.session.commit()
@@ -474,13 +479,12 @@ class ProviderConfiguration(BaseModel):
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=True,
)
model_setting = ProviderModelSetting()
model_setting.tenant_id = self.tenant_id
model_setting.provider_name = self.provider.provider
model_setting.model_type = model_type.to_origin_model_type()
model_setting.model_name = model
model_setting.enabled = True
db.session.add(model_setting)
db.session.commit()
@@ -509,13 +513,12 @@ class ProviderConfiguration(BaseModel):
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=False,
)
model_setting = ProviderModelSetting()
model_setting.tenant_id = self.tenant_id
model_setting.provider_name = self.provider.provider
model_setting.model_type = model_type.to_origin_model_type()
model_setting.model_name = model
model_setting.enabled = False
db.session.add(model_setting)
db.session.commit()
@@ -576,13 +579,12 @@ class ProviderConfiguration(BaseModel):
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=True,
)
model_setting = ProviderModelSetting()
model_setting.tenant_id = self.tenant_id
model_setting.provider_name = self.provider.provider
model_setting.model_type = model_type.to_origin_model_type()
model_setting.model_name = model
model_setting.load_balancing_enabled = True
db.session.add(model_setting)
db.session.commit()
@@ -611,25 +613,17 @@ class ProviderConfiguration(BaseModel):
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=False,
)
model_setting = ProviderModelSetting()
model_setting.tenant_id = self.tenant_id
model_setting.provider_name = self.provider.provider
model_setting.model_type = model_type.to_origin_model_type()
model_setting.model_name = model
model_setting.load_balancing_enabled = False
db.session.add(model_setting)
db.session.commit()
return model_setting
def get_provider_instance(self) -> ModelProvider:
"""
Get provider instance.
:return:
"""
return model_provider_factory.get_provider_instance(self.provider.provider)
def get_model_type_instance(self, model_type: ModelType) -> AIModel:
"""
Get current model type instance.
@@ -637,11 +631,19 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
# Get provider instance
provider_instance = self.get_provider_instance()
model_provider_factory = ModelProviderFactory(self.tenant_id)
# Get model instance of LLM
return provider_instance.get_model_instance(model_type)
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get model schema
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
return model_provider_factory.get_model_schema(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
"""
@@ -668,11 +670,10 @@ class ProviderConfiguration(BaseModel):
if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value
else:
preferred_model_provider = TenantPreferredModelProvider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
preferred_provider_type=provider_type.value,
)
preferred_model_provider = TenantPreferredModelProvider()
preferred_model_provider.tenant_id = self.tenant_id
preferred_model_provider.provider_name = self.provider.provider
preferred_model_provider.preferred_provider_type = provider_type.value
db.session.add(preferred_model_provider)
db.session.commit()
@@ -737,13 +738,14 @@ class ProviderConfiguration(BaseModel):
:param only_active: only active models
:return:
"""
provider_instance = self.get_provider_instance()
model_provider_factory = ModelProviderFactory(self.tenant_id)
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
model_types = []
model_types: list[ModelType] = []
if model_type:
model_types.append(model_type)
else:
model_types = list(provider_instance.get_provider_schema().supported_model_types)
model_types = list(provider_schema.supported_model_types)
# Group model settings by model type and model
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
@@ -752,11 +754,11 @@ class ProviderConfiguration(BaseModel):
if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models(
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
)
else:
provider_models = self._get_custom_provider_models(
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
)
if only_active:
@@ -767,23 +769,26 @@ class ProviderConfiguration(BaseModel):
def _get_system_provider_models(
self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_types: Sequence[ModelType],
provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]:
"""
Get system provider models.
:param model_types: model types
:param provider_instance: provider instance
:param provider_schema: provider schema
:param model_setting_map: model setting map
:return:
"""
provider_models = []
for model_type in model_types:
for m in provider_instance.models(model_type):
for m in provider_schema.models:
if m.model_type != model_type:
continue
status = ModelStatus.ACTIVE
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
if m.model in model_setting_map:
model_setting = model_setting_map[m.model_type][m.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
@@ -804,7 +809,7 @@ class ProviderConfiguration(BaseModel):
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
for configurate_method in provider_schema.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
should_use_custom_model = False
@@ -825,18 +830,22 @@ class ProviderConfiguration(BaseModel):
]:
# only customizable model
for restrict_model in restrict_models:
if self.system_configuration.credentials is not None:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials["base_model_name"] = restrict_model.base_model_name
copy_credentials = (
self.system_configuration.credentials.copy()
if self.system_configuration.credentials
else {}
)
if restrict_model.base_model_name:
copy_credentials["base_model_name"] = restrict_model.base_model_name
try:
custom_model_schema = provider_instance.get_model_instance(
restrict_model.model_type
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}")
continue
try:
custom_model_schema = self.get_model_schema(
model_type=restrict_model.model_type,
model=restrict_model.model,
credentials=copy_credentials,
)
except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}")
if not custom_model_schema:
continue
@@ -881,15 +890,15 @@ class ProviderConfiguration(BaseModel):
def _get_custom_provider_models(
self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_types: Sequence[ModelType],
provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
:param model_types: model types
:param provider_instance: provider instance
:param provider_schema: provider schema
:param model_setting_map: model setting map
:return:
"""
@@ -903,8 +912,10 @@ class ProviderConfiguration(BaseModel):
if model_type not in self.provider.supported_model_types:
continue
models = provider_instance.models(model_type)
for m in models:
for m in provider_schema.models:
if m.model_type != model_type:
continue
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
load_balancing_enabled = False
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
@@ -936,10 +947,10 @@ class ProviderConfiguration(BaseModel):
continue
try:
custom_model_schema = provider_instance.get_model_instance(
model_configuration.model_type
).get_customizable_model_schema_from_credentials(
model_configuration.model, model_configuration.credentials
custom_model_schema = self.get_model_schema(
model_type=model_configuration.model_type,
model=model_configuration.model,
credentials=model_configuration.credentials,
)
except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}")
@@ -967,7 +978,7 @@ class ProviderConfiguration(BaseModel):
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=custom_model_schema.fetch_from,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
@@ -1040,6 +1051,9 @@ class ProviderConfigurations(BaseModel):
return list(self.values())
def __getitem__(self, key):
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
return self.configurations[key]
def __setitem__(self, key, value):
@@ -1051,8 +1065,11 @@ class ProviderConfigurations(BaseModel):
def values(self) -> Iterator[ProviderConfiguration]:
return iter(self.configurations.values())
def get(self, key, default=None):
return self.configurations.get(key, default)
def get(self, key, default=None) -> ProviderConfiguration | None:
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
return self.configurations.get(key, default) # type: ignore
class ProviderModelBundle(BaseModel):
@@ -1061,7 +1078,6 @@ class ProviderModelBundle(BaseModel):
"""
configuration: ProviderConfiguration
provider_instance: ModelProvider
model_type_instance: AIModel
# pydantic configs

View File

@@ -1,10 +1,34 @@
from enum import Enum
from typing import Optional
from typing import Optional, Union
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from core.entities.parameter_entities import (
AppSelectorScope,
CommonParameterType,
ModelSelectorScope,
ToolSelectorScope,
)
from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType
from core.tools.entities.common_entities import I18nObject
class ProviderQuotaType(Enum):
PAID = "paid"
"""hosted paid quota"""
FREE = "free"
"""third-party free quota"""
TRIAL = "trial"
"""hosted trial quota"""
@staticmethod
def value_of(value):
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class QuotaUnit(Enum):
@@ -108,3 +132,55 @@ class ModelSettings(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class BasicProviderConfig(BaseModel):
"""
Base model class for common provider settings like credentials
"""
class Type(Enum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
SELECT = CommonParameterType.SELECT.value
BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
type: Type = Field(..., description="The type of the credentials")
name: str = Field(..., description="The name of the credentials")
class ProviderConfig(BasicProviderConfig):
"""
Model class for common provider settings like credentials
"""
class Option(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
required: bool = False
default: Optional[Union[int, str]] = None
options: Optional[list[Option]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None
url: Optional[str] = None
placeholder: Optional[I18nObject] = None
def to_basic_provider_config(self) -> BasicProviderConfig:
return BasicProviderConfig(type=self.type, name=self.name)