Feat/assistant app (#2086)
Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: Pascal M <11357019+perzeuss@users.noreply.github.com>
169
api/core/tools/provider/api_tool_provider.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from typing import Any, Dict, List
|
||||
from core.tools.entities.tool_entities import ToolProviderType, ApiProviderAuthType, ToolProviderCredentials, ToolCredentialsOption
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
class ApiBasedToolProviderController(ToolProviderController):
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
|
||||
credentials_schema = {
|
||||
'auth_type': ToolProviderCredentials(
|
||||
name='auth_type',
|
||||
required=True,
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
options=[
|
||||
ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
|
||||
ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
|
||||
],
|
||||
default='none',
|
||||
help=I18nObject(
|
||||
en_US='The auth type of the api provider',
|
||||
zh_Hans='api provider 的认证类型'
|
||||
)
|
||||
)
|
||||
}
|
||||
if auth_type == ApiProviderAuthType.API_KEY:
|
||||
credentials_schema = {
|
||||
**credentials_schema,
|
||||
'api_key_header': ToolProviderCredentials(
|
||||
name='api_key_header',
|
||||
required=False,
|
||||
default='api_key',
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
help=I18nObject(
|
||||
en_US='The header name of the api key',
|
||||
zh_Hans='携带 api key 的 header 名称'
|
||||
)
|
||||
),
|
||||
'api_key_value': ToolProviderCredentials(
|
||||
name='api_key_value',
|
||||
required=True,
|
||||
type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
|
||||
help=I18nObject(
|
||||
en_US='The api key',
|
||||
zh_Hans='api key的值'
|
||||
)
|
||||
)
|
||||
}
|
||||
elif auth_type == ApiProviderAuthType.NONE:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f'invalid auth type {auth_type}')
|
||||
|
||||
return ApiBasedToolProviderController(**{
|
||||
'identity': {
|
||||
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
|
||||
'name': db_provider.name,
|
||||
'label': {
|
||||
'en_US': db_provider.name,
|
||||
'zh_Hans': db_provider.name
|
||||
},
|
||||
'description': {
|
||||
'en_US': db_provider.description,
|
||||
'zh_Hans': db_provider.description
|
||||
},
|
||||
'icon': db_provider.icon
|
||||
},
|
||||
'credentials_schema': credentials_schema
|
||||
})
|
||||
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API_BASED
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def validate_parameters(self, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool:
|
||||
"""
|
||||
parse tool bundle to tool
|
||||
|
||||
:param tool_bundle: the tool bundle
|
||||
:return: the tool
|
||||
"""
|
||||
return ApiTool(**{
|
||||
'api_bundle': tool_bundle,
|
||||
'identity' : {
|
||||
'author': tool_bundle.author,
|
||||
'name': tool_bundle.operation_id,
|
||||
'label': {
|
||||
'en_US': tool_bundle.operation_id,
|
||||
'zh_Hans': tool_bundle.operation_id
|
||||
},
|
||||
'icon': tool_bundle.icon if tool_bundle.icon else ''
|
||||
},
|
||||
'description': {
|
||||
'human': {
|
||||
'en_US': tool_bundle.summary or '',
|
||||
'zh_Hans': tool_bundle.summary or ''
|
||||
},
|
||||
'llm': tool_bundle.summary or ''
|
||||
},
|
||||
'parameters' : tool_bundle.parameters if tool_bundle.parameters else [],
|
||||
})
|
||||
|
||||
def load_bundled_tools(self, tools: List[ApiBasedToolBundle]) -> List[ApiTool]:
|
||||
"""
|
||||
load bundled tools
|
||||
|
||||
:param tools: the bundled tools
|
||||
:return: the tools
|
||||
"""
|
||||
self.tools = [self._parse_tool_bundle(tool) for tool in tools]
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tools(self, user_id: str, tanent_id: str) -> List[ApiTool]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
:param user_id: the user id
|
||||
:param tanent_id: the tanent id
|
||||
:return: the tools
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
tools: List[Tool] = []
|
||||
|
||||
# get tanent api providers
|
||||
db_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tanent_id,
|
||||
ApiToolProvider.name == self.identity.name
|
||||
).all()
|
||||
|
||||
if db_providers and len(db_providers) != 0:
|
||||
for db_provider in db_providers:
|
||||
for tool in db_provider.tools:
|
||||
assistant_tool = self._parse_tool_bundle(tool)
|
||||
assistant_tool.is_team_authorization = True
|
||||
tools.append(assistant_tool)
|
||||
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> ApiTool:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:return: the tool
|
||||
"""
|
||||
if self.tools is None:
|
||||
self.get_tools()
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
raise ValueError(f'tool {tool_name} not found')
|
||||
116
api/core/tools/provider/app_tool_provider.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from typing import Any, Dict, List
|
||||
from core.tools.entities.tool_entities import ToolProviderType, ToolParamter, ToolParamterOption
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.tools import PublishedAppTool
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AppBasedToolProviderEntity(ToolProviderController):
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.APP_BASED
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def validate_parameters(self, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def get_tools(self, user_id: str) -> List[Tool]:
|
||||
db_tools: List[PublishedAppTool] = db.session.query(PublishedAppTool).filter(
|
||||
PublishedAppTool.user_id == user_id,
|
||||
).all()
|
||||
|
||||
if not db_tools or len(db_tools) == 0:
|
||||
return []
|
||||
|
||||
tools: List[Tool] = []
|
||||
|
||||
for db_tool in db_tools:
|
||||
tool = {
|
||||
'identity': {
|
||||
'author': db_tool.author,
|
||||
'name': db_tool.tool_name,
|
||||
'label': {
|
||||
'en_US': db_tool.tool_name,
|
||||
'zh_Hans': db_tool.tool_name
|
||||
},
|
||||
'icon': ''
|
||||
},
|
||||
'description': {
|
||||
'human': {
|
||||
'en_US': db_tool.description_i18n.en_US,
|
||||
'zh_Hans': db_tool.description_i18n.zh_Hans
|
||||
},
|
||||
'llm': db_tool.llm_description
|
||||
},
|
||||
'parameters': []
|
||||
}
|
||||
# get app from db
|
||||
app: App = db_tool.app
|
||||
|
||||
if not app:
|
||||
logger.error(f"app {db_tool.app_id} not found")
|
||||
continue
|
||||
|
||||
app_model_config: AppModelConfig = app.app_model_config
|
||||
user_input_form_list = app_model_config.user_input_form_list
|
||||
for input_form in user_input_form_list:
|
||||
# get type
|
||||
form_type = input_form.keys()[0]
|
||||
default = input_form[form_type]['default']
|
||||
required = input_form[form_type]['required']
|
||||
label = input_form[form_type]['label']
|
||||
variable_name = input_form[form_type]['variable_name']
|
||||
options = input_form[form_type].get('options', [])
|
||||
if form_type == 'paragraph' or form_type == 'text-input':
|
||||
tool['parameters'].append(ToolParamter(
|
||||
name=variable_name,
|
||||
label=I18nObject(
|
||||
en_US=label,
|
||||
zh_Hans=label
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=label,
|
||||
zh_Hans=label
|
||||
),
|
||||
llm_description=label,
|
||||
form=ToolParamter.ToolParameterForm.FORM,
|
||||
type=ToolParamter.ToolParameterType.STRING,
|
||||
required=required,
|
||||
default=default
|
||||
))
|
||||
elif form_type == 'select':
|
||||
tool['parameters'].append(ToolParamter(
|
||||
name=variable_name,
|
||||
label=I18nObject(
|
||||
en_US=label,
|
||||
zh_Hans=label
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=label,
|
||||
zh_Hans=label
|
||||
),
|
||||
llm_description=label,
|
||||
form=ToolParamter.ToolParameterForm.FORM,
|
||||
type=ToolParamter.ToolParameterType.SELECT,
|
||||
required=required,
|
||||
default=default,
|
||||
options=[ToolParamterOption(
|
||||
value=option,
|
||||
label=I18nObject(
|
||||
en_US=option,
|
||||
zh_Hans=option
|
||||
)
|
||||
) for option in options]
|
||||
))
|
||||
|
||||
tools.append(Tool(**tool))
|
||||
return tools
|
||||
0
api/core/tools/provider/builtin/__init__.py
Normal file
26
api/core/tools/provider/builtin/_positions.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from typing import List
|
||||
|
||||
position = {
|
||||
'google': 1,
|
||||
'wikipedia': 2,
|
||||
'dalle': 3,
|
||||
'webscraper': 4,
|
||||
'wolframalpha': 5,
|
||||
'chart': 6,
|
||||
'time': 7,
|
||||
'yahoo': 8,
|
||||
'stablediffusion': 9,
|
||||
'vectorizer': 10,
|
||||
'youtube': 11,
|
||||
}
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
@staticmethod
|
||||
def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]:
|
||||
def sort_compare(provider: UserToolProvider) -> int:
|
||||
return position.get(provider.name, 10000)
|
||||
|
||||
sorted_providers = sorted(providers, key=sort_compare)
|
||||
|
||||
return sorted_providers
|
||||
BIN
api/core/tools/provider/builtin/chart/_assets/icon.png
Normal file
|
After Width: | Height: | Size: 1.3 KiB |
24
api/core/tools/provider/builtin/chart/chart.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.chart.tools.line import LinearChartTool
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
# use a business theme
|
||||
plt.style.use('seaborn-v0_8-darkgrid')
|
||||
|
||||
class ChartProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
LinearChartTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"data": "1,3,5,7,9,2,4,6,8,10",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
11
api/core/tools/provider/builtin/chart/chart.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: chart
|
||||
label:
|
||||
en_US: ChartGenerator
|
||||
zh_Hans: 图表生成
|
||||
description:
|
||||
en_US: Chart Generator is a tool for generating statistical charts like bar chart, line chart, pie chart, etc.
|
||||
zh_Hans: 图表生成是一个用于生成可视化图表的工具,你可以通过它来生成柱状图、折线图、饼图等各类图表
|
||||
icon: icon.png
|
||||
credentails_for_provider:
|
||||
47
api/core/tools/provider/builtin/chart/tools/bar.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
class BarChartTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
data = tool_paramters.get('data', '')
|
||||
if not data:
|
||||
return self.create_text_message('Please input data')
|
||||
data = data.split(';')
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all([i.isdigit() for i in data]):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
axis = tool_paramters.get('x_axis', None) or None
|
||||
if axis:
|
||||
axis = axis.split(';')
|
||||
if len(axis) != len(data):
|
||||
axis = None
|
||||
|
||||
flg, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
if axis:
|
||||
axis = [label[:10] + '...' if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha='right')
|
||||
ax.bar(axis, data)
|
||||
else:
|
||||
ax.bar(range(len(data)), data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message('the bar chart is saved as an image.'),
|
||||
self.create_blob_message(blob=buf.read(),
|
||||
meta={'mime_type': 'image/png'})
|
||||
]
|
||||
|
||||
35
api/core/tools/provider/builtin/chart/tools/bar.yaml
Normal file
@@ -0,0 +1,35 @@
|
||||
identity:
|
||||
name: bar_chart
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Bar Chart
|
||||
zh_Hans: 柱状图
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Bar chart
|
||||
zh_Hans: 柱状图
|
||||
llm: generate a bar chart with input data
|
||||
parameters:
|
||||
- name: data
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: data
|
||||
zh_Hans: 数据
|
||||
human_description:
|
||||
en_US: data for generating bar chart
|
||||
zh_Hans: 用于生成柱状图的数据
|
||||
llm_description: data for generating bar chart, data should be a string contains a list of numbers like "1;2;3;4;5"
|
||||
form: llm
|
||||
- name: x_axis
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: X Axis
|
||||
zh_Hans: x 轴
|
||||
human_description:
|
||||
en_US: X axis for bar chart
|
||||
zh_Hans: 柱状图的 x 轴
|
||||
llm_description: x axis for bar chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data
|
||||
form: llm
|
||||
49
api/core/tools/provider/builtin/chart/tools/line.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
class LinearChartTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
data = tool_paramters.get('data', '')
|
||||
if not data:
|
||||
return self.create_text_message('Please input data')
|
||||
data = data.split(';')
|
||||
|
||||
axis = tool_paramters.get('x_axis', None) or None
|
||||
if axis:
|
||||
axis = axis.split(';')
|
||||
if len(axis) != len(data):
|
||||
axis = None
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all([i.isdigit() for i in data]):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
flg, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
if axis:
|
||||
axis = [label[:10] + '...' if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha='right')
|
||||
ax.plot(axis, data)
|
||||
else:
|
||||
ax.plot(data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message('the linear chart is saved as an image.'),
|
||||
self.create_blob_message(blob=buf.read(),
|
||||
meta={'mime_type': 'image/png'})
|
||||
]
|
||||
|
||||
35
api/core/tools/provider/builtin/chart/tools/line.yaml
Normal file
@@ -0,0 +1,35 @@
|
||||
identity:
|
||||
name: line_chart
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Linear Chart
|
||||
zh_Hans: 线性图表
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: linear chart
|
||||
zh_Hans: 线性图表
|
||||
llm: generate a linear chart with input data
|
||||
parameters:
|
||||
- name: data
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: data
|
||||
zh_Hans: 数据
|
||||
human_description:
|
||||
en_US: data for generating linear chart
|
||||
zh_Hans: 用于生成线性图表的数据
|
||||
llm_description: data for generating linear chart, data should be a string contains a list of numbers like "1;2;3;4;5"
|
||||
form: llm
|
||||
- name: x_axis
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: X Axis
|
||||
zh_Hans: x 轴
|
||||
human_description:
|
||||
en_US: X axis for linear chart
|
||||
zh_Hans: 线性图表的 x 轴
|
||||
llm_description: x axis for linear chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data
|
||||
form: llm
|
||||
46
api/core/tools/provider/builtin/chart/tools/pie.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
class PieChartTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
data = tool_paramters.get('data', '')
|
||||
if not data:
|
||||
return self.create_text_message('Please input data')
|
||||
data = data.split(';')
|
||||
categories = tool_paramters.get('categories', None) or None
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all([i.isdigit() for i in data]):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
flg, ax = plt.subplots()
|
||||
|
||||
if categories:
|
||||
categories = categories.split(';')
|
||||
if len(categories) != len(data):
|
||||
categories = None
|
||||
|
||||
if categories:
|
||||
ax.pie(data, labels=categories)
|
||||
else:
|
||||
ax.pie(data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message('the pie chart is saved as an image.'),
|
||||
self.create_blob_message(blob=buf.read(),
|
||||
meta={'mime_type': 'image/png'})
|
||||
]
|
||||
35
api/core/tools/provider/builtin/chart/tools/pie.yaml
Normal file
@@ -0,0 +1,35 @@
|
||||
identity:
|
||||
name: pie_chart
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Pie Chart
|
||||
zh_Hans: 饼图
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Pie chart
|
||||
zh_Hans: 饼图
|
||||
llm: generate a pie chart with input data
|
||||
parameters:
|
||||
- name: data
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: data
|
||||
zh_Hans: 数据
|
||||
human_description:
|
||||
en_US: data for generating pie chart
|
||||
zh_Hans: 用于生成饼图的数据
|
||||
llm_description: data for generating pie chart, data should be a string contains a list of numbers like "1;2;3;4;5"
|
||||
form: llm
|
||||
- name: categories
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Categories
|
||||
zh_Hans: 分类
|
||||
human_description:
|
||||
en_US: Categories for pie chart
|
||||
zh_Hans: 饼图的分类
|
||||
llm_description: categories for pie chart, categories should be a string contains a list of texts like "a;b;c;1;2" in order to match the data, each category should be split by ";"
|
||||
form: llm
|
||||
0
api/core/tools/provider/builtin/dalle/__init__.py
Normal file
BIN
api/core/tools/provider/builtin/dalle/_assets/icon.png
Normal file
|
After Width: | Height: | Size: 153 KiB |
23
api/core/tools/provider/builtin/dalle/dalle.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.provider.builtin.dalle.tools.dalle2 import DallE2Tool
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
class DALLEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
DallE2Tool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"prompt": "cute girl, blue eyes, white hair, anime style",
|
||||
"size": "small",
|
||||
"n": 1
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
47
api/core/tools/provider/builtin/dalle/dalle.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: dalle
|
||||
label:
|
||||
en_US: DALL-E
|
||||
zh_Hans: DALL-E 绘画
|
||||
description:
|
||||
en_US: DALL-E art
|
||||
zh_Hans: DALL-E 绘画
|
||||
icon: icon.png
|
||||
credentails_for_provider:
|
||||
openai_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: OpenAI API key
|
||||
zh_Hans: OpenAI API key
|
||||
help:
|
||||
en_US: Please input your OpenAI API key
|
||||
zh_Hans: 请输入你的 OpenAI API key
|
||||
placeholder:
|
||||
en_US: Please input your OpenAI API key
|
||||
zh_Hans: 请输入你的 OpenAI API key
|
||||
openai_organizaion_id:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: OpenAI organization ID
|
||||
zh_Hans: OpenAI organization ID
|
||||
help:
|
||||
en_US: Please input your OpenAI organization ID
|
||||
zh_Hans: 请输入你的 OpenAI organization ID
|
||||
placeholder:
|
||||
en_US: Please input your OpenAI organization ID
|
||||
zh_Hans: 请输入你的 OpenAI organization ID
|
||||
openai_base_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: OpenAI base URL
|
||||
zh_Hans: OpenAI base URL
|
||||
help:
|
||||
en_US: Please input your OpenAI base URL
|
||||
zh_Hans: 请输入你的 OpenAI base URL
|
||||
placeholder:
|
||||
en_US: Please input your OpenAI base URL
|
||||
zh_Hans: 请输入你的 OpenAI base URL
|
||||
66
api/core/tools/provider/builtin/dalle/tools/dalle2.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
from base64 import b64decode
|
||||
from os.path import join
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
class DallE2Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
openai_organization = self.runtime.credentials.get('openai_organizaion_id', None)
|
||||
if not openai_organization:
|
||||
openai_organization = None
|
||||
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
|
||||
if not openai_base_url:
|
||||
openai_base_url = None
|
||||
else:
|
||||
openai_base_url = join(openai_base_url, 'v1')
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['openai_api_key'],
|
||||
base_url=openai_base_url,
|
||||
organization=openai_organization
|
||||
)
|
||||
|
||||
SIZE_MAPPING = {
|
||||
'small': '256x256',
|
||||
'medium': '512x512',
|
||||
'large': '1024x1024',
|
||||
}
|
||||
|
||||
# prompt
|
||||
prompt = tool_paramters.get('prompt', '')
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
|
||||
# get size
|
||||
size = SIZE_MAPPING[tool_paramters.get('size', 'large')]
|
||||
|
||||
# get n
|
||||
n = tool_paramters.get('n', 1)
|
||||
|
||||
# call openapi dalle2
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
model='dall-e-2',
|
||||
size=size,
|
||||
n=n,
|
||||
response_format='b64_json'
|
||||
)
|
||||
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
|
||||
return result
|
||||
63
api/core/tools/provider/builtin/dalle/tools/dalle2.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
identity:
|
||||
name: dalle2
|
||||
author: Dify
|
||||
label:
|
||||
en_US: DALL-E 2
|
||||
zh_Hans: DALL-E 2 绘画
|
||||
description:
|
||||
en_US: DALL-E 2 is a powerful drawing tool that can draw the image you want based on your prompt
|
||||
zh_Hans: DALL-E 2 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像
|
||||
description:
|
||||
human:
|
||||
en_US: DALL-E is a text to image tool
|
||||
zh_Hans: DALL-E 是一个文本到图像的工具
|
||||
llm: DALL-E is a tool used to generate images from text
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of DallE 2
|
||||
zh_Hans: 图像提示词,您可以查看DallE 2 的官方文档
|
||||
llm_description: Image prompt of DallE 2, you should describe the image you want to generate as a list of words as possible as detailed
|
||||
form: llm
|
||||
- name: size
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: used for selecting the image size
|
||||
zh_Hans: 用于选择图像大小
|
||||
label:
|
||||
en_US: Image size
|
||||
zh_Hans: 图像大小
|
||||
form: form
|
||||
options:
|
||||
- value: small
|
||||
label:
|
||||
en_US: Small(256x256)
|
||||
zh_Hans: 小(256x256)
|
||||
- value: medium
|
||||
label:
|
||||
en_US: Medium(512x512)
|
||||
zh_Hans: 中(512x512)
|
||||
- value: large
|
||||
label:
|
||||
en_US: Large(1024x1024)
|
||||
zh_Hans: 大(1024x1024)
|
||||
default: large
|
||||
- name: n
|
||||
type: number
|
||||
required: true
|
||||
human_description:
|
||||
en_US: used for selecting the number of images
|
||||
zh_Hans: 用于选择图像数量
|
||||
label:
|
||||
en_US: Number of images
|
||||
zh_Hans: 图像数量
|
||||
form: form
|
||||
default: 1
|
||||
min: 1
|
||||
max: 10
|
||||
74
api/core/tools/provider/builtin/dalle/tools/dalle3.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
from base64 import b64decode
|
||||
from os.path import join
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
openai_organization = self.runtime.credentials.get('openai_organizaion_id', None)
|
||||
if not openai_organization:
|
||||
openai_organization = None
|
||||
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
|
||||
if not openai_base_url:
|
||||
openai_base_url = None
|
||||
else:
|
||||
openai_base_url = join(openai_base_url, 'v1')
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['openai_api_key'],
|
||||
base_url=openai_base_url,
|
||||
organization=openai_organization
|
||||
)
|
||||
|
||||
SIZE_MAPPING = {
|
||||
'square': '1024x1024',
|
||||
'vertical': '1024x1792',
|
||||
'horizontal': '1792x1024',
|
||||
}
|
||||
|
||||
# prompt
|
||||
prompt = tool_paramters.get('prompt', '')
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
# get size
|
||||
size = SIZE_MAPPING[tool_paramters.get('size', 'square')]
|
||||
# get n
|
||||
n = tool_paramters.get('n', 1)
|
||||
# get quality
|
||||
quality = tool_paramters.get('quality', 'standard')
|
||||
if quality not in ['standard', 'hd']:
|
||||
return self.create_text_message('Invalid quality')
|
||||
# get style
|
||||
style = tool_paramters.get('style', 'vivid')
|
||||
if style not in ['natural', 'vivid']:
|
||||
return self.create_text_message('Invalid style')
|
||||
|
||||
# call openapi dalle3
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
model='dall-e-3',
|
||||
size=size,
|
||||
n=n,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format='b64_json'
|
||||
)
|
||||
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
|
||||
return result
|
||||
103
api/core/tools/provider/builtin/dalle/tools/dalle3.yaml
Normal file
@@ -0,0 +1,103 @@
|
||||
identity:
|
||||
name: dalle3
|
||||
author: Dify
|
||||
label:
|
||||
en_US: DALL-E 3
|
||||
zh_Hans: DALL-E 3 绘画
|
||||
description:
|
||||
en_US: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources
|
||||
zh_Hans: DALL-E 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像,相比于DallE 2, DallE 3拥有更强的绘画能力,但会消耗更多的资源
|
||||
description:
|
||||
human:
|
||||
en_US: DALL-E is a text to image tool
|
||||
zh_Hans: DALL-E 是一个文本到图像的工具
|
||||
llm: DALL-E is a tool used to generate images from text
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of DallE 3
|
||||
zh_Hans: 图像提示词,您可以查看DallE 3 的官方文档
|
||||
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
|
||||
form: llm
|
||||
- name: size
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image size
|
||||
zh_Hans: 选择图像大小
|
||||
label:
|
||||
en_US: Image size
|
||||
zh_Hans: 图像大小
|
||||
form: form
|
||||
options:
|
||||
- value: square
|
||||
label:
|
||||
en_US: Squre(1024x1024)
|
||||
zh_Hans: 方(1024x1024)
|
||||
- value: vertical
|
||||
label:
|
||||
en_US: Vertical(1024x1792)
|
||||
zh_Hans: 竖屏(1024x1792)
|
||||
- value: horizontal
|
||||
label:
|
||||
en_US: Horizontal(1792x1024)
|
||||
zh_Hans: 横屏(1792x1024)
|
||||
default: square
|
||||
- name: n
|
||||
type: number
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the number of images
|
||||
zh_Hans: 选择图像数量
|
||||
label:
|
||||
en_US: Number of images
|
||||
zh_Hans: 图像数量
|
||||
form: form
|
||||
min: 1
|
||||
max: 1
|
||||
default: 1
|
||||
- name: quality
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image quality
|
||||
zh_Hans: 选择图像质量
|
||||
label:
|
||||
en_US: Image quality
|
||||
zh_Hans: 图像质量
|
||||
form: form
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
- value: hd
|
||||
label:
|
||||
en_US: HD
|
||||
zh_Hans: 高清
|
||||
default: standard
|
||||
- name: style
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image style
|
||||
zh_Hans: 选择图像风格
|
||||
label:
|
||||
en_US: Image style
|
||||
zh_Hans: 图像风格
|
||||
form: form
|
||||
options:
|
||||
- value: vivid
|
||||
label:
|
||||
en_US: Vivid
|
||||
zh_Hans: 生动
|
||||
- value: natural
|
||||
label:
|
||||
en_US: Natural
|
||||
zh_Hans: 自然
|
||||
default: vivid
|
||||
6
api/core/tools/provider/builtin/google/_assets/icon.svg
Normal file
@@ -0,0 +1,6 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="25" viewBox="0 0 24 25" fill="none">
|
||||
<path d="M22.501 12.7332C22.501 11.8699 22.4296 11.2399 22.2748 10.5865H12.2153V14.4832H18.12C18.001 15.4515 17.3582 16.9099 15.9296 17.8898L15.9096 18.0203L19.0902 20.435L19.3106 20.4565C21.3343 18.6249 22.501 15.9298 22.501 12.7332Z" fill="#4285F4"/>
|
||||
<path d="M12.214 23C15.1068 23 17.5353 22.0666 19.3092 20.4567L15.9282 17.8899C15.0235 18.5083 13.8092 18.9399 12.214 18.9399C9.38069 18.9399 6.97596 17.1083 6.11874 14.5766L5.99309 14.5871L2.68583 17.0954L2.64258 17.2132C4.40446 20.6433 8.0235 23 12.214 23Z" fill="#34A853"/>
|
||||
<path d="M6.12046 14.5766C5.89428 13.9233 5.76337 13.2233 5.76337 12.5C5.76337 11.7766 5.89428 11.0766 6.10856 10.4233L6.10257 10.2841L2.75386 7.7355L2.64429 7.78658C1.91814 9.20993 1.50146 10.8083 1.50146 12.5C1.50146 14.1916 1.91814 15.7899 2.64429 17.2132L6.12046 14.5766Z" fill="#FBBC05"/>
|
||||
<path d="M12.2141 6.05997C14.2259 6.05997 15.583 6.91163 16.3569 7.62335L19.3807 4.73C17.5236 3.03834 15.1069 2 12.2141 2C8.02353 2 4.40447 4.35665 2.64258 7.78662L6.10686 10.4233C6.97598 7.89166 9.38073 6.05997 12.2141 6.05997Z" fill="#EB4335"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.2 KiB |
23
api/core/tools/provider/builtin/google/google.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
class GoogleProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
GoogleSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"query": "test",
|
||||
"result_type": "link"
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
24
api/core/tools/provider/builtin/google/google.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: google
|
||||
label:
|
||||
en_US: Google
|
||||
zh_Hans: Google
|
||||
description:
|
||||
en_US: Google
|
||||
zh_Hans: GoogleSearch
|
||||
icon: icon.svg
|
||||
credentails_for_provider:
|
||||
serpapi_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: SerpApi API key
|
||||
zh_Hans: SerpApi API key
|
||||
placeholder:
|
||||
en_US: Please input your SerpApi API key
|
||||
zh_Hans: 请输入你的 SerpApi API key
|
||||
help:
|
||||
en_US: Get your SerpApi API key from SerpApi
|
||||
zh_Hans: 从 SerpApi 获取您的 SerpApi API key
|
||||
url: https://serpapi.com/manage-api-key
|
||||
163
api/core/tools/provider/builtin/google/tools/google_search.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from serpapi import GoogleSearch
|
||||
|
||||
class HiddenPrints:
|
||||
"""Context manager to hide prints."""
|
||||
|
||||
def __enter__(self) -> None:
|
||||
"""Open file to pipe stdout to."""
|
||||
self._original_stdout = sys.stdout
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
def __exit__(self, *_: Any) -> None:
|
||||
"""Close file that stdout was piped to."""
|
||||
sys.stdout.close()
|
||||
sys.stdout = self._original_stdout
|
||||
|
||||
|
||||
class SerpAPI:
|
||||
"""
|
||||
SerpAPI tool provider.
|
||||
"""
|
||||
|
||||
search_engine: Any #: :meta private:
|
||||
serpapi_api_key: str = None
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
"""Initialize SerpAPI tool provider."""
|
||||
self.serpapi_api_key = api_key
|
||||
self.search_engine = GoogleSearch
|
||||
|
||||
def run(self, query: str, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result."""
|
||||
typ = kwargs.get("result_type", "text")
|
||||
return self._process_response(self.results(query), typ=typ)
|
||||
|
||||
def results(self, query: str) -> dict:
|
||||
"""Run query through SerpAPI and return the raw result."""
|
||||
params = self.get_params(query)
|
||||
with HiddenPrints():
|
||||
search = self.search_engine(params)
|
||||
res = search.get_dict()
|
||||
return res
|
||||
|
||||
def get_params(self, query: str) -> Dict[str, str]:
|
||||
"""Get parameters for SerpAPI."""
|
||||
_params = {
|
||||
"api_key": self.serpapi_api_key,
|
||||
"q": query,
|
||||
}
|
||||
params = {
|
||||
"engine": "google",
|
||||
"google_domain": "google.com",
|
||||
"gl": "us",
|
||||
"hl": "en",
|
||||
**_params
|
||||
}
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict, typ: str) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
if "error" in res.keys():
|
||||
raise ValueError(f"Got error from SerpAPI: {res['error']}")
|
||||
|
||||
if typ == "text":
|
||||
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
|
||||
res["answer_box"] = res["answer_box"][0]
|
||||
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
elif (
|
||||
"answer_box" in res.keys()
|
||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||
):
|
||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||
elif (
|
||||
"sports_results" in res.keys()
|
||||
and "game_spotlight" in res["sports_results"].keys()
|
||||
):
|
||||
toret = res["sports_results"]["game_spotlight"]
|
||||
elif (
|
||||
"shopping_results" in res.keys()
|
||||
and "title" in res["shopping_results"][0].keys()
|
||||
):
|
||||
toret = res["shopping_results"][:3]
|
||||
elif (
|
||||
"knowledge_graph" in res.keys()
|
||||
and "description" in res["knowledge_graph"].keys()
|
||||
):
|
||||
toret = res["knowledge_graph"]["description"]
|
||||
elif "snippet" in res["organic_results"][0].keys():
|
||||
toret = res["organic_results"][0]["snippet"]
|
||||
elif "link" in res["organic_results"][0].keys():
|
||||
toret = res["organic_results"][0]["link"]
|
||||
elif (
|
||||
"images_results" in res.keys()
|
||||
and "thumbnail" in res["images_results"][0].keys()
|
||||
):
|
||||
thumbnails = [item["thumbnail"] for item in res["images_results"][:10]]
|
||||
toret = thumbnails
|
||||
else:
|
||||
toret = "No good search result found"
|
||||
elif typ == "link":
|
||||
if "knowledge_graph" in res.keys() and "title" in res["knowledge_graph"].keys() \
|
||||
and "description_link" in res["knowledge_graph"].keys():
|
||||
toret = res["knowledge_graph"]["description_link"]
|
||||
elif "knowledge_graph" in res.keys() and "see_results_about" in res["knowledge_graph"].keys() \
|
||||
and len(res["knowledge_graph"]["see_results_about"]) > 0:
|
||||
see_result_about = res["knowledge_graph"]["see_results_about"]
|
||||
toret = ""
|
||||
for item in see_result_about:
|
||||
if "name" not in item.keys() or "link" not in item.keys():
|
||||
continue
|
||||
toret += f"[{item['name']}]({item['link']})\n"
|
||||
elif "organic_results" in res.keys() and len(res["organic_results"]) > 0:
|
||||
organic_results = res["organic_results"]
|
||||
toret = ""
|
||||
for item in organic_results:
|
||||
if "title" not in item.keys() or "link" not in item.keys():
|
||||
continue
|
||||
toret += f"[{item['title']}]({item['link']})\n"
|
||||
elif "related_questions" in res.keys() and len(res["related_questions"]) > 0:
|
||||
related_questions = res["related_questions"]
|
||||
toret = ""
|
||||
for item in related_questions:
|
||||
if "question" not in item.keys() or "link" not in item.keys():
|
||||
continue
|
||||
toret += f"[{item['question']}]({item['link']})\n"
|
||||
elif "related_searches" in res.keys() and len(res["related_searches"]) > 0:
|
||||
related_searches = res["related_searches"]
|
||||
toret = ""
|
||||
for item in related_searches:
|
||||
if "query" not in item.keys() or "link" not in item.keys():
|
||||
continue
|
||||
toret += f"[{item['query']}]({item['link']})\n"
|
||||
else:
|
||||
toret = "No good search result found"
|
||||
return toret
|
||||
|
||||
class GoogleSearchTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
query = tool_paramters['query']
|
||||
result_type = tool_paramters['result_type']
|
||||
api_key = self.runtime.credentials['serpapi_api_key']
|
||||
result = SerpAPI(api_key).run(query, result_type=result_type)
|
||||
if result_type == 'text':
|
||||
return self.create_text_message(text=result)
|
||||
return self.create_link_message(link=result)
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
identity:
|
||||
name: google_search
|
||||
author: Dify
|
||||
label:
|
||||
en_US: GoogleSearch
|
||||
zh_Hans: 谷歌搜索
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
|
||||
llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
human_description:
|
||||
en_US: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
llm_description: key words for searching
|
||||
form: llm
|
||||
- name: result_type
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: text
|
||||
label:
|
||||
en_US: text
|
||||
zh_Hans: 文本
|
||||
- value: link
|
||||
label:
|
||||
en_US: link
|
||||
zh_Hans: 链接
|
||||
default: link
|
||||
label:
|
||||
en_US: Result type
|
||||
zh_Hans: 结果类型
|
||||
human_description:
|
||||
en_US: used for selecting the result type, text or link
|
||||
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
|
||||
form: form
|
||||
BIN
api/core/tools/provider/builtin/stablediffusion/_assets/icon.png
Normal file
|
After Width: | Height: | Size: 16 KiB |
@@ -0,0 +1,26 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import StableDiffusionTool
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
class StableDiffusionProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
StableDiffusionTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"prompt": "cat",
|
||||
"lora": "",
|
||||
"steps": 1,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@@ -0,0 +1,29 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: stablediffusion
|
||||
label:
|
||||
en_US: Stable Diffusion
|
||||
zh_Hans: Stable Diffusion
|
||||
description:
|
||||
en_US: Stable Diffusion is a tool for generating images which can be deployed locally.
|
||||
zh_Hans: Stable Diffusion 是一个可以在本地部署的图片生成的工具。
|
||||
icon: icon.png
|
||||
credentails_for_provider:
|
||||
base_url:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_Hans: StableDiffusion服务器的Base URL
|
||||
placeholder:
|
||||
en_US: Please input your StableDiffusion server's Base URL
|
||||
zh_Hans: 请输入你的 StableDiffusion 服务器的 Base URL
|
||||
model:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Model
|
||||
zh_Hans: 模型
|
||||
placeholder:
|
||||
en_US: Please input your model
|
||||
zh_Hans: 请输入你的模型名称
|
||||
@@ -0,0 +1,244 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter, ToolParamterOption
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from httpx import post
|
||||
from os.path import join
|
||||
from base64 import b64decode, b64encode
|
||||
from PIL import Image
|
||||
|
||||
import json
|
||||
import io
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
DRAW_TEXT_OPTIONS = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "",
|
||||
"seed": -1,
|
||||
"subseed": -1,
|
||||
"subseed_strength": 0,
|
||||
"seed_resize_from_h": -1,
|
||||
'sampler_index': 'DPM++ SDE Karras',
|
||||
"seed_resize_from_w": -1,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 10,
|
||||
"cfg_scale": 7,
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"restore_faces": False,
|
||||
"do_not_save_samples": False,
|
||||
"do_not_save_grid": False,
|
||||
"eta": 0,
|
||||
"denoising_strength": 0,
|
||||
"s_min_uncond": 0,
|
||||
"s_churn": 0,
|
||||
"s_tmax": 0,
|
||||
"s_tmin": 0,
|
||||
"s_noise": 0,
|
||||
"override_settings": {},
|
||||
"override_settings_restore_afterwards": True,
|
||||
"refiner_switch_at": 0,
|
||||
"disable_extra_networks": False,
|
||||
"comments": {},
|
||||
"enable_hr": False,
|
||||
"firstphase_width": 0,
|
||||
"firstphase_height": 0,
|
||||
"hr_scale": 2,
|
||||
"hr_second_pass_steps": 0,
|
||||
"hr_resize_x": 0,
|
||||
"hr_resize_y": 0,
|
||||
"hr_prompt": "",
|
||||
"hr_negative_prompt": "",
|
||||
"script_args": [],
|
||||
"send_images": True,
|
||||
"save_images": False,
|
||||
"alwayson_scripts": {}
|
||||
}
|
||||
|
||||
class StableDiffusionTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
# base url
|
||||
base_url = self.runtime.credentials.get('base_url', None)
|
||||
if not base_url:
|
||||
return self.create_text_message('Please input base_url')
|
||||
model = self.runtime.credentials.get('model', None)
|
||||
if not model:
|
||||
return self.create_text_message('Please input model')
|
||||
|
||||
# set model
|
||||
try:
|
||||
url = join(base_url, 'sdapi/v1/options')
|
||||
response = post(url, data=json.dumps({
|
||||
'sd_model_checkpoint': model
|
||||
}))
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model')
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model')
|
||||
|
||||
|
||||
# prompt
|
||||
prompt = tool_paramters.get('prompt', '')
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
|
||||
# get negative prompt
|
||||
negative_prompt = tool_paramters.get('negative_prompt', '')
|
||||
|
||||
# get size
|
||||
width = tool_paramters.get('width', 1024)
|
||||
height = tool_paramters.get('height', 1024)
|
||||
|
||||
# get steps
|
||||
steps = tool_paramters.get('steps', 1)
|
||||
|
||||
# get lora
|
||||
lora = tool_paramters.get('lora', '')
|
||||
|
||||
# get image id
|
||||
image_id = tool_paramters.get('image_id', '')
|
||||
if image_id.strip():
|
||||
image_variable = self.get_default_image_variable()
|
||||
if image_variable:
|
||||
image_binary = self.get_variable_file(image_variable.name)
|
||||
if not image_binary:
|
||||
return self.create_text_message('Image not found, please request user to generate image firstly.')
|
||||
|
||||
# convert image to RGB
|
||||
image = Image.open(io.BytesIO(image_binary))
|
||||
image = image.convert("RGB")
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
image_binary = buffer.getvalue()
|
||||
image.close()
|
||||
|
||||
return self.img2img(base_url=base_url,
|
||||
lora=lora,
|
||||
image_binary=image_binary,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps)
|
||||
|
||||
return self.text2img(base_url=base_url,
|
||||
lora=lora,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps)
|
||||
|
||||
def img2img(self, base_url: str, lora: str, image_binary: bytes,
|
||||
prompt: str, negative_prompt: str,
|
||||
width: int, height: int, steps: int) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
generate image
|
||||
"""
|
||||
draw_options = {
|
||||
"init_images": [b64encode(image_binary).decode('utf-8')],
|
||||
"prompt": "",
|
||||
"negative_prompt": negative_prompt,
|
||||
"denoising_strength": 0.9,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"cfg_scale": 7,
|
||||
"sampler_name": "Euler a",
|
||||
"restore_faces": False,
|
||||
"steps": steps,
|
||||
"script_args": ["outpainting mk2"]
|
||||
}
|
||||
|
||||
if lora:
|
||||
draw_options['prompt'] = f'{lora},{prompt}'
|
||||
else:
|
||||
draw_options['prompt'] = prompt
|
||||
|
||||
try:
|
||||
url = join(base_url, 'sdapi/v1/img2img')
|
||||
response = post(url, data=json.dumps(draw_options), timeout=120)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message('Failed to generate image')
|
||||
|
||||
image = response.json()['images'][0]
|
||||
|
||||
return self.create_blob_message(blob=b64decode(image),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message('Failed to generate image')
|
||||
|
||||
def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
generate image
|
||||
"""
|
||||
# copy draw options
|
||||
draw_options = deepcopy(DRAW_TEXT_OPTIONS)
|
||||
|
||||
if lora:
|
||||
draw_options['prompt'] = f'{lora},{prompt}'
|
||||
|
||||
draw_options['width'] = width
|
||||
draw_options['height'] = height
|
||||
draw_options['steps'] = steps
|
||||
draw_options['negative_prompt'] = negative_prompt
|
||||
|
||||
try:
|
||||
url = join(base_url, 'sdapi/v1/txt2img')
|
||||
response = post(url, data=json.dumps(draw_options), timeout=120)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message('Failed to generate image')
|
||||
|
||||
image = response.json()['images'][0]
|
||||
|
||||
return self.create_blob_message(blob=b64decode(image),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message('Failed to generate image')
|
||||
|
||||
|
||||
def get_runtime_parameters(self) -> List[ToolParamter]:
|
||||
parameters = [
|
||||
ToolParamter(name='prompt',
|
||||
label=I18nObject(en_US='Prompt', zh_Hans='Prompt'),
|
||||
human_description=I18nObject(
|
||||
en_US='Image prompt, you can check the official documentation of Stable Diffusion',
|
||||
zh_Hans='图像提示词,您可以查看 Stable Diffusion 的官方文档',
|
||||
),
|
||||
type=ToolParamter.ToolParameterType.STRING,
|
||||
form=ToolParamter.ToolParameterForm.LLM,
|
||||
llm_description='Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.',
|
||||
required=True),
|
||||
]
|
||||
if len(self.list_default_image_variables()) != 0:
|
||||
parameters.append(
|
||||
ToolParamter(name='image_id',
|
||||
label=I18nObject(en_US='image_id', zh_Hans='image_id'),
|
||||
human_description=I18nObject(
|
||||
en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.',
|
||||
zh_Hans='您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。',
|
||||
),
|
||||
type=ToolParamter.ToolParameterType.STRING,
|
||||
form=ToolParamter.ToolParameterForm.LLM,
|
||||
llm_description='Image id of the original image, you can leave this field empty if you want to generate a new image.',
|
||||
required=True,
|
||||
options=[ToolParamterOption(
|
||||
value=i.name,
|
||||
label=I18nObject(en_US=i.name, zh_Hans=i.name)
|
||||
) for i in self.list_default_image_variables()])
|
||||
)
|
||||
|
||||
return parameters
|
||||
@@ -0,0 +1,77 @@
|
||||
identity:
|
||||
name: stable_diffusion
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Stable Diffusion WebUI
|
||||
zh_Hans: Stable Diffusion WebUI
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generating images which can be deployed locally, you can use stable-diffusion-webui to deploy it.
|
||||
zh_Hans: 一个可以在本地部署的图片生成的工具,您可以使用 stable-diffusion-webui 来部署它。
|
||||
llm: draw the image you want based on your prompt.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of Stable Diffusion
|
||||
zh_Hans: 图像提示词,您可以查看 Stable Diffusion 的官方文档
|
||||
llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.
|
||||
form: llm
|
||||
- name: lora
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora
|
||||
zh_Hans: Lora
|
||||
human_description:
|
||||
en_US: Lora
|
||||
zh_Hans: Lora
|
||||
form: form
|
||||
- name: steps
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Steps
|
||||
zh_Hans: Steps
|
||||
human_description:
|
||||
en_US: Steps
|
||||
zh_Hans: Steps
|
||||
form: form
|
||||
default: 10
|
||||
- name: width
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Width
|
||||
zh_Hans: Width
|
||||
human_description:
|
||||
en_US: Width
|
||||
zh_Hans: Width
|
||||
form: form
|
||||
default: 1024
|
||||
- name: height
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Height
|
||||
zh_Hans: Height
|
||||
human_description:
|
||||
en_US: Height
|
||||
zh_Hans: Height
|
||||
form: form
|
||||
default: 1024
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Negative prompt
|
||||
zh_Hans: Negative prompt
|
||||
human_description:
|
||||
en_US: Negative prompt
|
||||
zh_Hans: Negative prompt
|
||||
form: form
|
||||
default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines
|
||||
3
api/core/tools/provider/builtin/time/_assets/icon.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M0.666992 8.00008C0.666992 3.94999 3.95024 0.666748 8.00033 0.666748C12.0504 0.666748 15.3337 3.94999 15.3337 8.00008C15.3337 12.0502 12.0504 15.3334 8.00033 15.3334C3.95024 15.3334 0.666992 12.0502 0.666992 8.00008ZM8.66699 4.00008C8.66699 3.63189 8.36852 3.33341 8.00033 3.33341C7.63213 3.33341 7.33366 3.63189 7.33366 4.00008V8.00008C7.33366 8.2526 7.47633 8.48344 7.70218 8.59637L10.3688 9.9297C10.6982 10.0944 11.0986 9.96088 11.2633 9.63156C11.4279 9.30224 11.2945 8.90179 10.9651 8.73713L8.66699 7.58806V4.00008Z" fill="#EC4A0A"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 691 B |
16
api/core/tools/provider/builtin/time/time.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.time.tools.current_time import CurrentTimeTool
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
CurrentTimeTool().invoke(
|
||||
user_id='',
|
||||
tool_paramters={},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
11
api/core/tools/provider/builtin/time/time.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: time
|
||||
label:
|
||||
en_US: CurrentTime
|
||||
zh_Hans: 时间
|
||||
description:
|
||||
en_US: A tool for getting the current time.
|
||||
zh_Hans: 一个用于获取当前时间的工具。
|
||||
icon: icon.svg
|
||||
credentails_for_provider:
|
||||
17
api/core/tools/provider/builtin/time/tools/current_time.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
class CurrentTimeTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
return self.create_text_message(f'{datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")}')
|
||||
|
||||
12
api/core/tools/provider/builtin/time/tools/current_time.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
identity:
|
||||
name: current_time
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Current Time
|
||||
zh_Hans: 获取当前时间
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for getting the current time.
|
||||
zh_Hans: 一个用于获取当前时间的工具。
|
||||
llm: A tool for getting the current time.
|
||||
parameters:
|
||||
BIN
api/core/tools/provider/builtin/vectorizer/_assets/icon.png
Normal file
|
After Width: | Height: | Size: 1.8 KiB |
@@ -0,0 +1 @@
|
||||
VECTORIZER_ICON_PNG = 'iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC'
|
||||
@@ -0,0 +1,74 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter
|
||||
from core.tools.provider.builtin.vectorizer.tools.test_data import VECTORIZER_ICON_PNG
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from httpx import post
|
||||
from base64 import b64decode
|
||||
|
||||
class VectorizerTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key_name = self.runtime.credentials.get('api_key_name', None)
|
||||
api_key_value = self.runtime.credentials.get('api_key_value', None)
|
||||
mode = tool_paramters.get('mode', 'test')
|
||||
if mode == 'production':
|
||||
mode = 'preview'
|
||||
|
||||
if not api_key_name or not api_key_value:
|
||||
raise ToolProviderCredentialValidationError('Please input api key name and value')
|
||||
|
||||
image_id = tool_paramters.get('image_id', '')
|
||||
if not image_id:
|
||||
return self.create_text_message('Please input image id')
|
||||
|
||||
if image_id.startswith('__test_'):
|
||||
image_binary = b64decode(VECTORIZER_ICON_PNG)
|
||||
else:
|
||||
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
|
||||
if not image_binary:
|
||||
return self.create_text_message('Image not found, please request user to generate image firstly.')
|
||||
|
||||
response = post(
|
||||
'https://vectorizer.ai/api/v1/vectorize',
|
||||
files={
|
||||
'image': image_binary
|
||||
},
|
||||
data={
|
||||
'mode': mode
|
||||
} if mode == 'test' else {},
|
||||
auth=(api_key_name, api_key_value),
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
return [
|
||||
self.create_text_message('the vectorized svg is saved as an image.'),
|
||||
self.create_blob_message(blob=response.content,
|
||||
meta={'mime_type': 'image/svg+xml'})
|
||||
]
|
||||
|
||||
def get_runtime_parameters(self) -> List[ToolParamter]:
|
||||
"""
|
||||
override the runtime parameters
|
||||
"""
|
||||
return [
|
||||
ToolParamter.get_simple_instance(
|
||||
name='image_id',
|
||||
llm_description=f'the image id that you want to vectorize, \
|
||||
and the image id should be specified in \
|
||||
{[i.name for i in self.list_default_image_variables()]}',
|
||||
type=ToolParamter.ToolParameterType.SELECT,
|
||||
required=True,
|
||||
options=[i.name for i in self.list_default_image_variables()]
|
||||
)
|
||||
]
|
||||
|
||||
def is_tool_avaliable(self) -> bool:
|
||||
return len(self.list_default_image_variables()) > 0
|
||||
@@ -0,0 +1,32 @@
|
||||
identity:
|
||||
name: vectorizer
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Vectorizer.AI
|
||||
zh_Hans: Vectorizer.AI
|
||||
description:
|
||||
human:
|
||||
en_US: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI.
|
||||
zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。
|
||||
llm: A tool for converting images to SVG vectors. you should input the image id as the input of this tool. the image id can be got from parameters.
|
||||
parameters:
|
||||
- name: mode
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: production
|
||||
label:
|
||||
en_US: production
|
||||
zh_Hans: 生产模式
|
||||
- value: test
|
||||
label:
|
||||
en_US: test
|
||||
zh_Hans: 测试模式
|
||||
default: test
|
||||
label:
|
||||
en_US: Mode
|
||||
zh_Hans: 模式
|
||||
human_description:
|
||||
en_US: It is free to integrate with and test out the API in test mode, no subscription required.
|
||||
zh_Hans: 在测试模式下,可以免费测试API。
|
||||
form: form
|
||||
23
api/core/tools/provider/builtin/vectorizer/vectorizer.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
class VectorizerProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
VectorizerTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"mode": "test",
|
||||
"image_id": "__test_123"
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
36
api/core/tools/provider/builtin/vectorizer/vectorizer.yaml
Normal file
@@ -0,0 +1,36 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: vectorizer
|
||||
label:
|
||||
en_US: Vectorizer.AI
|
||||
zh_Hans: Vectorizer.AI
|
||||
description:
|
||||
en_US: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI.
|
||||
zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。
|
||||
icon: icon.png
|
||||
credentails_for_provider:
|
||||
api_key_name:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Vectorizer.AI API Key name
|
||||
zh_Hans: Vectorizer.AI API Key name
|
||||
placeholder:
|
||||
en_US: Please input your Vectorizer.AI ApiKey name
|
||||
zh_Hans: 请输入你的 Vectorizer.AI ApiKey name
|
||||
help:
|
||||
en_US: Get your Vectorizer.AI API Key from Vectorizer.AI.
|
||||
zh_Hans: 从 Vectorizer.AI 获取您的 Vectorizer.AI API Key。
|
||||
url: https://vectorizer.ai/api
|
||||
api_key_value:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Vectorizer.AI API Key
|
||||
zh_Hans: Vectorizer.AI API Key
|
||||
placeholder:
|
||||
en_US: Please input your Vectorizer.AI ApiKey
|
||||
zh_Hans: 请输入你的 Vectorizer.AI ApiKey
|
||||
help:
|
||||
en_US: Get your Vectorizer.AI API Key from Vectorizer.AI.
|
||||
zh_Hans: 从 Vectorizer.AI 获取您的 Vectorizer.AI API Key。
|
||||
@@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="17" viewBox="0 0 16 17" fill="none">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M2.6665 1.16667C1.56193 1.16667 0.666504 2.0621 0.666504 3.16667C0.666504 4.27124 1.56193 5.16667 2.6665 5.16667C2.79161 5.16667 2.91403 5.15519 3.03277 5.13321C2.3808 6.09319 1.99984 7.25211 1.99984 8.5C1.99984 9.7479 2.3808 10.9068 3.03277 11.8668C2.91403 11.8448 2.79161 11.8333 2.6665 11.8333C1.56193 11.8333 0.666504 12.7288 0.666504 13.8333C0.666504 14.9379 1.56193 15.8333 2.6665 15.8333C3.77107 15.8333 4.6665 14.9379 4.6665 13.8333C4.6665 13.7082 4.65502 13.5858 4.63304 13.4671C5.59302 14.119 6.75194 14.5 7.99984 14.5C9.24773 14.5 10.4066 14.119 11.3666 13.4671C11.3447 13.5858 11.3332 13.7082 11.3332 13.8333C11.3332 14.9379 12.2286 15.8333 13.3332 15.8333C14.4377 15.8333 15.3332 14.9379 15.3332 13.8333C15.3332 12.7288 14.4377 11.8333 13.3332 11.8333C13.2081 11.8333 13.0856 11.8448 12.9669 11.8668C13.6189 10.9068 13.9998 9.7479 13.9998 8.5C13.9998 7.25211 13.6189 6.09319 12.9669 5.13321C13.0856 5.15519 13.2081 5.16667 13.3332 5.16667C14.4377 5.16667 15.3332 4.27124 15.3332 3.16667C15.3332 2.0621 14.4377 1.16667 13.3332 1.16667C12.2286 1.16667 11.3332 2.0621 11.3332 3.16667C11.3332 3.29177 11.3447 3.41419 11.3666 3.53293C10.4066 2.88097 9.24773 2.50001 7.99984 2.50001C6.75194 2.50001 5.59302 2.88097 4.63304 3.53293C4.65502 3.41419 4.6665 3.29177 4.6665 3.16667C4.6665 2.0621 3.77107 1.16667 2.6665 1.16667ZM3.38043 7.83334C3.63081 6.08287 4.85262 4.64578 6.48223 4.08565C5.79223 5.22099 5.36488 6.50185 5.23815 7.83334H3.38043ZM6.48228 12.9144C4.85264 12.3543 3.63082 10.9172 3.38043 9.16667H5.23815C5.3649 10.4982 5.79226 11.779 6.48228 12.9144ZM12.6192 9.16667C12.3689 10.9168 11.1475 12.3537 9.5183 12.9141C10.2082 11.7788 10.6355 10.498 10.7622 9.16667H12.6192ZM9.51834 4.08596C11.1475 4.64631 12.3689 6.0832 12.6192 7.83334H10.7622C10.6355 6.50197 10.2082 5.22123 9.51834 4.08596ZM9.4218 7.83334C9.27457 6.52262 8.78381 5.27411 8.00019 4.2145C7.21658 5.27411 6.72582 6.52262 6.57859 7.83334H9.4218ZM6.5786 9.16667C6.72583 10.4774 7.21659 11.7259 8.00019 12.7855C8.7838 11.7259 9.27456 10.4774 9.42179 9.16667H6.5786Z" fill="#DD2590"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.2 KiB |
@@ -0,0 +1,28 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
class WebscraperTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
url = tool_paramters.get('url', '')
|
||||
user_agent = tool_paramters.get('user_agent', '')
|
||||
if not url:
|
||||
return self.create_text_message('Please input url')
|
||||
|
||||
# get webpage
|
||||
result = self.get_url(url, user_agent=user_agent)
|
||||
|
||||
# summarize and return
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=result))
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
identity:
|
||||
name: webscraper
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Web Scraper
|
||||
zh_Hans: 网页爬虫
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for scraping webpages.
|
||||
zh_Hans: 一个用于爬取网页的工具。
|
||||
llm: A tool for scraping webpages. Input should be a URL.
|
||||
parameters:
|
||||
- name: url
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: URL
|
||||
zh_Hans: 网页链接
|
||||
human_description:
|
||||
en_US: used for linking to webpages
|
||||
zh_Hans: 用于链接到网页
|
||||
llm_description: url for scraping
|
||||
form: llm
|
||||
- name: user_agent
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: User Agent
|
||||
zh_Hans: User Agent
|
||||
human_description:
|
||||
en_US: used for identifying the browser.
|
||||
zh_Hans: 用于识别浏览器。
|
||||
form: form
|
||||
default: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.1000.0 Safari/537.36
|
||||
23
api/core/tools/provider/builtin/webscraper/webscraper.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.webscraper.tools.webscraper import WebscraperTool
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
class WebscraperProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
WebscraperTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
'url': 'https://www.google.com',
|
||||
'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
11
api/core/tools/provider/builtin/webscraper/webscraper.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: webscraper
|
||||
label:
|
||||
en_US: WebScraper
|
||||
zh_Hans: 网页抓取
|
||||
description:
|
||||
en_US: Web Scrapper tool kit is used to scrape web
|
||||
zh_Hans: 一个用于抓取网页的工具。
|
||||
icon: icon.svg
|
||||
credentails_for_provider:
|
||||
@@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="25" viewBox="0 0 24 25" fill="none">
|
||||
<path d="M22.3627 6.50009H18.3156H18.1783V6.63743V7.07751V7.21484H18.3156H18.5969C18.924 7.21484 19.2189 7.38272 19.386 7.66394C19.553 7.94516 19.5593 8.28448 19.4028 8.57169L14.9027 16.8317L12.8532 11.9459L14.7837 8.40336C15.1832 7.67026 15.95 7.21484 16.7849 7.21484H16.8761H17.0134V7.07751V6.63743V6.50009H16.8761H12.829H12.6917V6.63743V7.07751V7.21484H12.829H13.1102C13.4373 7.21484 13.7323 7.38272 13.8993 7.66394C14.0663 7.94516 14.0726 8.28448 13.9162 8.57169L12.5159 11.1419L11.268 8.16696C11.1776 7.95134 11.1999 7.71594 11.3294 7.52124C11.4589 7.32654 11.6673 7.21484 11.9011 7.21484H12.221H12.3583V7.07751V6.63743V6.50009H12.221H7.3808H7.24347V6.63743V7.07751V7.21484H7.3808H7.44737C8.40218 7.21484 9.25775 7.78379 9.62715 8.66426L11.471 13.0599L9.4161 16.8317L5.78141 8.16696C5.69095 7.95134 5.71334 7.71594 5.8428 7.52124C5.97227 7.32654 6.18065 7.21484 6.41449 7.21484H6.90603H7.04337V7.07751V6.63743V6.50009H6.90603H1.63734H1.5V6.63743V7.07751V7.21484H1.63734H1.96072C2.91554 7.21484 3.77116 7.78379 4.1405 8.66426L8.33049 18.6529C8.40379 18.8276 8.57372 18.9405 8.76347 18.9405C8.93762 18.9405 9.09139 18.849 9.17485 18.6958L9.72141 17.6928L11.8081 13.8635L13.8171 18.6528C13.8904 18.8275 14.0603 18.9404 14.2501 18.9404C14.4242 18.9404 14.578 18.849 14.6614 18.6958L15.208 17.6928L20.2703 8.40327C20.6698 7.67016 21.4366 7.21475 22.2715 7.21475H22.3627H22.5V7.07741V6.63734V6.5H22.3627V6.50009Z" fill="#222A30"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
@@ -0,0 +1,37 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain import WikipediaAPIWrapper
|
||||
from langchain.tools import WikipediaQueryRun
|
||||
|
||||
class WikipediaInput(BaseModel):
|
||||
query: str = Field(..., description="search query.")
|
||||
|
||||
class WikiPediaSearchTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
query = tool_paramters.get('query', '')
|
||||
if not query:
|
||||
return self.create_text_message('Please input query')
|
||||
|
||||
tool = WikipediaQueryRun(
|
||||
name="wikipedia",
|
||||
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
|
||||
args_schema=WikipediaInput
|
||||
)
|
||||
|
||||
result = tool.run(tool_input={
|
||||
'query': query
|
||||
})
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id,content=result))
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
identity:
|
||||
name: wikipedia_search
|
||||
author: Dify
|
||||
label:
|
||||
en_US: WikipediaSearch
|
||||
zh_Hans: 维基百科搜索
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for performing a Wikipedia search and extracting snippets and webpages.
|
||||
zh_Hans: 一个用于执行维基百科搜索并提取片段和网页的工具。
|
||||
llm: A tool for performing a Wikipedia search and extracting snippets and webpages. Input should be a search query.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
human_description:
|
||||
en_US: key words for searching
|
||||
zh_Hans: 查询关键词
|
||||
llm_description: key words for searching
|
||||
form: llm
|
||||
20
api/core/tools/provider/builtin/wikipedia/wikipedia.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.wikipedia.tools.wikipedia_search import WikiPediaSearchTool
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
WikiPediaSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"query": "misaka mikoto",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
11
api/core/tools/provider/builtin/wikipedia/wikipedia.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: wikipedia
|
||||
label:
|
||||
en_US: Wikipedia
|
||||
zh_Hans: 维基百科
|
||||
description:
|
||||
en_US: Wikipedia is a free online encyclopedia, created and edited by volunteers around the world.
|
||||
zh_Hans: 维基百科是一个由全世界的志愿者创建和编辑的免费在线百科全书。
|
||||
icon: icon.svg
|
||||
credentails_for_provider:
|
||||
@@ -0,0 +1,23 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="22" height="23" viewBox="0 0 22 23" fill="none">
|
||||
<path d="M21.4992 14.615L17.0326 15.5683L17.4865 20.0065L13.3037 18.2132L10.9994 22.0669L8.69549 18.2132L4.51225 20.0065L4.96656 15.5678L0.5 14.615L3.542 11.2832L0.5 7.9519L4.96656 6.99854L4.51225 2.55981L8.69549 4.35369L10.9994 0.5L13.3037 4.35369L17.4865 2.56031L17.0326 7.00401L21.4992 7.9519L18.4572 11.2832L21.4992 14.615Z" fill="#F16850"/>
|
||||
<path d="M10.9993 7.23111L8.69495 4.35315L4.51221 2.56026L7.00396 5.84084L10.9993 7.23111Z" fill="#FD694F"/>
|
||||
<path d="M4.96656 6.99847L0.5 7.95183L3.542 11.2831L7.11734 9.9838L4.96656 6.99847Z" fill="#FF3413"/>
|
||||
<path d="M7.00346 5.84037L4.51221 2.55978L4.96602 6.99851L7.11729 9.98384L7.00346 5.84037Z" fill="#DC1D23"/>
|
||||
<path d="M13.3031 4.35369L10.9987 0.5L8.69434 4.35369L10.9987 7.23116L13.3031 4.35369Z" fill="#FF9281"/>
|
||||
<path d="M18.4577 11.2831L21.4997 7.95183L17.0331 6.99847L14.8818 9.9838L18.4577 11.2831Z" fill="#FF8B79"/>
|
||||
<path d="M14.8823 9.98384L17.0331 6.99851L17.4929 2.55978L14.9957 5.84037L14.8818 9.98384H14.8823Z" fill="#FD694F"/>
|
||||
<path d="M14.9954 5.84034L17.4926 2.55975L13.3044 4.35364L11 7.23111L14.9954 5.84034Z" fill="#EF5240"/>
|
||||
<path d="M17.47 13.2694L21.4997 14.6149L18.4577 11.2831L14.8818 9.98383L17.47 13.2694Z" fill="#FF482C"/>
|
||||
<path d="M7.11783 9.98383L3.542 11.2831L0.5 14.6149L4.52965 13.2699L7.11783 9.98383Z" fill="#EC2101"/>
|
||||
<path d="M11 17.8612V22.0664L13.3044 18.2132L13.4008 14.439L11 17.8612Z" fill="#D21C22"/>
|
||||
<path d="M17.4703 13.2693L13.4009 14.4389L17.0334 15.5682L21.4999 14.6149L17.4703 13.2693Z" fill="#C90901"/>
|
||||
<path d="M13.3042 18.2132L17.4874 20.0065L17.0331 15.5683L13.4011 14.439L13.3042 18.2132Z" fill="#EC2101"/>
|
||||
<path d="M4.52965 13.2693L0.5 14.6154L4.96656 15.5632L8.59906 14.4394L4.52965 13.2703V13.2693Z" fill="#B6171E"/>
|
||||
<path d="M8.59912 14.439L8.69555 18.2132L10.9999 22.0669V17.8612L8.59912 14.439Z" fill="#B4151B"/>
|
||||
<path d="M4.96602 15.5623L4.51221 20.006L8.69495 18.2131L8.59852 14.439L4.96602 15.5623Z" fill="#D21C22"/>
|
||||
<path d="M14.882 9.98384L14.9954 5.84036L11 7.23113V11.2608L14.882 9.98384Z" fill="#E63320"/>
|
||||
<path d="M11.0003 7.23113L7.00391 5.84036L7.11773 9.98384L10.9998 11.2608L11.0003 7.23113Z" fill="#FF4527"/>
|
||||
<path d="M8.59912 14.439L10.9999 17.8613L13.4007 14.439L10.9999 11.2608L8.59912 14.439Z" fill="#FF9281"/>
|
||||
<path d="M11 11.2608L13.4008 14.439L17.4702 13.2699L14.882 9.98386L11 11.2608Z" fill="#FD684D"/>
|
||||
<path d="M7.1165 9.9839L4.52832 13.2694L8.59773 14.439L10.9985 11.2608L7.1165 9.9839Z" fill="#FD745C"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.5 KiB |
@@ -0,0 +1,77 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolProviderCredentialValidationError, ToolInvokeError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from httpx import get
|
||||
|
||||
class WolframAlphaTool(BuiltinTool):
|
||||
_base_url = 'https://api.wolframalpha.com/v2/query'
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
query = tool_paramters.get('query', '')
|
||||
if not query:
|
||||
return self.create_text_message('Please input query')
|
||||
appid = self.runtime.credentials.get('appid', '')
|
||||
if not appid:
|
||||
raise ToolProviderCredentialValidationError('Please input appid')
|
||||
|
||||
params = {
|
||||
'appid': appid,
|
||||
'input': query,
|
||||
'includepodid': 'Result',
|
||||
'format': 'plaintext',
|
||||
'output': 'json'
|
||||
}
|
||||
|
||||
finished = False
|
||||
result = None
|
||||
# try 3 times at most
|
||||
counter = 0
|
||||
|
||||
while not finished and counter < 3:
|
||||
counter += 1
|
||||
try:
|
||||
response = get(self._base_url, params=params, timeout=20)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
|
||||
if 'success' not in response_data['queryresult'] or response_data['queryresult']['success'] != True:
|
||||
query_result = response_data.get('queryresult', {})
|
||||
if 'error' in query_result and query_result['error']:
|
||||
if 'msg' in query_result['error']:
|
||||
if query_result['error']['msg'] == 'Invalid appid':
|
||||
raise ToolProviderCredentialValidationError('Invalid appid')
|
||||
raise ToolInvokeError('Failed to invoke tool')
|
||||
|
||||
if 'didyoumeans' in response_data['queryresult']:
|
||||
# get the most likely interpretation
|
||||
query = ''
|
||||
max_score = 0
|
||||
for didyoumean in response_data['queryresult']['didyoumeans']:
|
||||
if float(didyoumean['score']) > max_score:
|
||||
query = didyoumean['val']
|
||||
max_score = float(didyoumean['score'])
|
||||
|
||||
params['input'] = query
|
||||
else:
|
||||
finished = True
|
||||
if 'souces' in response_data['queryresult']:
|
||||
return self.create_link_message(response_data['queryresult']['sources']['url'])
|
||||
elif 'pods' in response_data['queryresult']:
|
||||
result = response_data['queryresult']['pods'][0]['subpods'][0]['plaintext']
|
||||
|
||||
if not finished or not result:
|
||||
return self.create_text_message('No result found')
|
||||
|
||||
return self.create_text_message(result)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
identity:
|
||||
name: wolframalpha
|
||||
author: Dify
|
||||
label:
|
||||
en_US: WolframAlpha
|
||||
zh_Hans: WolframAlpha
|
||||
description:
|
||||
human:
|
||||
en_US: WolframAlpha is a powerful computational knowledge engine.
|
||||
zh_Hans: WolframAlpha 是一个强大的计算知识引擎。
|
||||
llm: WolframAlpha is a powerful computational knowledge engine. one single query can get the answer of a question.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 计算语句
|
||||
human_description:
|
||||
en_US: used for calculating
|
||||
zh_Hans: 用于计算最终结果
|
||||
llm_description: a single query for calculating
|
||||
form: llm
|
||||
24
api/core/tools/provider/builtin/wolframalpha/wolframalpha.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.wolframalpha.tools.wolframalpha import WolframAlphaTool
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
class GoogleProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
WolframAlphaTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"query": "1+2+....+111",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@@ -0,0 +1,24 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: wolframalpha
|
||||
label:
|
||||
en_US: WolframAlpha
|
||||
zh_Hans: WolframAlpha
|
||||
description:
|
||||
en_US: WolframAlpha is a powerful computational knowledge engine.
|
||||
zh_Hans: WolframAlpha 是一个强大的计算知识引擎。
|
||||
icon: icon.svg
|
||||
credentails_for_provider:
|
||||
appid:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: WolframAlpha AppID
|
||||
zh_Hans: WolframAlpha AppID
|
||||
placeholder:
|
||||
en_US: Please input your WolframAlpha AppID
|
||||
zh_Hans: 请输入你的 WolframAlpha AppID
|
||||
help:
|
||||
en_US: Get your WolframAlpha AppID from WolframAlpha, please use "full results" api access.
|
||||
zh_Hans: 从 WolframAlpha 获取您的 WolframAlpha AppID,请使用 "full results" API。
|
||||
url: https://products.wolframalpha.com/api
|
||||
BIN
api/core/tools/provider/builtin/yahoo/_assets/icon.png
Normal file
|
After Width: | Height: | Size: 7.5 KiB |
69
api/core/tools/provider/builtin/yahoo/tools/analytics.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from requests.exceptions import HTTPError, ReadTimeout
|
||||
from datetime import datetime
|
||||
|
||||
from yfinance import download
|
||||
import pandas as pd
|
||||
|
||||
class YahooFinanceAnalyticsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
symbol = tool_paramters.get('symbol', '')
|
||||
if not symbol:
|
||||
return self.create_text_message('Please input symbol')
|
||||
|
||||
time_range = [None, None]
|
||||
start_date = tool_paramters.get('start_date', '')
|
||||
if start_date:
|
||||
time_range[0] = start_date
|
||||
else:
|
||||
time_range[0] = '1800-01-01'
|
||||
|
||||
end_date = tool_paramters.get('end_date', '')
|
||||
if end_date:
|
||||
time_range[1] = end_date
|
||||
else:
|
||||
time_range[1] = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
stock_data = download(symbol, start=time_range[0], end=time_range[1])
|
||||
max_segments = min(15, len(stock_data))
|
||||
rows_per_segment = len(stock_data) // max_segments
|
||||
summary_data = []
|
||||
for i in range(max_segments):
|
||||
start_idx = i * rows_per_segment
|
||||
end_idx = (i + 1) * rows_per_segment if i < max_segments - 1 else len(stock_data)
|
||||
segment_data = stock_data.iloc[start_idx:end_idx]
|
||||
segment_summary = {
|
||||
'Start Date': segment_data.index[0],
|
||||
'End Date': segment_data.index[-1],
|
||||
'Average Close': segment_data['Close'].mean(),
|
||||
'Average Volume': segment_data['Volume'].mean(),
|
||||
'Average Open': segment_data['Open'].mean(),
|
||||
'Average High': segment_data['High'].mean(),
|
||||
'Average Low': segment_data['Low'].mean(),
|
||||
'Average Adj Close': segment_data['Adj Close'].mean(),
|
||||
'Max Close': segment_data['Close'].max(),
|
||||
'Min Close': segment_data['Close'].min(),
|
||||
'Max Volume': segment_data['Volume'].max(),
|
||||
'Min Volume': segment_data['Volume'].min(),
|
||||
'Max Open': segment_data['Open'].max(),
|
||||
'Min Open': segment_data['Open'].min(),
|
||||
'Max High': segment_data['High'].max(),
|
||||
'Min High': segment_data['High'].min(),
|
||||
}
|
||||
|
||||
summary_data.append(segment_summary)
|
||||
|
||||
summary_df = pd.DataFrame(summary_data)
|
||||
|
||||
try:
|
||||
return self.create_text_message(str(summary_df.to_dict()))
|
||||
except (HTTPError, ReadTimeout):
|
||||
return self.create_text_message(f'There is a internet connection problem. Please try again later.')
|
||||
|
||||
46
api/core/tools/provider/builtin/yahoo/tools/analytics.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
identity:
|
||||
name: yahoo_finance_analytics
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Analytics
|
||||
zh_Hans: 分析
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for get analytics about a ticker from Yahoo Finance.
|
||||
zh_Hans: 一个用于从雅虎财经获取分析数据的工具。
|
||||
llm: A tool for get analytics from Yahoo Finance. Input should be the ticker symbol like AAPL.
|
||||
parameters:
|
||||
- name: symbol
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Ticker symbol
|
||||
zh_Hans: 股票代码
|
||||
human_description:
|
||||
en_US: The ticker symbol of the company you want to analyze.
|
||||
zh_Hans: 你想要搜索的公司的股票代码。
|
||||
llm_description: The ticker symbol of the company you want to analyze.
|
||||
form: llm
|
||||
- name: start_date
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Start date
|
||||
zh_Hans: 开始日期
|
||||
human_description:
|
||||
en_US: The start date of the analytics.
|
||||
zh_Hans: 分析的开始日期。
|
||||
llm_description: The start date of the analytics, the format of the date must be YYYY-MM-DD like 2020-01-01.
|
||||
form: llm
|
||||
- name: end_date
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: End date
|
||||
zh_Hans: 结束日期
|
||||
human_description:
|
||||
en_US: The end date of the analytics.
|
||||
zh_Hans: 分析的结束日期。
|
||||
llm_description: The end date of the analytics, the format of the date must be YYYY-MM-DD like 2024-01-01.
|
||||
form: llm
|
||||
46
api/core/tools/provider/builtin/yahoo/tools/news.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from requests.exceptions import HTTPError, ReadTimeout
|
||||
|
||||
import yfinance
|
||||
|
||||
class YahooFinanceSearchTickerTool(BuiltinTool):
|
||||
def _invoke(self,user_id: str, tool_paramters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
'''
|
||||
invoke tools
|
||||
'''
|
||||
|
||||
query = tool_paramters.get('symbol', '')
|
||||
if not query:
|
||||
return self.create_text_message('Please input symbol')
|
||||
|
||||
try:
|
||||
return self.run(ticker=query, user_id=user_id)
|
||||
except (HTTPError, ReadTimeout):
|
||||
return self.create_text_message(f'There is a internet connection problem. Please try again later.')
|
||||
|
||||
def run(self, ticker: str, user_id: str) -> ToolInvokeMessage:
|
||||
company = yfinance.Ticker(ticker)
|
||||
try:
|
||||
if company.isin is None:
|
||||
return self.create_text_message(f'Company ticker {ticker} not found.')
|
||||
except (HTTPError, ReadTimeout, ConnectionError):
|
||||
return self.create_text_message(f'Company ticker {ticker} not found.')
|
||||
|
||||
links = []
|
||||
try:
|
||||
links = [n['link'] for n in company.news if n['type'] == 'STORY']
|
||||
except (HTTPError, ReadTimeout, ConnectionError):
|
||||
if not links:
|
||||
return self.create_text_message(f'There is nothing about {ticker} ticker')
|
||||
if not links:
|
||||
return self.create_text_message(f'No news found for company that searched with {ticker} ticker.')
|
||||
|
||||
result = '\n\n'.join([
|
||||
self.get_url(link) for link in links
|
||||
])
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=result))
|
||||
24
api/core/tools/provider/builtin/yahoo/tools/news.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
identity:
|
||||
name: yahoo_finance_news
|
||||
author: Dify
|
||||
label:
|
||||
en_US: News
|
||||
zh_Hans: 新闻
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for get news about a ticker from Yahoo Finance.
|
||||
zh_Hans: 一个用于从雅虎财经获取新闻的工具。
|
||||
llm: A tool for get news from Yahoo Finance. Input should be the ticker symbol like AAPL.
|
||||
parameters:
|
||||
- name: symbol
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Ticker symbol
|
||||
zh_Hans: 股票代码
|
||||
human_description:
|
||||
en_US: The ticker symbol of the company you want to search.
|
||||
zh_Hans: 你想要搜索的公司的股票代码。
|
||||
llm_description: The ticker symbol of the company you want to search.
|
||||
form: llm
|
||||
25
api/core/tools/provider/builtin/yahoo/tools/ticker.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from requests.exceptions import HTTPError, ReadTimeout
|
||||
|
||||
from yfinance import Ticker
|
||||
|
||||
class YahooFinanceSearchTickerTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
query = tool_paramters.get('symbol', '')
|
||||
if not query:
|
||||
return self.create_text_message('Please input symbol')
|
||||
|
||||
try:
|
||||
return self.create_text_message(self.run(ticker=query))
|
||||
except (HTTPError, ReadTimeout):
|
||||
return self.create_text_message(f'There is a internet connection problem. Please try again later.')
|
||||
|
||||
def run(self, ticker: str) -> str:
|
||||
return str(Ticker(ticker).info)
|
||||
24
api/core/tools/provider/builtin/yahoo/tools/ticker.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
identity:
|
||||
name: yahoo_finance_ticker
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Ticker
|
||||
zh_Hans: 股票信息
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for search ticker information from Yahoo Finance.
|
||||
zh_Hans: 一个用于从雅虎财经搜索股票信息的工具。
|
||||
llm: A tool for search ticker information from Yahoo Finance. Input should be the ticker symbol like AAPL.
|
||||
parameters:
|
||||
- name: symbol
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Ticker symbol
|
||||
zh_Hans: 股票代码
|
||||
human_description:
|
||||
en_US: The ticker symbol of the company you want to search.
|
||||
zh_Hans: 你想要搜索的公司的股票代码。
|
||||
llm_description: The ticker symbol of the company you want to search.
|
||||
form: llm
|
||||
20
api/core/tools/provider/builtin/yahoo/yahoo.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.yahoo.tools.ticker import YahooFinanceSearchTickerTool
|
||||
|
||||
class YahooFinanceProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
YahooFinanceSearchTickerTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"ticker": "MSFT",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
11
api/core/tools/provider/builtin/yahoo/yahoo.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: yahoo
|
||||
label:
|
||||
en_US: YahooFinance
|
||||
zh_Hans: 雅虎财经
|
||||
description:
|
||||
en_US: Finance, and Yahoo! get the latest news, stock quotes, and interactive chart with Yahoo!
|
||||
zh_Hans: 雅虎财经,获取并整理出最新的新闻、股票报价等一切你想要的财经信息。
|
||||
icon: icon.png
|
||||
credentails_for_provider:
|
||||
BIN
api/core/tools/provider/builtin/youtube/_assets/icon.png
Normal file
|
After Width: | Height: | Size: 1.0 KiB |
66
api/core/tools/provider/builtin/youtube/tools/videos.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from datetime import datetime
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
class YoutubeVideosAnalyticsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
channel = tool_paramters.get('channel', '')
|
||||
if not channel:
|
||||
return self.create_text_message('Please input symbol')
|
||||
|
||||
time_range = [None, None]
|
||||
start_date = tool_paramters.get('start_date', '')
|
||||
if start_date:
|
||||
time_range[0] = start_date
|
||||
else:
|
||||
time_range[0] = '1800-01-01'
|
||||
|
||||
end_date = tool_paramters.get('end_date', '')
|
||||
if end_date:
|
||||
time_range[1] = end_date
|
||||
else:
|
||||
time_range[1] = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
if 'google_api_key' not in self.runtime.credentials or not self.runtime.credentials['google_api_key']:
|
||||
return self.create_text_message('Please input api key')
|
||||
|
||||
youtube = build('youtube', 'v3', developerKey=self.runtime.credentials['google_api_key'])
|
||||
|
||||
# try to get channel id
|
||||
search_results = youtube.search().list(q='mrbeast', type='channel', order='relevance', part='id').execute()
|
||||
channel_id = search_results['items'][0]['id']['channelId']
|
||||
|
||||
start_date, end_date = time_range
|
||||
|
||||
start_date = datetime.strptime(start_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
end_date = datetime.strptime(end_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
|
||||
# get videos
|
||||
time_range_videos = youtube.search().list(
|
||||
part='snippet', channelId=channel_id, order='date', type='video',
|
||||
publishedAfter=start_date,
|
||||
publishedBefore=end_date
|
||||
).execute()
|
||||
|
||||
def extract_video_data(video_list):
|
||||
data = []
|
||||
for video in video_list['items']:
|
||||
video_id = video['id']['videoId']
|
||||
video_info = youtube.videos().list(part='snippet,statistics', id=video_id).execute()
|
||||
title = video_info['items'][0]['snippet']['title']
|
||||
views = video_info['items'][0]['statistics']['viewCount']
|
||||
data.append({'Title': title, 'Views': views})
|
||||
return data
|
||||
|
||||
summary = extract_video_data(time_range_videos)
|
||||
|
||||
return self.create_text_message(str(summary))
|
||||
|
||||
46
api/core/tools/provider/builtin/youtube/tools/videos.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
identity:
|
||||
name: youtube_video_statistics
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Video statistics
|
||||
zh_Hans: 视频统计
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for get statistics about a channel's videos.
|
||||
zh_Hans: 一个用于获取油管频道视频统计数据的工具。
|
||||
llm: A tool for get statistics about a channel's videos. Input should be the name of the channel like PewDiePie.
|
||||
parameters:
|
||||
- name: channel
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Channel name
|
||||
zh_Hans: 频道名
|
||||
human_description:
|
||||
en_US: The name of the channel you want to search.
|
||||
zh_Hans: 你想要搜索的油管频道名。
|
||||
llm_description: The name of the channel you want to search.
|
||||
form: llm
|
||||
- name: start_date
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Start date
|
||||
zh_Hans: 开始日期
|
||||
human_description:
|
||||
en_US: The start date of the analytics.
|
||||
zh_Hans: 分析的开始日期。
|
||||
llm_description: The start date of the analytics, the format of the date must be YYYY-MM-DD like 2020-01-01.
|
||||
form: llm
|
||||
- name: end_date
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: End date
|
||||
zh_Hans: 结束日期
|
||||
human_description:
|
||||
en_US: The end date of the analytics.
|
||||
zh_Hans: 分析的结束日期。
|
||||
llm_description: The end date of the analytics, the format of the date must be YYYY-MM-DD like 2024-01-01.
|
||||
form: llm
|
||||
22
api/core/tools/provider/builtin/youtube/youtube.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.youtube.tools.videos import YoutubeVideosAnalyticsTool
|
||||
|
||||
class YahooFinanceProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
YoutubeVideosAnalyticsTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"channel": "TOKYO GIRLS COLLECTION",
|
||||
"start_date": "2020-01-01",
|
||||
"end_date": "2024-12-31",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
24
api/core/tools/provider/builtin/youtube/youtube.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: youtube
|
||||
label:
|
||||
en_US: Youtube
|
||||
zh_Hans: Youtube
|
||||
description:
|
||||
en_US: Youtube
|
||||
zh_Hans: Youtube(油管)是全球最大的视频分享网站,用户可以在上面上传、观看和分享视频。
|
||||
icon: icon.png
|
||||
credentails_for_provider:
|
||||
google_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Google API key
|
||||
zh_Hans: Google API key
|
||||
placeholder:
|
||||
en_US: Please input your Google API key
|
||||
zh_Hans: 请输入你的 Google API key
|
||||
help:
|
||||
en_US: Get your Google API key from Google
|
||||
zh_Hans: 从 Google 获取您的 Google API key
|
||||
url: https://console.developers.google.com/apis/credentials
|
||||
286
api/core/tools/provider/builtin_tool_provider.py
Normal file
@@ -0,0 +1,286 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from os import path, listdir
|
||||
from yaml import load, FullLoader
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType, \
|
||||
ToolParamter, ToolProviderCredentials
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.entities.user_entities import UserToolProviderCredentials
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError, \
|
||||
ToolParamterValidationError, ToolProviderCredentialValidationError
|
||||
|
||||
import importlib
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if self.app_type == ToolProviderType.API_BASED or self.app_type == ToolProviderType.APP_BASED:
|
||||
super().__init__(**data)
|
||||
return
|
||||
|
||||
# load provider yaml
|
||||
provider = self.__class__.__module__.split('.')[-1]
|
||||
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
|
||||
try:
|
||||
with open(yaml_path, 'r') as f:
|
||||
provider_yaml = load(f.read(), FullLoader)
|
||||
except:
|
||||
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}')
|
||||
|
||||
if 'credentails_for_provider' in provider_yaml and provider_yaml['credentails_for_provider'] is not None:
|
||||
# set credentials name
|
||||
for credential_name in provider_yaml['credentails_for_provider']:
|
||||
provider_yaml['credentails_for_provider'][credential_name]['name'] = credential_name
|
||||
|
||||
super().__init__(**{
|
||||
'identity': provider_yaml['identity'],
|
||||
'credentials_schema': provider_yaml['credentails_for_provider'] if 'credentails_for_provider' in provider_yaml else None,
|
||||
})
|
||||
|
||||
def _get_bulitin_tools(self) -> List[Tool]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
:return: list of tools
|
||||
"""
|
||||
if self.tools:
|
||||
return self.tools
|
||||
|
||||
provider = self.identity.name
|
||||
tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools")
|
||||
# get all the yaml files in the tool path
|
||||
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
|
||||
tools = []
|
||||
for tool_file in tool_files:
|
||||
with open(path.join(tool_path, tool_file), "r") as f:
|
||||
# get tool name
|
||||
tool_name = tool_file.split(".")[0]
|
||||
tool = load(f.read(), FullLoader)
|
||||
# get tool class, import the module
|
||||
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, 'tools', f'{tool_name}.py')
|
||||
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.tools.{tool_name}', py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
# get all the classes in the module
|
||||
classes = [x for _, x in vars(mod).items()
|
||||
if isinstance(x, type) and x not in [BuiltinTool, Tool] and issubclass(x, BuiltinTool)
|
||||
]
|
||||
assistant_tool_class = classes[0]
|
||||
tools.append(assistant_tool_class(**tool))
|
||||
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
def get_credentails_schema(self) -> Dict[str, ToolProviderCredentials]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
if not self.credentials_schema:
|
||||
return {}
|
||||
|
||||
return self.credentials_schema.copy()
|
||||
|
||||
def user_get_credentails_schema(self) -> UserToolProviderCredentials:
|
||||
"""
|
||||
returns the credentials schema of the provider, this method is used for user
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
credentials = self.credentials_schema.copy()
|
||||
return UserToolProviderCredentials(credentails=credentials)
|
||||
|
||||
def get_tools(self) -> List[Tool]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
:return: list of tools
|
||||
"""
|
||||
return self._get_bulitin_tools()
|
||||
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
|
||||
def get_parameters(self, tool_name: str) -> List[ToolParamter]:
|
||||
"""
|
||||
returns the parameters of the tool
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:return: list of parameters
|
||||
"""
|
||||
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f'tool {tool_name} not found')
|
||||
return tool.parameters
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
"""
|
||||
returns whether the provider needs credentials
|
||||
|
||||
:return: whether the provider needs credentials
|
||||
"""
|
||||
return self.credentials_schema is not None and len(self.credentials_schema) != 0
|
||||
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the parameters of the tool and set the default value if needed
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param tool_parameters: the parameters of the tool
|
||||
"""
|
||||
tool_parameters_schema = self.get_parameters(tool_name)
|
||||
|
||||
tool_parameters_need_to_validate: Dict[str, ToolParamter] = {}
|
||||
for parameter in tool_parameters_schema:
|
||||
tool_parameters_need_to_validate[parameter.name] = parameter
|
||||
|
||||
for parameter in tool_parameters:
|
||||
if parameter not in tool_parameters_need_to_validate:
|
||||
raise ToolParamterValidationError(f'parameter {parameter} not found in tool {tool_name}')
|
||||
|
||||
# check type
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
if parameter_schema.type == ToolParamter.ToolParameterType.STRING:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be string')
|
||||
|
||||
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
|
||||
if not isinstance(tool_parameters[parameter], (int, float)):
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be number')
|
||||
|
||||
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
|
||||
|
||||
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
|
||||
|
||||
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
|
||||
if not isinstance(tool_parameters[parameter], bool):
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be boolean')
|
||||
|
||||
elif parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be string')
|
||||
|
||||
options = parameter_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolParamterValidationError(f'parameter {parameter} options should be list')
|
||||
|
||||
if tool_parameters[parameter] not in [x.value for x in options]:
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be one of {options}')
|
||||
|
||||
tool_parameters_need_to_validate.pop(parameter)
|
||||
|
||||
for parameter in tool_parameters_need_to_validate:
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
if parameter_schema.required:
|
||||
raise ToolParamterValidationError(f'parameter {parameter} is required')
|
||||
|
||||
# the parameter is not set currently, set the default value if needed
|
||||
if parameter_schema.default is not None:
|
||||
default_value = parameter_schema.default
|
||||
# parse default value into the correct type
|
||||
if parameter_schema.type == ToolParamter.ToolParameterType.STRING or \
|
||||
parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
|
||||
default_value = str(default_value)
|
||||
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
|
||||
default_value = float(default_value)
|
||||
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
|
||||
default_value = bool(default_value)
|
||||
|
||||
tool_parameters[parameter] = default_value
|
||||
|
||||
def validate_credentials_format(self, credentials: Dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = self.credentials_schema
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
credentials_need_to_validate: Dict[str, ToolProviderCredentials] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}')
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
|
||||
credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
|
||||
|
||||
elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list')
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}')
|
||||
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} is required')
|
||||
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
|
||||
credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \
|
||||
credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
|
||||
def validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
# validate credentials format
|
||||
self.validate_credentials_format(credentials)
|
||||
|
||||
# validate credentials
|
||||
self._validate_credentials(credentials)
|
||||
|
||||
@abstractmethod
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
pass
|
||||
218
api/core/tools/provider/tool_provider.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType, \
|
||||
ToolProviderIdentity, ToolParamter, ToolProviderCredentials
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.entities.user_entities import UserToolProviderCredentials
|
||||
from core.tools.errors import ToolNotFoundError, \
|
||||
ToolParamterValidationError, ToolProviderCredentialValidationError
|
||||
|
||||
class ToolProviderController(BaseModel, ABC):
|
||||
identity: Optional[ToolProviderIdentity] = None
|
||||
tools: Optional[List[Tool]] = None
|
||||
credentials_schema: Optional[Dict[str, ToolProviderCredentials]] = None
|
||||
|
||||
def get_credentails_schema(self) -> Dict[str, ToolProviderCredentials]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
return self.credentials_schema.copy()
|
||||
|
||||
def user_get_credentails_schema(self) -> UserToolProviderCredentials:
|
||||
"""
|
||||
returns the credentials schema of the provider, this method is used for user
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
credentials = self.credentials_schema.copy()
|
||||
return UserToolProviderCredentials(credentails=credentials)
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> List[Tool]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
:return: list of tools
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
"""
|
||||
returns a tool that the provider can provide
|
||||
|
||||
:return: tool
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_parameters(self, tool_name: str) -> List[ToolParamter]:
|
||||
"""
|
||||
returns the parameters of the tool
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:return: list of parameters
|
||||
"""
|
||||
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f'tool {tool_name} not found')
|
||||
return tool.parameters
|
||||
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the parameters of the tool and set the default value if needed
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param tool_parameters: the parameters of the tool
|
||||
"""
|
||||
tool_parameters_schema = self.get_parameters(tool_name)
|
||||
|
||||
tool_parameters_need_to_validate: Dict[str, ToolParamter] = {}
|
||||
for parameter in tool_parameters_schema:
|
||||
tool_parameters_need_to_validate[parameter.name] = parameter
|
||||
|
||||
for parameter in tool_parameters:
|
||||
if parameter not in tool_parameters_need_to_validate:
|
||||
raise ToolParamterValidationError(f'parameter {parameter} not found in tool {tool_name}')
|
||||
|
||||
# check type
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
if parameter_schema.type == ToolParamter.ToolParameterType.STRING:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be string')
|
||||
|
||||
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
|
||||
if not isinstance(tool_parameters[parameter], (int, float)):
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be number')
|
||||
|
||||
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
|
||||
|
||||
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
|
||||
|
||||
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
|
||||
if not isinstance(tool_parameters[parameter], bool):
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be boolean')
|
||||
|
||||
elif parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be string')
|
||||
|
||||
options = parameter_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolParamterValidationError(f'parameter {parameter} options should be list')
|
||||
|
||||
if tool_parameters[parameter] not in [x.value for x in options]:
|
||||
raise ToolParamterValidationError(f'parameter {parameter} should be one of {options}')
|
||||
|
||||
tool_parameters_need_to_validate.pop(parameter)
|
||||
|
||||
for parameter in tool_parameters_need_to_validate:
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
if parameter_schema.required:
|
||||
raise ToolParamterValidationError(f'parameter {parameter} is required')
|
||||
|
||||
# the parameter is not set currently, set the default value if needed
|
||||
if parameter_schema.default is not None:
|
||||
default_value = parameter_schema.default
|
||||
# parse default value into the correct type
|
||||
if parameter_schema.type == ToolParamter.ToolParameterType.STRING or \
|
||||
parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
|
||||
default_value = str(default_value)
|
||||
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
|
||||
default_value = float(default_value)
|
||||
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
|
||||
default_value = bool(default_value)
|
||||
|
||||
tool_parameters[parameter] = default_value
|
||||
|
||||
def validate_credentials_format(self, credentials: Dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = self.credentials_schema
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
credentials_need_to_validate: Dict[str, ToolProviderCredentials] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}')
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
|
||||
credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
|
||||
|
||||
elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list')
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}')
|
||||
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f'credential {credential_name} is required')
|
||||
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
|
||||
credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \
|
||||
credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
|
||||
def validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
# validate credentials format
|
||||
self.validate_credentials_format(credentials)
|
||||
|
||||
# validate credentials
|
||||
self._validate_credentials(credentials)
|
||||
|
||||
@abstractmethod
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
pass
|
||||