chore: refurish python code by applying Pylint linter rules (#8322)

This commit is contained in:
Bowen Liang
2024-09-13 22:42:08 +08:00
committed by GitHub
parent 1ab81b4972
commit a1104ab97e
126 changed files with 253 additions and 272 deletions

View File

@@ -68,7 +68,7 @@ class AppToolProviderEntity(ToolProviderController):
label = input_form[form_type]["label"]
variable_name = input_form[form_type]["variable_name"]
options = input_form[form_type].get("options", [])
if form_type == "paragraph" or form_type == "text-input":
if form_type in {"paragraph", "text-input"}:
tool["parameters"].append(
ToolParameter(
name=variable_name,

View File

@@ -168,7 +168,7 @@ class AIPPTGenerateTool(BuiltinTool):
pass
elif event == "close":
break
elif event == "error" or event == "filter":
elif event in {"error", "filter"}:
raise Exception(f"Failed to generate outline: {data}")
return outline
@@ -213,7 +213,7 @@ class AIPPTGenerateTool(BuiltinTool):
pass
elif event == "close":
break
elif event == "error" or event == "filter":
elif event in {"error", "filter"}:
raise Exception(f"Failed to generate content: {data}")
return content

View File

@@ -39,11 +39,11 @@ class DallE3Tool(BuiltinTool):
n = tool_parameters.get("n", 1)
# get quality
quality = tool_parameters.get("quality", "standard")
if quality not in ["standard", "hd"]:
if quality not in {"standard", "hd"}:
return self.create_text_message("Invalid quality")
# get style
style = tool_parameters.get("style", "vivid")
if style not in ["natural", "vivid"]:
if style not in {"natural", "vivid"}:
return self.create_text_message("Invalid style")
# set extra body
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))

View File

@@ -14,7 +14,7 @@ class SimpleCode(BuiltinTool):
language = tool_parameters.get("language", CodeLanguage.PYTHON3)
code = tool_parameters.get("code", "")
if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]:
if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
raise ValueError(f"Only python3 and javascript are supported, not {language}")
result = CodeExecutor.execute_code(language, "", code)

View File

@@ -34,11 +34,11 @@ class CogView3Tool(BuiltinTool):
n = tool_parameters.get("n", 1)
# get quality
quality = tool_parameters.get("quality", "standard")
if quality not in ["standard", "hd"]:
if quality not in {"standard", "hd"}:
return self.create_text_message("Invalid quality")
# get style
style = tool_parameters.get("style", "vivid")
if style not in ["natural", "vivid"]:
if style not in {"natural", "vivid"}:
return self.create_text_message("Invalid style")
# set extra body
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))

View File

@@ -49,11 +49,11 @@ class DallE3Tool(BuiltinTool):
n = tool_parameters.get("n", 1)
# get quality
quality = tool_parameters.get("quality", "standard")
if quality not in ["standard", "hd"]:
if quality not in {"standard", "hd"}:
return self.create_text_message("Invalid quality")
# get style
style = tool_parameters.get("style", "vivid")
if style not in ["natural", "vivid"]:
if style not in {"natural", "vivid"}:
return self.create_text_message("Invalid style")
# call openapi dalle3

View File

@@ -133,9 +133,9 @@ class GetWorksheetFieldsTool(BuiltinTool):
def _extract_options(self, control: dict) -> list:
options = []
if control["type"] in [9, 10, 11]:
if control["type"] in {9, 10, 11}:
options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])])
elif control["type"] in [28, 36]:
elif control["type"] in {28, 36}:
itemnames = control["advancedSetting"].get("itemnames")
if itemnames and itemnames.startswith("[{"):
try:

View File

@@ -183,11 +183,11 @@ class ListWorksheetRecordsTool(BuiltinTool):
type_id = field.get("typeId")
if type_id == 10:
value = value if isinstance(value, str) else "".join(value)
elif type_id in [28, 36]:
elif type_id in {28, 36}:
value = field.get("options", {}).get(value, value)
elif type_id in [26, 27, 48, 14]:
elif type_id in {26, 27, 48, 14}:
value = self.process_value(value)
elif type_id in [35, 29]:
elif type_id in {35, 29}:
value = self.parse_cascade_or_associated(field, value)
elif type_id == 40:
value = self.parse_location(value)

View File

@@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool):
models_data=[],
headers=headers,
params=params,
recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"),
recursive=result_type not in {"first sd_name", "first name sd_name pair"},
)
result_str = ""

View File

@@ -38,7 +38,7 @@ class SearchAPI:
return {
"engine": "google",
"q": query,
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
**{key: value for key, value in kwargs.items() if value not in {None, ""}},
}
@staticmethod

View File

@@ -38,7 +38,7 @@ class SearchAPI:
return {
"engine": "google_jobs",
"q": query,
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
**{key: value for key, value in kwargs.items() if value not in {None, ""}},
}
@staticmethod

View File

@@ -38,7 +38,7 @@ class SearchAPI:
return {
"engine": "google_news",
"q": query,
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
**{key: value for key, value in kwargs.items() if value not in {None, ""}},
}
@staticmethod

View File

@@ -38,7 +38,7 @@ class SearchAPI:
"engine": "youtube_transcripts",
"video_id": video_id,
"lang": language or "en",
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
**{key: value for key, value in kwargs.items() if value not in {None, ""}},
}
@staticmethod

View File

@@ -214,7 +214,7 @@ class Spider:
return requests.delete(url, headers=headers, stream=stream)
def _handle_error(self, response, action):
if response.status_code in [402, 409, 500]:
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
else:

View File

@@ -32,7 +32,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
model = tool_parameters.get("model", "core")
if model in ["sd3", "sd3-turbo"]:
if model in {"sd3", "sd3-turbo"}:
payload["model"] = tool_parameters.get("model")
if model != "sd3-turbo":

View File

@@ -38,7 +38,7 @@ class VannaTool(BuiltinTool):
vn = VannaDefault(model=model, api_key=api_key)
db_type = tool_parameters.get("db_type", "")
if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]:
if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}:
if not db_name:
return self.create_text_message("Please input database name")
if not username:

View File

@@ -19,7 +19,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
class BuiltinToolProviderController(ToolProviderController):
def __init__(self, **data: Any) -> None:
if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}:
super().__init__(**data)
return

View File

@@ -153,10 +153,10 @@ class ToolProviderController(BaseModel, ABC):
# check type
credential_schema = credentials_need_to_validate[credential_name]
if (
credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT
or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT
):
if credential_schema in {
ToolProviderCredentials.CredentialsType.SECRET_INPUT,
ToolProviderCredentials.CredentialsType.TEXT_INPUT,
}:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
@@ -184,11 +184,11 @@ class ToolProviderController(BaseModel, ABC):
if credential_schema.default is not None:
default_value = credential_schema.default
# parse default value into the correct type
if (
credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT
or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT
or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT
):
if credential_schema.type in {
ToolProviderCredentials.CredentialsType.SECRET_INPUT,
ToolProviderCredentials.CredentialsType.TEXT_INPUT,
ToolProviderCredentials.CredentialsType.SELECT,
}:
default_value = str(default_value)
credentials[credential_name] = default_value

View File

@@ -5,7 +5,7 @@ from urllib.parse import urlencode
import httpx
import core.helper.ssrf_proxy as ssrf_proxy
from core.helper import ssrf_proxy
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
@@ -191,7 +191,7 @@ class ApiTool(Tool):
else:
body = body
if method in ("get", "head", "post", "put", "delete", "patch"):
if method in {"get", "head", "post", "put", "delete", "patch"}:
response = getattr(ssrf_proxy, method)(
url,
params=params,
@@ -224,9 +224,9 @@ class ApiTool(Tool):
elif option["type"] == "string":
return str(value)
elif option["type"] == "boolean":
if str(value).lower() in ["true", "1"]:
if str(value).lower() in {"true", "1"}:
return True
elif str(value).lower() in ["false", "0"]:
elif str(value).lower() in {"false", "0"}:
return False
else:
continue # Not a boolean, try next option

View File

@@ -189,10 +189,7 @@ class ToolEngine:
result += response.message
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please tell user to check it."
elif (
response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
or response.type == ToolInvokeMessage.MessageType.IMAGE
):
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
result += (
"image has been created and sent to user already, you do not need to create it,"
" just tell the user to check it now."
@@ -212,10 +209,7 @@ class ToolEngine:
result = []
for response in tool_response:
if (
response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
or response.type == ToolInvokeMessage.MessageType.IMAGE
):
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
mimetype = None
if response.meta.get("mime_type"):
mimetype = response.meta.get("mime_type")
@@ -297,7 +291,7 @@ class ToolEngine:
belongs_to="assistant",
url=message.url,
upload_file_id=None,
created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"),
created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"),
created_by=user_id,
)

View File

@@ -19,7 +19,7 @@ class ToolFileMessageTransformer:
result = []
for message in messages:
if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK:
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
result.append(message)
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
# try to download image

View File

@@ -165,7 +165,7 @@ class ApiBasedToolSchemaParser:
elif "schema" in parameter and "type" in parameter["schema"]:
typ = parameter["schema"]["type"]
if typ == "integer" or typ == "number":
if typ in {"integer", "number"}:
return ToolParameter.ToolParameterType.NUMBER
elif typ == "boolean":
return ToolParameter.ToolParameterType.BOOLEAN

View File

@@ -313,7 +313,7 @@ def normalize_whitespace(text):
def is_leaf(element):
return element.name in ["p", "li"]
return element.name in {"p", "li"}
def is_text(element):