mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-11 03:46:52 +08:00
chore: refurish python code by applying Pylint linter rules (#8322)
This commit is contained in:
@@ -27,17 +27,17 @@ class ModelType(Enum):
|
||||
|
||||
:return: model type
|
||||
"""
|
||||
if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value:
|
||||
if origin_model_type in {"text-generation", cls.LLM.value}:
|
||||
return cls.LLM
|
||||
elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value:
|
||||
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
|
||||
return cls.TEXT_EMBEDDING
|
||||
elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value:
|
||||
elif origin_model_type in {"reranking", cls.RERANK.value}:
|
||||
return cls.RERANK
|
||||
elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value:
|
||||
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type == "tts" or origin_model_type == cls.TTS.value:
|
||||
elif origin_model_type in {"tts", cls.TTS.value}:
|
||||
return cls.TTS
|
||||
elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value:
|
||||
elif origin_model_type in {"text2img", cls.TEXT2IMG.value}:
|
||||
return cls.TEXT2IMG
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
return cls.MODERATION
|
||||
|
||||
@@ -494,7 +494,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
|
||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
||||
|
||||
@@ -85,14 +85,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
for i in range(len(sentences))
|
||||
]
|
||||
for future in futures:
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
|
||||
|
||||
else:
|
||||
response = client.audio.speech.with_streaming_response.create(
|
||||
model=model, voice=voice, response_format="mp3", input=content_text.strip()
|
||||
)
|
||||
|
||||
yield from response.__enter__().iter_bytes(1024)
|
||||
yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
|
||||
@@ -454,7 +454,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
base64_data = data_split[1]
|
||||
image_content = base64.b64decode(base64_data)
|
||||
|
||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
||||
@@ -886,16 +886,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
if error_code == "AccessDeniedException":
|
||||
return InvokeAuthorizationError(error_msg)
|
||||
elif error_code in ["ResourceNotFoundException", "ValidationException"]:
|
||||
elif error_code in {"ResourceNotFoundException", "ValidationException"}:
|
||||
return InvokeBadRequestError(error_msg)
|
||||
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
|
||||
elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
|
||||
return InvokeRateLimitError(error_msg)
|
||||
elif error_code in [
|
||||
elif error_code in {
|
||||
"ModelTimeoutException",
|
||||
"ModelErrorException",
|
||||
"InternalServerException",
|
||||
"ModelNotReadyException",
|
||||
]:
|
||||
}:
|
||||
return InvokeServerUnavailableError(error_msg)
|
||||
elif error_code == "ModelStreamErrorException":
|
||||
return InvokeConnectionError(error_msg)
|
||||
|
||||
@@ -186,16 +186,16 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
if error_code == "AccessDeniedException":
|
||||
return InvokeAuthorizationError(error_msg)
|
||||
elif error_code in ["ResourceNotFoundException", "ValidationException"]:
|
||||
elif error_code in {"ResourceNotFoundException", "ValidationException"}:
|
||||
return InvokeBadRequestError(error_msg)
|
||||
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
|
||||
elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
|
||||
return InvokeRateLimitError(error_msg)
|
||||
elif error_code in [
|
||||
elif error_code in {
|
||||
"ModelTimeoutException",
|
||||
"ModelErrorException",
|
||||
"InternalServerException",
|
||||
"ModelNotReadyException",
|
||||
]:
|
||||
}:
|
||||
return InvokeServerUnavailableError(error_msg)
|
||||
elif error_code == "ModelStreamErrorException":
|
||||
return InvokeConnectionError(error_msg)
|
||||
|
||||
@@ -6,10 +6,10 @@ from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.api_core.exceptions as exceptions
|
||||
import google.generativeai as genai
|
||||
import google.generativeai.client as client
|
||||
import requests
|
||||
from google.api_core import exceptions
|
||||
from google.generativeai import client
|
||||
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
|
||||
from google.generativeai.types.content_types import to_part
|
||||
from PIL import Image
|
||||
|
||||
@@ -77,7 +77,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
if "huggingfacehub_api_type" not in credentials:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
|
||||
|
||||
if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"):
|
||||
if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}:
|
||||
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
|
||||
|
||||
if "huggingfacehub_api_token" not in credentials:
|
||||
@@ -94,7 +94,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
credentials["huggingfacehub_api_token"], model
|
||||
)
|
||||
|
||||
if credentials["task_type"] not in ("text2text-generation", "text-generation"):
|
||||
if credentials["task_type"] not in {"text2text-generation", "text-generation"}:
|
||||
raise CredentialsValidateFailedError(
|
||||
"Huggingface Hub Task Type must be one of text2text-generation, text-generation."
|
||||
)
|
||||
|
||||
@@ -75,7 +75,7 @@ class TeiHelper:
|
||||
if len(model_type.keys()) < 1:
|
||||
raise RuntimeError("model_type is empty")
|
||||
model_type = list(model_type.keys())[0]
|
||||
if model_type not in ["embedding", "reranker"]:
|
||||
if model_type not in {"embedding", "reranker"}:
|
||||
raise RuntimeError(f"invalid model_type: {model_type}")
|
||||
|
||||
max_input_length = response_json.get("max_input_length", 512)
|
||||
|
||||
@@ -100,9 +100,9 @@ class MinimaxChatCompletion:
|
||||
return self._handle_chat_generate_response(response)
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
|
||||
if code in {1000, 1001, 1013, 1027}:
|
||||
raise InternalServerError(msg)
|
||||
elif code == 1002 or code == 1039:
|
||||
elif code in {1002, 1039}:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif code == 1004:
|
||||
raise InvalidAuthenticationError(msg)
|
||||
|
||||
@@ -105,9 +105,9 @@ class MinimaxChatCompletionPro:
|
||||
return self._handle_chat_generate_response(response)
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
|
||||
if code in {1000, 1001, 1013, 1027}:
|
||||
raise InternalServerError(msg)
|
||||
elif code == 1002 or code == 1039:
|
||||
elif code in {1002, 1039}:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif code == 1004:
|
||||
raise InvalidAuthenticationError(msg)
|
||||
|
||||
@@ -114,7 +114,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
raise CredentialsValidateFailedError("Invalid api key")
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
if code == 1000 or code == 1001:
|
||||
if code in {1000, 1001}:
|
||||
raise InternalServerError(msg)
|
||||
elif code == 1002:
|
||||
raise RateLimitReachedError(msg)
|
||||
|
||||
@@ -125,7 +125,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
model_mode = self.get_model_mode(base_model, credentials)
|
||||
|
||||
# transform response format
|
||||
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
|
||||
if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
|
||||
stop = stop or []
|
||||
if model_mode == LLMMode.CHAT:
|
||||
# chat model
|
||||
|
||||
@@ -89,14 +89,14 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
for i in range(len(sentences))
|
||||
]
|
||||
for future in futures:
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
|
||||
|
||||
else:
|
||||
response = client.audio.speech.with_streaming_response.create(
|
||||
model=model, voice=voice, response_format="mp3", input=content_text.strip()
|
||||
)
|
||||
|
||||
yield from response.__enter__().iter_bytes(1024)
|
||||
yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
credentials["endpoint_url"] = "https://openrouter.ai/api/v1"
|
||||
credentials["mode"] = self.get_model_mode(model).value
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
return
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
|
||||
@@ -154,7 +154,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
)
|
||||
|
||||
for key, value in input_properties:
|
||||
if key not in ["system_prompt", "prompt"] and "stop" not in key:
|
||||
if key not in {"system_prompt", "prompt"} and "stop" not in key:
|
||||
value_type = value.get("type")
|
||||
|
||||
if not value_type:
|
||||
|
||||
@@ -86,7 +86,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
)
|
||||
|
||||
for input_property in input_properties:
|
||||
if input_property[0] in ("text", "texts", "inputs"):
|
||||
if input_property[0] in {"text", "texts", "inputs"}:
|
||||
text_input_key = input_property[0]
|
||||
return text_input_key
|
||||
|
||||
@@ -96,7 +96,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
def _generate_embeddings_by_text_input_key(
|
||||
client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str]
|
||||
) -> list[list[float]]:
|
||||
if text_input_key in ("text", "inputs"):
|
||||
if text_input_key in {"text", "inputs"}:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
result = client.run(replicate_model_version, input={text_input_key: text})
|
||||
|
||||
@@ -89,7 +89,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
if model in ["qwen-turbo-chat", "qwen-plus-chat"]:
|
||||
if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
|
||||
model = model.replace("-chat", "")
|
||||
if model == "farui-plus":
|
||||
model = "qwen-farui-plus"
|
||||
@@ -157,7 +157,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
mode = self.get_model_mode(model, credentials)
|
||||
|
||||
if model in ["qwen-turbo-chat", "qwen-plus-chat"]:
|
||||
if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
|
||||
model = model.replace("-chat", "")
|
||||
|
||||
extra_model_kwargs = {}
|
||||
@@ -201,7 +201,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
if response.status_code != 200 and response.status_code != HTTPStatus.OK:
|
||||
if response.status_code not in {200, HTTPStatus.OK}:
|
||||
raise ServiceUnavailableError(response.message)
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
@@ -240,7 +240,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
full_text = ""
|
||||
tool_calls = []
|
||||
for index, response in enumerate(responses):
|
||||
if response.status_code != 200 and response.status_code != HTTPStatus.OK:
|
||||
if response.status_code not in {200, HTTPStatus.OK}:
|
||||
raise ServiceUnavailableError(
|
||||
f"Failed to invoke model {model}, status code: {response.status_code}, "
|
||||
f"message: {response.message}"
|
||||
|
||||
@@ -93,7 +93,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
"""
|
||||
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
|
||||
if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
|
||||
stop = stop or []
|
||||
self._transform_chat_json_prompts(
|
||||
model=model,
|
||||
|
||||
@@ -5,7 +5,6 @@ import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import google.api_core.exceptions as exceptions
|
||||
import google.auth.transport.requests
|
||||
import vertexai.generative_models as glm
|
||||
from anthropic import AnthropicVertex, Stream
|
||||
@@ -17,6 +16,7 @@ from anthropic.types import (
|
||||
MessageStopEvent,
|
||||
MessageStreamEvent,
|
||||
)
|
||||
from google.api_core import exceptions
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from PIL import Image
|
||||
@@ -346,7 +346,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
|
||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
||||
|
||||
@@ -96,7 +96,6 @@ class Signer:
|
||||
signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
|
||||
sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
|
||||
request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def hashed_canonical_request_v4(request, meta):
|
||||
@@ -105,7 +104,7 @@ class Signer:
|
||||
|
||||
signed_headers = {}
|
||||
for key in request.headers:
|
||||
if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"):
|
||||
if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"):
|
||||
signed_headers[key.lower()] = request.headers[key]
|
||||
|
||||
if "host" in signed_headers:
|
||||
|
||||
@@ -69,7 +69,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
"""
|
||||
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
|
||||
if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
|
||||
response_format = model_parameters["response_format"]
|
||||
stop = stop or []
|
||||
self._transform_json_prompts(
|
||||
|
||||
@@ -103,7 +103,7 @@ class XinferenceHelper:
|
||||
model_handle_type = "embedding"
|
||||
elif response_json.get("model_type") == "audio":
|
||||
model_handle_type = "audio"
|
||||
if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]:
|
||||
if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}:
|
||||
model_ability.append("text-to-audio")
|
||||
else:
|
||||
model_ability.append("audio-to-text")
|
||||
|
||||
@@ -186,10 +186,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
new_prompt_messages: list[PromptMessage] = []
|
||||
for prompt_message in prompt_messages:
|
||||
copy_prompt_message = prompt_message.copy()
|
||||
if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]:
|
||||
if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
|
||||
if isinstance(copy_prompt_message.content, list):
|
||||
# check if model is 'glm-4v'
|
||||
if model not in ("glm-4v", "glm-4v-plus"):
|
||||
if model not in {"glm-4v", "glm-4v-plus"}:
|
||||
# not support list message
|
||||
continue
|
||||
# get image and
|
||||
@@ -209,10 +209,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
):
|
||||
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
||||
else:
|
||||
if (
|
||||
copy_prompt_message.role == PromptMessageRole.USER
|
||||
or copy_prompt_message.role == PromptMessageRole.TOOL
|
||||
):
|
||||
if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}:
|
||||
new_prompt_messages.append(copy_prompt_message)
|
||||
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
||||
@@ -226,7 +223,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
else:
|
||||
new_prompt_messages.append(copy_prompt_message)
|
||||
|
||||
if model == "glm-4v" or model == "glm-4v-plus":
|
||||
if model in {"glm-4v", "glm-4v-plus"}:
|
||||
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
|
||||
else:
|
||||
params = {"model": model, "messages": [], **model_parameters}
|
||||
@@ -270,11 +267,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
# chatglm model
|
||||
for prompt_message in new_prompt_messages:
|
||||
# merge system message to user message
|
||||
if (
|
||||
prompt_message.role == PromptMessageRole.SYSTEM
|
||||
or prompt_message.role == PromptMessageRole.TOOL
|
||||
or prompt_message.role == PromptMessageRole.USER
|
||||
):
|
||||
if prompt_message.role in {
|
||||
PromptMessageRole.SYSTEM,
|
||||
PromptMessageRole.TOOL,
|
||||
PromptMessageRole.USER,
|
||||
}:
|
||||
if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user":
|
||||
params["messages"][-1]["content"] += "\n\n" + prompt_message.content
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .fine_tuning_job import FineTuningJob as FineTuningJob
|
||||
from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
|
||||
from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent
|
||||
from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob
|
||||
from .fine_tuning_job_event import FineTuningJobEvent
|
||||
|
||||
@@ -75,7 +75,7 @@ class CommonValidator:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"Variable {credential_form_schema.variable} should be string")
|
||||
|
||||
if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]:
|
||||
if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}:
|
||||
# If the value is in options, no validation is performed
|
||||
if credential_form_schema.options:
|
||||
if value not in [option.value for option in credential_form_schema.options]:
|
||||
@@ -83,7 +83,7 @@ class CommonValidator:
|
||||
|
||||
if credential_form_schema.type == FormType.SWITCH:
|
||||
# If the value is not in ['true', 'false'], an exception is thrown
|
||||
if value.lower() not in ["true", "false"]:
|
||||
if value.lower() not in {"true", "false"}:
|
||||
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
|
||||
|
||||
value = True if value.lower() == "true" else False
|
||||
|
||||
Reference in New Issue
Block a user