mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 19:06:51 +08:00
tts models support (#2033)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com>
This commit is contained in:
@@ -555,6 +555,12 @@ class ApplicationManager:
|
||||
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
|
||||
properties['speech_to_text'] = True
|
||||
|
||||
# text to speech
|
||||
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict:
|
||||
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
|
||||
properties['text_to_speech'] = True
|
||||
|
||||
# sensitive word avoidance
|
||||
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
|
||||
if sensitive_word_avoidance_dict:
|
||||
|
||||
@@ -219,6 +219,7 @@ class AppOrchestrationConfigEntity(BaseModel):
|
||||
show_retrieve_source: bool = False
|
||||
more_like_this: bool = False
|
||||
speech_to_text: bool = False
|
||||
text_to_speech: bool = False
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
@@ -283,7 +284,6 @@ class ApplicationGenerateEntity(BaseModel):
|
||||
query: Optional[str] = None
|
||||
files: list[FileObj] = []
|
||||
user_id: str
|
||||
|
||||
# extras
|
||||
stream: bool
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
@@ -12,6 +12,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
@@ -144,7 +145,7 @@ class ModelInstance:
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None, **params) \
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
@@ -161,8 +162,29 @@ class ModelInstance:
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, streaming: bool, user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:param streaming: output is streaming
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception(f"Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
**params
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class ModelType(Enum):
|
||||
RERANK = "rerank"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
# TTS = "tts"
|
||||
TTS = "tts"
|
||||
# TEXT2IMG = "text2img"
|
||||
|
||||
@classmethod
|
||||
@@ -33,6 +33,8 @@ class ModelType(Enum):
|
||||
return cls.RERANK
|
||||
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
|
||||
return cls.TTS
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
@@ -52,6 +54,8 @@ class ModelType(Enum):
|
||||
return 'reranking'
|
||||
elif self == self.SPEECH2TEXT:
|
||||
return 'speech2text'
|
||||
elif self == self.TTS:
|
||||
return 'tts'
|
||||
elif self == self.MODERATION:
|
||||
return 'moderation'
|
||||
else:
|
||||
@@ -120,6 +124,10 @@ class ModelPropertyKey(Enum):
|
||||
FILE_UPLOAD_LIMIT = "file_upload_limit"
|
||||
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
|
||||
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
|
||||
DEFAULT_VOICE = "default_voice"
|
||||
WORD_LIMIT = "word_limit"
|
||||
AUDOI_TYPE = "audio_type"
|
||||
MAX_WORKERS = "max_workers"
|
||||
|
||||
|
||||
class ProviderModel(BaseModel):
|
||||
|
||||
42
api/core/model_runtime/model_providers/__base/tts_model.py
Normal file
42
api/core/model_runtime/model_providers/__base/tts_model.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class TTSModel(AIModel):
|
||||
"""
|
||||
Model class for ttstext model.
|
||||
"""
|
||||
model_type: ModelType = ModelType.TTS
|
||||
|
||||
def invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming, content_text=content_text)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -20,6 +20,7 @@ supported_model_types:
|
||||
- text-embedding
|
||||
- speech2text
|
||||
- moderation
|
||||
- tts
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
- customizable-model
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
model: tts-1-hd
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'alloy'
|
||||
word_limit: 120
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
@@ -0,0 +1,7 @@
|
||||
model: tts-1
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'alloy'
|
||||
word_limit: 120
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
235
api/core/model_runtime/model_providers/openai/tts/tts.py
Normal file
235
api/core/model_runtime/model_providers/openai/tts/tts.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import uuid
|
||||
import hashlib
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from functools import reduce
|
||||
from pydub import AudioSegment
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
||||
|
||||
from typing_extensions import Literal
|
||||
from flask import Response, stream_with_context
|
||||
from openai import OpenAI
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
"""
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
self._is_ffmpeg_installed()
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
user=user)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, user=user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
validate credentials text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
self._tts_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text='Hello world!',
|
||||
user=user
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_tts_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
audio_bytes_list = list()
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(self._process_sentence, sentence, model, credentials) for sentence
|
||||
in sentences]
|
||||
for future in futures:
|
||||
try:
|
||||
audio_bytes_list.append(future.result())
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
|
||||
audio_bytes_list if audio_bytes]
|
||||
combined_segment = reduce(lambda x, y: x + y, audio_segments)
|
||||
buffer: BytesIO = BytesIO()
|
||||
combined_segment.export(buffer, format=audio_type)
|
||||
buffer.seek(0)
|
||||
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
voice_name = self._get_model_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'storage/generate_files/{audio_type}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = client.audio.speech.create(model=model, voice=voice_name, input=sentence.strip())
|
||||
response.stream_to_file(file_path)
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
def _get_model_voice(self, model: str, credentials: dict) -> Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]:
|
||||
"""
|
||||
Get voice for given tts model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: voice
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.DEFAULT_VOICE]
|
||||
|
||||
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
|
||||
"""
|
||||
Get audio type for given tts model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: voice
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.AUDOI_TYPE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.AUDOI_TYPE]
|
||||
|
||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get audio type for given tts model
|
||||
:return: audio type
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||
|
||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get audio max workers for given tts model
|
||||
:return: audio type
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
|
||||
if delimiters is None:
|
||||
delimiters = set('。!?;\n')
|
||||
|
||||
buf = []
|
||||
word_count = 0
|
||||
for char in text:
|
||||
buf.append(char)
|
||||
if char in delimiters:
|
||||
if word_count >= limit:
|
||||
yield ''.join(buf)
|
||||
buf = []
|
||||
word_count = 0
|
||||
else:
|
||||
word_count += 1
|
||||
else:
|
||||
word_count += 1
|
||||
|
||||
if buf:
|
||||
yield ''.join(buf)
|
||||
|
||||
@staticmethod
|
||||
def _get_file_name(file_content: str) -> str:
|
||||
hash_object = hashlib.sha256(file_content.encode())
|
||||
hex_digest = hash_object.hexdigest()
|
||||
|
||||
namespace_uuid = uuid.UUID('a5da6ef9-b303-596f-8e88-bf8fa40f4b31')
|
||||
unique_uuid = uuid.uuid5(namespace_uuid, hex_digest)
|
||||
return str(unique_uuid)
|
||||
|
||||
def _process_sentence(self, sentence: str, model: str, credentials: dict):
|
||||
"""
|
||||
_tts_invoke openai text2speech model api
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param sentence: text content to be translated
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
voice_name = self._get_model_voice(model, credentials)
|
||||
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
response = client.audio.speech.create(model=model, voice=voice_name, input=sentence.strip())
|
||||
if isinstance(response.read(), bytes):
|
||||
return response.read()
|
||||
|
||||
@staticmethod
|
||||
def _is_ffmpeg_installed():
|
||||
try:
|
||||
output = subprocess.check_output("ffmpeg -version", shell=True)
|
||||
if "ffmpeg version" in output.decode("utf-8"):
|
||||
return True
|
||||
else:
|
||||
raise InvokeBadRequestError("ffmpeg is not installed")
|
||||
except Exception:
|
||||
raise InvokeBadRequestError("ffmpeg is not installed")
|
||||
Reference in New Issue
Block a user