mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 11:26:52 +08:00
chore: refurish python code by applying Pylint linter rules (#8322)
This commit is contained in:
@@ -68,7 +68,7 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
label = input_form[form_type]["label"]
|
||||
variable_name = input_form[form_type]["variable_name"]
|
||||
options = input_form[form_type].get("options", [])
|
||||
if form_type == "paragraph" or form_type == "text-input":
|
||||
if form_type in {"paragraph", "text-input"}:
|
||||
tool["parameters"].append(
|
||||
ToolParameter(
|
||||
name=variable_name,
|
||||
|
||||
@@ -168,7 +168,7 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
pass
|
||||
elif event == "close":
|
||||
break
|
||||
elif event == "error" or event == "filter":
|
||||
elif event in {"error", "filter"}:
|
||||
raise Exception(f"Failed to generate outline: {data}")
|
||||
|
||||
return outline
|
||||
@@ -213,7 +213,7 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
pass
|
||||
elif event == "close":
|
||||
break
|
||||
elif event == "error" or event == "filter":
|
||||
elif event in {"error", "filter"}:
|
||||
raise Exception(f"Failed to generate content: {data}")
|
||||
|
||||
return content
|
||||
|
||||
@@ -39,11 +39,11 @@ class DallE3Tool(BuiltinTool):
|
||||
n = tool_parameters.get("n", 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
if quality not in ["standard", "hd"]:
|
||||
if quality not in {"standard", "hd"}:
|
||||
return self.create_text_message("Invalid quality")
|
||||
# get style
|
||||
style = tool_parameters.get("style", "vivid")
|
||||
if style not in ["natural", "vivid"]:
|
||||
if style not in {"natural", "vivid"}:
|
||||
return self.create_text_message("Invalid style")
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
||||
|
||||
@@ -14,7 +14,7 @@ class SimpleCode(BuiltinTool):
|
||||
language = tool_parameters.get("language", CodeLanguage.PYTHON3)
|
||||
code = tool_parameters.get("code", "")
|
||||
|
||||
if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]:
|
||||
if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
|
||||
raise ValueError(f"Only python3 and javascript are supported, not {language}")
|
||||
|
||||
result = CodeExecutor.execute_code(language, "", code)
|
||||
|
||||
@@ -34,11 +34,11 @@ class CogView3Tool(BuiltinTool):
|
||||
n = tool_parameters.get("n", 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
if quality not in ["standard", "hd"]:
|
||||
if quality not in {"standard", "hd"}:
|
||||
return self.create_text_message("Invalid quality")
|
||||
# get style
|
||||
style = tool_parameters.get("style", "vivid")
|
||||
if style not in ["natural", "vivid"]:
|
||||
if style not in {"natural", "vivid"}:
|
||||
return self.create_text_message("Invalid style")
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
||||
|
||||
@@ -49,11 +49,11 @@ class DallE3Tool(BuiltinTool):
|
||||
n = tool_parameters.get("n", 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
if quality not in ["standard", "hd"]:
|
||||
if quality not in {"standard", "hd"}:
|
||||
return self.create_text_message("Invalid quality")
|
||||
# get style
|
||||
style = tool_parameters.get("style", "vivid")
|
||||
if style not in ["natural", "vivid"]:
|
||||
if style not in {"natural", "vivid"}:
|
||||
return self.create_text_message("Invalid style")
|
||||
|
||||
# call openapi dalle3
|
||||
|
||||
@@ -133,9 +133,9 @@ class GetWorksheetFieldsTool(BuiltinTool):
|
||||
|
||||
def _extract_options(self, control: dict) -> list:
|
||||
options = []
|
||||
if control["type"] in [9, 10, 11]:
|
||||
if control["type"] in {9, 10, 11}:
|
||||
options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])])
|
||||
elif control["type"] in [28, 36]:
|
||||
elif control["type"] in {28, 36}:
|
||||
itemnames = control["advancedSetting"].get("itemnames")
|
||||
if itemnames and itemnames.startswith("[{"):
|
||||
try:
|
||||
|
||||
@@ -183,11 +183,11 @@ class ListWorksheetRecordsTool(BuiltinTool):
|
||||
type_id = field.get("typeId")
|
||||
if type_id == 10:
|
||||
value = value if isinstance(value, str) else "、".join(value)
|
||||
elif type_id in [28, 36]:
|
||||
elif type_id in {28, 36}:
|
||||
value = field.get("options", {}).get(value, value)
|
||||
elif type_id in [26, 27, 48, 14]:
|
||||
elif type_id in {26, 27, 48, 14}:
|
||||
value = self.process_value(value)
|
||||
elif type_id in [35, 29]:
|
||||
elif type_id in {35, 29}:
|
||||
value = self.parse_cascade_or_associated(field, value)
|
||||
elif type_id == 40:
|
||||
value = self.parse_location(value)
|
||||
|
||||
@@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool):
|
||||
models_data=[],
|
||||
headers=headers,
|
||||
params=params,
|
||||
recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"),
|
||||
recursive=result_type not in {"first sd_name", "first name sd_name pair"},
|
||||
)
|
||||
|
||||
result_str = ""
|
||||
|
||||
@@ -38,7 +38,7 @@ class SearchAPI:
|
||||
return {
|
||||
"engine": "google",
|
||||
"q": query,
|
||||
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
|
||||
**{key: value for key, value in kwargs.items() if value not in {None, ""}},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -38,7 +38,7 @@ class SearchAPI:
|
||||
return {
|
||||
"engine": "google_jobs",
|
||||
"q": query,
|
||||
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
|
||||
**{key: value for key, value in kwargs.items() if value not in {None, ""}},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -38,7 +38,7 @@ class SearchAPI:
|
||||
return {
|
||||
"engine": "google_news",
|
||||
"q": query,
|
||||
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
|
||||
**{key: value for key, value in kwargs.items() if value not in {None, ""}},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -38,7 +38,7 @@ class SearchAPI:
|
||||
"engine": "youtube_transcripts",
|
||||
"video_id": video_id,
|
||||
"lang": language or "en",
|
||||
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
|
||||
**{key: value for key, value in kwargs.items() if value not in {None, ""}},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -214,7 +214,7 @@ class Spider:
|
||||
return requests.delete(url, headers=headers, stream=stream)
|
||||
|
||||
def _handle_error(self, response, action):
|
||||
if response.status_code in [402, 409, 500]:
|
||||
if response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
|
||||
else:
|
||||
|
||||
@@ -32,7 +32,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
|
||||
|
||||
model = tool_parameters.get("model", "core")
|
||||
|
||||
if model in ["sd3", "sd3-turbo"]:
|
||||
if model in {"sd3", "sd3-turbo"}:
|
||||
payload["model"] = tool_parameters.get("model")
|
||||
|
||||
if model != "sd3-turbo":
|
||||
|
||||
@@ -38,7 +38,7 @@ class VannaTool(BuiltinTool):
|
||||
vn = VannaDefault(model=model, api_key=api_key)
|
||||
|
||||
db_type = tool_parameters.get("db_type", "")
|
||||
if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]:
|
||||
if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}:
|
||||
if not db_name:
|
||||
return self.create_text_message("Please input database name")
|
||||
if not username:
|
||||
|
||||
@@ -19,7 +19,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
|
||||
if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}:
|
||||
super().__init__(**data)
|
||||
return
|
||||
|
||||
|
||||
@@ -153,10 +153,10 @@ class ToolProviderController(BaseModel, ABC):
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if (
|
||||
credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT
|
||||
or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT
|
||||
):
|
||||
if credential_schema in {
|
||||
ToolProviderCredentials.CredentialsType.SECRET_INPUT,
|
||||
ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
}:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
@@ -184,11 +184,11 @@ class ToolProviderController(BaseModel, ABC):
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if (
|
||||
credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT
|
||||
or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT
|
||||
or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT
|
||||
):
|
||||
if credential_schema.type in {
|
||||
ToolProviderCredentials.CredentialsType.SECRET_INPUT,
|
||||
ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
ToolProviderCredentials.CredentialsType.SELECT,
|
||||
}:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
|
||||
@@ -5,7 +5,7 @@ from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
import core.helper.ssrf_proxy as ssrf_proxy
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
@@ -191,7 +191,7 @@ class ApiTool(Tool):
|
||||
else:
|
||||
body = body
|
||||
|
||||
if method in ("get", "head", "post", "put", "delete", "patch"):
|
||||
if method in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
response = getattr(ssrf_proxy, method)(
|
||||
url,
|
||||
params=params,
|
||||
@@ -224,9 +224,9 @@ class ApiTool(Tool):
|
||||
elif option["type"] == "string":
|
||||
return str(value)
|
||||
elif option["type"] == "boolean":
|
||||
if str(value).lower() in ["true", "1"]:
|
||||
if str(value).lower() in {"true", "1"}:
|
||||
return True
|
||||
elif str(value).lower() in ["false", "0"]:
|
||||
elif str(value).lower() in {"false", "0"}:
|
||||
return False
|
||||
else:
|
||||
continue # Not a boolean, try next option
|
||||
|
||||
@@ -189,10 +189,7 @@ class ToolEngine:
|
||||
result += response.message
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
elif (
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
|
||||
or response.type == ToolInvokeMessage.MessageType.IMAGE
|
||||
):
|
||||
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
result += (
|
||||
"image has been created and sent to user already, you do not need to create it,"
|
||||
" just tell the user to check it now."
|
||||
@@ -212,10 +209,7 @@ class ToolEngine:
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if (
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
|
||||
or response.type == ToolInvokeMessage.MessageType.IMAGE
|
||||
):
|
||||
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
mimetype = None
|
||||
if response.meta.get("mime_type"):
|
||||
mimetype = response.meta.get("mime_type")
|
||||
@@ -297,7 +291,7 @@ class ToolEngine:
|
||||
belongs_to="assistant",
|
||||
url=message.url,
|
||||
upload_file_id=None,
|
||||
created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"),
|
||||
created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"),
|
||||
created_by=user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class ToolFileMessageTransformer:
|
||||
result = []
|
||||
|
||||
for message in messages:
|
||||
if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
|
||||
result.append(message)
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# try to download image
|
||||
|
||||
@@ -165,7 +165,7 @@ class ApiBasedToolSchemaParser:
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
typ = parameter["schema"]["type"]
|
||||
|
||||
if typ == "integer" or typ == "number":
|
||||
if typ in {"integer", "number"}:
|
||||
return ToolParameter.ToolParameterType.NUMBER
|
||||
elif typ == "boolean":
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
|
||||
@@ -313,7 +313,7 @@ def normalize_whitespace(text):
|
||||
|
||||
|
||||
def is_leaf(element):
|
||||
return element.name in ["p", "li"]
|
||||
return element.name in {"p", "li"}
|
||||
|
||||
|
||||
def is_text(element):
|
||||
|
||||
Reference in New Issue
Block a user