mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-15 13:56:53 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -11,15 +11,14 @@ class SageMakerProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"sagemaker_endpoint" : "",
|
||||
"sagemaker_endpoint": "",
|
||||
"query": "misaka mikoto",
|
||||
"candidate_texts" : "hello$$$hello world",
|
||||
"topk" : 5,
|
||||
"aws_region" : ""
|
||||
"candidate_texts": "hello$$$hello world",
|
||||
"topk": 5,
|
||||
"aws_region": "",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@@ -12,6 +12,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuardrailParameters(BaseModel):
|
||||
guardrail_id: str = Field(..., description="The identifier of the guardrail")
|
||||
guardrail_version: str = Field(..., description="The version of the guardrail")
|
||||
@@ -19,35 +20,35 @@ class GuardrailParameters(BaseModel):
|
||||
text: str = Field(..., description="The text to apply the guardrail to")
|
||||
aws_region: str = Field(..., description="AWS region for the Bedrock client")
|
||||
|
||||
|
||||
class ApplyGuardrailTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke the ApplyGuardrail tool
|
||||
"""
|
||||
try:
|
||||
# Validate and parse input parameters
|
||||
params = GuardrailParameters(**tool_parameters)
|
||||
|
||||
|
||||
# Initialize AWS client
|
||||
bedrock_client = boto3.client('bedrock-runtime', region_name=params.aws_region)
|
||||
bedrock_client = boto3.client("bedrock-runtime", region_name=params.aws_region)
|
||||
|
||||
# Apply guardrail
|
||||
response = bedrock_client.apply_guardrail(
|
||||
guardrailIdentifier=params.guardrail_id,
|
||||
guardrailVersion=params.guardrail_version,
|
||||
source=params.source,
|
||||
content=[{"text": {"text": params.text}}]
|
||||
content=[{"text": {"text": params.text}}],
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}")
|
||||
|
||||
# Check for empty response
|
||||
if not response:
|
||||
return self.create_text_message(text="Received empty response from AWS Bedrock.")
|
||||
|
||||
|
||||
# Process the result
|
||||
action = response.get("action", "No action specified")
|
||||
outputs = response.get("outputs", [])
|
||||
@@ -58,9 +59,11 @@ class ApplyGuardrailTool(BuiltinTool):
|
||||
formatted_assessments = []
|
||||
for assessment in assessments:
|
||||
for policy_type, policy_data in assessment.items():
|
||||
if isinstance(policy_data, dict) and 'topics' in policy_data:
|
||||
for topic in policy_data['topics']:
|
||||
formatted_assessments.append(f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}")
|
||||
if isinstance(policy_data, dict) and "topics" in policy_data:
|
||||
for topic in policy_data["topics"]:
|
||||
formatted_assessments.append(
|
||||
f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}"
|
||||
)
|
||||
else:
|
||||
formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}")
|
||||
|
||||
@@ -68,19 +71,19 @@ class ApplyGuardrailTool(BuiltinTool):
|
||||
result += f"Output: {output}\n "
|
||||
if formatted_assessments:
|
||||
result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n "
|
||||
# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"
|
||||
# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except BotoCoreError as e:
|
||||
error_message = f'AWS service error: {str(e)}'
|
||||
error_message = f"AWS service error: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
except json.JSONDecodeError as e:
|
||||
error_message = f'JSON parsing error: {str(e)}'
|
||||
error_message = f"JSON parsing error: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
except Exception as e:
|
||||
error_message = f'An unexpected error occurred: {str(e)}'
|
||||
error_message = f"An unexpected error occurred: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
return self.create_text_message(text=error_message)
|
||||
|
||||
@@ -11,78 +11,81 @@ class LambdaTranslateUtilsTool(BuiltinTool):
|
||||
lambda_client: Any = None
|
||||
|
||||
def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name):
|
||||
msg = {
|
||||
"src_content":text_content,
|
||||
"src_lang": src_lang,
|
||||
"dest_lang":dest_lang,
|
||||
msg = {
|
||||
"src_content": text_content,
|
||||
"src_lang": src_lang,
|
||||
"dest_lang": dest_lang,
|
||||
"dictionary_id": dictionary_name,
|
||||
"request_type" : request_type,
|
||||
"model_id" : model_id
|
||||
"request_type": request_type,
|
||||
"model_id": model_id,
|
||||
}
|
||||
|
||||
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
|
||||
InvocationType='RequestResponse',
|
||||
Payload=json.dumps(msg))
|
||||
response_body = invoke_response['Payload']
|
||||
invoke_response = self.lambda_client.invoke(
|
||||
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
|
||||
)
|
||||
response_body = invoke_response["Payload"]
|
||||
|
||||
response_str = response_body.read().decode("unicode_escape")
|
||||
|
||||
return response_str
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.lambda_client:
|
||||
aws_region = tool_parameters.get('aws_region')
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||
else:
|
||||
self.lambda_client = boto3.client("lambda")
|
||||
|
||||
line = 1
|
||||
text_content = tool_parameters.get('text_content', '')
|
||||
text_content = tool_parameters.get("text_content", "")
|
||||
if not text_content:
|
||||
return self.create_text_message('Please input text_content')
|
||||
|
||||
return self.create_text_message("Please input text_content")
|
||||
|
||||
line = 2
|
||||
src_lang = tool_parameters.get('src_lang', '')
|
||||
src_lang = tool_parameters.get("src_lang", "")
|
||||
if not src_lang:
|
||||
return self.create_text_message('Please input src_lang')
|
||||
|
||||
return self.create_text_message("Please input src_lang")
|
||||
|
||||
line = 3
|
||||
dest_lang = tool_parameters.get('dest_lang', '')
|
||||
dest_lang = tool_parameters.get("dest_lang", "")
|
||||
if not dest_lang:
|
||||
return self.create_text_message('Please input dest_lang')
|
||||
|
||||
return self.create_text_message("Please input dest_lang")
|
||||
|
||||
line = 4
|
||||
lambda_name = tool_parameters.get('lambda_name', '')
|
||||
lambda_name = tool_parameters.get("lambda_name", "")
|
||||
if not lambda_name:
|
||||
return self.create_text_message('Please input lambda_name')
|
||||
|
||||
return self.create_text_message("Please input lambda_name")
|
||||
|
||||
line = 5
|
||||
request_type = tool_parameters.get('request_type', '')
|
||||
request_type = tool_parameters.get("request_type", "")
|
||||
if not request_type:
|
||||
return self.create_text_message('Please input request_type')
|
||||
|
||||
return self.create_text_message("Please input request_type")
|
||||
|
||||
line = 6
|
||||
model_id = tool_parameters.get('model_id', '')
|
||||
model_id = tool_parameters.get("model_id", "")
|
||||
if not model_id:
|
||||
return self.create_text_message('Please input model_id')
|
||||
return self.create_text_message("Please input model_id")
|
||||
|
||||
line = 7
|
||||
dictionary_name = tool_parameters.get('dictionary_name', '')
|
||||
dictionary_name = tool_parameters.get("dictionary_name", "")
|
||||
if not dictionary_name:
|
||||
return self.create_text_message('Please input dictionary_name')
|
||||
|
||||
result = self._invoke_lambda(text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name)
|
||||
return self.create_text_message("Please input dictionary_name")
|
||||
|
||||
result = self._invoke_lambda(
|
||||
text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name
|
||||
)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception {str(e)}, line : {line}')
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
||||
@@ -18,54 +18,53 @@ class LambdaYamlToJsonTool(BuiltinTool):
|
||||
lambda_client: Any = None
|
||||
|
||||
def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
|
||||
msg = {
|
||||
"body": yaml_content
|
||||
}
|
||||
msg = {"body": yaml_content}
|
||||
logger.info(json.dumps(msg))
|
||||
|
||||
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
|
||||
InvocationType='RequestResponse',
|
||||
Payload=json.dumps(msg))
|
||||
response_body = invoke_response['Payload']
|
||||
invoke_response = self.lambda_client.invoke(
|
||||
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
|
||||
)
|
||||
response_body = invoke_response["Payload"]
|
||||
|
||||
response_str = response_body.read().decode("utf-8")
|
||||
resp_json = json.loads(response_str)
|
||||
|
||||
logger.info(resp_json)
|
||||
if resp_json['statusCode'] != 200:
|
||||
if resp_json["statusCode"] != 200:
|
||||
raise Exception(f"Invalid status code: {response_str}")
|
||||
|
||||
return resp_json['body']
|
||||
return resp_json["body"]
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.lambda_client:
|
||||
aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region
|
||||
aws_region = tool_parameters.get("aws_region") # todo: move aws_region out, and update client region
|
||||
if aws_region:
|
||||
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||
else:
|
||||
self.lambda_client = boto3.client("lambda")
|
||||
|
||||
yaml_content = tool_parameters.get('yaml_content', '')
|
||||
yaml_content = tool_parameters.get("yaml_content", "")
|
||||
if not yaml_content:
|
||||
return self.create_text_message('Please input yaml_content')
|
||||
return self.create_text_message("Please input yaml_content")
|
||||
|
||||
lambda_name = tool_parameters.get('lambda_name', '')
|
||||
lambda_name = tool_parameters.get("lambda_name", "")
|
||||
if not lambda_name:
|
||||
return self.create_text_message('Please input lambda_name')
|
||||
logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}')
|
||||
|
||||
return self.create_text_message("Please input lambda_name")
|
||||
logger.debug(f"{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}")
|
||||
|
||||
result = self._invoke_lambda(lambda_name, yaml_content)
|
||||
logger.debug(result)
|
||||
|
||||
|
||||
return self.create_text_message(result)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception: {str(e)}')
|
||||
return self.create_text_message(f"Exception: {str(e)}")
|
||||
|
||||
console_handler.flush()
|
||||
console_handler.flush()
|
||||
|
||||
@@ -9,37 +9,33 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
class SageMakerReRankTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint:str = None
|
||||
topk:int = None
|
||||
sagemaker_endpoint: str = None
|
||||
topk: int = None
|
||||
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
|
||||
inputs = [query_input]*len(docs)
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
|
||||
inputs = [query_input] * len(docs)
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=rerank_endpoint,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": inputs,
|
||||
"docs": docs
|
||||
}
|
||||
),
|
||||
Body=json.dumps({"inputs": inputs, "docs": docs}),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_str = response_model["Body"].read().decode("utf8")
|
||||
json_obj = json.loads(json_str)
|
||||
scores = json_obj['scores']
|
||||
scores = json_obj["scores"]
|
||||
return scores if isinstance(scores, list) else [scores]
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
aws_region = tool_parameters.get('aws_region')
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
@@ -47,25 +43,25 @@ class SageMakerReRankTool(BuiltinTool):
|
||||
|
||||
line = 1
|
||||
if not self.sagemaker_endpoint:
|
||||
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
|
||||
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
|
||||
|
||||
line = 2
|
||||
if not self.topk:
|
||||
self.topk = tool_parameters.get('topk', 5)
|
||||
self.topk = tool_parameters.get("topk", 5)
|
||||
|
||||
line = 3
|
||||
query = tool_parameters.get('query', '')
|
||||
query = tool_parameters.get("query", "")
|
||||
if not query:
|
||||
return self.create_text_message('Please input query')
|
||||
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
line = 4
|
||||
candidate_texts = tool_parameters.get('candidate_texts')
|
||||
candidate_texts = tool_parameters.get("candidate_texts")
|
||||
if not candidate_texts:
|
||||
return self.create_text_message('Please input candidate_texts')
|
||||
|
||||
return self.create_text_message("Please input candidate_texts")
|
||||
|
||||
line = 5
|
||||
candidate_docs = json.loads(candidate_texts)
|
||||
docs = [ item.get('content') for item in candidate_docs ]
|
||||
docs = [item.get("content") for item in candidate_docs]
|
||||
|
||||
line = 6
|
||||
scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint)
|
||||
@@ -75,10 +71,10 @@ class SageMakerReRankTool(BuiltinTool):
|
||||
candidate_docs[idx]["score"] = scores[idx]
|
||||
|
||||
line = 8
|
||||
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
|
||||
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x["score"], reverse=True)
|
||||
|
||||
line = 9
|
||||
return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ]
|
||||
|
||||
return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception {str(e)}, line : {line}')
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
||||
@@ -14,82 +14,88 @@ class TTSModelType(Enum):
|
||||
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
|
||||
InstructVoice = "InstructVoice"
|
||||
|
||||
|
||||
class SageMakerTTSTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint:str = None
|
||||
s3_client : Any = None
|
||||
comprehend_client : Any = None
|
||||
sagemaker_endpoint: str = None
|
||||
s3_client: Any = None
|
||||
comprehend_client: Any = None
|
||||
|
||||
def _detect_lang_code(self, content:str, map_dict:dict=None):
|
||||
map_dict = {
|
||||
"zh" : "<|zh|>",
|
||||
"en" : "<|en|>",
|
||||
"ja" : "<|jp|>",
|
||||
"zh-TW" : "<|yue|>",
|
||||
"ko" : "<|ko|>"
|
||||
}
|
||||
def _detect_lang_code(self, content: str, map_dict: dict = None):
|
||||
map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"}
|
||||
|
||||
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||
language_code = response['Languages'][0]['LanguageCode']
|
||||
return map_dict.get(language_code, '<|zh|>')
|
||||
language_code = response["Languages"][0]["LanguageCode"]
|
||||
return map_dict.get(language_code, "<|zh|>")
|
||||
|
||||
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
|
||||
def _build_tts_payload(
|
||||
self,
|
||||
model_type: str,
|
||||
content_text: str,
|
||||
model_role: str,
|
||||
prompt_text: str,
|
||||
prompt_audio: str,
|
||||
instruct_text: str,
|
||||
):
|
||||
if model_type == TTSModelType.PresetVoice.value and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role }
|
||||
return {"tts_text": content_text, "role": model_role}
|
||||
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
|
||||
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||
return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||
lang_tag = self._detect_lang_code(content_text)
|
||||
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
|
||||
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
|
||||
return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag}
|
||||
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||
return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text}
|
||||
|
||||
raise RuntimeError(f"Invalid params for {model_type}")
|
||||
|
||||
def _invoke_sagemaker(self, payload:dict, endpoint:str):
|
||||
def _invoke_sagemaker(self, payload: dict, endpoint: str):
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=endpoint,
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_str = response_model["Body"].read().decode("utf8")
|
||||
json_obj = json.loads(json_str)
|
||||
return json_obj
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
aws_region = tool_parameters.get('aws_region')
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
|
||||
self.comprehend_client = boto3.client("comprehend", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
self.s3_client = boto3.client("s3")
|
||||
self.comprehend_client = boto3.client('comprehend')
|
||||
self.comprehend_client = boto3.client("comprehend")
|
||||
|
||||
if not self.sagemaker_endpoint:
|
||||
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
|
||||
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
|
||||
|
||||
tts_text = tool_parameters.get('tts_text')
|
||||
tts_infer_type = tool_parameters.get('tts_infer_type')
|
||||
tts_text = tool_parameters.get("tts_text")
|
||||
tts_infer_type = tool_parameters.get("tts_infer_type")
|
||||
|
||||
voice = tool_parameters.get('voice')
|
||||
mock_voice_audio = tool_parameters.get('mock_voice_audio')
|
||||
mock_voice_text = tool_parameters.get('mock_voice_text')
|
||||
voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt')
|
||||
payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt)
|
||||
voice = tool_parameters.get("voice")
|
||||
mock_voice_audio = tool_parameters.get("mock_voice_audio")
|
||||
mock_voice_text = tool_parameters.get("mock_voice_text")
|
||||
voice_instruct_prompt = tool_parameters.get("voice_instruct_prompt")
|
||||
payload = self._build_tts_payload(
|
||||
tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt
|
||||
)
|
||||
|
||||
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
|
||||
|
||||
return self.create_text_message(text=result['s3_presign_url'])
|
||||
|
||||
return self.create_text_message(text=result["s3_presign_url"])
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception {str(e)}')
|
||||
return self.create_text_message(f"Exception {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user