chore(api/services): apply ruff reformatting (#7599)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Bowen Liang
2024-08-26 13:43:57 +08:00
committed by GitHub
parent 979422cdc6
commit 17fd773a30
49 changed files with 2630 additions and 2655 deletions

View File

@@ -29,111 +29,107 @@ class ApiToolManageService:
@staticmethod
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
"""
parse api schema to tool bundle
parse api schema to tool bundle
"""
try:
warnings = {}
try:
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
except Exception as e:
raise ValueError(f'invalid schema: {str(e)}')
raise ValueError(f"invalid schema: {str(e)}")
credentials_schema = [
ToolProviderCredentials(
name='auth_type',
name="auth_type",
type=ToolProviderCredentials.CredentialsType.SELECT,
required=True,
default='none',
default="none",
options=[
ToolCredentialsOption(value='none', label=I18nObject(
en_US='None',
zh_Hans=''
)),
ToolCredentialsOption(value='api_key', label=I18nObject(
en_US='Api Key',
zh_Hans='Api Key'
)),
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="")),
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
],
placeholder=I18nObject(
en_US='Select auth type',
zh_Hans='选择认证方式'
)
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
),
ToolProviderCredentials(
name='api_key_header',
name="api_key_header",
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
required=False,
placeholder=I18nObject(
en_US='Enter api key header',
zh_Hans='输入 api key headerX-API-KEY'
),
default='api_key',
help=I18nObject(
en_US='HTTP header name for api key',
zh_Hans='HTTP 头部字段名,用于传递 api key'
)
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key headerX-API-KEY"),
default="api_key",
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
),
ToolProviderCredentials(
name='api_key_value',
name="api_key_value",
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
required=False,
placeholder=I18nObject(
en_US='Enter api key',
zh_Hans='输入 api key'
),
default=''
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
default="",
),
]
return jsonable_encoder({
'schema_type': schema_type,
'parameters_schema': tool_bundles,
'credentials_schema': credentials_schema,
'warning': warnings
})
return jsonable_encoder(
{
"schema_type": schema_type,
"parameters_schema": tool_bundles,
"credentials_schema": credentials_schema,
"warning": warnings,
}
)
except Exception as e:
raise ValueError(f'invalid schema: {str(e)}')
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
"""
convert schema to tool bundles
convert schema to tool bundles
:return: the list of tool bundles, description
:return: the list of tool bundles, description
"""
try:
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
return tool_bundles
except Exception as e:
raise ValueError(f'invalid schema: {str(e)}')
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def create_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
user_id: str,
tenant_id: str,
provider_name: str,
icon: dict,
credentials: dict,
schema_type: str,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
labels: list[str],
):
"""
create api tool provider
create api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f'invalid schema type {schema}')
raise ValueError(f"invalid schema type {schema}")
# check if the provider exists
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
).first()
provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is not None:
raise ValueError(f'provider {provider_name} already exists')
raise ValueError(f"provider {provider_name} already exists")
# parse openapi to tool bundle
extra_info = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
if len(tool_bundles) > 100:
raise ValueError('the number of apis should be less than 100')
raise ValueError("the number of apis should be less than 100")
# create db provider
db_provider = ApiToolProvider(
@@ -142,19 +138,19 @@ class ApiToolManageService:
name=provider_name,
icon=json.dumps(icon),
schema=schema,
description=extra_info.get('description', ''),
description=extra_info.get("description", ""),
schema_type_str=schema_type,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str={},
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer
custom_disclaimer=custom_disclaimer,
)
if 'auth_type' not in credentials:
raise ValueError('auth_type is required')
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
@@ -172,14 +168,12 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
return { 'result': 'success' }
return {"result": "success"}
@staticmethod
def get_api_tool_provider_remote_schema(
user_id: str, tenant_id: str, url: str
):
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
"""
get api tool provider remote schema
get api tool provider remote schema
"""
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
@@ -189,84 +183,98 @@ class ApiToolManageService:
try:
response = get(url, headers=headers, timeout=10)
if response.status_code != 200:
raise ValueError(f'Got status code {response.status_code}')
raise ValueError(f"Got status code {response.status_code}")
schema = response.text
# try to parse schema, avoid SSRF attack
ApiToolManageService.parser_api_schema(schema)
except Exception as e:
logger.error(f"parse api schema error: {str(e)}")
raise ValueError('invalid schema, please check the url you provided')
return {
'schema': schema
}
raise ValueError("invalid schema, please check the url you provided")
return {"schema": schema}
@staticmethod
def list_api_tool_provider_tools(
user_id: str, tenant_id: str, provider: str
) -> list[UserTool]:
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
"""
list api tool provider tools
list api tool provider tools
"""
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
).first()
provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
)
.first()
)
if provider is None:
raise ValueError(f'you have not added provider {provider}')
raise ValueError(f"you have not added provider {provider}")
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(controller)
return [
ToolTransformService.tool_to_user_tool(
tool_bundle,
labels=labels,
) for tool_bundle in provider.tools
)
for tool_bundle in provider.tools
]
@staticmethod
def update_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
user_id: str,
tenant_id: str,
provider_name: str,
original_provider: str,
icon: dict,
credentials: dict,
schema_type: str,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
labels: list[str],
):
"""
update api tool provider
update api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f'invalid schema type {schema}')
raise ValueError(f"invalid schema type {schema}")
# check if the provider exists
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider,
).first()
provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider,
)
.first()
)
if provider is None:
raise ValueError(f'api provider {provider_name} does not exists')
raise ValueError(f"api provider {provider_name} does not exists")
# parse openapi to tool bundle
extra_info = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
# update db provider
provider.name = provider_name
provider.icon = json.dumps(icon)
provider.schema = schema
provider.description = extra_info.get('description', '')
provider.description = extra_info.get("description", "")
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
provider.custom_disclaimer = custom_disclaimer
if 'auth_type' not in credentials:
raise ValueError('auth_type is required')
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
@@ -295,84 +303,91 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
return { 'result': 'success' }
return {"result": "success"}
@staticmethod
def delete_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str
):
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
"""
delete tool provider
delete tool provider
"""
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
).first()
provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is None:
raise ValueError(f'you have not added provider {provider_name}')
raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider)
db.session.commit()
return { 'result': 'success' }
return {"result": "success"}
@staticmethod
def get_api_tool_provider(
user_id: str, tenant_id: str, provider: str
):
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
"""
get api tool provider
get api tool provider
"""
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
@staticmethod
def test_api_tool_preview(
tenant_id: str,
tenant_id: str,
provider_name: str,
tool_name: str,
credentials: dict,
parameters: dict,
schema_type: str,
schema: str
tool_name: str,
credentials: dict,
parameters: dict,
schema_type: str,
schema: str,
):
"""
test api tool before adding api tool provider
test api tool before adding api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f'invalid schema type {schema_type}')
raise ValueError(f"invalid schema type {schema_type}")
try:
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
except Exception as e:
raise ValueError('invalid schema')
raise ValueError("invalid schema")
# get tool bundle
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
if tool_bundle is None:
raise ValueError(f'invalid tool name {tool_name}')
db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
).first()
raise ValueError(f"invalid tool name {tool_name}")
db_provider: ApiToolProvider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if not db_provider:
# create a fake db provider
db_provider = ApiToolProvider(
tenant_id='', user_id='', name='', icon='',
tenant_id="",
user_id="",
name="",
icon="",
schema=schema,
description='',
description="",
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str=json.dumps(credentials),
)
if 'auth_type' not in credentials:
raise ValueError('auth_type is required')
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
@@ -381,10 +396,7 @@ class ApiToolManageService:
# decrypt credentials
if db_provider.id:
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
provider_controller=provider_controller
)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
# check if the credential has changed, save the original credential
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
@@ -396,27 +408,27 @@ class ApiToolManageService:
provider_controller.validate_credentials_format(credentials)
# get tool
tool = provider_controller.get_tool(tool_name)
tool = tool.fork_tool_runtime(runtime={
'credentials': credentials,
'tenant_id': tenant_id,
})
tool = tool.fork_tool_runtime(
runtime={
"credentials": credentials,
"tenant_id": tenant_id,
}
)
result = tool.validate_credentials(credentials, parameters)
except Exception as e:
return { 'error': str(e) }
return { 'result': result or 'empty response' }
return {"error": str(e)}
return {"result": result or "empty response"}
@staticmethod
def list_api_tools(
user_id: str, tenant_id: str
) -> list[UserToolProvider]:
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
"""
list api tools
list api tools
"""
# get all api providers
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id
).all() or []
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
)
result: list[UserToolProvider] = []
@@ -425,26 +437,21 @@ class ApiToolManageService:
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(provider_controller)
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller,
db_provider=provider,
decrypt_credentials=True
provider_controller, db_provider=provider, decrypt_credentials=True
)
user_provider.labels = labels
# add icon
ToolTransformService.repack_provider(user_provider)
tools = provider_controller.get_tools(
user_id=user_id, tenant_id=tenant_id
)
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
for tool in tools:
user_provider.tools.append(ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id,
tool=tool,
credentials=user_provider.original_credentials,
labels=labels
))
user_provider.tools.append(
ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
)
)
result.append(user_provider)

View File

@@ -20,21 +20,25 @@ logger = logging.getLogger(__name__)
class BuiltinToolManageService:
@staticmethod
def list_builtin_tool_provider_tools(
user_id: str, tenant_id: str, provider: str
) -> list[UserTool]:
def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
"""
list builtin tool provider tools
list builtin tool provider tools
"""
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
tools = provider_controller.get_tools()
tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_provider_configurations = ToolConfigurationManager(
tenant_id=tenant_id, provider_controller=provider_controller
)
# check if user has added the provider
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
builtin_provider: BuiltinToolProvider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.first()
)
credentials = {}
if builtin_provider is not None:
@@ -44,47 +48,47 @@ class BuiltinToolManageService:
result = []
for tool in tools:
result.append(ToolTransformService.tool_to_user_tool(
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller)
))
result.append(
ToolTransformService.tool_to_user_tool(
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
return result
@staticmethod
def list_builtin_provider_credentials_schema(
provider_name
):
def list_builtin_provider_credentials_schema(provider_name):
"""
list builtin provider credentials schema
list builtin provider credentials schema
:return: the list of tool providers
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name)
return jsonable_encoder([
v for _, v in (provider.credentials_schema or {}).items()
])
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
@staticmethod
def update_builtin_tool_provider(
user_id: str, tenant_id: str, provider_name: str, credentials: dict
):
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
"""
update builtin tool provider
update builtin tool provider
"""
# get if the provider exists
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
).first()
provider: BuiltinToolProvider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
.first()
)
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
raise ValueError(f'provider {provider_name} does not need credentials')
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
# get original credentials if exists
if provider is not None:
@@ -121,19 +125,21 @@ class BuiltinToolManageService:
# delete cache
tool_configuration.delete_tool_credentials_cache()
return {'result': 'success'}
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_credentials(
user_id: str, tenant_id: str, provider: str
):
def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
"""
get builtin tool provider credentials
get builtin tool provider credentials
"""
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
provider: BuiltinToolProvider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.first()
)
if provider is None:
return {}
@@ -145,19 +151,21 @@ class BuiltinToolManageService:
return credentials
@staticmethod
def delete_builtin_tool_provider(
user_id: str, tenant_id: str, provider_name: str
):
def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
"""
delete tool provider
delete tool provider
"""
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
).first()
provider: BuiltinToolProvider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
.first()
)
if provider is None:
raise ValueError(f'you have not added provider {provider_name}')
raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider)
db.session.commit()
@@ -167,38 +175,36 @@ class BuiltinToolManageService:
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration.delete_tool_credentials_cache()
return {'result': 'success'}
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_icon(
provider: str
):
def get_builtin_tool_provider_icon(provider: str):
"""
get tool provider icon and it's mimetype
get tool provider icon and it's mimetype
"""
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
with open(icon_path, 'rb') as f:
with open(icon_path, "rb") as f:
icon_bytes = f.read()
return icon_bytes, mime_type
@staticmethod
def list_builtin_tools(
user_id: str, tenant_id: str
) -> list[UserToolProvider]:
def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
"""
list builtin tools
list builtin tools
"""
# get all builtin providers
provider_controllers = ToolManager.list_builtin_providers()
# get all user added providers
db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id
).all() or []
db_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
)
# find provider
find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
find_provider = lambda provider: next(
filter(lambda db_provider: db_provider.provider == provider, db_providers), None
)
result: list[UserToolProvider] = []
@@ -209,7 +215,7 @@ class BuiltinToolManageService:
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider_controller,
name_func=lambda x: x.identity.name
name_func=lambda x: x.identity.name,
):
continue
@@ -217,7 +223,7 @@ class BuiltinToolManageService:
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.identity.name),
decrypt_credentials=True
decrypt_credentials=True,
)
# add icon
@@ -225,12 +231,14 @@ class BuiltinToolManageService:
tools = provider_controller.get_tools()
for tool in tools:
user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller)
))
user_builtin_provider.tools.append(
ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
result.append(user_builtin_provider)
except Exception as e:

View File

@@ -5,4 +5,4 @@ from core.tools.entities.values import default_tool_labels
class ToolLabelsService:
@classmethod
def list_tool_labels(cls) -> list[ToolLabel]:
return default_tool_labels
return default_tool_labels

View File

@@ -11,13 +11,11 @@ class ToolCommonService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
"""
list tool providers
list tool providers
:return: the list of tool providers
:return: the list of tool providers
"""
providers = ToolManager.user_list_providers(
user_id, tenant_id, typ
)
providers = ToolManager.user_list_providers(user_id, tenant_id, typ)
# add icon
for provider in providers:
@@ -26,4 +24,3 @@ class ToolCommonService:
result = [provider.to_dict() for provider in providers]
return result

View File

@@ -22,46 +22,39 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
logger = logging.getLogger(__name__)
class ToolTransformService:
@staticmethod
def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
"""
get tool provider icon url
get tool provider icon url
"""
url_prefix = (dify_config.CONSOLE_API_URL
+ "/console/api/workspaces/current/tool-provider/")
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/"
if provider_type == ToolProviderType.BUILT_IN.value:
return url_prefix + 'builtin/' + provider_name + '/icon'
return url_prefix + "builtin/" + provider_name + "/icon"
elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]:
try:
return json.loads(icon)
except:
return {
"background": "#252525",
"content": "\ud83d\ude01"
}
return ''
return {"background": "#252525", "content": "\ud83d\ude01"}
return ""
@staticmethod
def repack_provider(provider: Union[dict, UserToolProvider]):
"""
repack provider
repack provider
:param provider: the provider dict
:param provider: the provider dict
"""
if isinstance(provider, dict) and 'icon' in provider:
provider['icon'] = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider['type'],
provider_name=provider['name'],
icon=provider['icon']
if isinstance(provider, dict) and "icon" in provider:
provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
)
elif isinstance(provider, UserToolProvider):
provider.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value,
provider_name=provider.name,
icon=provider.icon
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
)
@staticmethod
@@ -92,14 +85,13 @@ class ToolTransformService:
masked_credentials={},
is_team_authorization=False,
tools=[],
labels=provider_controller.tool_labels
labels=provider_controller.tool_labels,
)
# get credentials schema
schema = provider_controller.get_credentials_schema()
for name, value in schema.items():
result.masked_credentials[name] = \
ToolProviderCredentials.CredentialsType.default(value.type)
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
# check if the provider need credentials
if not provider_controller.need_credentials:
@@ -113,8 +105,7 @@ class ToolTransformService:
# init tool configuration
tool_configuration = ToolConfigurationManager(
tenant_id=db_provider.tenant_id,
provider_controller=provider_controller
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
@@ -124,7 +115,7 @@ class ToolTransformService:
result.original_credentials = decrypted_credentials
return result
@staticmethod
def api_provider_to_controller(
db_provider: ApiToolProvider,
@@ -135,25 +126,23 @@ class ToolTransformService:
# package tool provider controller
controller = ApiToolProviderController.from_db(
db_provider=db_provider,
auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else
ApiProviderAuthType.NONE
auth_type=ApiProviderAuthType.API_KEY
if db_provider.credentials["auth_type"] == "api_key"
else ApiProviderAuthType.NONE,
)
return controller
@staticmethod
def workflow_provider_to_controller(
db_provider: WorkflowToolProvider
) -> WorkflowToolProviderController:
def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
"""
convert provider controller to provider
"""
return WorkflowToolProviderController.from_db(db_provider)
@staticmethod
def workflow_provider_to_user_provider(
provider_controller: WorkflowToolProviderController,
labels: list[str] = None
provider_controller: WorkflowToolProviderController, labels: list[str] = None
):
"""
convert provider controller to user provider
@@ -175,7 +164,7 @@ class ToolTransformService:
masked_credentials={},
is_team_authorization=True,
tools=[],
labels=labels or []
labels=labels or [],
)
@staticmethod
@@ -183,16 +172,16 @@ class ToolTransformService:
provider_controller: ApiToolProviderController,
db_provider: ApiToolProvider,
decrypt_credentials: bool = True,
labels: list[str] = None
labels: list[str] = None,
) -> UserToolProvider:
"""
convert provider controller to user provider
"""
username = 'Anonymous'
username = "Anonymous"
try:
username = db_provider.user.name
except Exception as e:
logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}')
logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
# add provider into providers
credentials = db_provider.credentials
result = UserToolProvider(
@@ -212,14 +201,13 @@ class ToolTransformService:
masked_credentials={},
is_team_authorization=True,
tools=[],
labels=labels or []
labels=labels or [],
)
if decrypt_credentials:
# init tool configuration
tool_configuration = ToolConfigurationManager(
tenant_id=db_provider.tenant_id,
provider_controller=provider_controller
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
)
# decrypt the credentials and mask the credentials
@@ -229,23 +217,25 @@ class ToolTransformService:
result.masked_credentials = masked_credentials
return result
@staticmethod
def tool_to_user_tool(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
credentials: dict = None,
tool: Union[ApiToolBundle, WorkflowTool, Tool],
credentials: dict = None,
tenant_id: str = None,
labels: list[str] = None
labels: list[str] = None,
) -> UserTool:
"""
convert tool to user tool
"""
if isinstance(tool, Tool):
# fork tool runtime
tool = tool.fork_tool_runtime(runtime={
'credentials': credentials,
'tenant_id': tenant_id,
})
tool = tool.fork_tool_runtime(
runtime={
"credentials": credentials,
"tenant_id": tenant_id,
}
)
# get tool parameters
parameters = tool.parameters or []
@@ -270,20 +260,14 @@ class ToolTransformService:
label=tool.identity.label,
description=tool.description.human,
parameters=current_parameters,
labels=labels
labels=labels,
)
if isinstance(tool, ApiToolBundle):
return UserTool(
author=tool.author,
name=tool.operation_id,
label=I18nObject(
en_US=tool.operation_id,
zh_Hans=tool.operation_id
),
description=I18nObject(
en_US=tool.summary or '',
zh_Hans=tool.summary or ''
),
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
parameters=tool.parameters,
labels=labels
)
labels=labels,
)

View File

@@ -19,10 +19,21 @@ class WorkflowToolManageService:
"""
Service class for managing workflow tools.
"""
@classmethod
def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str,
label: str, icon: dict, description: str,
parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
def create_workflow_tool(
cls,
user_id: str,
tenant_id: str,
workflow_app_id: str,
name: str,
label: str,
icon: dict,
description: str,
parameters: list[dict],
privacy_policy: str = "",
labels: list[str] = None,
) -> dict:
"""
Create a workflow tool.
:param user_id: the user id
@@ -38,27 +49,28 @@ class WorkflowToolManageService:
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id)
).first()
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
.filter(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
)
.first()
)
if existing_workflow_tool_provider is not None:
raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists')
app: App = db.session.query(App).filter(
App.id == workflow_app_id,
App.tenant_id == tenant_id
).first()
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
if app is None:
raise ValueError(f'App {workflow_app_id} not found')
raise ValueError(f"App {workflow_app_id} not found")
workflow: Workflow = app.workflow
if workflow is None:
raise ValueError(f'Workflow not found for app {workflow_app_id}')
raise ValueError(f"Workflow not found for app {workflow_app_id}")
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
@@ -76,19 +88,26 @@ class WorkflowToolManageService:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
db.session.add(workflow_tool_provider)
db.session.commit()
return {
'result': 'success'
}
return {"result": "success"}
@classmethod
def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str,
name: str, label: str, icon: dict, description: str,
parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
def update_workflow_tool(
cls,
user_id: str,
tenant_id: str,
workflow_tool_id: str,
name: str,
label: str,
icon: dict,
description: str,
parameters: list[dict],
privacy_policy: str = "",
labels: list[str] = None,
) -> dict:
"""
Update a workflow tool.
:param user_id: the user id
@@ -106,35 +125,39 @@ class WorkflowToolManageService:
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id
).first()
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
.filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
)
.first()
)
if existing_workflow_tool_provider is not None:
raise ValueError(f'Tool with name {name} already exists')
workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == workflow_tool_id
).first()
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
if workflow_tool_provider is None:
raise ValueError(f'Tool {workflow_tool_id} not found')
app: App = db.session.query(App).filter(
App.id == workflow_tool_provider.app_id,
App.tenant_id == tenant_id
).first()
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App = (
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
)
if app is None:
raise ValueError(f'App {workflow_tool_provider.app_id} not found')
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
workflow: Workflow = app.workflow
if workflow is None:
raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}')
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
workflow_tool_provider.name = name
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
@@ -154,13 +177,10 @@ class WorkflowToolManageService:
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider),
labels
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
return {
'result': 'success'
}
return {"result": "success"}
@classmethod
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
@@ -170,9 +190,7 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id
).all()
db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
tools = []
for provider in db_tools:
@@ -188,14 +206,12 @@ class WorkflowToolManageService:
for tool in tools:
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=tool,
labels=labels.get(tool.provider_id, [])
provider_controller=tool, labels=labels.get(tool.provider_id, [])
)
ToolTransformService.repack_provider(user_tool_provider)
user_tool_provider.tools = [
ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0],
labels=labels.get(tool.provider_id, [])
tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, [])
)
]
result.append(user_tool_provider)
@@ -211,15 +227,12 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
"""
db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == workflow_tool_id
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
).delete()
db.session.commit()
return {
'result': 'success'
}
return {"result": "success"}
@classmethod
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
@@ -230,40 +243,37 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == workflow_tool_id
).first()
db_tool: WorkflowToolProvider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
if db_tool is None:
raise ValueError(f'Tool {workflow_tool_id} not found')
workflow_app: App = db.session.query(App).filter(
App.id == db_tool.app_id,
App.tenant_id == tenant_id
).first()
raise ValueError(f"Tool {workflow_tool_id} not found")
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
if workflow_app is None:
raise ValueError(f'App {db_tool.app_id} not found')
raise ValueError(f"App {db_tool.app_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
return {
'name': db_tool.name,
'label': db_tool.label,
'workflow_tool_id': db_tool.id,
'workflow_app_id': db_tool.app_id,
'icon': json.loads(db_tool.icon),
'description': db_tool.description,
'parameters': jsonable_encoder(db_tool.parameter_configurations),
'tool': ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool)
"name": db_tool.name,
"label": db_tool.label,
"workflow_tool_id": db_tool.id,
"workflow_app_id": db_tool.app_id,
"icon": json.loads(db_tool.icon),
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
),
'synced': workflow_app.workflow.version == db_tool.version,
'privacy_policy': db_tool.privacy_policy,
"synced": workflow_app.workflow.version == db_tool.version,
"privacy_policy": db_tool.privacy_policy,
}
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
"""
@@ -273,40 +283,37 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == workflow_app_id
).first()
db_tool: WorkflowToolProvider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
)
if db_tool is None:
raise ValueError(f'Tool {workflow_app_id} not found')
workflow_app: App = db.session.query(App).filter(
App.id == db_tool.app_id,
App.tenant_id == tenant_id
).first()
raise ValueError(f"Tool {workflow_app_id} not found")
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
if workflow_app is None:
raise ValueError(f'App {db_tool.app_id} not found')
raise ValueError(f"App {db_tool.app_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
return {
'name': db_tool.name,
'label': db_tool.label,
'workflow_tool_id': db_tool.id,
'workflow_app_id': db_tool.app_id,
'icon': json.loads(db_tool.icon),
'description': db_tool.description,
'parameters': jsonable_encoder(db_tool.parameter_configurations),
'tool': ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool)
"name": db_tool.name,
"label": db_tool.label,
"workflow_tool_id": db_tool.id,
"workflow_app_id": db_tool.app_id,
"icon": json.loads(db_tool.icon),
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
),
'synced': workflow_app.workflow.version == db_tool.version,
'privacy_policy': db_tool.privacy_policy
"synced": workflow_app.workflow.version == db_tool.version,
"privacy_policy": db_tool.privacy_policy,
}
@classmethod
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
"""
@@ -316,19 +323,19 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the list of tools
"""
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == workflow_tool_id
).first()
db_tool: WorkflowToolProvider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
if db_tool is None:
raise ValueError(f'Tool {workflow_tool_id} not found')
raise ValueError(f"Tool {workflow_tool_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
return [
ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool)
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
)
]
]