mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-25 02:33:00 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -12,10 +12,10 @@ class NovitaAiToolBase:
|
||||
if not loras_str:
|
||||
return []
|
||||
|
||||
loras_ori_list = lora_str.strip().split(';')
|
||||
loras_ori_list = lora_str.strip().split(";")
|
||||
result_list = []
|
||||
for lora_str in loras_ori_list:
|
||||
lora_info = lora_str.strip().split(',')
|
||||
lora_info = lora_str.strip().split(",")
|
||||
lora = Txt2ImgV3LoRA(
|
||||
model_name=lora_info[0].strip(),
|
||||
strength=float(lora_info[1]),
|
||||
@@ -28,43 +28,39 @@ class NovitaAiToolBase:
|
||||
if not embeddings_str:
|
||||
return []
|
||||
|
||||
embeddings_ori_list = embeddings_str.strip().split(';')
|
||||
embeddings_ori_list = embeddings_str.strip().split(";")
|
||||
result_list = []
|
||||
for embedding_str in embeddings_ori_list:
|
||||
embedding = Txt2ImgV3Embedding(
|
||||
model_name=embedding_str.strip()
|
||||
)
|
||||
embedding = Txt2ImgV3Embedding(model_name=embedding_str.strip())
|
||||
result_list.append(embedding)
|
||||
|
||||
return result_list
|
||||
|
||||
def _extract_hires_fix(self, hires_fix_str: str):
|
||||
hires_fix_info = hires_fix_str.strip().split(',')
|
||||
if 'upscaler' in hires_fix_info:
|
||||
hires_fix_info = hires_fix_str.strip().split(",")
|
||||
if "upscaler" in hires_fix_info:
|
||||
hires_fix = Txt2ImgV3HiresFix(
|
||||
target_width=int(hires_fix_info[0]),
|
||||
target_height=int(hires_fix_info[1]),
|
||||
strength=float(hires_fix_info[2]),
|
||||
upscaler=hires_fix_info[3].strip()
|
||||
upscaler=hires_fix_info[3].strip(),
|
||||
)
|
||||
else:
|
||||
hires_fix = Txt2ImgV3HiresFix(
|
||||
target_width=int(hires_fix_info[0]),
|
||||
target_height=int(hires_fix_info[1]),
|
||||
strength=float(hires_fix_info[2])
|
||||
strength=float(hires_fix_info[2]),
|
||||
)
|
||||
|
||||
return hires_fix
|
||||
|
||||
def _extract_refiner(self, switch_at: str):
|
||||
refiner = Txt2ImgV3Refiner(
|
||||
switch_at=float(switch_at)
|
||||
)
|
||||
refiner = Txt2ImgV3Refiner(switch_at=float(switch_at))
|
||||
return refiner
|
||||
|
||||
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
|
||||
"""
|
||||
is hit nsfw
|
||||
is hit nsfw
|
||||
"""
|
||||
if image.nsfw_detection_result is None:
|
||||
return False
|
||||
|
||||
@@ -8,23 +8,27 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
||||
class NovitaAIProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
result = NovitaAiTxt2ImgTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
'model_name': 'cinenautXLATRUE_cinenautV10_392434.safetensors',
|
||||
'prompt': 'a futuristic city with flying cars',
|
||||
'negative_prompt': '',
|
||||
'width': 128,
|
||||
'height': 128,
|
||||
'image_num': 1,
|
||||
'guidance_scale': 7.5,
|
||||
'seed': -1,
|
||||
'steps': 1,
|
||||
},
|
||||
result = (
|
||||
NovitaAiTxt2ImgTool()
|
||||
.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
)
|
||||
.invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"model_name": "cinenautXLATRUE_cinenautV10_392434.safetensors",
|
||||
"prompt": "a futuristic city with flying cars",
|
||||
"negative_prompt": "",
|
||||
"width": 128,
|
||||
"height": 128,
|
||||
"image_num": 1,
|
||||
"guidance_scale": 7.5,
|
||||
"seed": -1,
|
||||
"steps": 1,
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@@ -12,17 +12,18 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class NovitaAiCreateTileTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
|
||||
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
|
||||
raise ToolProviderCredentialValidationError("Novita AI API Key is required.")
|
||||
|
||||
api_key = self.runtime.credentials.get('api_key')
|
||||
api_key = self.runtime.credentials.get("api_key")
|
||||
|
||||
client = NovitaClient(api_key=api_key)
|
||||
param = self._process_parameters(tool_parameters)
|
||||
@@ -30,21 +31,23 @@ class NovitaAiCreateTileTool(BuiltinTool):
|
||||
|
||||
results = []
|
||||
results.append(
|
||||
self.create_blob_message(blob=b64decode(client_result.image_file),
|
||||
meta={'mime_type': f'image/{client_result.image_type}'},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value)
|
||||
self.create_blob_message(
|
||||
blob=b64decode(client_result.image_file),
|
||||
meta={"mime_type": f"image/{client_result.image_type}"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
process parameters
|
||||
process parameters
|
||||
"""
|
||||
res_parameters = deepcopy(parameters)
|
||||
|
||||
# delete none and empty
|
||||
keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == '']
|
||||
keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""]
|
||||
for k in keys_to_delete:
|
||||
del res_parameters[k]
|
||||
|
||||
|
||||
@@ -12,127 +12,137 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class NovitaAiModelQueryTool(BuiltinTool):
|
||||
_model_query_endpoint = 'https://api.novita.ai/v3/model'
|
||||
_model_query_endpoint = "https://api.novita.ai/v3/model"
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
|
||||
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
|
||||
raise ToolProviderCredentialValidationError("Novita AI API Key is required.")
|
||||
|
||||
api_key = self.runtime.credentials.get('api_key')
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': "Bearer " + api_key
|
||||
}
|
||||
api_key = self.runtime.credentials.get("api_key")
|
||||
headers = {"Content-Type": "application/json", "Authorization": "Bearer " + api_key}
|
||||
params = self._process_parameters(tool_parameters)
|
||||
result_type = params.get('result_type')
|
||||
del params['result_type']
|
||||
result_type = params.get("result_type")
|
||||
del params["result_type"]
|
||||
|
||||
models_data = self._query_models(
|
||||
models_data=[],
|
||||
headers=headers,
|
||||
params=params,
|
||||
recursive=False if result_type == 'first sd_name' or result_type == 'first name sd_name pair' else True
|
||||
recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True,
|
||||
)
|
||||
|
||||
result_str = ''
|
||||
if result_type == 'first sd_name':
|
||||
result_str = models_data[0]['sd_name_in_api'] if len(models_data) > 0 else ''
|
||||
elif result_type == 'first name sd_name pair':
|
||||
result_str = json.dumps({'name': models_data[0]['name'], 'sd_name': models_data[0]['sd_name_in_api']}) if len(models_data) > 0 else ''
|
||||
elif result_type == 'sd_name array':
|
||||
sd_name_array = [model['sd_name_in_api'] for model in models_data] if len(models_data) > 0 else []
|
||||
result_str = ""
|
||||
if result_type == "first sd_name":
|
||||
result_str = models_data[0]["sd_name_in_api"] if len(models_data) > 0 else ""
|
||||
elif result_type == "first name sd_name pair":
|
||||
result_str = (
|
||||
json.dumps({"name": models_data[0]["name"], "sd_name": models_data[0]["sd_name_in_api"]})
|
||||
if len(models_data) > 0
|
||||
else ""
|
||||
)
|
||||
elif result_type == "sd_name array":
|
||||
sd_name_array = [model["sd_name_in_api"] for model in models_data] if len(models_data) > 0 else []
|
||||
result_str = json.dumps(sd_name_array)
|
||||
elif result_type == 'name array':
|
||||
name_array = [model['name'] for model in models_data] if len(models_data) > 0 else []
|
||||
elif result_type == "name array":
|
||||
name_array = [model["name"] for model in models_data] if len(models_data) > 0 else []
|
||||
result_str = json.dumps(name_array)
|
||||
elif result_type == 'name sd_name pair array':
|
||||
name_sd_name_pair_array = [{'name': model['name'], 'sd_name': model['sd_name_in_api']}
|
||||
for model in models_data] if len(models_data) > 0 else []
|
||||
elif result_type == "name sd_name pair array":
|
||||
name_sd_name_pair_array = (
|
||||
[{"name": model["name"], "sd_name": model["sd_name_in_api"]} for model in models_data]
|
||||
if len(models_data) > 0
|
||||
else []
|
||||
)
|
||||
result_str = json.dumps(name_sd_name_pair_array)
|
||||
elif result_type == 'whole info array':
|
||||
elif result_type == "whole info array":
|
||||
result_str = json.dumps(models_data)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return self.create_text_message(result_str)
|
||||
|
||||
def _query_models(self, models_data: list, headers: dict[str, Any],
|
||||
params: dict[str, Any], pagination_cursor: str = '', recursive: bool = True) -> list:
|
||||
def _query_models(
|
||||
self,
|
||||
models_data: list,
|
||||
headers: dict[str, Any],
|
||||
params: dict[str, Any],
|
||||
pagination_cursor: str = "",
|
||||
recursive: bool = True,
|
||||
) -> list:
|
||||
"""
|
||||
query models
|
||||
query models
|
||||
"""
|
||||
inside_params = deepcopy(params)
|
||||
|
||||
if pagination_cursor != '':
|
||||
inside_params['pagination.cursor'] = pagination_cursor
|
||||
if pagination_cursor != "":
|
||||
inside_params["pagination.cursor"] = pagination_cursor
|
||||
|
||||
response = ssrf_proxy.get(
|
||||
url=str(URL(self._model_query_endpoint)),
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=(10, 60)
|
||||
url=str(URL(self._model_query_endpoint)), headers=headers, params=params, timeout=(10, 60)
|
||||
)
|
||||
|
||||
res_data = response.json()
|
||||
|
||||
models_data.extend(res_data['models'])
|
||||
models_data.extend(res_data["models"])
|
||||
|
||||
res_data_len = len(res_data['models'])
|
||||
if res_data_len == 0 or res_data_len < int(params['pagination.limit']) or recursive is False:
|
||||
res_data_len = len(res_data["models"])
|
||||
if res_data_len == 0 or res_data_len < int(params["pagination.limit"]) or recursive is False:
|
||||
# deduplicate
|
||||
df = DataFrame.from_dict(models_data)
|
||||
df_unique = df.drop_duplicates(subset=['id'])
|
||||
models_data = df_unique.to_dict('records')
|
||||
df_unique = df.drop_duplicates(subset=["id"])
|
||||
models_data = df_unique.to_dict("records")
|
||||
return models_data
|
||||
|
||||
return self._query_models(
|
||||
models_data=models_data,
|
||||
headers=headers,
|
||||
params=inside_params,
|
||||
pagination_cursor=res_data['pagination']['next_cursor']
|
||||
pagination_cursor=res_data["pagination"]["next_cursor"],
|
||||
)
|
||||
|
||||
def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
process parameters
|
||||
process parameters
|
||||
"""
|
||||
process_parameters = deepcopy(parameters)
|
||||
res_parameters = {}
|
||||
|
||||
# delete none or empty
|
||||
keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == '']
|
||||
keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ""]
|
||||
for k in keys_to_delete:
|
||||
del process_parameters[k]
|
||||
|
||||
if 'query' in process_parameters and process_parameters.get('query') != 'unspecified':
|
||||
res_parameters['filter.query'] = process_parameters['query']
|
||||
if "query" in process_parameters and process_parameters.get("query") != "unspecified":
|
||||
res_parameters["filter.query"] = process_parameters["query"]
|
||||
|
||||
if 'visibility' in process_parameters and process_parameters.get('visibility') != 'unspecified':
|
||||
res_parameters['filter.visibility'] = process_parameters['visibility']
|
||||
if "visibility" in process_parameters and process_parameters.get("visibility") != "unspecified":
|
||||
res_parameters["filter.visibility"] = process_parameters["visibility"]
|
||||
|
||||
if 'source' in process_parameters and process_parameters.get('source') != 'unspecified':
|
||||
res_parameters['filter.source'] = process_parameters['source']
|
||||
if "source" in process_parameters and process_parameters.get("source") != "unspecified":
|
||||
res_parameters["filter.source"] = process_parameters["source"]
|
||||
|
||||
if 'type' in process_parameters and process_parameters.get('type') != 'unspecified':
|
||||
res_parameters['filter.types'] = process_parameters['type']
|
||||
if "type" in process_parameters and process_parameters.get("type") != "unspecified":
|
||||
res_parameters["filter.types"] = process_parameters["type"]
|
||||
|
||||
if 'is_sdxl' in process_parameters:
|
||||
if process_parameters['is_sdxl'] == 'true':
|
||||
res_parameters['filter.is_sdxl'] = True
|
||||
elif process_parameters['is_sdxl'] == 'false':
|
||||
res_parameters['filter.is_sdxl'] = False
|
||||
if "is_sdxl" in process_parameters:
|
||||
if process_parameters["is_sdxl"] == "true":
|
||||
res_parameters["filter.is_sdxl"] = True
|
||||
elif process_parameters["is_sdxl"] == "false":
|
||||
res_parameters["filter.is_sdxl"] = False
|
||||
|
||||
res_parameters['result_type'] = process_parameters.get('result_type', 'first sd_name')
|
||||
res_parameters["result_type"] = process_parameters.get("result_type", "first sd_name")
|
||||
|
||||
res_parameters['pagination.limit'] = 1 \
|
||||
if res_parameters.get('result_type') == 'first sd_name' \
|
||||
or res_parameters.get('result_type') == 'first name sd_name pair'\
|
||||
res_parameters["pagination.limit"] = (
|
||||
1
|
||||
if res_parameters.get("result_type") == "first sd_name"
|
||||
or res_parameters.get("result_type") == "first name sd_name pair"
|
||||
else 100
|
||||
)
|
||||
|
||||
return res_parameters
|
||||
|
||||
@@ -13,17 +13,18 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
|
||||
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
|
||||
raise ToolProviderCredentialValidationError("Novita AI API Key is required.")
|
||||
|
||||
api_key = self.runtime.credentials.get('api_key')
|
||||
api_key = self.runtime.credentials.get("api_key")
|
||||
|
||||
client = NovitaClient(api_key=api_key)
|
||||
param = self._process_parameters(tool_parameters)
|
||||
@@ -32,56 +33,58 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
|
||||
results = []
|
||||
for image_encoded, image in zip(client_result.images_encoded, client_result.images):
|
||||
if self._is_hit_nsfw_detection(image, 0.8):
|
||||
results = self.create_text_message(text='NSFW detected!')
|
||||
results = self.create_text_message(text="NSFW detected!")
|
||||
break
|
||||
|
||||
results.append(
|
||||
self.create_blob_message(blob=b64decode(image_encoded),
|
||||
meta={'mime_type': f'image/{image.image_type}'},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value)
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image_encoded),
|
||||
meta={"mime_type": f"image/{image.image_type}"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
process parameters
|
||||
process parameters
|
||||
"""
|
||||
res_parameters = deepcopy(parameters)
|
||||
|
||||
# delete none and empty
|
||||
keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == '']
|
||||
keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""]
|
||||
for k in keys_to_delete:
|
||||
del res_parameters[k]
|
||||
|
||||
if 'clip_skip' in res_parameters and res_parameters.get('clip_skip') == 0:
|
||||
del res_parameters['clip_skip']
|
||||
if "clip_skip" in res_parameters and res_parameters.get("clip_skip") == 0:
|
||||
del res_parameters["clip_skip"]
|
||||
|
||||
if 'refiner_switch_at' in res_parameters and res_parameters.get('refiner_switch_at') == 0:
|
||||
del res_parameters['refiner_switch_at']
|
||||
if "refiner_switch_at" in res_parameters and res_parameters.get("refiner_switch_at") == 0:
|
||||
del res_parameters["refiner_switch_at"]
|
||||
|
||||
if 'enabled_enterprise_plan' in res_parameters:
|
||||
res_parameters['enterprise_plan'] = {'enabled': res_parameters['enabled_enterprise_plan']}
|
||||
del res_parameters['enabled_enterprise_plan']
|
||||
if "enabled_enterprise_plan" in res_parameters:
|
||||
res_parameters["enterprise_plan"] = {"enabled": res_parameters["enabled_enterprise_plan"]}
|
||||
del res_parameters["enabled_enterprise_plan"]
|
||||
|
||||
if 'nsfw_detection_level' in res_parameters:
|
||||
res_parameters['nsfw_detection_level'] = int(res_parameters['nsfw_detection_level'])
|
||||
if "nsfw_detection_level" in res_parameters:
|
||||
res_parameters["nsfw_detection_level"] = int(res_parameters["nsfw_detection_level"])
|
||||
|
||||
# process loras
|
||||
if 'loras' in res_parameters:
|
||||
res_parameters['loras'] = self._extract_loras(res_parameters.get('loras'))
|
||||
if "loras" in res_parameters:
|
||||
res_parameters["loras"] = self._extract_loras(res_parameters.get("loras"))
|
||||
|
||||
# process embeddings
|
||||
if 'embeddings' in res_parameters:
|
||||
res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings'))
|
||||
if "embeddings" in res_parameters:
|
||||
res_parameters["embeddings"] = self._extract_embeddings(res_parameters.get("embeddings"))
|
||||
|
||||
# process hires_fix
|
||||
if 'hires_fix' in res_parameters:
|
||||
res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix'))
|
||||
if "hires_fix" in res_parameters:
|
||||
res_parameters["hires_fix"] = self._extract_hires_fix(res_parameters.get("hires_fix"))
|
||||
|
||||
# process refiner
|
||||
if 'refiner_switch_at' in res_parameters:
|
||||
res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at'))
|
||||
del res_parameters['refiner_switch_at']
|
||||
if "refiner_switch_at" in res_parameters:
|
||||
res_parameters["refiner"] = self._extract_refiner(res_parameters.get("refiner_switch_at"))
|
||||
del res_parameters["refiner_switch_at"]
|
||||
|
||||
return res_parameters
|
||||
|
||||
Reference in New Issue
Block a user