feat/enhance the multi-modal support (#8818)

This commit is contained in:
-LAN-
2024-10-21 10:43:49 +08:00
committed by GitHub
parent 7a1d6fe509
commit e61752bd3a
267 changed files with 6263 additions and 3523 deletions

View File

@@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool):
for image in response.data:
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
blob_message = self.create_blob_message(
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE
)
result.append(blob_message)
return result

View File

@@ -2,7 +2,7 @@ from typing import Any
from duckduckgo_search import DDGS
from core.file.file_obj import FileTransferMethod
from core.file.models import FileTransferMethod
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -0,0 +1,24 @@
<svg width="100" height="100" viewBox="0 0 100 100" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="100" height="100" rx="20" fill="#4A90E2" />
<path
d="M50 25C40.6 25 33 32.6 33 42V58C33 67.4 40.6 75 50 75C59.4 75 67 67.4 67 58V42C67 32.6 59.4 25 50 25ZM61 58C61 64.1 56.1 69 50 69C43.9 69 39 64.1 39 58V42C39 35.9 43.9 31 50 31C56.1 31 61 35.9 61 42V58Z"
fill="white" />
<path d="M50 37C47.2 37 45 39.2 45 42V58C45 60.8 47.2 63 50 63C52.8 63 55 60.8 55 58V42C55 39.2 52.8 37 50 37Z"
fill="white" />
<path
d="M73 49H69V58C69 68.5 60.5 77 50 77C39.5 77 31 68.5 31 58V49H27V58C27 70.7 37.3 81 50 81C62.7 81 73 70.7 73 58V49Z"
fill="white" />
<path d="M50 85C51.1 85 52 84.1 52 83V81H48V83C48 84.1 48.9 85 50 85Z" fill="white" />
<path
d="M35 45C36.1046 45 37 44.1046 37 43C37 41.8954 36.1046 41 35 41C33.8954 41 33 41.8954 33 43C33 44.1046 33.8954 45 35 45Z"
fill="white" />
<path
d="M35 55C36.1046 55 37 54.1046 37 53C37 51.8954 36.1046 51 35 51C33.8954 51 33 51.8954 33 53C33 54.1046 33.8954 55 35 55Z"
fill="white" />
<path
d="M65 45C66.1046 45 67 44.1046 67 43C67 41.8954 66.1046 41 65 41C63.8954 41 63 41.8954 63 43C63 44.1046 63.8954 45 65 45Z"
fill="white" />
<path
d="M65 55C66.1046 55 67 54.1046 67 53C67 51.8954 66.1046 51 65 51C63.8954 51 63 51.8954 63 53C63 54.1046 63.8954 55 65 55Z"
fill="white" />
</svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -0,0 +1,33 @@
from typing import Any
import openai
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class PodcastGeneratorProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
tts_service = credentials.get("tts_service")
api_key = credentials.get("api_key")
if not tts_service:
raise ToolProviderCredentialValidationError("TTS service is not specified")
if not api_key:
raise ToolProviderCredentialValidationError("API key is missing")
if tts_service == "openai":
self._validate_openai_credentials(api_key)
else:
raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}")
def _validate_openai_credentials(self, api_key: str) -> None:
client = openai.OpenAI(api_key=api_key)
try:
# We're using a simple API call to validate the credentials
client.models.list()
except openai.AuthenticationError:
raise ToolProviderCredentialValidationError("Invalid OpenAI API key")
except Exception as e:
raise ToolProviderCredentialValidationError(f"Error validating OpenAI API key: {str(e)}")

View File

@@ -0,0 +1,34 @@
identity:
author: Dify
name: podcast_generator
label:
en_US: Podcast Generator
zh_Hans: 播客生成器
description:
en_US: Generate podcast audio using Text-to-Speech services
zh_Hans: 使用文字转语音服务生成播客音频
icon: icon.svg
credentials_for_provider:
tts_service:
type: select
required: true
label:
en_US: TTS Service
zh_Hans: TTS 服务
placeholder:
en_US: Select a TTS service
zh_Hans: 选择一个 TTS 服务
options:
- label:
en_US: OpenAI TTS
zh_Hans: OpenAI TTS
value: openai
api_key:
type: secret-input
required: true
label:
en_US: API Key
zh_Hans: API 密钥
placeholder:
en_US: Enter your TTS service API key
zh_Hans: 输入您的 TTS 服务 API 密钥

View File

@@ -0,0 +1,100 @@
import concurrent.futures
import io
import random
from typing import Any, Literal, Optional, Union
import openai
from pydub import AudioSegment
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.builtin_tool import BuiltinTool
class PodcastAudioGeneratorTool(BuiltinTool):
@staticmethod
def _generate_silence(duration: float):
# Generate silent WAV data using pydub
silence = AudioSegment.silent(duration=int(duration * 1000)) # pydub uses milliseconds
return silence
@staticmethod
def _generate_audio_segment(
client: openai.OpenAI,
line: str,
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"],
index: int,
) -> tuple[int, Union[AudioSegment, str], Optional[AudioSegment]]:
try:
response = client.audio.speech.create(model="tts-1", voice=voice, input=line.strip(), response_format="wav")
audio = AudioSegment.from_wav(io.BytesIO(response.content))
silence_duration = random.uniform(0.1, 1.5)
silence = PodcastAudioGeneratorTool._generate_silence(silence_duration)
return index, audio, silence
except Exception as e:
return index, f"Error generating audio: {str(e)}", None
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
# Extract parameters
script = tool_parameters.get("script", "")
host1_voice = tool_parameters.get("host1_voice")
host2_voice = tool_parameters.get("host2_voice")
# Split the script into lines
script_lines = [line for line in script.split("\n") if line.strip()]
# Ensure voices are provided
if not host1_voice or not host2_voice:
raise ToolParameterValidationError("Host voices are required")
# Get OpenAI API key from credentials
if not self.runtime or not self.runtime.credentials:
raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
api_key = self.runtime.credentials.get("api_key")
if not api_key:
raise ToolProviderCredentialValidationError("OpenAI API key is missing")
# Initialize OpenAI client
client = openai.OpenAI(api_key=api_key)
# Create a thread pool
max_workers = 5
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for i, line in enumerate(script_lines):
voice = host1_voice if i % 2 == 0 else host2_voice
future = executor.submit(self._generate_audio_segment, client, line, voice, i)
futures.append(future)
# Collect results
audio_segments: list[Any] = [None] * len(script_lines)
for future in concurrent.futures.as_completed(futures):
index, audio, silence = future.result()
if isinstance(audio, str): # Error occurred
return self.create_text_message(audio)
audio_segments[index] = (audio, silence)
# Combine audio segments in the correct order
combined_audio = AudioSegment.empty()
for i, (audio, silence) in enumerate(audio_segments):
if audio:
combined_audio += audio
if i < len(audio_segments) - 1 and silence:
combined_audio += silence
# Export the combined audio to a WAV file in memory
buffer = io.BytesIO()
combined_audio.export(buffer, format="wav")
wav_bytes = buffer.getvalue()
# Create a blob message with the combined audio
return [
self.create_text_message("Audio generated successfully"),
self.create_blob_message(
blob=wav_bytes,
meta={"mime_type": "audio/x-wav"},
save_as=self.VariableKey.AUDIO,
),
]

View File

@@ -0,0 +1,95 @@
identity:
name: podcast_audio_generator
author: Dify
label:
en_US: Podcast Audio Generator
zh_Hans: 播客音频生成器
description:
human:
en_US: Generate a podcast audio file from a script with two alternating voices using OpenAI's TTS service.
zh_Hans: 使用 OpenAI 的 TTS 服务,从包含两个交替声音的脚本生成播客音频文件。
llm: This tool converts a prepared podcast script into an audio file using OpenAI's Text-to-Speech service, with two specified voices for alternating hosts.
parameters:
- name: script
type: string
required: true
label:
en_US: Podcast Script
zh_Hans: 播客脚本
human_description:
en_US: A string containing alternating lines for two hosts, separated by newline characters.
zh_Hans: 包含两位主持人交替台词的字符串,每行用换行符分隔。
llm_description: A string representing the script, with alternating lines for two hosts separated by newline characters.
form: llm
- name: host1_voice
type: select
required: true
label:
en_US: Host 1 Voice
zh_Hans: 主持人1 音色
human_description:
en_US: The voice for the first host.
zh_Hans: 第一位主持人的音色。
llm_description: The voice identifier for the first host's voice.
options:
- label:
en_US: Alloy
zh_Hans: Alloy
value: alloy
- label:
en_US: Echo
zh_Hans: Echo
value: echo
- label:
en_US: Fable
zh_Hans: Fable
value: fable
- label:
en_US: Onyx
zh_Hans: Onyx
value: onyx
- label:
en_US: Nova
zh_Hans: Nova
value: nova
- label:
en_US: Shimmer
zh_Hans: Shimmer
value: shimmer
form: form
- name: host2_voice
type: select
required: true
label:
en_US: Host 2 Voice
zh_Hans: 主持人2 音色
human_description:
en_US: The voice for the second host.
zh_Hans: 第二位主持人的音色。
llm_description: The voice identifier for the second host's voice.
options:
- label:
en_US: Alloy
zh_Hans: Alloy
value: alloy
- label:
en_US: Echo
zh_Hans: Echo
value: echo
- label:
en_US: Fable
zh_Hans: Fable
value: fable
- label:
en_US: Onyx
zh_Hans: Onyx
value: onyx
- label:
en_US: Nova
zh_Hans: Nova
value: nova
- label:
en_US: Shimmer
zh_Hans: Shimmer
value: shimmer
form: form

View File

@@ -13,7 +13,6 @@ from core.tools.errors import (
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.tools.utils.yaml_utils import load_yaml_file
@@ -208,9 +207,7 @@ class BuiltinToolProviderController(ToolProviderController):
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
default_value = ToolParameterConverter.cast_parameter_by_type(
parameter_schema.default, parameter_schema.type
)
default_value = parameter_schema.type.cast_value(parameter_schema.default)
tool_parameters[parameter] = default_value
def validate_credentials(self, credentials: dict[str, Any]) -> None:

View File

@@ -11,7 +11,6 @@ from core.tools.entities.tool_entities import (
)
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
class ToolProviderController(BaseModel, ABC):
@@ -127,9 +126,7 @@ class ToolProviderController(BaseModel, ABC):
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(
parameter_schema.default, parameter_schema.type
)
tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default)
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""

View File

@@ -1,6 +1,6 @@
from typing import Optional
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.app_config.entities import VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
@@ -23,6 +23,8 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
}
@@ -36,8 +38,8 @@ class WorkflowToolProviderController(ToolProviderController):
if not app:
raise ValueError("app not found")
controller = WorkflowToolProviderController(
**{
controller = WorkflowToolProviderController.model_validate(
{
"identity": {
"author": db_provider.user.name if db_provider.user_id and db_provider.user else "",
"name": db_provider.label,
@@ -67,7 +69,7 @@ class WorkflowToolProviderController(ToolProviderController):
:param app: the app
:return: the tool
"""
workflow: Workflow = (
workflow = (
db.session.query(Workflow)
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
@@ -76,14 +78,14 @@ class WorkflowToolProviderController(ToolProviderController):
raise ValueError("workflow not found")
# fetch start node
graph: dict = workflow.graph_dict
features_dict: dict = workflow.features_dict
graph = workflow.graph_dict
features_dict = workflow.features_dict
features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW)
parameters = db_provider.parameter_configurations
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity:
def fetch_workflow_variable(variable_name: str):
return next(filter(lambda x: x.variable == variable_name, variables), None)
user = db_provider.user
@@ -114,7 +116,6 @@ class WorkflowToolProviderController(ToolProviderController):
llm_description=parameter.description,
required=variable.required,
options=options,
default=variable.default,
)
)
elif features.file_upload:
@@ -123,7 +124,7 @@ class WorkflowToolProviderController(ToolProviderController):
name=parameter.name,
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
type=ToolParameter.ToolParameterType.FILE,
type=ToolParameter.ToolParameterType.SYSTEM_FILES,
llm_description=parameter.description,
required=False,
form=parameter.form,