chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -11,15 +11,14 @@ class SageMakerProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
user_id="",
tool_parameters={
"sagemaker_endpoint" : "",
"sagemaker_endpoint": "",
"query": "misaka mikoto",
"candidate_texts" : "hello$$$hello world",
"topk" : 5,
"aws_region" : ""
"candidate_texts": "hello$$$hello world",
"topk": 5,
"aws_region": "",
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@@ -12,6 +12,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GuardrailParameters(BaseModel):
guardrail_id: str = Field(..., description="The identifier of the guardrail")
guardrail_version: str = Field(..., description="The version of the guardrail")
@@ -19,35 +20,35 @@ class GuardrailParameters(BaseModel):
text: str = Field(..., description="The text to apply the guardrail to")
aws_region: str = Field(..., description="AWS region for the Bedrock client")
class ApplyGuardrailTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke the ApplyGuardrail tool
"""
try:
# Validate and parse input parameters
params = GuardrailParameters(**tool_parameters)
# Initialize AWS client
bedrock_client = boto3.client('bedrock-runtime', region_name=params.aws_region)
bedrock_client = boto3.client("bedrock-runtime", region_name=params.aws_region)
# Apply guardrail
response = bedrock_client.apply_guardrail(
guardrailIdentifier=params.guardrail_id,
guardrailVersion=params.guardrail_version,
source=params.source,
content=[{"text": {"text": params.text}}]
content=[{"text": {"text": params.text}}],
)
logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}")
# Check for empty response
if not response:
return self.create_text_message(text="Received empty response from AWS Bedrock.")
# Process the result
action = response.get("action", "No action specified")
outputs = response.get("outputs", [])
@@ -58,9 +59,11 @@ class ApplyGuardrailTool(BuiltinTool):
formatted_assessments = []
for assessment in assessments:
for policy_type, policy_data in assessment.items():
if isinstance(policy_data, dict) and 'topics' in policy_data:
for topic in policy_data['topics']:
formatted_assessments.append(f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}")
if isinstance(policy_data, dict) and "topics" in policy_data:
for topic in policy_data["topics"]:
formatted_assessments.append(
f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}"
)
else:
formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}")
@@ -68,19 +71,19 @@ class ApplyGuardrailTool(BuiltinTool):
result += f"Output: {output}\n "
if formatted_assessments:
result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n "
# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"
# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"
return self.create_text_message(text=result)
except BotoCoreError as e:
error_message = f'AWS service error: {str(e)}'
error_message = f"AWS service error: {str(e)}"
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
except json.JSONDecodeError as e:
error_message = f'JSON parsing error: {str(e)}'
error_message = f"JSON parsing error: {str(e)}"
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
except Exception as e:
error_message = f'An unexpected error occurred: {str(e)}'
error_message = f"An unexpected error occurred: {str(e)}"
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
return self.create_text_message(text=error_message)

View File

@@ -11,78 +11,81 @@ class LambdaTranslateUtilsTool(BuiltinTool):
lambda_client: Any = None
def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name):
msg = {
"src_content":text_content,
"src_lang": src_lang,
"dest_lang":dest_lang,
msg = {
"src_content": text_content,
"src_lang": src_lang,
"dest_lang": dest_lang,
"dictionary_id": dictionary_name,
"request_type" : request_type,
"model_id" : model_id
"request_type": request_type,
"model_id": model_id,
}
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
InvocationType='RequestResponse',
Payload=json.dumps(msg))
response_body = invoke_response['Payload']
invoke_response = self.lambda_client.invoke(
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
)
response_body = invoke_response["Payload"]
response_str = response_body.read().decode("unicode_escape")
return response_str
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
line = 0
try:
if not self.lambda_client:
aws_region = tool_parameters.get('aws_region')
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.lambda_client = boto3.client("lambda", region_name=aws_region)
else:
self.lambda_client = boto3.client("lambda")
line = 1
text_content = tool_parameters.get('text_content', '')
text_content = tool_parameters.get("text_content", "")
if not text_content:
return self.create_text_message('Please input text_content')
return self.create_text_message("Please input text_content")
line = 2
src_lang = tool_parameters.get('src_lang', '')
src_lang = tool_parameters.get("src_lang", "")
if not src_lang:
return self.create_text_message('Please input src_lang')
return self.create_text_message("Please input src_lang")
line = 3
dest_lang = tool_parameters.get('dest_lang', '')
dest_lang = tool_parameters.get("dest_lang", "")
if not dest_lang:
return self.create_text_message('Please input dest_lang')
return self.create_text_message("Please input dest_lang")
line = 4
lambda_name = tool_parameters.get('lambda_name', '')
lambda_name = tool_parameters.get("lambda_name", "")
if not lambda_name:
return self.create_text_message('Please input lambda_name')
return self.create_text_message("Please input lambda_name")
line = 5
request_type = tool_parameters.get('request_type', '')
request_type = tool_parameters.get("request_type", "")
if not request_type:
return self.create_text_message('Please input request_type')
return self.create_text_message("Please input request_type")
line = 6
model_id = tool_parameters.get('model_id', '')
model_id = tool_parameters.get("model_id", "")
if not model_id:
return self.create_text_message('Please input model_id')
return self.create_text_message("Please input model_id")
line = 7
dictionary_name = tool_parameters.get('dictionary_name', '')
dictionary_name = tool_parameters.get("dictionary_name", "")
if not dictionary_name:
return self.create_text_message('Please input dictionary_name')
result = self._invoke_lambda(text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name)
return self.create_text_message("Please input dictionary_name")
result = self._invoke_lambda(
text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name
)
return self.create_text_message(text=result)
except Exception as e:
return self.create_text_message(f'Exception {str(e)}, line : {line}')
return self.create_text_message(f"Exception {str(e)}, line : {line}")

View File

@@ -18,54 +18,53 @@ class LambdaYamlToJsonTool(BuiltinTool):
lambda_client: Any = None
def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
msg = {
"body": yaml_content
}
msg = {"body": yaml_content}
logger.info(json.dumps(msg))
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
InvocationType='RequestResponse',
Payload=json.dumps(msg))
response_body = invoke_response['Payload']
invoke_response = self.lambda_client.invoke(
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
)
response_body = invoke_response["Payload"]
response_str = response_body.read().decode("utf-8")
resp_json = json.loads(response_str)
logger.info(resp_json)
if resp_json['statusCode'] != 200:
if resp_json["statusCode"] != 200:
raise Exception(f"Invalid status code: {response_str}")
return resp_json['body']
return resp_json["body"]
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
try:
if not self.lambda_client:
aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region
aws_region = tool_parameters.get("aws_region") # todo: move aws_region out, and update client region
if aws_region:
self.lambda_client = boto3.client("lambda", region_name=aws_region)
else:
self.lambda_client = boto3.client("lambda")
yaml_content = tool_parameters.get('yaml_content', '')
yaml_content = tool_parameters.get("yaml_content", "")
if not yaml_content:
return self.create_text_message('Please input yaml_content')
return self.create_text_message("Please input yaml_content")
lambda_name = tool_parameters.get('lambda_name', '')
lambda_name = tool_parameters.get("lambda_name", "")
if not lambda_name:
return self.create_text_message('Please input lambda_name')
logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}')
return self.create_text_message("Please input lambda_name")
logger.debug(f"{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}")
result = self._invoke_lambda(lambda_name, yaml_content)
logger.debug(result)
return self.create_text_message(result)
except Exception as e:
return self.create_text_message(f'Exception: {str(e)}')
return self.create_text_message(f"Exception: {str(e)}")
console_handler.flush()
console_handler.flush()

View File

@@ -9,37 +9,33 @@ from core.tools.tool.builtin_tool import BuiltinTool
class SageMakerReRankTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint:str = None
topk:int = None
sagemaker_endpoint: str = None
topk: int = None
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
inputs = [query_input]*len(docs)
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
inputs = [query_input] * len(docs)
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=rerank_endpoint,
Body=json.dumps(
{
"inputs": inputs,
"docs": docs
}
),
Body=json.dumps({"inputs": inputs, "docs": docs}),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_str = response_model["Body"].read().decode("utf8")
json_obj = json.loads(json_str)
scores = json_obj['scores']
scores = json_obj["scores"]
return scores if isinstance(scores, list) else [scores]
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
line = 0
try:
if not self.sagemaker_client:
aws_region = tool_parameters.get('aws_region')
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
else:
@@ -47,25 +43,25 @@ class SageMakerReRankTool(BuiltinTool):
line = 1
if not self.sagemaker_endpoint:
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
line = 2
if not self.topk:
self.topk = tool_parameters.get('topk', 5)
self.topk = tool_parameters.get("topk", 5)
line = 3
query = tool_parameters.get('query', '')
query = tool_parameters.get("query", "")
if not query:
return self.create_text_message('Please input query')
return self.create_text_message("Please input query")
line = 4
candidate_texts = tool_parameters.get('candidate_texts')
candidate_texts = tool_parameters.get("candidate_texts")
if not candidate_texts:
return self.create_text_message('Please input candidate_texts')
return self.create_text_message("Please input candidate_texts")
line = 5
candidate_docs = json.loads(candidate_texts)
docs = [ item.get('content') for item in candidate_docs ]
docs = [item.get("content") for item in candidate_docs]
line = 6
scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint)
@@ -75,10 +71,10 @@ class SageMakerReRankTool(BuiltinTool):
candidate_docs[idx]["score"] = scores[idx]
line = 8
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x["score"], reverse=True)
line = 9
return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ]
return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]
except Exception as e:
return self.create_text_message(f'Exception {str(e)}, line : {line}')
return self.create_text_message(f"Exception {str(e)}, line : {line}")

View File

@@ -14,82 +14,88 @@ class TTSModelType(Enum):
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
InstructVoice = "InstructVoice"
class SageMakerTTSTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint:str = None
s3_client : Any = None
comprehend_client : Any = None
sagemaker_endpoint: str = None
s3_client: Any = None
comprehend_client: Any = None
def _detect_lang_code(self, content:str, map_dict:dict=None):
map_dict = {
"zh" : "<|zh|>",
"en" : "<|en|>",
"ja" : "<|jp|>",
"zh-TW" : "<|yue|>",
"ko" : "<|ko|>"
}
def _detect_lang_code(self, content: str, map_dict: dict = None):
map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"}
response = self.comprehend_client.detect_dominant_language(Text=content)
language_code = response['Languages'][0]['LanguageCode']
return map_dict.get(language_code, '<|zh|>')
language_code = response["Languages"][0]["LanguageCode"]
return map_dict.get(language_code, "<|zh|>")
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
def _build_tts_payload(
self,
model_type: str,
content_text: str,
model_role: str,
prompt_text: str,
prompt_audio: str,
instruct_text: str,
):
if model_type == TTSModelType.PresetVoice.value and model_role:
return { "tts_text" : content_text, "role" : model_role }
return {"tts_text": content_text, "role": model_role}
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
lang_tag = self._detect_lang_code(content_text)
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag}
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text}
raise RuntimeError(f"Invalid params for {model_type}")
def _invoke_sagemaker(self, payload:dict, endpoint:str):
def _invoke_sagemaker(self, payload: dict, endpoint: str):
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=endpoint,
Body=json.dumps(payload),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_str = response_model["Body"].read().decode("utf8")
json_obj = json.loads(json_str)
return json_obj
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
try:
if not self.sagemaker_client:
aws_region = tool_parameters.get('aws_region')
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
self.s3_client = boto3.client("s3", region_name=aws_region)
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
self.comprehend_client = boto3.client("comprehend", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
self.s3_client = boto3.client("s3")
self.comprehend_client = boto3.client('comprehend')
self.comprehend_client = boto3.client("comprehend")
if not self.sagemaker_endpoint:
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
tts_text = tool_parameters.get('tts_text')
tts_infer_type = tool_parameters.get('tts_infer_type')
tts_text = tool_parameters.get("tts_text")
tts_infer_type = tool_parameters.get("tts_infer_type")
voice = tool_parameters.get('voice')
mock_voice_audio = tool_parameters.get('mock_voice_audio')
mock_voice_text = tool_parameters.get('mock_voice_text')
voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt')
payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt)
voice = tool_parameters.get("voice")
mock_voice_audio = tool_parameters.get("mock_voice_audio")
mock_voice_text = tool_parameters.get("mock_voice_text")
voice_instruct_prompt = tool_parameters.get("voice_instruct_prompt")
payload = self._build_tts_payload(
tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt
)
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
return self.create_text_message(text=result['s3_presign_url'])
return self.create_text_message(text=result["s3_presign_url"])
except Exception as e:
return self.create_text_message(f'Exception {str(e)}')
return self.create_text_message(f"Exception {str(e)}")