mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
Feat/workflow phase2 (#4687)
This commit is contained in:
450
api/services/tools/api_tools_manage_service.py
Normal file
450
api/services/tools/api_tools_manage_service.py
Normal file
@@ -0,0 +1,450 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from httpx import get
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ApiProviderSchemaType,
|
||||
ToolCredentialsOption,
|
||||
ToolProviderCredentials,
|
||||
)
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiToolManageService:
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse api schema to tool bundle
|
||||
"""
|
||||
try:
|
||||
warnings = {}
|
||||
try:
|
||||
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
|
||||
credentials_schema = [
|
||||
ToolProviderCredentials(
|
||||
name='auth_type',
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
required=True,
|
||||
default='none',
|
||||
options=[
|
||||
ToolCredentialsOption(value='none', label=I18nObject(
|
||||
en_US='None',
|
||||
zh_Hans='无'
|
||||
)),
|
||||
ToolCredentialsOption(value='api_key', label=I18nObject(
|
||||
en_US='Api Key',
|
||||
zh_Hans='Api Key'
|
||||
)),
|
||||
],
|
||||
placeholder=I18nObject(
|
||||
en_US='Select auth type',
|
||||
zh_Hans='选择认证方式'
|
||||
)
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
name='api_key_header',
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(
|
||||
en_US='Enter api key header',
|
||||
zh_Hans='输入 api key header,如:X-API-KEY'
|
||||
),
|
||||
default='api_key',
|
||||
help=I18nObject(
|
||||
en_US='HTTP header name for api key',
|
||||
zh_Hans='HTTP 头部字段名,用于传递 api key'
|
||||
)
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
name='api_key_value',
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(
|
||||
en_US='Enter api key',
|
||||
zh_Hans='输入 api key'
|
||||
),
|
||||
default=''
|
||||
),
|
||||
]
|
||||
|
||||
return jsonable_encoder({
|
||||
'schema_type': schema_type,
|
||||
'parameters_schema': tool_bundles,
|
||||
'credentials_schema': credentials_schema,
|
||||
'warning': warnings
|
||||
})
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
|
||||
@staticmethod
|
||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
|
||||
"""
|
||||
convert schema to tool bundles
|
||||
|
||||
:return: the list of tool bundles, description
|
||||
"""
|
||||
try:
|
||||
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
||||
return tool_bundles
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
|
||||
@staticmethod
|
||||
def create_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,
|
||||
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
|
||||
):
|
||||
"""
|
||||
create api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema}')
|
||||
|
||||
# check if the provider exists
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
|
||||
if provider is not None:
|
||||
raise ValueError(f'provider {provider_name} already exists')
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
if len(tool_bundles) > 100:
|
||||
raise ValueError('the number of apis should be less than 100')
|
||||
|
||||
# create db provider
|
||||
db_provider = ApiToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=provider_name,
|
||||
icon=json.dumps(icon),
|
||||
schema=schema,
|
||||
description=extra_info.get('description', ''),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str={},
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer
|
||||
)
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
|
||||
# load tools into provider entity
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# encrypt credentials
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
||||
db_provider.credentials_str = json.dumps(encrypted_credentials)
|
||||
|
||||
db.session.add(db_provider)
|
||||
db.session.commit()
|
||||
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
@staticmethod
|
||||
def get_api_tool_provider_remote_schema(
|
||||
user_id: str, tenant_id: str, url: str
|
||||
):
|
||||
"""
|
||||
get api tool provider remote schema
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
|
||||
"Accept": "*/*",
|
||||
}
|
||||
|
||||
try:
|
||||
response = get(url, headers=headers, timeout=10)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f'Got status code {response.status_code}')
|
||||
schema = response.text
|
||||
|
||||
# try to parse schema, avoid SSRF attack
|
||||
ApiToolManageService.parser_api_schema(schema)
|
||||
except Exception as e:
|
||||
logger.error(f"parse api schema error: {str(e)}")
|
||||
raise ValueError('invalid schema, please check the url you provided')
|
||||
|
||||
return {
|
||||
'schema': schema
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tool_provider_tools(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
) -> list[UserTool]:
|
||||
"""
|
||||
list api tool provider tools
|
||||
"""
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
).first()
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider}')
|
||||
|
||||
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool_bundle,
|
||||
labels=labels,
|
||||
) for tool_bundle in provider.tools
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def update_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
|
||||
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
|
||||
):
|
||||
"""
|
||||
update api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema}')
|
||||
|
||||
# check if the provider exists
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == original_provider,
|
||||
).first()
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'api provider {provider_name} does not exists')
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
# update db provider
|
||||
provider.name = provider_name
|
||||
provider.icon = json.dumps(icon)
|
||||
provider.schema = schema
|
||||
provider.description = extra_info.get('description', '')
|
||||
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
|
||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||
provider.privacy_policy = privacy_policy
|
||||
provider.custom_disclaimer = custom_disclaimer
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
|
||||
# load tools into provider entity
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# get original credentials if exists
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
|
||||
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
|
||||
credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
||||
provider.credentials_str = json.dumps(credentials)
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
@staticmethod
|
||||
def delete_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str
|
||||
):
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider_name}')
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
@staticmethod
|
||||
def get_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
):
|
||||
"""
|
||||
get api tool provider
|
||||
"""
|
||||
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def test_api_tool_preview(
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
schema_type: str,
|
||||
schema: str
|
||||
):
|
||||
"""
|
||||
test api tool before adding api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema_type}')
|
||||
|
||||
try:
|
||||
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
|
||||
except Exception as e:
|
||||
raise ValueError('invalid schema')
|
||||
|
||||
# get tool bundle
|
||||
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
|
||||
if tool_bundle is None:
|
||||
raise ValueError(f'invalid tool name {tool_name}')
|
||||
|
||||
db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
|
||||
if not db_provider:
|
||||
# create a fake db provider
|
||||
db_provider = ApiToolProvider(
|
||||
tenant_id='', user_id='', name='', icon='',
|
||||
schema=schema,
|
||||
description='',
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
|
||||
# load tools into provider entity
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
provider_controller=provider_controller
|
||||
)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = decrypted_credentials[name]
|
||||
|
||||
try:
|
||||
provider_controller.validate_credentials_format(credentials)
|
||||
# get tool
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
tool = tool.fork_tool_runtime(runtime={
|
||||
'credentials': credentials,
|
||||
'tenant_id': tenant_id,
|
||||
})
|
||||
result = tool.validate_credentials(credentials, parameters)
|
||||
except Exception as e:
|
||||
return { 'error': str(e) }
|
||||
|
||||
return { 'result': result or 'empty response' }
|
||||
|
||||
@staticmethod
|
||||
def list_api_tools(
|
||||
user_id: str, tenant_id: str
|
||||
) -> list[UserToolProvider]:
|
||||
"""
|
||||
list api tools
|
||||
"""
|
||||
# get all api providers
|
||||
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id
|
||||
).all() or []
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
|
||||
for provider in db_providers:
|
||||
# convert provider controller to user provider
|
||||
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
labels = ToolLabelManager.get_tool_labels(provider_controller)
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller,
|
||||
db_provider=provider,
|
||||
decrypt_credentials=True
|
||||
)
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(user_provider)
|
||||
|
||||
tools = provider_controller.get_tools(
|
||||
user_id=user_id, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
for tool in tools:
|
||||
user_provider.tools.append(ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_provider.original_credentials,
|
||||
labels=labels
|
||||
))
|
||||
|
||||
result.append(user_provider)
|
||||
|
||||
return result
|
||||
226
api/services/tools/builtin_tools_manage_service.py
Normal file
226
api/services/tools/builtin_tools_manage_service.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager
|
||||
from extensions.ext_database import db
|
||||
from models.tools import BuiltinToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BuiltinToolManageService:
|
||||
@staticmethod
|
||||
def list_builtin_tool_provider_tools(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
) -> list[UserTool]:
|
||||
"""
|
||||
list builtin tool provider tools
|
||||
"""
|
||||
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
# check if user has added the provider
|
||||
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
).first()
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
# get credentials
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
|
||||
|
||||
result = []
|
||||
for tool in tools:
|
||||
result.append(ToolTransformService.tool_to_user_tool(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller)
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(
|
||||
provider_name
|
||||
):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name)
|
||||
return jsonable_encoder([
|
||||
v for _, v in (provider.credentials_schema or {}).items()
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, credentials: dict
|
||||
):
|
||||
"""
|
||||
update builtin tool provider
|
||||
"""
|
||||
# get if the provider exists
|
||||
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
).first()
|
||||
|
||||
try:
|
||||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f'provider {provider_name} does not need credentials')
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
# get original credentials if exists
|
||||
if provider is not None:
|
||||
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
# validate credentials
|
||||
provider_controller.validate_credentials(credentials)
|
||||
# encrypt credentials
|
||||
credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
||||
except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
if provider is None:
|
||||
# create provider
|
||||
provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
)
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
else:
|
||||
provider.encrypted_credentials = json.dumps(credentials)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
):
|
||||
"""
|
||||
get builtin tool provider credentials
|
||||
"""
|
||||
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
).first()
|
||||
|
||||
if provider is None:
|
||||
return {}
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider.provider)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
credentials = tool_configuration.mask_tool_credentials(credentials)
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_builtin_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str
|
||||
):
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
).first()
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider_name}')
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_icon(
|
||||
provider: str
|
||||
):
|
||||
"""
|
||||
get tool provider icon and it's mimetype
|
||||
"""
|
||||
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
|
||||
with open(icon_path, 'rb') as f:
|
||||
icon_bytes = f.read()
|
||||
|
||||
return icon_bytes, mime_type
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(
|
||||
user_id: str, tenant_id: str
|
||||
) -> list[UserToolProvider]:
|
||||
"""
|
||||
list builtin tools
|
||||
"""
|
||||
# get all builtin providers
|
||||
provider_controllers = ToolManager.list_builtin_providers()
|
||||
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id
|
||||
).all() or []
|
||||
|
||||
# find provider
|
||||
find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.identity.name),
|
||||
decrypt_credentials=True
|
||||
)
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(user_builtin_provider)
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools:
|
||||
user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_builtin_provider.original_credentials,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller)
|
||||
))
|
||||
|
||||
result.append(user_builtin_provider)
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
8
api/services/tools/tool_labels_service.py
Normal file
8
api/services/tools/tool_labels_service.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from core.tools.entities.tool_entities import ToolLabel
|
||||
from core.tools.entities.values import default_tool_labels
|
||||
|
||||
|
||||
class ToolLabelsService:
|
||||
@classmethod
|
||||
def list_tool_labels(cls) -> list[ToolLabel]:
|
||||
return default_tool_labels
|
||||
29
api/services/tools/tools_manage_service.py
Normal file
29
api/services/tools/tools_manage_service.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import logging
|
||||
|
||||
from core.tools.entities.api_entities import UserToolProviderTypeLiteral
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolCommonService:
|
||||
@staticmethod
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
|
||||
"""
|
||||
list tool providers
|
||||
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
providers = ToolManager.user_list_providers(
|
||||
user_id, tenant_id, typ
|
||||
)
|
||||
|
||||
# add icon
|
||||
for provider in providers:
|
||||
ToolTransformService.repack_provider(provider)
|
||||
|
||||
result = [provider.to_dict() for provider in providers]
|
||||
|
||||
return result
|
||||
|
||||
288
api/services/tools/tools_transform_service.py
Normal file
288
api/services/tools/tools_transform_service.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.utils.configuration import ToolConfigurationManager
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolTransformService:
|
||||
@staticmethod
|
||||
def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
|
||||
"""
|
||||
get tool provider icon url
|
||||
"""
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ "/console/api/workspaces/current/tool-provider/")
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN.value:
|
||||
return url_prefix + 'builtin/' + provider_name + '/icon'
|
||||
elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]:
|
||||
try:
|
||||
return json.loads(icon)
|
||||
except:
|
||||
return {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
return ''
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(provider: Union[dict, UserToolProvider]):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
:param provider: the provider dict
|
||||
"""
|
||||
if isinstance(provider, dict) and 'icon' in provider:
|
||||
provider['icon'] = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider['type'],
|
||||
provider_name=provider['name'],
|
||||
icon=provider['icon']
|
||||
)
|
||||
elif isinstance(provider, UserToolProvider):
|
||||
provider.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value,
|
||||
provider_name=provider.name,
|
||||
icon=provider.icon
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def builtin_provider_to_user_provider(
|
||||
provider_controller: BuiltinToolProviderController,
|
||||
db_provider: Optional[BuiltinToolProvider],
|
||||
decrypt_credentials: bool = True,
|
||||
) -> UserToolProvider:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
result = UserToolProvider(
|
||||
id=provider_controller.identity.name,
|
||||
author=provider_controller.identity.author,
|
||||
name=provider_controller.identity.name,
|
||||
description=I18nObject(
|
||||
en_US=provider_controller.identity.description.en_US,
|
||||
zh_Hans=provider_controller.identity.description.zh_Hans,
|
||||
),
|
||||
icon=provider_controller.identity.icon,
|
||||
label=I18nObject(
|
||||
en_US=provider_controller.identity.label.en_US,
|
||||
zh_Hans=provider_controller.identity.label.zh_Hans,
|
||||
),
|
||||
type=ToolProviderType.BUILT_IN,
|
||||
masked_credentials={},
|
||||
is_team_authorization=False,
|
||||
tools=[],
|
||||
labels=provider_controller.tool_labels
|
||||
)
|
||||
|
||||
# get credentials schema
|
||||
schema = provider_controller.get_credentials_schema()
|
||||
for name, value in schema.items():
|
||||
result.masked_credentials[name] = \
|
||||
ToolProviderCredentials.CredentialsType.default(value.type)
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
result.is_team_authorization = True
|
||||
result.allow_delete = False
|
||||
elif db_provider:
|
||||
result.is_team_authorization = True
|
||||
|
||||
if decrypt_credentials:
|
||||
credentials = db_provider.credentials
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
provider_controller=provider_controller
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
result.original_credentials = decrypted_credentials
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def api_provider_to_controller(
|
||||
db_provider: ApiToolProvider,
|
||||
) -> ApiToolProviderController:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
# package tool provider controller
|
||||
controller = ApiToolProviderController.from_db(
|
||||
db_provider=db_provider,
|
||||
auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else
|
||||
ApiProviderAuthType.NONE
|
||||
)
|
||||
|
||||
return controller
|
||||
|
||||
@staticmethod
|
||||
def workflow_provider_to_controller(
|
||||
db_provider: WorkflowToolProvider
|
||||
) -> WorkflowToolProviderController:
|
||||
"""
|
||||
convert provider controller to provider
|
||||
"""
|
||||
return WorkflowToolProviderController.from_db(db_provider)
|
||||
|
||||
@staticmethod
|
||||
def workflow_provider_to_user_provider(
|
||||
provider_controller: WorkflowToolProviderController,
|
||||
labels: list[str] = None
|
||||
):
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
return UserToolProvider(
|
||||
id=provider_controller.provider_id,
|
||||
author=provider_controller.identity.author,
|
||||
name=provider_controller.identity.name,
|
||||
description=I18nObject(
|
||||
en_US=provider_controller.identity.description.en_US,
|
||||
zh_Hans=provider_controller.identity.description.zh_Hans,
|
||||
),
|
||||
icon=provider_controller.identity.icon,
|
||||
label=I18nObject(
|
||||
en_US=provider_controller.identity.label.en_US,
|
||||
zh_Hans=provider_controller.identity.label.zh_Hans,
|
||||
),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
tools=[],
|
||||
labels=labels or []
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def api_provider_to_user_provider(
|
||||
provider_controller: ApiToolProviderController,
|
||||
db_provider: ApiToolProvider,
|
||||
decrypt_credentials: bool = True,
|
||||
labels: list[str] = None
|
||||
) -> UserToolProvider:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
username = 'Anonymous'
|
||||
try:
|
||||
username = db_provider.user.name
|
||||
except Exception as e:
|
||||
logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}')
|
||||
# add provider into providers
|
||||
credentials = db_provider.credentials
|
||||
result = UserToolProvider(
|
||||
id=db_provider.id,
|
||||
author=username,
|
||||
name=db_provider.name,
|
||||
description=I18nObject(
|
||||
en_US=db_provider.description,
|
||||
zh_Hans=db_provider.description,
|
||||
),
|
||||
icon=db_provider.icon,
|
||||
label=I18nObject(
|
||||
en_US=db_provider.name,
|
||||
zh_Hans=db_provider.name,
|
||||
),
|
||||
type=ToolProviderType.API,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
tools=[],
|
||||
labels=labels or []
|
||||
)
|
||||
|
||||
if decrypt_credentials:
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
provider_controller=provider_controller
|
||||
)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def tool_to_user_tool(
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
credentials: dict = None,
|
||||
tenant_id: str = None,
|
||||
labels: list[str] = None
|
||||
) -> UserTool:
|
||||
"""
|
||||
convert tool to user tool
|
||||
"""
|
||||
if isinstance(tool, Tool):
|
||||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(runtime={
|
||||
'credentials': credentials,
|
||||
'tenant_id': tenant_id,
|
||||
})
|
||||
|
||||
# get tool parameters
|
||||
parameters = tool.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = tool.get_runtime_parameters() or []
|
||||
# override parameters
|
||||
current_parameters = parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return UserTool(
|
||||
author=tool.identity.author,
|
||||
name=tool.identity.name,
|
||||
label=tool.identity.label,
|
||||
description=tool.description.human,
|
||||
parameters=current_parameters,
|
||||
labels=labels
|
||||
)
|
||||
if isinstance(tool, ApiToolBundle):
|
||||
return UserTool(
|
||||
author=tool.author,
|
||||
name=tool.operation_id,
|
||||
label=I18nObject(
|
||||
en_US=tool.operation_id,
|
||||
zh_Hans=tool.operation_id
|
||||
),
|
||||
description=I18nObject(
|
||||
en_US=tool.summary or '',
|
||||
zh_Hans=tool.summary or ''
|
||||
),
|
||||
parameters=tool.parameters,
|
||||
labels=labels
|
||||
)
|
||||
326
api/services/tools/workflow_tools_manage_service.py
Normal file
326
api/services/tools/workflow_tools_manage_service.py
Normal file
@@ -0,0 +1,326 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import or_
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class WorkflowToolManageService:
|
||||
"""
|
||||
Service class for managing workflow tools.
|
||||
"""
|
||||
@classmethod
|
||||
def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str,
|
||||
label: str, icon: dict, description: str,
|
||||
parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
|
||||
"""
|
||||
Create a workflow tool.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param name: the name
|
||||
:param icon: the icon
|
||||
:param description: the description
|
||||
:param parameters: the parameters
|
||||
:param privacy_policy: the privacy policy
|
||||
:return: the created tool
|
||||
"""
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
# name or app_id
|
||||
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
).first()
|
||||
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists')
|
||||
|
||||
app: App = db.session.query(App).filter(
|
||||
App.id == workflow_app_id,
|
||||
App.tenant_id == tenant_id
|
||||
).first()
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f'App {workflow_app_id} not found')
|
||||
|
||||
workflow: Workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f'Workflow not found for app {workflow_app_id}')
|
||||
|
||||
workflow_tool_provider = WorkflowToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=workflow_app_id,
|
||||
name=name,
|
||||
label=label,
|
||||
icon=json.dumps(icon),
|
||||
description=description,
|
||||
parameter_configuration=json.dumps(parameters),
|
||||
privacy_policy=privacy_policy,
|
||||
version=workflow.version,
|
||||
)
|
||||
|
||||
try:
|
||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
db.session.add(workflow_tool_provider)
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
'result': 'success'
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str,
|
||||
name: str, label: str, icon: dict, description: str,
|
||||
parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
|
||||
"""
|
||||
Update a workflow tool.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param tool: the tool
|
||||
:return: the updated tool
|
||||
"""
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.name == name,
|
||||
WorkflowToolProvider.id != workflow_tool_id
|
||||
).first()
|
||||
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f'Tool with name {name} already exists')
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == workflow_tool_id
|
||||
).first()
|
||||
|
||||
if workflow_tool_provider is None:
|
||||
raise ValueError(f'Tool {workflow_tool_id} not found')
|
||||
|
||||
app: App = db.session.query(App).filter(
|
||||
App.id == workflow_tool_provider.app_id,
|
||||
App.tenant_id == tenant_id
|
||||
).first()
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f'App {workflow_tool_provider.app_id} not found')
|
||||
|
||||
workflow: Workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}')
|
||||
|
||||
workflow_tool_provider.name = name
|
||||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
workflow_tool_provider.description = description
|
||||
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
|
||||
workflow_tool_provider.privacy_policy = privacy_policy
|
||||
workflow_tool_provider.version = workflow.version
|
||||
workflow_tool_provider.updated_at = datetime.now()
|
||||
|
||||
try:
|
||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
db.session.add(workflow_tool_provider)
|
||||
db.session.commit()
|
||||
|
||||
if labels is not None:
|
||||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider),
|
||||
labels
|
||||
)
|
||||
|
||||
return {
|
||||
'result': 'success'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
"""
|
||||
List workflow tools.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tools = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id
|
||||
).all()
|
||||
|
||||
tools = []
|
||||
for provider in db_tools:
|
||||
try:
|
||||
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
|
||||
except:
|
||||
# skip deleted tools
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(tools)
|
||||
|
||||
result = []
|
||||
|
||||
for tool in tools:
|
||||
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=tool,
|
||||
labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
ToolTransformService.repack_provider(user_tool_provider)
|
||||
user_tool_provider.tools = [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
]
|
||||
result.append(user_tool_provider)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
|
||||
"""
|
||||
Delete a workflow tool.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_app_id: the workflow app id
|
||||
"""
|
||||
db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == workflow_tool_id
|
||||
).delete()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
'result': 'success'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == workflow_tool_id
|
||||
).first()
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f'Tool {workflow_tool_id} not found')
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(
|
||||
App.id == db_tool.app_id,
|
||||
App.tenant_id == tenant_id
|
||||
).first()
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f'App {db_tool.app_id} not found')
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
return {
|
||||
'name': db_tool.name,
|
||||
'label': db_tool.label,
|
||||
'workflow_tool_id': db_tool.id,
|
||||
'workflow_app_id': db_tool.app_id,
|
||||
'icon': json.loads(db_tool.icon),
|
||||
'description': db_tool.description,
|
||||
'parameters': jsonable_encoder(db_tool.parameter_configurations),
|
||||
'tool': ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
'synced': workflow_app.workflow.version == db_tool.version,
|
||||
'privacy_policy': db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == workflow_app_id
|
||||
).first()
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f'Tool {workflow_app_id} not found')
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(
|
||||
App.id == db_tool.app_id,
|
||||
App.tenant_id == tenant_id
|
||||
).first()
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f'App {db_tool.app_id} not found')
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
return {
|
||||
'name': db_tool.name,
|
||||
'label': db_tool.label,
|
||||
'workflow_tool_id': db_tool.id,
|
||||
'workflow_app_id': db_tool.app_id,
|
||||
'icon': json.loads(db_tool.icon),
|
||||
'description': db_tool.description,
|
||||
'parameters': jsonable_encoder(db_tool.parameter_configurations),
|
||||
'tool': ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
'synced': workflow_app.workflow.version == db_tool.version,
|
||||
'privacy_policy': db_tool.privacy_policy
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
|
||||
"""
|
||||
List workflow tool provider tools.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == workflow_tool_id
|
||||
).first()
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f'Tool {workflow_tool_id} not found')
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool)
|
||||
)
|
||||
]
|
||||
Reference in New Issue
Block a user