tts models support (#2033)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com>
This commit is contained in:
Charlie.Wei
2024-01-24 01:05:37 +08:00
committed by GitHub
parent 27828f44b9
commit 6355e61eb8
86 changed files with 1645 additions and 133 deletions

View File

@@ -555,6 +555,12 @@ class ApplicationManager:
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
properties['speech_to_text'] = True
# text to speech
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
if text_to_speech_dict:
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
properties['text_to_speech'] = True
# sensitive word avoidance
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
if sensitive_word_avoidance_dict:

View File

@@ -219,6 +219,7 @@ class AppOrchestrationConfigEntity(BaseModel):
show_retrieve_source: bool = False
more_like_this: bool = False
speech_to_text: bool = False
text_to_speech: bool = False
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
@@ -283,7 +284,6 @@ class ApplicationGenerateEntity(BaseModel):
query: Optional[str] = None
files: list[FileObj] = []
user_id: str
# extras
stream: bool
invoke_from: InvokeFrom

View File

@@ -12,6 +12,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.provider_manager import ProviderManager
@@ -144,7 +145,7 @@ class ModelInstance:
user=user
)
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None, **params) \
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke large language model
@@ -161,8 +162,29 @@ class ModelInstance:
model=self.model,
credentials=self.credentials,
file=file,
user=user
)
def invoke_tts(self, content_text: str, streaming: bool, user: Optional[str] = None) \
-> str:
"""
Invoke large language model
:param content_text: text content to be translated
:param user: unique user id
:param streaming: output is streaming
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception(f"Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
content_text=content_text,
user=user,
**params
streaming=streaming
)

View File

@@ -15,7 +15,7 @@ class ModelType(Enum):
RERANK = "rerank"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
# TTS = "tts"
TTS = "tts"
# TEXT2IMG = "text2img"
@classmethod
@@ -33,6 +33,8 @@ class ModelType(Enum):
return cls.RERANK
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
return cls.SPEECH2TEXT
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
return cls.TTS
elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION
else:
@@ -52,6 +54,8 @@ class ModelType(Enum):
return 'reranking'
elif self == self.SPEECH2TEXT:
return 'speech2text'
elif self == self.TTS:
return 'tts'
elif self == self.MODERATION:
return 'moderation'
else:
@@ -120,6 +124,10 @@ class ModelPropertyKey(Enum):
FILE_UPLOAD_LIMIT = "file_upload_limit"
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
DEFAULT_VOICE = "default_voice"
WORD_LIMIT = "word_limit"
AUDOI_TYPE = "audio_type"
MAX_WORKERS = "max_workers"
class ProviderModel(BaseModel):

View File

@@ -0,0 +1,42 @@
from abc import abstractmethod
from typing import Optional
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
class TTSModel(AIModel):
"""
Model class for ttstext model.
"""
model_type: ModelType = ModelType.TTS
def invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param streaming: output is streaming
:param user: unique user id
:return: translated audio file
"""
try:
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming, content_text=content_text)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param streaming: output is streaming
:param user: unique user id
:return: translated audio file
"""
raise NotImplementedError

View File

@@ -20,6 +20,7 @@ supported_model_types:
- text-embedding
- speech2text
- moderation
- tts
configurate_methods:
- predefined-model
- customizable-model

View File

@@ -0,0 +1,7 @@
model: tts-1-hd
model_type: tts
model_properties:
default_voice: 'alloy'
word_limit: 120
audio_type: 'mp3'
max_workers: 5

View File

@@ -0,0 +1,7 @@
model: tts-1
model_type: tts
model_properties:
default_voice: 'alloy'
word_limit: 120
audio_type: 'mp3'
max_workers: 5

View File

@@ -0,0 +1,235 @@
import uuid
import hashlib
import subprocess
from io import BytesIO
from typing import Optional
from functools import reduce
from pydub import AudioSegment
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.openai._common import _CommonOpenAI
from typing_extensions import Literal
from flask import Response, stream_with_context
from openai import OpenAI
import concurrent.futures
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
"""
Model class for OpenAI Speech to text model.
"""
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None) -> any:
"""
_invoke text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param streaming: output is streaming
:param user: unique user id
:return: text translated to audio file
"""
self._is_ffmpeg_installed()
audio_type = self._get_model_audio_type(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
user=user)),
status=200, mimetype=f'audio/{audio_type}')
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, user=user)
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
"""
validate credentials text2speech model
:param model: model name
:param credentials: model credentials
:param user: unique user id
:return: text translated to audio file
"""
try:
self._tts_invoke(
model=model,
credentials=credentials,
content_text='Hello world!',
user=user
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> any:
"""
_tts_invoke text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param user: unique user id
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
max_workers = self._get_model_workers_limit(model, credentials)
try:
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
audio_bytes_list = list()
# Create a thread pool and map the function to the list of sentences
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(self._process_sentence, sentence, model, credentials) for sentence
in sentences]
for future in futures:
try:
audio_bytes_list.append(future.result())
except Exception as ex:
raise InvokeBadRequestError(str(ex))
audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
audio_bytes_list if audio_bytes]
combined_segment = reduce(lambda x, y: x + y, audio_segments)
buffer: BytesIO = BytesIO()
combined_segment.export(buffer, format=audio_type)
buffer.seek(0)
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
except Exception as ex:
raise InvokeBadRequestError(str(ex))
# Todo: To improve the streaming function
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param user: unique user id
:return: text translated to audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
voice_name = self._get_model_voice(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
audio_type = self._get_model_audio_type(model, credentials)
tts_file_id = self._get_file_name(content_text)
file_path = f'storage/generate_files/{audio_type}/{tts_file_id}.{audio_type}'
try:
client = OpenAI(**credentials_kwargs)
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
for sentence in sentences:
response = client.audio.speech.create(model=model, voice=voice_name, input=sentence.strip())
response.stream_to_file(file_path)
except Exception as ex:
raise InvokeBadRequestError(str(ex))
def _get_model_voice(self, model: str, credentials: dict) -> Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]:
"""
Get voice for given tts model
:param model: model name
:param credentials: model credentials
:return: voice
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.DEFAULT_VOICE]
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
"""
Get audio type for given tts model
:param model: model name
:param credentials: model credentials
:return: voice
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.AUDOI_TYPE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.AUDOI_TYPE]
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
"""
Get audio type for given tts model
:return: audio type
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
"""
Get audio max workers for given tts model
:return: audio type
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
@staticmethod
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
if delimiters is None:
delimiters = set('。!?;\n')
buf = []
word_count = 0
for char in text:
buf.append(char)
if char in delimiters:
if word_count >= limit:
yield ''.join(buf)
buf = []
word_count = 0
else:
word_count += 1
else:
word_count += 1
if buf:
yield ''.join(buf)
@staticmethod
def _get_file_name(file_content: str) -> str:
hash_object = hashlib.sha256(file_content.encode())
hex_digest = hash_object.hexdigest()
namespace_uuid = uuid.UUID('a5da6ef9-b303-596f-8e88-bf8fa40f4b31')
unique_uuid = uuid.uuid5(namespace_uuid, hex_digest)
return str(unique_uuid)
def _process_sentence(self, sentence: str, model: str, credentials: dict):
"""
_tts_invoke openai text2speech model api
:param model: model name
:param credentials: model credentials
:param sentence: text content to be translated
:return: text translated to audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
voice_name = self._get_model_voice(model, credentials)
client = OpenAI(**credentials_kwargs)
response = client.audio.speech.create(model=model, voice=voice_name, input=sentence.strip())
if isinstance(response.read(), bytes):
return response.read()
@staticmethod
def _is_ffmpeg_installed():
try:
output = subprocess.check_output("ffmpeg -version", shell=True)
if "ffmpeg version" in output.decode("utf-8"):
return True
else:
raise InvokeBadRequestError("ffmpeg is not installed")
except Exception:
raise InvokeBadRequestError("ffmpeg is not installed")