mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
Introduce Plugins (#13836)
Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: xhe <xw897002528@gmail.com> Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: takatost <takatost@gmail.com> Co-authored-by: kurokobo <kuro664@gmail.com> Co-authored-by: Novice Lee <novicelee@NoviPro.local> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: AkaraChen <akarachen@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com> Co-authored-by: AkaraChen <85140972+AkaraChen@users.noreply.github.com> Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Novice <857526207@qq.com> Co-authored-by: Hiroki Nagai <82458324+nagaihiroki-git@users.noreply.github.com> Co-authored-by: Gen Sato <52241300+halogen22@users.noreply.github.com> Co-authored-by: eux <euxuuu@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: lotsik <lotsik@mail.ru> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: gakkiyomi <gakkiyomi@aliyun.com> Co-authored-by: CN-P5 <heibai2006@gmail.com> Co-authored-by: CN-P5 <heibai2006@qq.com> Co-authored-by: Chuehnone <1897025+chuehnone@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Boris Feld <lothiraldan@gmail.com> Co-authored-by: mbo <himabo@gmail.com> Co-authored-by: mabo <mabo@aeyes.ai> Co-authored-by: Warren Chen <warren.chen830@gmail.com> Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com> Co-authored-by: jiandanfeng <chenjh3@wangsu.com> Co-authored-by: zhu-an <70234959+xhdd123321@users.noreply.github.com> Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com> Co-authored-by: 海狸大師 <86974027+yenslife@users.noreply.github.com> Co-authored-by: Xu Song <xusong.vip@gmail.com> Co-authored-by: rayshaw001 <396301947@163.com> Co-authored-by: Ding Jiatong <dingjiatong@gmail.com> Co-authored-by: Bowen Liang <liangbowen@gf.com.cn> Co-authored-by: JasonVV <jasonwangiii@outlook.com> Co-authored-by: le0zh <newlight@qq.com> Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com> Co-authored-by: k-zaku <zaku99@outlook.jp> Co-authored-by: luckylhb90 <luckylhb90@gmail.com> Co-authored-by: hobo.l <hobo.l@binance.com> Co-authored-by: jiangbo721 <365065261@qq.com> Co-authored-by: 刘江波 <jiangbo721@163.com> Co-authored-by: Shun Miyazawa <34241526+miya@users.noreply.github.com> Co-authored-by: EricPan <30651140+Egfly@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: sino <sino2322@gmail.com> Co-authored-by: Jhvcc <37662342+Jhvcc@users.noreply.github.com> Co-authored-by: lowell <lowell.hu@zkteco.in> Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com> Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com> Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com> Co-authored-by: IWAI, Masaharu <iwaim.sub@gmail.com> Co-authored-by: Yueh-Po Peng (Yabi) <94939112+y10ab1@users.noreply.github.com> Co-authored-by: Jason <ggbbddjm@gmail.com> Co-authored-by: Xin Zhang <sjhpzx@gmail.com> Co-authored-by: yjc980121 <3898524+yjc980121@users.noreply.github.com> Co-authored-by: heyszt <36215648+hieheihei@users.noreply.github.com> Co-authored-by: Abdullah AlOsaimi <osaimiacc@gmail.com> Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com> Co-authored-by: Yingchun Lai <laiyingchun@apache.org> Co-authored-by: Hash Brown <hi@xzd.me> Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com> Co-authored-by: Masashi Tomooka <tmokmss@users.noreply.github.com> Co-authored-by: aplio <ryo.091219@gmail.com> Co-authored-by: Obada Khalili <54270856+obadakhalili@users.noreply.github.com> Co-authored-by: Nam Vu <zuzoovn@gmail.com> Co-authored-by: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com> Co-authored-by: TechnoHouse <13776377+deephbz@users.noreply.github.com> Co-authored-by: Riddhimaan-Senapati <114703025+Riddhimaan-Senapati@users.noreply.github.com> Co-authored-by: MaFee921 <31881301+2284730142@users.noreply.github.com> Co-authored-by: te-chan <t-nakanome@sakura-is.co.jp> Co-authored-by: HQidea <HQidea@users.noreply.github.com> Co-authored-by: Joshbly <36315710+Joshbly@users.noreply.github.com> Co-authored-by: xhe <xw897002528@gmail.com> Co-authored-by: weiwenyan-dev <154779315+weiwenyan-dev@users.noreply.github.com> Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com> Co-authored-by: engchina <12236799+engchina@users.noreply.github.com> Co-authored-by: engchina <atjapan2015@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: 呆萌闷油瓶 <253605712@qq.com> Co-authored-by: Kemal <kemalmeler@outlook.com> Co-authored-by: Lazy_Frog <4590648+lazyFrogLOL@users.noreply.github.com> Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Co-authored-by: Steven sun <98230804+Tuyohai@users.noreply.github.com> Co-authored-by: steven <sunzwj@digitalchina.com> Co-authored-by: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com> Co-authored-by: Katy Tao <34019945+KatyTao@users.noreply.github.com> Co-authored-by: depy <42985524+h4ckdepy@users.noreply.github.com> Co-authored-by: 胡春东 <gycm520@gmail.com> Co-authored-by: Junjie.M <118170653@qq.com> Co-authored-by: MuYu <mr.muzea@gmail.com> Co-authored-by: Naoki Takashima <39912547+takatea@users.noreply.github.com> Co-authored-by: Summer-Gu <37869445+gubinjie@users.noreply.github.com> Co-authored-by: Fei He <droxer.he@gmail.com> Co-authored-by: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Co-authored-by: Yuanbo Li <ybalbert@amazon.com> Co-authored-by: douxc <7553076+douxc@users.noreply.github.com> Co-authored-by: liuzhenghua <1090179900@qq.com> Co-authored-by: Wu Jiayang <62842862+Wu-Jiayang@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: kimjion <45935338+kimjion@users.noreply.github.com> Co-authored-by: AugNSo <song.tiankai@icloud.com> Co-authored-by: llinvokerl <38915183+llinvokerl@users.noreply.github.com> Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com> Co-authored-by: Vasu Negi <vasu-negi@users.noreply.github.com> Co-authored-by: Hundredwz <1808096180@qq.com> Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
This commit is contained in:
@@ -10,6 +10,7 @@ from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@@ -101,7 +102,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
def load_user(user_id: str) -> None | Account:
|
||||
account = Account.query.filter_by(id=user_id).first()
|
||||
account = db.session.query(Account).filter_by(id=user_id).first()
|
||||
if not account:
|
||||
return None
|
||||
|
||||
@@ -146,7 +147,7 @@ class AccountService:
|
||||
def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
|
||||
@@ -924,11 +925,14 @@ class RegisterService:
|
||||
|
||||
@classmethod
|
||||
def invite_new_member(
|
||||
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None
|
||||
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
|
||||
) -> str:
|
||||
if not inviter:
|
||||
raise ValueError("Inviter is required")
|
||||
|
||||
"""Invite new member"""
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
assert inviter is not None, "Inviter must be provided."
|
||||
with Session(db.engine) as session:
|
||||
account = session.query(Account).filter_by(email=email).first()
|
||||
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import pytz
|
||||
from flask_login import current_user # type: ignore
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
|
||||
from core.plugin.manager.agent import PluginAgentManager
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
@@ -16,7 +20,10 @@ class AgentService:
|
||||
"""
|
||||
Service to get agent logs
|
||||
"""
|
||||
conversation: Optional[Conversation] = (
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
conversation: Conversation | None = (
|
||||
db.session.query(Conversation)
|
||||
.filter(
|
||||
Conversation.id == conversation_id,
|
||||
@@ -59,6 +66,10 @@ class AgentService:
|
||||
|
||||
timezone = pytz.timezone(current_user.timezone)
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
if not app_model_config:
|
||||
raise ValueError("App model config not found")
|
||||
|
||||
result = {
|
||||
"meta": {
|
||||
"status": "success",
|
||||
@@ -66,16 +77,16 @@ class AgentService:
|
||||
"start_time": message.created_at.astimezone(timezone).isoformat(),
|
||||
"elapsed_time": message.provider_response_latency,
|
||||
"total_tokens": message.answer_tokens + message.message_tokens,
|
||||
"agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"),
|
||||
"agent_mode": app_model_config.agent_mode_dict.get("strategy", "react"),
|
||||
"iterations": len(agent_thoughts),
|
||||
},
|
||||
"iterations": [],
|
||||
"files": message.message_files,
|
||||
}
|
||||
|
||||
agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict())
|
||||
agent_config = AgentConfigManager.convert(app_model_config.to_dict())
|
||||
if not agent_config:
|
||||
return result
|
||||
raise ValueError("Agent config not found")
|
||||
|
||||
agent_tools = agent_config.tools or []
|
||||
|
||||
@@ -89,7 +100,7 @@ class AgentService:
|
||||
tool_labels = agent_thought.tool_labels
|
||||
tool_meta = agent_thought.tool_meta
|
||||
tool_inputs = agent_thought.tool_inputs_dict
|
||||
tool_outputs = agent_thought.tool_outputs_dict
|
||||
tool_outputs = agent_thought.tool_outputs_dict or {}
|
||||
tool_calls = []
|
||||
for tool in tools:
|
||||
tool_name = tool
|
||||
@@ -144,3 +155,22 @@ class AgentService:
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def list_agent_providers(cls, user_id: str, tenant_id: str):
|
||||
"""
|
||||
List agent providers
|
||||
"""
|
||||
manager = PluginAgentManager()
|
||||
return manager.fetch_agent_strategy_providers(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def get_agent_provider(cls, user_id: str, tenant_id: str, provider_name: str):
|
||||
"""
|
||||
Get agent provider
|
||||
"""
|
||||
manager = PluginAgentManager()
|
||||
try:
|
||||
return manager.fetch_agent_strategy_provider(tenant_id, provider_name)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
@@ -7,22 +8,34 @@ from uuid import uuid4
|
||||
|
||||
import yaml # type: ignore
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from events.app_event import app_model_config_was_updated, app_was_created
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig
|
||||
from models.workflow import Workflow
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 180 # 3 minutes
|
||||
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
|
||||
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.1.5"
|
||||
|
||||
|
||||
@@ -47,6 +60,10 @@ class Import(BaseModel):
|
||||
error: str = ""
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
"""Determine import status based on version comparison"""
|
||||
try:
|
||||
@@ -76,6 +93,11 @@ class PendingData(BaseModel):
|
||||
app_id: str | None
|
||||
|
||||
|
||||
class CheckDependenciesPendingData(BaseModel):
|
||||
dependencies: list[PluginDependency]
|
||||
app_id: str | None
|
||||
|
||||
|
||||
class AppDslService:
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
@@ -113,7 +135,6 @@ class AppDslService:
|
||||
error="yaml_url is required when import_mode is yaml-url",
|
||||
)
|
||||
try:
|
||||
max_size = 10 * 1024 * 1024 # 10MB
|
||||
parsed_url = urlparse(yaml_url)
|
||||
if (
|
||||
parsed_url.scheme == "https"
|
||||
@@ -126,7 +147,7 @@ class AppDslService:
|
||||
response.raise_for_status()
|
||||
content = response.content.decode()
|
||||
|
||||
if len(content) > max_size:
|
||||
if len(content) > DSL_MAX_SIZE:
|
||||
return Import(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
@@ -208,7 +229,7 @@ class AppDslService:
|
||||
|
||||
# If major version mismatch, store import info in Redis
|
||||
if status == ImportStatus.PENDING:
|
||||
panding_data = PendingData(
|
||||
pending_data = PendingData(
|
||||
import_mode=import_mode,
|
||||
yaml_content=content,
|
||||
name=name,
|
||||
@@ -221,7 +242,7 @@ class AppDslService:
|
||||
redis_client.setex(
|
||||
f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}",
|
||||
IMPORT_INFO_REDIS_EXPIRY,
|
||||
panding_data.model_dump_json(),
|
||||
pending_data.model_dump_json(),
|
||||
)
|
||||
|
||||
return Import(
|
||||
@@ -231,6 +252,22 @@ class AppDslService:
|
||||
imported_dsl_version=imported_version,
|
||||
)
|
||||
|
||||
# Extract dependencies
|
||||
dependencies = data.get("dependencies", [])
|
||||
check_dependencies_pending_data = None
|
||||
if dependencies:
|
||||
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
|
||||
elif imported_version <= "0.1.5":
|
||||
if "workflow" in data:
|
||||
graph = data.get("workflow", {}).get("graph", {})
|
||||
dependencies_list = self._extract_dependencies_from_workflow_graph(graph)
|
||||
else:
|
||||
dependencies_list = self._extract_dependencies_from_model_config(data.get("model_config", {}))
|
||||
|
||||
check_dependencies_pending_data = DependenciesAnalysisService.generate_latest_dependencies(
|
||||
dependencies_list
|
||||
)
|
||||
|
||||
# Create or update app
|
||||
app = self._create_or_update_app(
|
||||
app=app,
|
||||
@@ -241,6 +278,7 @@ class AppDslService:
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background,
|
||||
dependencies=check_dependencies_pending_data,
|
||||
)
|
||||
|
||||
return Import(
|
||||
@@ -325,6 +363,29 @@ class AppDslService:
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def check_dependencies(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
) -> CheckDependenciesResult:
|
||||
"""Check dependencies"""
|
||||
# Get dependencies from Redis
|
||||
redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app_model.id}"
|
||||
dependencies = redis_client.get(redis_key)
|
||||
if not dependencies:
|
||||
return CheckDependenciesResult()
|
||||
|
||||
# Extract dependencies
|
||||
dependencies = CheckDependenciesPendingData.model_validate_json(dependencies)
|
||||
|
||||
# Get leaked dependencies
|
||||
leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies(
|
||||
tenant_id=app_model.tenant_id, dependencies=dependencies.dependencies
|
||||
)
|
||||
return CheckDependenciesResult(
|
||||
leaked_dependencies=leaked_dependencies,
|
||||
)
|
||||
|
||||
def _create_or_update_app(
|
||||
self,
|
||||
*,
|
||||
@@ -336,6 +397,7 @@ class AppDslService:
|
||||
icon_type: Optional[str] = None,
|
||||
icon: Optional[str] = None,
|
||||
icon_background: Optional[str] = None,
|
||||
dependencies: Optional[list[PluginDependency]] = None,
|
||||
) -> App:
|
||||
"""Create a new app or update an existing one."""
|
||||
app_data = data.get("app", {})
|
||||
@@ -384,6 +446,14 @@ class AppDslService:
|
||||
self._session.commit()
|
||||
app_was_created.send(app, account=account)
|
||||
|
||||
# save dependencies
|
||||
if dependencies:
|
||||
redis_client.setex(
|
||||
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}",
|
||||
IMPORT_INFO_REDIS_EXPIRY,
|
||||
CheckDependenciesPendingData(app_id=app.id, dependencies=dependencies).model_dump_json(),
|
||||
)
|
||||
|
||||
# Initialize app based on mode
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_data = data.get("workflow")
|
||||
@@ -479,6 +549,13 @@ class AppDslService:
|
||||
raise ValueError("Missing draft workflow configuration, please check.")
|
||||
|
||||
export_data["workflow"] = workflow.to_dict(include_secret=include_secret)
|
||||
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
||||
export_data["dependencies"] = [
|
||||
jsonable_encoder(d.model_dump())
|
||||
for d in DependenciesAnalysisService.generate_dependencies(
|
||||
tenant_id=app_model.tenant_id, dependencies=dependencies
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
|
||||
@@ -492,3 +569,154 @@ class AppDslService:
|
||||
raise ValueError("Missing app configuration, please check.")
|
||||
|
||||
export_data["model_config"] = app_model_config.to_dict()
|
||||
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
|
||||
export_data["dependencies"] = [
|
||||
jsonable_encoder(d.model_dump())
|
||||
for d in DependenciesAnalysisService.generate_dependencies(
|
||||
tenant_id=app_model.tenant_id, dependencies=dependencies
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from workflow
|
||||
:param workflow: Workflow instance
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
graph = workflow.graph_dict
|
||||
dependencies = cls._extract_dependencies_from_workflow_graph(graph)
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from workflow graph
|
||||
:param graph: Workflow graph
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
dependencies = []
|
||||
for node in graph.get("nodes", []):
|
||||
try:
|
||||
typ = node.get("data", {}).get("type")
|
||||
match typ:
|
||||
case NodeType.TOOL.value:
|
||||
tool_entity = ToolNodeData(**node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
|
||||
)
|
||||
case NodeType.LLM.value:
|
||||
llm_entity = LLMNodeData(**node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
|
||||
)
|
||||
case NodeType.QUESTION_CLASSIFIER.value:
|
||||
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
question_classifier_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.PARAMETER_EXTRACTOR.value:
|
||||
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
parameter_extractor_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
|
||||
if knowledge_retrieval_entity.retrieval_mode == "multiple":
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config:
|
||||
if (
|
||||
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
|
||||
== "reranking_model"
|
||||
):
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model:
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider
|
||||
),
|
||||
)
|
||||
elif (
|
||||
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
|
||||
== "weighted_score"
|
||||
):
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config.weights:
|
||||
vector_setting = (
|
||||
knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting
|
||||
)
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
vector_setting.embedding_provider_name
|
||||
),
|
||||
)
|
||||
elif knowledge_retrieval_entity.retrieval_mode == "single":
|
||||
model_config = knowledge_retrieval_entity.single_retrieval_config
|
||||
if model_config:
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
model_config.model.provider
|
||||
),
|
||||
)
|
||||
case _:
|
||||
# TODO: Handle default case or unknown node types
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception("Error extracting node dependency", exc_info=e)
|
||||
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_model_config(cls, model_config: Mapping) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from model config
|
||||
:param model_config: model config dict
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
dependencies = []
|
||||
|
||||
try:
|
||||
# completion model
|
||||
model_dict = model_config.get("model", {})
|
||||
if model_dict:
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", ""))
|
||||
)
|
||||
|
||||
# reranking model
|
||||
dataset_configs = model_config.get("dataset_configs", {})
|
||||
if dataset_configs:
|
||||
for dataset_config in dataset_configs.get("datasets", {}).get("datasets", []):
|
||||
if dataset_config.get("reranking_model"):
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
dataset_config.get("reranking_model", {})
|
||||
.get("reranking_provider_name", {})
|
||||
.get("provider")
|
||||
)
|
||||
)
|
||||
|
||||
# tools
|
||||
agent_configs = model_config.get("agent_mode", {})
|
||||
if agent_configs:
|
||||
for agent_config in agent_configs.get("tools", []):
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_tool_dependency(agent_config.get("provider_id"))
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error extracting model config dependency", exc_info=e)
|
||||
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
|
||||
"""
|
||||
Returns the leaked dependencies in current workspace
|
||||
"""
|
||||
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
|
||||
if not dependencies:
|
||||
return []
|
||||
|
||||
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
|
||||
|
||||
@@ -43,66 +43,62 @@ class AppGenerateService:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
if app_model.mode == AppMode.COMPLETION.value:
|
||||
return rate_limit.generate(
|
||||
generator=CompletionAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
CompletionAppGenerator.convert_to_event_stream(
|
||||
CompletionAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
generator = AgentChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
)
|
||||
return rate_limit.generate(
|
||||
generator=generator,
|
||||
request_id=request_id,
|
||||
AgentChatAppGenerator.convert_to_event_stream(
|
||||
AgentChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.CHAT.value:
|
||||
return rate_limit.generate(
|
||||
generator=ChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
ChatAppGenerator.convert_to_event_stream(
|
||||
ChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
return rate_limit.generate(
|
||||
generator=AdvancedChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
generator = WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
call_depth=0,
|
||||
workflow_thread_pool_id=None,
|
||||
)
|
||||
return rate_limit.generate(
|
||||
generator=generator,
|
||||
request_id=request_id,
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
call_depth=0,
|
||||
workflow_thread_pool_id=None,
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
@@ -126,18 +122,17 @@ class AppGenerateService:
|
||||
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
streaming=streaming,
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
@@ -9,9 +9,11 @@ from typing import Any, Optional
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@@ -268,7 +270,15 @@ class DatasetService:
|
||||
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
|
||||
if not external_knowledge_api_id:
|
||||
raise ValueError("External knowledge api id is required.")
|
||||
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(dataset_id=dataset_id).first()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
external_knowledge_binding = (
|
||||
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
||||
)
|
||||
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("External knowledge binding not found.")
|
||||
|
||||
if (
|
||||
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
||||
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
||||
@@ -316,8 +326,19 @@ class DatasetService:
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
else:
|
||||
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
||||
plugin_model_provider = dataset.embedding_model_provider
|
||||
if "/" not in plugin_model_provider:
|
||||
plugin_model_provider = f"{DEFAULT_PLUGIN_ID}/{plugin_model_provider}/{plugin_model_provider}"
|
||||
|
||||
new_plugin_model_provider = data["embedding_model_provider"]
|
||||
if "/" not in new_plugin_model_provider:
|
||||
new_plugin_model_provider = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{new_plugin_model_provider}/{new_plugin_model_provider}"
|
||||
)
|
||||
|
||||
if (
|
||||
data["embedding_model_provider"] != dataset.embedding_model_provider
|
||||
new_plugin_model_provider != plugin_model_provider
|
||||
or data["embedding_model"] != dataset.embedding_model
|
||||
):
|
||||
action = "update"
|
||||
@@ -1468,7 +1489,7 @@ class SegmentService:
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||
lock_name = "add_segment_lock_document_id_{}".format(document.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
max_position = (
|
||||
@@ -1545,9 +1566,12 @@ class SegmentService:
|
||||
if dataset.indexing_technique == "high_quality" and embedding_model:
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]])
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[content + segment_item["answer"]]
|
||||
)[0]
|
||||
else:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
@@ -1695,9 +1719,9 @@ class SegmentService:
|
||||
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0]
|
||||
else:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||
segment.content = content
|
||||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
|
||||
@@ -8,7 +8,7 @@ from core.entities.model_entities import (
|
||||
ModelWithProviderEntity,
|
||||
ProviderModelWithStatusEntity,
|
||||
)
|
||||
from core.entities.provider_entities import QuotaConfiguration
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
@@ -18,7 +18,7 @@ from core.model_runtime.entities.provider_entities import (
|
||||
ProviderHelpEntity,
|
||||
SimpleProviderEntity,
|
||||
)
|
||||
from models.provider import ProviderQuotaType, ProviderType
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class CustomConfigurationStatus(Enum):
|
||||
@@ -53,6 +53,7 @@ class ProviderResponse(BaseModel):
|
||||
Model class for provider response.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: Optional[I18nObject] = None
|
||||
@@ -74,7 +75,9 @@ class ProviderResponse(BaseModel):
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
|
||||
url_prefix = (
|
||||
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
|
||||
)
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
@@ -91,6 +94,7 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
Model class for provider with models response.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
@@ -101,7 +105,9 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
|
||||
url_prefix = (
|
||||
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
|
||||
)
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
@@ -118,10 +124,14 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
Simple provider entity response.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
|
||||
url_prefix = (
|
||||
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
|
||||
)
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
@@ -146,13 +156,14 @@ class DefaultModelResponse(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
|
||||
class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
|
||||
"""
|
||||
Model with provider entity.
|
||||
"""
|
||||
|
||||
# FIXME type error ignore here
|
||||
provider: SimpleProviderEntityResponse # type: ignore
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
def __init__(self, model: ModelWithProviderEntity) -> None:
|
||||
super().__init__(**model.model_dump())
|
||||
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
|
||||
dump_model = model.model_dump()
|
||||
dump_model["provider"]["tenant_id"] = tenant_id
|
||||
super().__init__(**dump_model)
|
||||
|
||||
@@ -58,6 +58,8 @@ class SystemFeatureModel(BaseModel):
|
||||
sso_enforced_for_web: bool = False
|
||||
sso_enforced_for_web_protocol: str = ""
|
||||
enable_web_sso_switch_component: bool = False
|
||||
enable_marketplace: bool = True
|
||||
max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE
|
||||
enable_email_code_login: bool = False
|
||||
enable_email_password_login: bool = True
|
||||
enable_social_oauth_login: bool = False
|
||||
@@ -90,6 +92,9 @@ class FeatureService:
|
||||
|
||||
cls._fulfill_params_from_enterprise(system_features)
|
||||
|
||||
if dify_config.MARKETPLACE_ENABLED:
|
||||
system_features.enable_marketplace = True
|
||||
|
||||
return system_features
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.model_runtime.entities.provider_entities import (
|
||||
ModelCredentialSchema,
|
||||
ProviderCredentialSchema,
|
||||
)
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.ext_database import db
|
||||
from models.provider import LoadBalancingModelConfig
|
||||
@@ -527,6 +527,7 @@ class ModelLoadBalancingService:
|
||||
credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key])
|
||||
|
||||
if validate:
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
if isinstance(credential_schemas, ModelCredentialSchema):
|
||||
credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=provider_configuration.provider.provider,
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
from typing import Optional
|
||||
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
|
||||
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider import ProviderType
|
||||
from services.entities.model_provider_entities import (
|
||||
@@ -54,6 +47,7 @@ class ModelProviderService:
|
||||
continue
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_configuration.provider.provider,
|
||||
label=provider_configuration.provider.label,
|
||||
description=provider_configuration.provider.description,
|
||||
@@ -97,10 +91,11 @@ class ModelProviderService:
|
||||
|
||||
# Get provider available models
|
||||
return [
|
||||
ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
|
||||
ModelWithProviderEntityResponse(tenant_id=tenant_id, model=model)
|
||||
for model in provider_configurations.get_models(provider=provider)
|
||||
]
|
||||
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str):
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]:
|
||||
"""
|
||||
get provider credentials.
|
||||
"""
|
||||
@@ -168,7 +163,7 @@ class ModelProviderService:
|
||||
# Remove custom provider credentials.
|
||||
provider_configuration.delete_custom_credentials()
|
||||
|
||||
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str):
|
||||
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]:
|
||||
"""
|
||||
get model credentials.
|
||||
|
||||
@@ -302,6 +297,7 @@ class ModelProviderService:
|
||||
|
||||
providers_with_models.append(
|
||||
ProviderWithModelsResponse(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
label=first_model.provider.label,
|
||||
icon_small=first_model.provider.icon_small,
|
||||
@@ -343,18 +339,17 @@ class ModelProviderService:
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Get model instance of LLM
|
||||
model_type_instance = provider_configuration.get_model_type_instance(ModelType.LLM)
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# fetch credentials
|
||||
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
|
||||
|
||||
if not credentials:
|
||||
return []
|
||||
|
||||
# Call get_parameter_rules method of model instance to get model parameter rules
|
||||
return list(model_type_instance.get_parameter_rules(model=model, credentials=credentials))
|
||||
model_schema = provider_configuration.get_model_schema(
|
||||
model_type=ModelType.LLM, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
return model_schema.parameter_rules if model_schema else []
|
||||
|
||||
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
|
||||
"""
|
||||
@@ -365,13 +360,15 @@ class ModelProviderService:
|
||||
:return:
|
||||
"""
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum)
|
||||
|
||||
try:
|
||||
result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum)
|
||||
return (
|
||||
DefaultModelResponse(
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
tenant_id=tenant_id,
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
@@ -383,7 +380,7 @@ class ModelProviderService:
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f"get_default_model_of_model_type error: {e}")
|
||||
logger.debug(f"get_default_model_of_model_type error: {e}")
|
||||
return None
|
||||
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||
@@ -402,55 +399,21 @@ class ModelProviderService:
|
||||
)
|
||||
|
||||
def get_model_provider_icon(
|
||||
self, provider: str, icon_type: str, lang: str
|
||||
self, tenant_id: str, provider: str, icon_type: str, lang: str
|
||||
) -> tuple[Optional[bytes], Optional[str]]:
|
||||
"""
|
||||
get model provider icon.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param icon_type: icon type (icon_small or icon_large)
|
||||
:param lang: language (zh_Hans or en_US)
|
||||
:return:
|
||||
"""
|
||||
provider_instance = model_provider_factory.get_provider_instance(provider)
|
||||
provider_schema = provider_instance.get_provider_schema()
|
||||
file_name: str | None = None
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang)
|
||||
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_small.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small.en_US
|
||||
else:
|
||||
if not provider_schema.icon_large:
|
||||
raise ValueError(f"Provider {provider} does not have large icon.")
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_large.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_large.en_US
|
||||
if not file_name:
|
||||
return None, None
|
||||
|
||||
root_path = current_app.root_path
|
||||
provider_instance_path = os.path.dirname(
|
||||
os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/"))
|
||||
)
|
||||
file_path = os.path.join(provider_instance_path, "_assets")
|
||||
file_path = os.path.join(file_path, file_name)
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
return None, None
|
||||
|
||||
mimetype, _ = mimetypes.guess_type(file_path)
|
||||
mimetype = mimetype or "application/octet-stream"
|
||||
|
||||
# read binary from file
|
||||
byte_data = Path(file_path).read_bytes()
|
||||
return byte_data, mimetype
|
||||
return byte_data, mime_type
|
||||
|
||||
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
|
||||
"""
|
||||
@@ -516,48 +479,3 @@ class ModelProviderService:
|
||||
|
||||
# Enable model
|
||||
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def free_quota_submit(self, tenant_id: str, provider: str):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
|
||||
api_url = api_base_url + "/api/v1/providers/apply"
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider})
|
||||
if not response.ok:
|
||||
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
|
||||
raise ValueError(f"Error: {response.status_code} ")
|
||||
|
||||
if response.json()["code"] != "success":
|
||||
raise ValueError(f"error: {response.json()['message']}")
|
||||
|
||||
rst = response.json()
|
||||
|
||||
if rst["type"] == "redirect":
|
||||
return {"type": rst["type"], "redirect_url": rst["redirect_url"]}
|
||||
else:
|
||||
return {"type": rst["type"], "result": "success"}
|
||||
|
||||
def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
|
||||
api_url = api_base_url + "/api/v1/providers/qualification-verify"
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
json_data = {"workspace_id": tenant_id, "provider_name": provider}
|
||||
if token:
|
||||
json_data["token"] = token
|
||||
response = requests.post(api_url, headers=headers, json=json_data)
|
||||
if not response.ok:
|
||||
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
|
||||
raise ValueError(f"Error: {response.status_code} ")
|
||||
|
||||
rst = response.json()
|
||||
if rst["code"] != "success":
|
||||
raise ValueError(f"error: {rst['message']}")
|
||||
|
||||
data = rst["data"]
|
||||
if data["qualified"] is True:
|
||||
return {"result": "success", "provider_name": provider, "flag": True}
|
||||
else:
|
||||
return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]}
|
||||
|
||||
0
api/services/plugin/__init__.py
Normal file
0
api/services/plugin/__init__.py
Normal file
185
api/services/plugin/data_migration.py
Normal file
185
api/services/plugin/data_migration.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import click
|
||||
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from models.engine import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginDataMigration:
|
||||
@classmethod
|
||||
def migrate(cls) -> None:
|
||||
cls.migrate_db_records("providers", "provider_name") # large table
|
||||
cls.migrate_db_records("provider_models", "provider_name")
|
||||
cls.migrate_db_records("provider_orders", "provider_name")
|
||||
cls.migrate_db_records("tenant_default_models", "provider_name")
|
||||
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
|
||||
cls.migrate_db_records("provider_model_settings", "provider_name")
|
||||
cls.migrate_db_records("load_balancing_model_configs", "provider_name")
|
||||
cls.migrate_datasets()
|
||||
cls.migrate_db_records("embeddings", "provider_name") # large table
|
||||
cls.migrate_db_records("dataset_collection_bindings", "provider_name")
|
||||
cls.migrate_db_records("tool_builtin_providers", "provider")
|
||||
|
||||
@classmethod
|
||||
def migrate_datasets(cls) -> None:
|
||||
table_name = "datasets"
|
||||
provider_column_name = "embedding_model_provider"
|
||||
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
failed_ids = []
|
||||
while True:
|
||||
sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
|
||||
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
|
||||
limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql))
|
||||
|
||||
current_iter_count = 0
|
||||
for i in rs:
|
||||
record_id = str(i.id)
|
||||
provider_name = str(i.provider_name)
|
||||
retrieval_model = i.retrieval_model
|
||||
print(type(retrieval_model))
|
||||
|
||||
if record_id in failed_ids:
|
||||
continue
|
||||
|
||||
retrieval_model_changed = False
|
||||
if retrieval_model:
|
||||
if (
|
||||
"reranking_model" in retrieval_model
|
||||
and "reranking_provider_name" in retrieval_model["reranking_model"]
|
||||
and retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||
and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||
):
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrating {table_name} {record_id} "
|
||||
f"(reranking_provider_name: "
|
||||
f"{retrieval_model['reranking_model']['reranking_provider_name']})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
retrieval_model["reranking_model"]["reranking_provider_name"] = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
|
||||
)
|
||||
retrieval_model_changed = True
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# update provider name append with "langgenius/{provider_name}/{provider_name}"
|
||||
params = {"record_id": record_id}
|
||||
update_retrieval_model_sql = ""
|
||||
if retrieval_model and retrieval_model_changed:
|
||||
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
|
||||
params["retrieval_model"] = json.dumps(retrieval_model)
|
||||
|
||||
sql = f"""update {table_name}
|
||||
set {provider_column_name} =
|
||||
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
|
||||
{update_retrieval_model_sql}
|
||||
where id = :record_id"""
|
||||
conn.execute(db.text(sql), params)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
failed_ids.append(record_id)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
|
||||
)
|
||||
continue
|
||||
|
||||
current_iter_count += 1
|
||||
processed_count += 1
|
||||
|
||||
if not current_iter_count:
|
||||
break
|
||||
|
||||
click.echo(
|
||||
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
failed_ids = []
|
||||
while True:
|
||||
sql = f"""select id, {provider_column_name} as provider_name from {table_name}
|
||||
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
|
||||
limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql))
|
||||
|
||||
current_iter_count = 0
|
||||
for i in rs:
|
||||
current_iter_count += 1
|
||||
processed_count += 1
|
||||
record_id = str(i.id)
|
||||
provider_name = str(i.provider_name)
|
||||
|
||||
if record_id in failed_ids:
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# update provider name append with "langgenius/{provider_name}/{provider_name}"
|
||||
sql = f"""update {table_name}
|
||||
set {provider_column_name} =
|
||||
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
|
||||
where id = :record_id"""
|
||||
conn.execute(db.text(sql), {"record_id": record_id})
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
failed_ids.append(record_id)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
|
||||
)
|
||||
continue
|
||||
|
||||
if not current_iter_count:
|
||||
break
|
||||
|
||||
click.echo(
|
||||
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
|
||||
)
|
||||
128
api/services/plugin/dependencies_analysis.py
Normal file
128
api/services/plugin/dependencies_analysis.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from core.helper import marketplace
|
||||
from core.plugin.entities.plugin import GenericProviderID, PluginDependency, PluginInstallationSource
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
|
||||
|
||||
class DependenciesAnalysisService:
|
||||
@classmethod
|
||||
def analyze_tool_dependency(cls, tool_id: str) -> str:
|
||||
"""
|
||||
Analyze the dependency of a tool.
|
||||
|
||||
Convert the tool id to the plugin_id
|
||||
"""
|
||||
try:
|
||||
tool_provider_id = GenericProviderID(tool_id)
|
||||
if tool_id in ["jina", "siliconflow"]:
|
||||
tool_provider_id.plugin_name = tool_provider_id.plugin_name + "_tool"
|
||||
return tool_provider_id.plugin_id
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def analyze_model_provider_dependency(cls, model_provider_id: str) -> str:
|
||||
"""
|
||||
Analyze the dependency of a model provider.
|
||||
|
||||
Convert the model provider id to the plugin_id
|
||||
"""
|
||||
try:
|
||||
generic_provider_id = GenericProviderID(model_provider_id)
|
||||
if model_provider_id == "google":
|
||||
generic_provider_id.plugin_name = "gemini"
|
||||
|
||||
return generic_provider_id.plugin_id
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def get_leaked_dependencies(cls, tenant_id: str, dependencies: list[PluginDependency]) -> list[PluginDependency]:
|
||||
"""
|
||||
Check dependencies, returns the leaked dependencies in current workspace
|
||||
"""
|
||||
required_plugin_unique_identifiers = []
|
||||
for dependency in dependencies:
|
||||
required_plugin_unique_identifiers.append(dependency.value.plugin_unique_identifier)
|
||||
|
||||
manager = PluginInstallationManager()
|
||||
|
||||
# get leaked dependencies
|
||||
missing_plugins = manager.fetch_missing_dependencies(tenant_id, required_plugin_unique_identifiers)
|
||||
missing_plugin_unique_identifiers = {plugin.plugin_unique_identifier: plugin for plugin in missing_plugins}
|
||||
|
||||
leaked_dependencies = []
|
||||
for dependency in dependencies:
|
||||
unique_identifier = dependency.value.plugin_unique_identifier
|
||||
if unique_identifier in missing_plugin_unique_identifiers:
|
||||
leaked_dependencies.append(
|
||||
PluginDependency(
|
||||
type=dependency.type,
|
||||
value=dependency.value,
|
||||
current_identifier=missing_plugin_unique_identifiers[unique_identifier].current_identifier,
|
||||
)
|
||||
)
|
||||
|
||||
return leaked_dependencies
|
||||
|
||||
@classmethod
|
||||
def generate_dependencies(cls, tenant_id: str, dependencies: list[str]) -> list[PluginDependency]:
|
||||
"""
|
||||
Generate dependencies through the list of plugin ids
|
||||
"""
|
||||
dependencies = list(set(dependencies))
|
||||
manager = PluginInstallationManager()
|
||||
plugins = manager.fetch_plugin_installation_by_ids(tenant_id, dependencies)
|
||||
result = []
|
||||
for plugin in plugins:
|
||||
if plugin.source == PluginInstallationSource.Github:
|
||||
result.append(
|
||||
PluginDependency(
|
||||
type=PluginDependency.Type.Github,
|
||||
value=PluginDependency.Github(
|
||||
repo=plugin.meta["repo"],
|
||||
version=plugin.meta["version"],
|
||||
package=plugin.meta["package"],
|
||||
github_plugin_unique_identifier=plugin.plugin_unique_identifier,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif plugin.source == PluginInstallationSource.Marketplace:
|
||||
result.append(
|
||||
PluginDependency(
|
||||
type=PluginDependency.Type.Marketplace,
|
||||
value=PluginDependency.Marketplace(
|
||||
marketplace_plugin_unique_identifier=plugin.plugin_unique_identifier
|
||||
),
|
||||
)
|
||||
)
|
||||
elif plugin.source == PluginInstallationSource.Package:
|
||||
result.append(
|
||||
PluginDependency(
|
||||
type=PluginDependency.Type.Package,
|
||||
value=PluginDependency.Package(plugin_unique_identifier=plugin.plugin_unique_identifier),
|
||||
)
|
||||
)
|
||||
elif plugin.source == PluginInstallationSource.Remote:
|
||||
raise ValueError(
|
||||
f"You used a remote plugin: {plugin.plugin_unique_identifier} in the app, please remove it first"
|
||||
" if you want to export the DSL."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin source: {plugin.source}")
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def generate_latest_dependencies(cls, dependencies: list[str]) -> list[PluginDependency]:
|
||||
"""
|
||||
Generate the latest version of dependencies
|
||||
"""
|
||||
dependencies = list(set(dependencies))
|
||||
deps = marketplace.batch_fetch_plugin_manifests(dependencies)
|
||||
return [
|
||||
PluginDependency(
|
||||
type=PluginDependency.Type.Marketplace,
|
||||
value=PluginDependency.Marketplace(marketplace_plugin_unique_identifier=dep.latest_package_identifier),
|
||||
)
|
||||
for dep in deps
|
||||
]
|
||||
66
api/services/plugin/endpoint_service.py
Normal file
66
api/services/plugin/endpoint_service.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from core.plugin.manager.endpoint import PluginEndpointManager
|
||||
|
||||
|
||||
class EndpointService:
|
||||
@classmethod
|
||||
def create_endpoint(cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict):
|
||||
return PluginEndpointManager().create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_endpoints(cls, tenant_id: str, user_id: str, page: int, page_size: int):
|
||||
return PluginEndpointManager().list_endpoints(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_endpoints_for_single_plugin(cls, tenant_id: str, user_id: str, plugin_id: str, page: int, page_size: int):
|
||||
return PluginEndpointManager().list_endpoints_for_single_plugin(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
|
||||
return PluginEndpointManager().update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
return PluginEndpointManager().delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def enable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
return PluginEndpointManager().enable_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def disable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
return PluginEndpointManager().disable_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
0
api/services/plugin/github_service.py
Normal file
0
api/services/plugin/github_service.py
Normal file
523
api/services/plugin/plugin_migration.py
Normal file
523
api/services/plugin/plugin_migration.py
Normal file
@@ -0,0 +1,523 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import click
|
||||
import tqdm
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.helper import marketplace
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from models.account import Tenant
|
||||
from models.engine import db
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.tools import BuiltinToolProvider
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
excluded_providers = ["time", "audio", "code", "webscraper"]
|
||||
|
||||
|
||||
class PluginMigration:
|
||||
@classmethod
|
||||
def extract_plugins(cls, filepath: str, workers: int) -> None:
|
||||
"""
|
||||
Migrate plugin.
|
||||
"""
|
||||
import concurrent.futures
|
||||
from threading import Lock
|
||||
|
||||
click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
|
||||
ended_at = datetime.datetime.now()
|
||||
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
|
||||
current_time = started_at
|
||||
|
||||
with Session(db.engine) as session:
|
||||
total_tenant_count = session.query(Tenant.id).count()
|
||||
|
||||
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
||||
|
||||
handled_tenant_count = 0
|
||||
file_lock = Lock()
|
||||
counter_lock = Lock()
|
||||
|
||||
thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
|
||||
with flask_app.app_context():
|
||||
nonlocal handled_tenant_count
|
||||
try:
|
||||
plugins = cls.extract_installed_plugin_ids(tenant_id)
|
||||
# Use lock when writing to file
|
||||
with file_lock:
|
||||
with open(filepath, "a") as f:
|
||||
f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
|
||||
|
||||
# Use lock when updating counter
|
||||
with counter_lock:
|
||||
nonlocal handled_tenant_count
|
||||
handled_tenant_count += 1
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{datetime.datetime.now()}] "
|
||||
f"Processed {handled_tenant_count} tenants "
|
||||
f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
|
||||
f"{handled_tenant_count}/{total_tenant_count}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process tenant {tenant_id}")
|
||||
|
||||
futures = []
|
||||
|
||||
while current_time < ended_at:
|
||||
click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white"))
|
||||
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
|
||||
interval = datetime.timedelta(days=1)
|
||||
# Process tenants in this batch
|
||||
with Session(db.engine) as session:
|
||||
# Calculate tenant count in next batch with current interval
|
||||
# Try different intervals until we find one with a reasonable tenant count
|
||||
test_intervals = [
|
||||
datetime.timedelta(days=1),
|
||||
datetime.timedelta(hours=12),
|
||||
datetime.timedelta(hours=6),
|
||||
datetime.timedelta(hours=3),
|
||||
datetime.timedelta(hours=1),
|
||||
]
|
||||
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
.filter(Tenant.created_at.between(current_time, current_time + test_interval))
|
||||
.count()
|
||||
)
|
||||
if tenant_count <= 100:
|
||||
interval = test_interval
|
||||
break
|
||||
else:
|
||||
# If all intervals have too many tenants, use minimum interval
|
||||
interval = datetime.timedelta(hours=1)
|
||||
|
||||
# Adjust interval to target ~100 tenants per batch
|
||||
if tenant_count > 0:
|
||||
# Scale interval based on ratio to target count
|
||||
interval = min(
|
||||
datetime.timedelta(days=1), # Max 1 day
|
||||
max(
|
||||
datetime.timedelta(hours=1), # Min 1 hour
|
||||
interval * (100 / tenant_count), # Scale to target 100
|
||||
),
|
||||
)
|
||||
|
||||
batch_end = min(current_time + interval, ended_at)
|
||||
|
||||
rs = (
|
||||
session.query(Tenant.id)
|
||||
.filter(Tenant.created_at.between(current_time, batch_end))
|
||||
.order_by(Tenant.created_at)
|
||||
)
|
||||
|
||||
tenants = []
|
||||
for row in rs:
|
||||
tenant_id = str(row.id)
|
||||
try:
|
||||
tenants.append(tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
futures.append(
|
||||
thread_pool.submit(
|
||||
process_tenant,
|
||||
current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
current_time = batch_end
|
||||
|
||||
# wait for all threads to finish
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
@classmethod
|
||||
def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract installed plugin ids.
|
||||
"""
|
||||
tools = cls.extract_tool_tables(tenant_id)
|
||||
models = cls.extract_model_tables(tenant_id)
|
||||
workflows = cls.extract_workflow_tables(tenant_id)
|
||||
apps = cls.extract_app_tables(tenant_id)
|
||||
|
||||
return list({*tools, *models, *workflows, *apps})
|
||||
|
||||
@classmethod
|
||||
def extract_model_tables(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract model tables.
|
||||
|
||||
"""
|
||||
models: list[str] = []
|
||||
table_pairs = [
|
||||
("providers", "provider_name"),
|
||||
("provider_models", "provider_name"),
|
||||
("provider_orders", "provider_name"),
|
||||
("tenant_default_models", "provider_name"),
|
||||
("tenant_preferred_model_providers", "provider_name"),
|
||||
("provider_model_settings", "provider_name"),
|
||||
("load_balancing_model_configs", "provider_name"),
|
||||
]
|
||||
|
||||
for table, column in table_pairs:
|
||||
models.extend(cls.extract_model_table(tenant_id, table, column))
|
||||
|
||||
# duplicate models
|
||||
models = list(set(models))
|
||||
|
||||
return models
|
||||
|
||||
@classmethod
|
||||
def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract model table.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
rs = session.execute(
|
||||
db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
|
||||
)
|
||||
result = []
|
||||
for row in rs:
|
||||
provider_name = str(row[0])
|
||||
if provider_name and "/" not in provider_name:
|
||||
if provider_name == "google":
|
||||
provider_name = "gemini"
|
||||
|
||||
result.append(DEFAULT_PLUGIN_ID + "/" + provider_name)
|
||||
elif provider_name:
|
||||
result.append(provider_name)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract tool tables.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
if "/" not in row.provider:
|
||||
result.append(DEFAULT_PLUGIN_ID + "/" + row.provider)
|
||||
else:
|
||||
result.append(row.provider)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _handle_builtin_tool_provider(cls, provider_name: str) -> str:
|
||||
"""
|
||||
Handle builtin tool provider.
|
||||
"""
|
||||
if provider_name == "jina":
|
||||
provider_name = "jina_tool"
|
||||
elif provider_name == "siliconflow":
|
||||
provider_name = "siliconflow_tool"
|
||||
elif provider_name == "stepfun":
|
||||
provider_name = "stepfun_tool"
|
||||
|
||||
if "/" not in provider_name:
|
||||
return DEFAULT_PLUGIN_ID + "/" + provider_name
|
||||
else:
|
||||
return provider_name
|
||||
|
||||
@classmethod
|
||||
def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract workflow tables, only ToolNode is required.
|
||||
"""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
graph = row.graph_dict
|
||||
# get nodes
|
||||
nodes = graph.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
data = node.get("data", {})
|
||||
if data.get("type") == "tool":
|
||||
provider_name = data.get("provider_name")
|
||||
provider_type = data.get("provider_type")
|
||||
if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
|
||||
provider_name = cls._handle_builtin_tool_provider(provider_name)
|
||||
result.append(provider_name)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def extract_app_tables(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract app tables.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
apps = session.query(App).filter(App.tenant_id == tenant_id).all()
|
||||
if not apps:
|
||||
return []
|
||||
|
||||
agent_app_model_config_ids = [
|
||||
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
|
||||
]
|
||||
|
||||
rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
agent_config = row.agent_mode_dict
|
||||
if "tools" in agent_config and isinstance(agent_config["tools"], list):
|
||||
for tool in agent_config["tools"]:
|
||||
if isinstance(tool, dict):
|
||||
try:
|
||||
tool_entity = AgentToolEntity(**tool)
|
||||
if (
|
||||
tool_entity.provider_type == ToolProviderType.BUILT_IN.value
|
||||
and tool_entity.provider_id not in excluded_providers
|
||||
):
|
||||
result.append(cls._handle_builtin_tool_provider(tool_entity.provider_id))
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process tool {tool}")
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
|
||||
"""
|
||||
Fetch plugin unique identifier using plugin id.
|
||||
"""
|
||||
plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
|
||||
if not plugin_manifest:
|
||||
return None
|
||||
|
||||
return plugin_manifest[0].latest_package_identifier
|
||||
|
||||
@classmethod
|
||||
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
|
||||
"""
|
||||
Extract unique plugins.
|
||||
"""
|
||||
Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
|
||||
|
||||
@classmethod
|
||||
def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
|
||||
plugins: dict[str, str] = {}
|
||||
plugin_ids = []
|
||||
plugin_not_exist = []
|
||||
logger.info(f"Extracting unique plugins from {extracted_plugins}")
|
||||
with open(extracted_plugins) as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
new_plugin_ids = data.get("plugins", [])
|
||||
for plugin_id in new_plugin_ids:
|
||||
if plugin_id not in plugin_ids:
|
||||
plugin_ids.append(plugin_id)
|
||||
|
||||
def fetch_plugin(plugin_id):
|
||||
try:
|
||||
unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id)
|
||||
if unique_identifier:
|
||||
plugins[plugin_id] = unique_identifier
|
||||
else:
|
||||
plugin_not_exist.append(plugin_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to fetch plugin unique identifier for {plugin_id}")
|
||||
plugin_not_exist.append(plugin_id)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids)))
|
||||
|
||||
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
|
||||
|
||||
@classmethod
|
||||
def install_plugins(cls, extracted_plugins: str, output_file: str) -> None:
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
|
||||
plugins = cls.extract_unique_plugins(extracted_plugins)
|
||||
not_installed = []
|
||||
plugin_install_failed = []
|
||||
|
||||
# use a fake tenant id to install all the plugins
|
||||
fake_tenant_id = uuid4().hex
|
||||
logger.info(f"Installing {len(plugins['plugins'])} plugin instances for fake tenant {fake_tenant_id}")
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=40)
|
||||
|
||||
response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
|
||||
if response.get("failed"):
|
||||
plugin_install_failed.extend(response.get("failed", []))
|
||||
|
||||
def install(tenant_id: str, plugin_ids: list[str]) -> None:
|
||||
logger.info(f"Installing {len(plugin_ids)} plugins for tenant {tenant_id}")
|
||||
# at most 64 plugins one batch
|
||||
for i in range(0, len(plugin_ids), 64):
|
||||
batch_plugin_ids = plugin_ids[i : i + 64]
|
||||
batch_plugin_identifiers = [plugins["plugins"][plugin_id] for plugin_id in batch_plugin_ids]
|
||||
manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
batch_plugin_identifiers,
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": identifier,
|
||||
}
|
||||
for identifier in batch_plugin_identifiers
|
||||
],
|
||||
)
|
||||
|
||||
with open(extracted_plugins) as f:
|
||||
"""
|
||||
Read line by line, and install plugins for each tenant.
|
||||
"""
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
tenant_id = data.get("tenant_id")
|
||||
plugin_ids = data.get("plugins", [])
|
||||
current_not_installed = {
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_not_exist": [],
|
||||
}
|
||||
# get plugin unique identifier
|
||||
for plugin_id in plugin_ids:
|
||||
unique_identifier = plugins.get(plugin_id)
|
||||
if unique_identifier:
|
||||
current_not_installed["plugin_not_exist"].append(plugin_id)
|
||||
|
||||
if current_not_installed["plugin_not_exist"]:
|
||||
not_installed.append(current_not_installed)
|
||||
|
||||
thread_pool.submit(install, tenant_id, plugin_ids)
|
||||
|
||||
thread_pool.shutdown(wait=True)
|
||||
|
||||
logger.info("Uninstall plugins")
|
||||
|
||||
# get installation
|
||||
try:
|
||||
installation = manager.list_plugins(fake_tenant_id)
|
||||
while installation:
|
||||
for plugin in installation:
|
||||
manager.uninstall(fake_tenant_id, plugin.installation_id)
|
||||
|
||||
installation = manager.list_plugins(fake_tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to get installation for tenant {fake_tenant_id}")
|
||||
|
||||
Path(output_file).write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"not_installed": not_installed,
|
||||
"plugin_install_failed": plugin_install_failed,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def handle_plugin_instance_install(
|
||||
cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Install plugins for a tenant.
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
|
||||
# download all the plugins and upload
|
||||
thread_pool = ThreadPoolExecutor(max_workers=10)
|
||||
futures = []
|
||||
|
||||
for plugin_id, plugin_identifier in plugin_identifiers_map.items():
|
||||
|
||||
def download_and_upload(tenant_id, plugin_id, plugin_identifier):
|
||||
plugin_package = marketplace.download_plugin_pkg(plugin_identifier)
|
||||
if not plugin_package:
|
||||
raise Exception(f"Failed to download plugin {plugin_identifier}")
|
||||
|
||||
# upload
|
||||
manager.upload_pkg(tenant_id, plugin_package, verify_signature=True)
|
||||
|
||||
futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier))
|
||||
|
||||
# Wait for all downloads to complete
|
||||
for future in futures:
|
||||
future.result() # This will raise any exceptions that occurred
|
||||
|
||||
thread_pool.shutdown(wait=True)
|
||||
success = []
|
||||
failed = []
|
||||
|
||||
reverse_map = {v: k for k, v in plugin_identifiers_map.items()}
|
||||
|
||||
# at most 8 plugins one batch
|
||||
for i in range(0, len(plugin_identifiers_map), 8):
|
||||
batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 8]
|
||||
batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids]
|
||||
|
||||
try:
|
||||
response = manager.install_from_identifiers(
|
||||
tenant_id=tenant_id,
|
||||
identifiers=batch_plugin_identifiers,
|
||||
source=PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": identifier,
|
||||
}
|
||||
for identifier in batch_plugin_identifiers
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
# add to failed
|
||||
failed.extend(batch_plugin_identifiers)
|
||||
continue
|
||||
|
||||
if response.all_installed:
|
||||
success.extend(batch_plugin_identifiers)
|
||||
continue
|
||||
|
||||
task_id = response.task_id
|
||||
done = False
|
||||
while not done:
|
||||
status = manager.fetch_plugin_installation_task(tenant_id, task_id)
|
||||
if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
|
||||
for plugin in status.plugins:
|
||||
if plugin.status == PluginInstallTaskStatus.Success:
|
||||
success.append(reverse_map[plugin.plugin_unique_identifier])
|
||||
else:
|
||||
failed.append(reverse_map[plugin.plugin_unique_identifier])
|
||||
logger.error(
|
||||
f"Failed to install plugin {plugin.plugin_unique_identifier}, error: {plugin.message}"
|
||||
)
|
||||
|
||||
done = True
|
||||
else:
|
||||
time.sleep(1)
|
||||
|
||||
return {"success": success, "failed": failed}
|
||||
34
api/services/plugin/plugin_permission_service.py
Normal file
34
api/services/plugin/plugin_permission_service.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
class PluginPermissionService:
|
||||
@staticmethod
|
||||
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
||||
with Session(db.engine) as session:
|
||||
return session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
|
||||
@staticmethod
|
||||
def change_permission(
|
||||
tenant_id: str,
|
||||
install_permission: TenantPluginPermission.InstallPermission,
|
||||
debug_permission: TenantPluginPermission.DebugPermission,
|
||||
):
|
||||
with Session(db.engine) as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
)
|
||||
if not permission:
|
||||
permission = TenantPluginPermission(
|
||||
tenant_id=tenant_id, install_permission=install_permission, debug_permission=debug_permission
|
||||
)
|
||||
|
||||
session.add(permission)
|
||||
else:
|
||||
permission.install_permission = install_permission
|
||||
permission.debug_permission = debug_permission
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
301
api/services/plugin/plugin_service.py
Normal file
301
api/services/plugin/plugin_service.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from mimetypes import guess_type
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import marketplace
|
||||
from core.helper.download import download_with_size_limit
|
||||
from core.helper.marketplace import download_plugin_pkg
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
GenericProviderID,
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
PluginInstallation,
|
||||
PluginInstallationSource,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse
|
||||
from core.plugin.manager.asset import PluginAssetManager
|
||||
from core.plugin.manager.debugging import PluginDebuggingManager
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginService:
|
||||
@staticmethod
|
||||
def get_debugging_key(tenant_id: str) -> str:
|
||||
"""
|
||||
get the debugging key of the tenant
|
||||
"""
|
||||
manager = PluginDebuggingManager()
|
||||
return manager.get_debugging_key(tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def list(tenant_id: str) -> list[PluginEntity]:
|
||||
"""
|
||||
list all plugins of the tenant
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
plugin_ids = [plugin.plugin_id for plugin in plugins if plugin.source == PluginInstallationSource.Marketplace]
|
||||
try:
|
||||
manifests = {
|
||||
manifest.plugin_id: manifest for manifest in marketplace.batch_fetch_plugin_manifests(plugin_ids)
|
||||
}
|
||||
except Exception:
|
||||
manifests = {}
|
||||
logger.exception("failed to fetch plugin manifests")
|
||||
|
||||
for plugin in plugins:
|
||||
if plugin.source == PluginInstallationSource.Marketplace:
|
||||
if plugin.plugin_id in manifests:
|
||||
# set latest_version
|
||||
plugin.latest_version = manifests[plugin.plugin_id].latest_version
|
||||
plugin.latest_unique_identifier = manifests[plugin.plugin_id].latest_package_identifier
|
||||
|
||||
return plugins
|
||||
|
||||
@staticmethod
|
||||
def list_installations_from_ids(tenant_id: str, ids: Sequence[str]) -> Sequence[PluginInstallation]:
|
||||
"""
|
||||
List plugin installations from ids
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.fetch_plugin_installation_by_ids(tenant_id, ids)
|
||||
|
||||
@staticmethod
|
||||
def get_asset(tenant_id: str, asset_file: str) -> tuple[bytes, str]:
|
||||
"""
|
||||
get the asset file of the plugin
|
||||
"""
|
||||
manager = PluginAssetManager()
|
||||
# guess mime type
|
||||
mime_type, _ = guess_type(asset_file)
|
||||
return manager.fetch_asset(tenant_id, asset_file), mime_type or "application/octet-stream"
|
||||
|
||||
@staticmethod
|
||||
def check_plugin_unique_identifier(tenant_id: str, plugin_unique_identifier: str) -> bool:
|
||||
"""
|
||||
check if the plugin unique identifier is already installed by other tenant
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.fetch_plugin_by_identifier(tenant_id, plugin_unique_identifier)
|
||||
|
||||
@staticmethod
|
||||
def fetch_plugin_manifest(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
|
||||
"""
|
||||
Fetch plugin manifest
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
|
||||
|
||||
@staticmethod
|
||||
def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
|
||||
"""
|
||||
Fetch plugin installation tasks
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
|
||||
|
||||
@staticmethod
|
||||
def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask:
|
||||
manager = PluginInstallationManager()
|
||||
return manager.fetch_plugin_installation_task(tenant_id, task_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_install_task(tenant_id: str, task_id: str) -> bool:
|
||||
"""
|
||||
Delete a plugin installation task
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.delete_plugin_installation_task(tenant_id, task_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_all_install_task_items(
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete all plugin installation task items
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.delete_all_plugin_installation_task_items(tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_install_task_item(tenant_id: str, task_id: str, identifier: str) -> bool:
|
||||
"""
|
||||
Delete a plugin installation task item
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.delete_plugin_installation_task_item(tenant_id, task_id, identifier)
|
||||
|
||||
@staticmethod
|
||||
def upgrade_plugin_with_marketplace(
|
||||
tenant_id: str, original_plugin_unique_identifier: str, new_plugin_unique_identifier: str
|
||||
):
|
||||
"""
|
||||
Upgrade plugin with marketplace
|
||||
"""
|
||||
if original_plugin_unique_identifier == new_plugin_unique_identifier:
|
||||
raise ValueError("you should not upgrade plugin with the same plugin")
|
||||
|
||||
# check if plugin pkg is already downloaded
|
||||
manager = PluginInstallationManager()
|
||||
|
||||
try:
|
||||
manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier)
|
||||
# already downloaded, skip, and record install event
|
||||
marketplace.record_install_plugin_event(new_plugin_unique_identifier)
|
||||
except Exception:
|
||||
# plugin not installed, download and upload pkg
|
||||
pkg = download_plugin_pkg(new_plugin_unique_identifier)
|
||||
manager.upload_pkg(tenant_id, pkg, verify_signature=False)
|
||||
|
||||
return manager.upgrade_plugin(
|
||||
tenant_id,
|
||||
original_plugin_unique_identifier,
|
||||
new_plugin_unique_identifier,
|
||||
PluginInstallationSource.Marketplace,
|
||||
{
|
||||
"plugin_unique_identifier": new_plugin_unique_identifier,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def upgrade_plugin_with_github(
|
||||
tenant_id: str,
|
||||
original_plugin_unique_identifier: str,
|
||||
new_plugin_unique_identifier: str,
|
||||
repo: str,
|
||||
version: str,
|
||||
package: str,
|
||||
):
|
||||
"""
|
||||
Upgrade plugin with github
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.upgrade_plugin(
|
||||
tenant_id,
|
||||
original_plugin_unique_identifier,
|
||||
new_plugin_unique_identifier,
|
||||
PluginInstallationSource.Github,
|
||||
{
|
||||
"repo": repo,
|
||||
"version": version,
|
||||
"package": package,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginUploadResponse:
|
||||
"""
|
||||
Upload plugin package files
|
||||
|
||||
returns: plugin_unique_identifier
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.upload_pkg(tenant_id, pkg, verify_signature)
|
||||
|
||||
@staticmethod
|
||||
def upload_pkg_from_github(
|
||||
tenant_id: str, repo: str, version: str, package: str, verify_signature: bool = False
|
||||
) -> PluginUploadResponse:
|
||||
"""
|
||||
Install plugin from github release package files,
|
||||
returns plugin_unique_identifier
|
||||
"""
|
||||
pkg = download_with_size_limit(
|
||||
f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE
|
||||
)
|
||||
|
||||
manager = PluginInstallationManager()
|
||||
return manager.upload_pkg(
|
||||
tenant_id,
|
||||
pkg,
|
||||
verify_signature,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def upload_bundle(
|
||||
tenant_id: str, bundle: bytes, verify_signature: bool = False
|
||||
) -> Sequence[PluginBundleDependency]:
|
||||
"""
|
||||
Upload a plugin bundle and return the dependencies.
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.upload_bundle(tenant_id, bundle, verify_signature)
|
||||
|
||||
@staticmethod
|
||||
def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]):
|
||||
manager = PluginInstallationManager()
|
||||
return manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
plugin_unique_identifiers,
|
||||
PluginInstallationSource.Package,
|
||||
[{}],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str):
|
||||
"""
|
||||
Install plugin from github release package files,
|
||||
returns plugin_unique_identifier
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
[plugin_unique_identifier],
|
||||
PluginInstallationSource.Github,
|
||||
[
|
||||
{
|
||||
"repo": repo,
|
||||
"version": version,
|
||||
"package": package,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def install_from_marketplace_pkg(
|
||||
tenant_id: str, plugin_unique_identifiers: Sequence[str], verify_signature: bool = False
|
||||
):
|
||||
"""
|
||||
Install plugin from marketplace package files,
|
||||
returns installation task id
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
|
||||
# check if already downloaded
|
||||
for plugin_unique_identifier in plugin_unique_identifiers:
|
||||
try:
|
||||
manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
|
||||
# already downloaded, skip
|
||||
except Exception:
|
||||
# plugin not installed, download and upload pkg
|
||||
pkg = download_plugin_pkg(plugin_unique_identifier)
|
||||
manager.upload_pkg(tenant_id, pkg, verify_signature)
|
||||
|
||||
return manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
plugin_unique_identifiers,
|
||||
PluginInstallationSource.Marketplace,
|
||||
[
|
||||
{
|
||||
"plugin_unique_identifier": plugin_unique_identifier,
|
||||
}
|
||||
for plugin_unique_identifier in plugin_unique_identifiers
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
|
||||
manager = PluginInstallationManager()
|
||||
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||
|
||||
@staticmethod
|
||||
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
|
||||
"""
|
||||
Check if the tools exist
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.check_tools_existence(tenant_id, provider_ids)
|
||||
@@ -1,24 +1,24 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from httpx import get
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ApiProviderSchemaType,
|
||||
ToolCredentialsOption,
|
||||
ToolProviderCredentials,
|
||||
)
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
@@ -41,28 +41,28 @@ class ApiToolManageService:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
credentials_schema = [
|
||||
ToolProviderCredentials(
|
||||
ProviderConfig(
|
||||
name="auth_type",
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
required=True,
|
||||
default="none",
|
||||
options=[
|
||||
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
|
||||
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ProviderConfig.Option(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
|
||||
],
|
||||
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
ProviderConfig(
|
||||
name="api_key_header",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
|
||||
default="api_key",
|
||||
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
ProviderConfig(
|
||||
name="api_key_value",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
|
||||
default="",
|
||||
@@ -84,17 +84,14 @@ class ApiToolManageService:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def convert_schema_to_tool_bundles(
|
||||
schema: str, extra_info: Optional[dict] = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
convert schema to tool bundles
|
||||
|
||||
:return: the list of tool bundles, description
|
||||
"""
|
||||
try:
|
||||
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
||||
return tool_bundles
|
||||
return ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@@ -167,8 +164,14 @@ class ApiToolManageService:
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# encrypt credentials
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
db_provider.credentials_str = json.dumps(encrypted_credentials)
|
||||
|
||||
db.session.add(db_provider)
|
||||
@@ -198,18 +201,18 @@ class ApiToolManageService:
|
||||
|
||||
# try to parse schema, avoid SSRF attack
|
||||
ApiToolManageService.parser_api_schema(schema)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("parse api schema error")
|
||||
raise ValueError("invalid schema, please check the url you provided")
|
||||
|
||||
return {"schema": schema}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
list api tool provider tools
|
||||
"""
|
||||
provider = (
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
@@ -225,8 +228,9 @@ class ApiToolManageService:
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool_bundle,
|
||||
tenant_id=tenant_id,
|
||||
labels=labels,
|
||||
)
|
||||
for tool_bundle in provider.tools
|
||||
@@ -293,16 +297,21 @@ class ApiToolManageService:
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# get original credentials if exists
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
|
||||
credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
provider.credentials_str = json.dumps(credentials)
|
||||
|
||||
db.session.add(provider)
|
||||
@@ -363,7 +372,7 @@ class ApiToolManageService:
|
||||
|
||||
try:
|
||||
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ValueError("invalid schema")
|
||||
|
||||
# get tool bundle
|
||||
@@ -407,8 +416,13 @@ class ApiToolManageService:
|
||||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
for name, value in credentials.items():
|
||||
@@ -419,20 +433,20 @@ class ApiToolManageService:
|
||||
provider_controller.validate_credentials_format(credentials)
|
||||
# get tool
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
runtime_tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
result = runtime_tool.validate_credentials(credentials, parameters)
|
||||
result = tool.validate_credentials(credentials, parameters)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
return {"result": result or "empty response"}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list api tools
|
||||
"""
|
||||
@@ -441,7 +455,7 @@ class ApiToolManageService:
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for provider in db_providers:
|
||||
# convert provider controller to user provider
|
||||
@@ -453,13 +467,13 @@ class ApiToolManageService:
|
||||
user_provider.labels = labels
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(user_provider)
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_provider)
|
||||
|
||||
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
|
||||
tools = provider_controller.get_tools(tenant_id=tenant_id)
|
||||
|
||||
for tool in tools or []:
|
||||
user_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,19 +2,19 @@ import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from extensions.ext_database import db
|
||||
from models.tools import BuiltinToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@@ -24,36 +24,38 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BuiltinToolManageService:
|
||||
@staticmethod
|
||||
def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
|
||||
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
list builtin tool provider tools
|
||||
|
||||
:param user_id: the id of the user
|
||||
:param tenant_id: the id of the tenant
|
||||
:param provider: the name of the provider
|
||||
|
||||
:return: the list of tools
|
||||
"""
|
||||
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
tool_provider_configurations = ToolConfigurationManager(
|
||||
tenant_id=tenant_id, provider_controller=provider_controller
|
||||
tool_provider_configurations = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
# check if user has added the provider
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
# get credentials
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
|
||||
credentials = tool_provider_configurations.decrypt(credentials)
|
||||
|
||||
result: list[UserTool] = []
|
||||
result: list[ToolApiEntity] = []
|
||||
for tool in tools or []:
|
||||
result.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
@@ -64,14 +66,47 @@ class BuiltinToolManageService:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(provider_name):
|
||||
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
|
||||
"""
|
||||
get builtin tool provider info
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tool_provider_configurations = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
# check if user has added the provider
|
||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
# get credentials
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_provider_configurations.decrypt(credentials)
|
||||
|
||||
entity = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=builtin_provider,
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
entity.original_credentials = {}
|
||||
|
||||
return entity
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
|
||||
:param provider_name: the name of the provider
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name)
|
||||
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
|
||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
return jsonable_encoder(provider.get_credentials_schema())
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(
|
||||
@@ -81,31 +116,38 @@ class BuiltinToolManageService:
|
||||
update builtin tool provider
|
||||
"""
|
||||
# get if the provider exists
|
||||
stmt = select(BuiltinToolProvider).where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
provider = session.scalar(stmt)
|
||||
provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
|
||||
|
||||
try:
|
||||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
# get original credentials if exists
|
||||
if provider is not None:
|
||||
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
# validate credentials
|
||||
provider_controller.validate_credentials(credentials)
|
||||
provider_controller.validate_credentials(user_id, credentials)
|
||||
# encrypt credentials
|
||||
credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
||||
except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
except (
|
||||
PluginDaemonClientSideError,
|
||||
ToolProviderNotFoundError,
|
||||
ToolNotFoundError,
|
||||
ToolProviderCredentialValidationError,
|
||||
) as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
if provider is None:
|
||||
@@ -117,14 +159,14 @@ class BuiltinToolManageService:
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
)
|
||||
|
||||
session.add(provider)
|
||||
|
||||
db.session.add(provider)
|
||||
else:
|
||||
provider.encrypted_credentials = json.dumps(credentials)
|
||||
|
||||
# delete cache
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
db.session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@@ -132,21 +174,19 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
get builtin tool provider credentials
|
||||
"""
|
||||
provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
|
||||
|
||||
if provider is None:
|
||||
if provider_obj is None:
|
||||
return {}
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider.provider)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
credentials = tool_configuration.decrypt(provider_obj.credentials)
|
||||
credentials = tool_configuration.mask_tool_credentials(credentials)
|
||||
return credentials
|
||||
|
||||
@@ -155,24 +195,22 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
|
||||
|
||||
if provider is None:
|
||||
if provider_obj is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.delete(provider_obj)
|
||||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return {"result": "success"}
|
||||
@@ -182,58 +220,59 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
get tool provider icon and it's mimetype
|
||||
"""
|
||||
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
|
||||
icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
|
||||
icon_bytes = Path(icon_path).read_bytes()
|
||||
|
||||
return icon_bytes, mime_type
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list builtin tools
|
||||
"""
|
||||
# get all builtin providers
|
||||
provider_controllers = ToolManager.list_builtin_providers()
|
||||
provider_controllers = ToolManager.list_builtin_providers(tenant_id)
|
||||
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
|
||||
# find provider
|
||||
find_provider = lambda provider: next(
|
||||
filter(lambda db_provider: db_provider.provider == provider, db_providers), None
|
||||
)
|
||||
# rewrite db_providers
|
||||
for db_provider in db_providers:
|
||||
db_provider.provider = str(ToolProviderID(db_provider.provider))
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
# find provider
|
||||
def find_provider(provider):
|
||||
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
|
||||
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
if provider_controller.identity is None:
|
||||
continue
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.identity.name),
|
||||
db_provider=find_provider(provider_controller.entity.identity.name),
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(user_builtin_provider)
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools or []:
|
||||
user_builtin_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_builtin_provider.original_credentials,
|
||||
@@ -246,3 +285,45 @@ class BuiltinToolManageService:
|
||||
raise e
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
|
||||
try:
|
||||
full_provider_name = provider_name
|
||||
provider_id_entity = GenericProviderID(provider_name)
|
||||
provider_name = provider_id_entity.provider_name
|
||||
if provider_id_entity.organization != "langgenius":
|
||||
provider_obj = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == full_provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
provider_obj = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == provider_name)
|
||||
| (BuiltinToolProvider.provider == full_provider_name),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider_obj is None:
|
||||
return None
|
||||
|
||||
provider_obj.provider = GenericProviderID(provider_obj.provider).to_string()
|
||||
return provider_obj
|
||||
except Exception:
|
||||
# it's an old provider without organization
|
||||
return (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == provider_name),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from core.tools.entities.api_entities import UserToolProviderTypeLiteral
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
@@ -9,17 +9,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolCommonService:
|
||||
@staticmethod
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
|
||||
"""
|
||||
list tool providers
|
||||
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
providers = ToolManager.user_list_providers(user_id, tenant_id, typ)
|
||||
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
|
||||
|
||||
# add icon
|
||||
for provider in providers:
|
||||
ToolTransformService.repack_provider(provider)
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
result = [provider.to_dict() for provider in providers]
|
||||
|
||||
|
||||
@@ -2,47 +2,57 @@ import json
|
||||
import logging
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.utils.configuration import ToolConfigurationManager
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolTransformService:
|
||||
@staticmethod
|
||||
def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
|
||||
@classmethod
|
||||
def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
|
||||
url_prefix = URL(dify_config.CONSOLE_API_URL) / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
|
||||
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
|
||||
|
||||
@classmethod
|
||||
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
|
||||
"""
|
||||
get tool provider icon url
|
||||
"""
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/"
|
||||
url_prefix = URL(dify_config.CONSOLE_API_URL) / "console" / "api" / "workspaces" / "current" / "tool-provider"
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN.value:
|
||||
return url_prefix + "builtin/" + provider_name + "/icon"
|
||||
return str(url_prefix / "builtin" / provider_name / "icon")
|
||||
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
|
||||
try:
|
||||
return cast(dict, json.loads(icon))
|
||||
except:
|
||||
if isinstance(icon, str):
|
||||
return cast(dict, json.loads(icon))
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(provider: Union[dict, UserToolProvider]):
|
||||
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
@@ -52,55 +62,52 @@ class ToolTransformService:
|
||||
provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
|
||||
)
|
||||
elif isinstance(provider, UserToolProvider):
|
||||
provider.icon = cast(
|
||||
str,
|
||||
ToolTransformService.get_tool_provider_icon_url(
|
||||
elif isinstance(provider, ToolProviderApiEntity):
|
||||
if provider.plugin_id:
|
||||
if isinstance(provider.icon, str):
|
||||
provider.icon = ToolTransformService.get_plugin_icon_url(
|
||||
tenant_id=tenant_id, filename=provider.icon
|
||||
)
|
||||
else:
|
||||
provider.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def builtin_provider_to_user_provider(
|
||||
provider_controller: BuiltinToolProviderController,
|
||||
cls,
|
||||
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
||||
db_provider: Optional[BuiltinToolProvider],
|
||||
decrypt_credentials: bool = True,
|
||||
) -> UserToolProvider:
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
if provider_controller.identity is None:
|
||||
raise ValueError("provider identity is None")
|
||||
|
||||
result = UserToolProvider(
|
||||
id=provider_controller.identity.name,
|
||||
author=provider_controller.identity.author,
|
||||
name=provider_controller.identity.name,
|
||||
description=I18nObject(
|
||||
en_US=provider_controller.identity.description.en_US,
|
||||
zh_Hans=provider_controller.identity.description.zh_Hans,
|
||||
pt_BR=provider_controller.identity.description.pt_BR,
|
||||
ja_JP=provider_controller.identity.description.ja_JP,
|
||||
),
|
||||
icon=provider_controller.identity.icon,
|
||||
label=I18nObject(
|
||||
en_US=provider_controller.identity.label.en_US,
|
||||
zh_Hans=provider_controller.identity.label.zh_Hans,
|
||||
pt_BR=provider_controller.identity.label.pt_BR,
|
||||
ja_JP=provider_controller.identity.label.ja_JP,
|
||||
),
|
||||
result = ToolProviderApiEntity(
|
||||
id=provider_controller.entity.identity.name,
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
description=provider_controller.entity.identity.description,
|
||||
icon=provider_controller.entity.identity.icon,
|
||||
label=provider_controller.entity.identity.label,
|
||||
type=ToolProviderType.BUILT_IN,
|
||||
masked_credentials={},
|
||||
is_team_authorization=False,
|
||||
plugin_id=None,
|
||||
tools=[],
|
||||
labels=provider_controller.tool_labels,
|
||||
)
|
||||
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
result.plugin_id = provider_controller.plugin_id
|
||||
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
|
||||
|
||||
# get credentials schema
|
||||
schema = provider_controller.get_credentials_schema()
|
||||
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
|
||||
|
||||
for name, value in schema.items():
|
||||
assert result.masked_credentials is not None, "masked credentials is None"
|
||||
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(str(value.type))
|
||||
if result.masked_credentials:
|
||||
result.masked_credentials[name] = ""
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
@@ -113,12 +120,15 @@ class ToolTransformService:
|
||||
credentials = db_provider.credentials
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
||||
decrypted_credentials = tool_configuration.decrypt(data=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
result.original_credentials = decrypted_credentials
|
||||
@@ -151,41 +161,35 @@ class ToolTransformService:
|
||||
|
||||
@staticmethod
|
||||
def workflow_provider_to_user_provider(
|
||||
provider_controller: WorkflowToolProviderController, labels: Optional[list[str]] = None
|
||||
provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
|
||||
):
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
if provider_controller.identity is None:
|
||||
raise ValueError("provider identity is None")
|
||||
|
||||
return UserToolProvider(
|
||||
return ToolProviderApiEntity(
|
||||
id=provider_controller.provider_id,
|
||||
author=provider_controller.identity.author,
|
||||
name=provider_controller.identity.name,
|
||||
description=I18nObject(
|
||||
en_US=provider_controller.identity.description.en_US,
|
||||
zh_Hans=provider_controller.identity.description.zh_Hans,
|
||||
),
|
||||
icon=provider_controller.identity.icon,
|
||||
label=I18nObject(
|
||||
en_US=provider_controller.identity.label.en_US,
|
||||
zh_Hans=provider_controller.identity.label.zh_Hans,
|
||||
),
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
description=provider_controller.entity.identity.description,
|
||||
icon=provider_controller.entity.identity.icon,
|
||||
label=provider_controller.entity.identity.label,
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
plugin_id=None,
|
||||
plugin_unique_identifier=None,
|
||||
tools=[],
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def api_provider_to_user_provider(
|
||||
cls,
|
||||
provider_controller: ApiToolProviderController,
|
||||
db_provider: ApiToolProvider,
|
||||
decrypt_credentials: bool = True,
|
||||
labels: Optional[list[str]] = None,
|
||||
) -> UserToolProvider:
|
||||
labels: list[str] | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
@@ -193,12 +197,16 @@ class ToolTransformService:
|
||||
if db_provider.user is None:
|
||||
raise ValueError(f"user is None for api provider {db_provider.id}")
|
||||
try:
|
||||
username = db_provider.user.name
|
||||
except Exception as e:
|
||||
user = db_provider.user
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
|
||||
username = user.name
|
||||
except Exception:
|
||||
logger.exception(f"failed to get user name for api provider {db_provider.id}")
|
||||
# add provider into providers
|
||||
credentials = db_provider.credentials
|
||||
result = UserToolProvider(
|
||||
result = ToolProviderApiEntity(
|
||||
id=db_provider.id,
|
||||
author=username,
|
||||
name=db_provider.name,
|
||||
@@ -212,6 +220,8 @@ class ToolTransformService:
|
||||
zh_Hans=db_provider.name,
|
||||
),
|
||||
type=ToolProviderType.API,
|
||||
plugin_id=None,
|
||||
plugin_unique_identifier=None,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
tools=[],
|
||||
@@ -220,39 +230,42 @@ class ToolTransformService:
|
||||
|
||||
if decrypt_credentials:
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
||||
decrypted_credentials = tool_configuration.decrypt(data=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def tool_to_user_tool(
|
||||
def convert_tool_entity_to_api_entity(
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
credentials: Optional[dict] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
labels: Optional[list[str]] = None,
|
||||
) -> UserTool:
|
||||
tenant_id: str,
|
||||
credentials: dict | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> ToolApiEntity:
|
||||
"""
|
||||
convert tool to user tool
|
||||
"""
|
||||
if isinstance(tool, Tool):
|
||||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
runtime=ToolRuntime(
|
||||
credentials=credentials or {},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
# get tool parameters
|
||||
parameters = tool.parameters or []
|
||||
parameters = tool.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = tool.get_runtime_parameters()
|
||||
# override parameters
|
||||
@@ -268,23 +281,21 @@ class ToolTransformService:
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
if tool.identity is None:
|
||||
raise ValueError("tool identity is None")
|
||||
|
||||
return UserTool(
|
||||
author=tool.identity.author,
|
||||
name=tool.identity.name,
|
||||
label=tool.identity.label,
|
||||
description=tool.description.human if tool.description else "", # type: ignore
|
||||
return ToolApiEntity(
|
||||
author=tool.entity.identity.author,
|
||||
name=tool.entity.identity.name,
|
||||
label=tool.entity.identity.label,
|
||||
description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
|
||||
output_schema=tool.entity.output_schema,
|
||||
parameters=current_parameters,
|
||||
labels=labels,
|
||||
labels=labels or [],
|
||||
)
|
||||
if isinstance(tool, ApiToolBundle):
|
||||
return UserTool(
|
||||
return ToolApiEntity(
|
||||
author=tool.author,
|
||||
name=tool.operation_id or "",
|
||||
label=I18nObject(en_US=tool.operation_id or "", zh_Hans=tool.operation_id or ""),
|
||||
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
|
||||
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
|
||||
parameters=tool.parameters,
|
||||
labels=labels,
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
@@ -36,7 +36,7 @@ class WorkflowToolManageService:
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: Optional[list[str]] = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
@@ -54,11 +54,12 @@ class WorkflowToolManageService:
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
|
||||
|
||||
app = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
|
||||
app: App | None = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_app_id} not found")
|
||||
|
||||
workflow = app.workflow
|
||||
workflow: Workflow | None = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
||||
|
||||
@@ -101,7 +102,7 @@ class WorkflowToolManageService:
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: Optional[list[str]] = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Update a workflow tool.
|
||||
@@ -133,7 +134,7 @@ class WorkflowToolManageService:
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} already exists")
|
||||
|
||||
workflow_tool_provider: Optional[WorkflowToolProvider] = (
|
||||
workflow_tool_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
@@ -142,14 +143,14 @@ class WorkflowToolManageService:
|
||||
if workflow_tool_provider is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
app: Optional[App] = (
|
||||
app: App | None = (
|
||||
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
|
||||
|
||||
workflow: Optional[Workflow] = app.workflow
|
||||
workflow: Workflow | None = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
|
||||
|
||||
@@ -178,7 +179,7 @@ class WorkflowToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
List workflow tools.
|
||||
:param user_id: the user id
|
||||
@@ -187,11 +188,11 @@ class WorkflowToolManageService:
|
||||
"""
|
||||
db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
tools = []
|
||||
tools: list[WorkflowToolProviderController] = []
|
||||
for provider in db_tools:
|
||||
try:
|
||||
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
|
||||
except:
|
||||
except Exception:
|
||||
# skip deleted tools
|
||||
pass
|
||||
|
||||
@@ -203,12 +204,13 @@ class WorkflowToolManageService:
|
||||
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=tool, labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
ToolTransformService.repack_provider(user_tool_provider)
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
continue
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_tool_provider)
|
||||
user_tool_provider.tools = [
|
||||
ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=labels.get(tool.provider_id, []))
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool.get_tools(tenant_id)[0],
|
||||
labels=labels.get(tool.provider_id, []),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
]
|
||||
result.append(user_tool_provider)
|
||||
|
||||
@@ -239,42 +241,12 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: Optional[WorkflowToolProvider] = (
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
workflow_app: Optional[App] = (
|
||||
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
return {
|
||||
"name": db_tool.name,
|
||||
"label": db_tool.label,
|
||||
"workflow_tool_id": db_tool.id,
|
||||
"workflow_app_id": db_tool.app_id,
|
||||
"icon": json.loads(db_tool.icon),
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
"synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
|
||||
@@ -285,26 +257,38 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: Optional[WorkflowToolProvider] = (
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.first()
|
||||
)
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:db_tool: the database tool
|
||||
:return: the tool
|
||||
"""
|
||||
if db_tool is None:
|
||||
raise ValueError(f"Tool {workflow_app_id} not found")
|
||||
raise ValueError("Tool not found")
|
||||
|
||||
workflow_app: Optional[App] = (
|
||||
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
workflow_app: App | None = (
|
||||
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
|
||||
)
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
workflow = workflow_app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
raise ValueError(f"Tool {workflow_app_id} not found")
|
||||
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
|
||||
if len(workflow_tools) == 0:
|
||||
raise ValueError(f"Tool {db_tool.id} not found")
|
||||
|
||||
return {
|
||||
"name": db_tool.name,
|
||||
@@ -314,15 +298,17 @@ class WorkflowToolManageService:
|
||||
"icon": json.loads(db_tool.icon),
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
"tool": ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool.get_tools(db_tool.tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool),
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
"synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
|
||||
"synced": workflow.version == db_tool.version,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
List workflow tool provider tools.
|
||||
:param user_id: the user id
|
||||
@@ -330,7 +316,7 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tool: Optional[WorkflowToolProvider] = (
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
@@ -340,8 +326,14 @@ class WorkflowToolManageService:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
|
||||
if len(workflow_tools) == 0:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
return [ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool))]
|
||||
return [
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool.get_tools(db_tool.tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import contexts
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
@@ -119,6 +121,9 @@ class WorkflowRunService:
|
||||
"""
|
||||
workflow_run = self.get_workflow_run(app_model, run_id)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
if not workflow_run:
|
||||
return []
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import desc
|
||||
@@ -13,11 +13,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.variables import Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import ErrorStrategy
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.event.types import NodeEvent
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
@@ -246,14 +247,69 @@ class WorkflowService:
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
node_instance, generator = WorkflowEntry.single_step_run(
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
)
|
||||
node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=app_model.tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
workflow_node_execution.app_id = app_model.id
|
||||
workflow_node_execution.created_by = account.id
|
||||
workflow_node_execution.workflow_id = draft_workflow.id
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.run_free_node(
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
user_inputs=user_inputs,
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_node_run_result(
|
||||
self,
|
||||
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
|
||||
start_at: float,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Handle node run result
|
||||
|
||||
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
|
||||
:param start_at: float
|
||||
:param tenant_id: str
|
||||
:param node_id: str
|
||||
"""
|
||||
try:
|
||||
node_instance, generator = getter()
|
||||
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
@@ -303,9 +359,7 @@ class WorkflowService:
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = app_model.tenant_id
|
||||
workflow_node_execution.app_id = app_model.id
|
||||
workflow_node_execution.workflow_id = draft_workflow.id
|
||||
workflow_node_execution.tenant_id = tenant_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
|
||||
workflow_node_execution.index = 1
|
||||
workflow_node_execution.node_id = node_id
|
||||
@@ -313,7 +367,6 @@ class WorkflowService:
|
||||
workflow_node_execution.title = node_instance.node_data.title
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
|
||||
workflow_node_execution.created_by = account.id
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
if run_succeeded and node_run_result:
|
||||
@@ -342,9 +395,6 @@ class WorkflowService:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
|
||||
|
||||
Reference in New Issue
Block a user