chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -12,10 +12,10 @@ class NovitaAiToolBase:
if not loras_str:
return []
loras_ori_list = lora_str.strip().split(';')
loras_ori_list = lora_str.strip().split(";")
result_list = []
for lora_str in loras_ori_list:
lora_info = lora_str.strip().split(',')
lora_info = lora_str.strip().split(",")
lora = Txt2ImgV3LoRA(
model_name=lora_info[0].strip(),
strength=float(lora_info[1]),
@@ -28,43 +28,39 @@ class NovitaAiToolBase:
if not embeddings_str:
return []
embeddings_ori_list = embeddings_str.strip().split(';')
embeddings_ori_list = embeddings_str.strip().split(";")
result_list = []
for embedding_str in embeddings_ori_list:
embedding = Txt2ImgV3Embedding(
model_name=embedding_str.strip()
)
embedding = Txt2ImgV3Embedding(model_name=embedding_str.strip())
result_list.append(embedding)
return result_list
def _extract_hires_fix(self, hires_fix_str: str):
hires_fix_info = hires_fix_str.strip().split(',')
if 'upscaler' in hires_fix_info:
hires_fix_info = hires_fix_str.strip().split(",")
if "upscaler" in hires_fix_info:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2]),
upscaler=hires_fix_info[3].strip()
upscaler=hires_fix_info[3].strip(),
)
else:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2])
strength=float(hires_fix_info[2]),
)
return hires_fix
def _extract_refiner(self, switch_at: str):
refiner = Txt2ImgV3Refiner(
switch_at=float(switch_at)
)
refiner = Txt2ImgV3Refiner(switch_at=float(switch_at))
return refiner
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
"""
is hit nsfw
is hit nsfw
"""
if image.nsfw_detection_result is None:
return False

View File

@@ -8,23 +8,27 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class NovitaAIProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
result = NovitaAiTxt2ImgTool().fork_tool_runtime(
runtime={
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
'model_name': 'cinenautXLATRUE_cinenautV10_392434.safetensors',
'prompt': 'a futuristic city with flying cars',
'negative_prompt': '',
'width': 128,
'height': 128,
'image_num': 1,
'guidance_scale': 7.5,
'seed': -1,
'steps': 1,
},
result = (
NovitaAiTxt2ImgTool()
.fork_tool_runtime(
runtime={
"credentials": credentials,
}
)
.invoke(
user_id="",
tool_parameters={
"model_name": "cinenautXLATRUE_cinenautV10_392434.safetensors",
"prompt": "a futuristic city with flying cars",
"negative_prompt": "",
"width": 128,
"height": 128,
"image_num": 1,
"guidance_scale": 7.5,
"seed": -1,
"steps": 1,
},
)
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@@ -12,17 +12,18 @@ from core.tools.tool.builtin_tool import BuiltinTool
class NovitaAiCreateTileTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
raise ToolProviderCredentialValidationError("Novita AI API Key is required.")
api_key = self.runtime.credentials.get('api_key')
api_key = self.runtime.credentials.get("api_key")
client = NovitaClient(api_key=api_key)
param = self._process_parameters(tool_parameters)
@@ -30,21 +31,23 @@ class NovitaAiCreateTileTool(BuiltinTool):
results = []
results.append(
self.create_blob_message(blob=b64decode(client_result.image_file),
meta={'mime_type': f'image/{client_result.image_type}'},
save_as=self.VARIABLE_KEY.IMAGE.value)
self.create_blob_message(
blob=b64decode(client_result.image_file),
meta={"mime_type": f"image/{client_result.image_type}"},
save_as=self.VARIABLE_KEY.IMAGE.value,
)
)
return results
def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
process parameters
process parameters
"""
res_parameters = deepcopy(parameters)
# delete none and empty
keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == '']
keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""]
for k in keys_to_delete:
del res_parameters[k]

View File

@@ -12,127 +12,137 @@ from core.tools.tool.builtin_tool import BuiltinTool
class NovitaAiModelQueryTool(BuiltinTool):
_model_query_endpoint = 'https://api.novita.ai/v3/model'
_model_query_endpoint = "https://api.novita.ai/v3/model"
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
raise ToolProviderCredentialValidationError("Novita AI API Key is required.")
api_key = self.runtime.credentials.get('api_key')
headers = {
'Content-Type': 'application/json',
'Authorization': "Bearer " + api_key
}
api_key = self.runtime.credentials.get("api_key")
headers = {"Content-Type": "application/json", "Authorization": "Bearer " + api_key}
params = self._process_parameters(tool_parameters)
result_type = params.get('result_type')
del params['result_type']
result_type = params.get("result_type")
del params["result_type"]
models_data = self._query_models(
models_data=[],
headers=headers,
params=params,
recursive=False if result_type == 'first sd_name' or result_type == 'first name sd_name pair' else True
recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True,
)
result_str = ''
if result_type == 'first sd_name':
result_str = models_data[0]['sd_name_in_api'] if len(models_data) > 0 else ''
elif result_type == 'first name sd_name pair':
result_str = json.dumps({'name': models_data[0]['name'], 'sd_name': models_data[0]['sd_name_in_api']}) if len(models_data) > 0 else ''
elif result_type == 'sd_name array':
sd_name_array = [model['sd_name_in_api'] for model in models_data] if len(models_data) > 0 else []
result_str = ""
if result_type == "first sd_name":
result_str = models_data[0]["sd_name_in_api"] if len(models_data) > 0 else ""
elif result_type == "first name sd_name pair":
result_str = (
json.dumps({"name": models_data[0]["name"], "sd_name": models_data[0]["sd_name_in_api"]})
if len(models_data) > 0
else ""
)
elif result_type == "sd_name array":
sd_name_array = [model["sd_name_in_api"] for model in models_data] if len(models_data) > 0 else []
result_str = json.dumps(sd_name_array)
elif result_type == 'name array':
name_array = [model['name'] for model in models_data] if len(models_data) > 0 else []
elif result_type == "name array":
name_array = [model["name"] for model in models_data] if len(models_data) > 0 else []
result_str = json.dumps(name_array)
elif result_type == 'name sd_name pair array':
name_sd_name_pair_array = [{'name': model['name'], 'sd_name': model['sd_name_in_api']}
for model in models_data] if len(models_data) > 0 else []
elif result_type == "name sd_name pair array":
name_sd_name_pair_array = (
[{"name": model["name"], "sd_name": model["sd_name_in_api"]} for model in models_data]
if len(models_data) > 0
else []
)
result_str = json.dumps(name_sd_name_pair_array)
elif result_type == 'whole info array':
elif result_type == "whole info array":
result_str = json.dumps(models_data)
else:
raise NotImplementedError
return self.create_text_message(result_str)
def _query_models(self, models_data: list, headers: dict[str, Any],
params: dict[str, Any], pagination_cursor: str = '', recursive: bool = True) -> list:
def _query_models(
self,
models_data: list,
headers: dict[str, Any],
params: dict[str, Any],
pagination_cursor: str = "",
recursive: bool = True,
) -> list:
"""
query models
query models
"""
inside_params = deepcopy(params)
if pagination_cursor != '':
inside_params['pagination.cursor'] = pagination_cursor
if pagination_cursor != "":
inside_params["pagination.cursor"] = pagination_cursor
response = ssrf_proxy.get(
url=str(URL(self._model_query_endpoint)),
headers=headers,
params=params,
timeout=(10, 60)
url=str(URL(self._model_query_endpoint)), headers=headers, params=params, timeout=(10, 60)
)
res_data = response.json()
models_data.extend(res_data['models'])
models_data.extend(res_data["models"])
res_data_len = len(res_data['models'])
if res_data_len == 0 or res_data_len < int(params['pagination.limit']) or recursive is False:
res_data_len = len(res_data["models"])
if res_data_len == 0 or res_data_len < int(params["pagination.limit"]) or recursive is False:
# deduplicate
df = DataFrame.from_dict(models_data)
df_unique = df.drop_duplicates(subset=['id'])
models_data = df_unique.to_dict('records')
df_unique = df.drop_duplicates(subset=["id"])
models_data = df_unique.to_dict("records")
return models_data
return self._query_models(
models_data=models_data,
headers=headers,
params=inside_params,
pagination_cursor=res_data['pagination']['next_cursor']
pagination_cursor=res_data["pagination"]["next_cursor"],
)
def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
process parameters
process parameters
"""
process_parameters = deepcopy(parameters)
res_parameters = {}
# delete none or empty
keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == '']
keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ""]
for k in keys_to_delete:
del process_parameters[k]
if 'query' in process_parameters and process_parameters.get('query') != 'unspecified':
res_parameters['filter.query'] = process_parameters['query']
if "query" in process_parameters and process_parameters.get("query") != "unspecified":
res_parameters["filter.query"] = process_parameters["query"]
if 'visibility' in process_parameters and process_parameters.get('visibility') != 'unspecified':
res_parameters['filter.visibility'] = process_parameters['visibility']
if "visibility" in process_parameters and process_parameters.get("visibility") != "unspecified":
res_parameters["filter.visibility"] = process_parameters["visibility"]
if 'source' in process_parameters and process_parameters.get('source') != 'unspecified':
res_parameters['filter.source'] = process_parameters['source']
if "source" in process_parameters and process_parameters.get("source") != "unspecified":
res_parameters["filter.source"] = process_parameters["source"]
if 'type' in process_parameters and process_parameters.get('type') != 'unspecified':
res_parameters['filter.types'] = process_parameters['type']
if "type" in process_parameters and process_parameters.get("type") != "unspecified":
res_parameters["filter.types"] = process_parameters["type"]
if 'is_sdxl' in process_parameters:
if process_parameters['is_sdxl'] == 'true':
res_parameters['filter.is_sdxl'] = True
elif process_parameters['is_sdxl'] == 'false':
res_parameters['filter.is_sdxl'] = False
if "is_sdxl" in process_parameters:
if process_parameters["is_sdxl"] == "true":
res_parameters["filter.is_sdxl"] = True
elif process_parameters["is_sdxl"] == "false":
res_parameters["filter.is_sdxl"] = False
res_parameters['result_type'] = process_parameters.get('result_type', 'first sd_name')
res_parameters["result_type"] = process_parameters.get("result_type", "first sd_name")
res_parameters['pagination.limit'] = 1 \
if res_parameters.get('result_type') == 'first sd_name' \
or res_parameters.get('result_type') == 'first name sd_name pair'\
res_parameters["pagination.limit"] = (
1
if res_parameters.get("result_type") == "first sd_name"
or res_parameters.get("result_type") == "first name sd_name pair"
else 100
)
return res_parameters

View File

@@ -13,17 +13,18 @@ from core.tools.tool.builtin_tool import BuiltinTool
class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
raise ToolProviderCredentialValidationError("Novita AI API Key is required.")
api_key = self.runtime.credentials.get('api_key')
api_key = self.runtime.credentials.get("api_key")
client = NovitaClient(api_key=api_key)
param = self._process_parameters(tool_parameters)
@@ -32,56 +33,58 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
results = []
for image_encoded, image in zip(client_result.images_encoded, client_result.images):
if self._is_hit_nsfw_detection(image, 0.8):
results = self.create_text_message(text='NSFW detected!')
results = self.create_text_message(text="NSFW detected!")
break
results.append(
self.create_blob_message(blob=b64decode(image_encoded),
meta={'mime_type': f'image/{image.image_type}'},
save_as=self.VARIABLE_KEY.IMAGE.value)
self.create_blob_message(
blob=b64decode(image_encoded),
meta={"mime_type": f"image/{image.image_type}"},
save_as=self.VARIABLE_KEY.IMAGE.value,
)
)
return results
def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
process parameters
process parameters
"""
res_parameters = deepcopy(parameters)
# delete none and empty
keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == '']
keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""]
for k in keys_to_delete:
del res_parameters[k]
if 'clip_skip' in res_parameters and res_parameters.get('clip_skip') == 0:
del res_parameters['clip_skip']
if "clip_skip" in res_parameters and res_parameters.get("clip_skip") == 0:
del res_parameters["clip_skip"]
if 'refiner_switch_at' in res_parameters and res_parameters.get('refiner_switch_at') == 0:
del res_parameters['refiner_switch_at']
if "refiner_switch_at" in res_parameters and res_parameters.get("refiner_switch_at") == 0:
del res_parameters["refiner_switch_at"]
if 'enabled_enterprise_plan' in res_parameters:
res_parameters['enterprise_plan'] = {'enabled': res_parameters['enabled_enterprise_plan']}
del res_parameters['enabled_enterprise_plan']
if "enabled_enterprise_plan" in res_parameters:
res_parameters["enterprise_plan"] = {"enabled": res_parameters["enabled_enterprise_plan"]}
del res_parameters["enabled_enterprise_plan"]
if 'nsfw_detection_level' in res_parameters:
res_parameters['nsfw_detection_level'] = int(res_parameters['nsfw_detection_level'])
if "nsfw_detection_level" in res_parameters:
res_parameters["nsfw_detection_level"] = int(res_parameters["nsfw_detection_level"])
# process loras
if 'loras' in res_parameters:
res_parameters['loras'] = self._extract_loras(res_parameters.get('loras'))
if "loras" in res_parameters:
res_parameters["loras"] = self._extract_loras(res_parameters.get("loras"))
# process embeddings
if 'embeddings' in res_parameters:
res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings'))
if "embeddings" in res_parameters:
res_parameters["embeddings"] = self._extract_embeddings(res_parameters.get("embeddings"))
# process hires_fix
if 'hires_fix' in res_parameters:
res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix'))
if "hires_fix" in res_parameters:
res_parameters["hires_fix"] = self._extract_hires_fix(res_parameters.get("hires_fix"))
# process refiner
if 'refiner_switch_at' in res_parameters:
res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at'))
del res_parameters['refiner_switch_at']
if "refiner_switch_at" in res_parameters:
res_parameters["refiner"] = self._extract_refiner(res_parameters.get("refiner_switch_at"))
del res_parameters["refiner_switch_at"]
return res_parameters