mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 19:36:53 +08:00
chore: apply flake8-comprehensions Ruff rules to improve collection comprehensions (#5652)
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -83,7 +83,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
audio_bytes_list = list()
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
|
||||
@@ -175,8 +175,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
# - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock
|
||||
# - https://github.com/anthropics/anthropic-sdk-python
|
||||
client = AnthropicBedrock(
|
||||
aws_access_key=credentials.get("aws_access_key_id", None),
|
||||
aws_secret_key=credentials.get("aws_secret_access_key", None),
|
||||
aws_access_key=credentials.get("aws_access_key_id"),
|
||||
aws_secret_key=credentials.get("aws_secret_access_key"),
|
||||
aws_region=credentials["aws_region"],
|
||||
)
|
||||
|
||||
@@ -576,7 +576,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Create payload for bedrock api call depending on model provider
|
||||
"""
|
||||
payload = dict()
|
||||
payload = {}
|
||||
model_prefix = model.split('.')[0]
|
||||
model_name = model.split('.')[1]
|
||||
|
||||
@@ -648,8 +648,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
runtime_client = boto3.client(
|
||||
service_name='bedrock-runtime',
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id", None),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key", None)
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key")
|
||||
)
|
||||
|
||||
model_prefix = model.split('.')[0]
|
||||
|
||||
@@ -49,8 +49,8 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
bedrock_runtime = boto3.client(
|
||||
service_name='bedrock-runtime',
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id", None),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key", None)
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key")
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
@@ -148,7 +148,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Create payload for bedrock api call depending on model provider
|
||||
"""
|
||||
payload = dict()
|
||||
payload = {}
|
||||
|
||||
if model_prefix == "amazon":
|
||||
payload['inputText'] = texts
|
||||
|
||||
@@ -696,12 +696,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
en_US=model
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[feature for feature in base_model_schema_features],
|
||||
features=list(base_model_schema_features),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
key: property for key, property in base_model_schema_model_properties.items()
|
||||
},
|
||||
parameter_rules=[rule for rule in base_model_schema_parameters_rules],
|
||||
model_properties=dict(base_model_schema_model_properties.items()),
|
||||
parameter_rules=list(base_model_schema_parameters_rules),
|
||||
pricing=base_model_schema.pricing
|
||||
)
|
||||
|
||||
|
||||
@@ -277,10 +277,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=part.function_call.name,
|
||||
arguments=json.dumps({
|
||||
key: value
|
||||
for key, value in part.function_call.args.items()
|
||||
})
|
||||
arguments=json.dumps(dict(part.function_call.args.items()))
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
@@ -88,9 +88,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
def _add_function_call(self, model: str, credentials: dict) -> None:
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
if model_schema and set([
|
||||
if model_schema and {
|
||||
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
|
||||
]).intersection(model_schema.features or []):
|
||||
}.intersection(model_schema.features or []):
|
||||
credentials['function_calling_type'] = 'tool_call'
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
|
||||
@@ -100,10 +100,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if endpoint_url and not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
server_url = credentials['server_url'] if 'server_url' in credentials else None
|
||||
server_url = credentials.get('server_url')
|
||||
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {
|
||||
@@ -182,10 +182,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
if stream:
|
||||
headers['Accept'] = 'text/event-stream'
|
||||
|
||||
endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if endpoint_url and not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
server_url = credentials['server_url'] if 'server_url' in credentials else None
|
||||
server_url = credentials.get('server_url')
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
|
||||
@@ -1073,12 +1073,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
en_US=model
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[feature for feature in base_model_schema_features],
|
||||
features=list(base_model_schema_features),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
key: property for key, property in base_model_schema_model_properties.items()
|
||||
},
|
||||
parameter_rules=[rule for rule in base_model_schema_parameters_rules],
|
||||
model_properties=dict(base_model_schema_model_properties.items()),
|
||||
parameter_rules=list(base_model_schema_parameters_rules),
|
||||
pricing=base_model_schema.pricing
|
||||
)
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
audio_bytes_list = list()
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
|
||||
@@ -275,14 +275,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
|
||||
@classmethod
|
||||
def _get_parameter_type(cls, param_type: str) -> str:
|
||||
if param_type == 'integer':
|
||||
return 'int'
|
||||
elif param_type == 'number':
|
||||
return 'float'
|
||||
elif param_type == 'boolean':
|
||||
return 'boolean'
|
||||
elif param_type == 'string':
|
||||
return 'string'
|
||||
type_mapping = {
|
||||
'integer': 'int',
|
||||
'number': 'float',
|
||||
'boolean': 'boolean',
|
||||
'string': 'string'
|
||||
}
|
||||
return type_mapping.get(param_type)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
@@ -80,7 +80,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
audio_bytes_list = list()
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
|
||||
@@ -579,10 +579,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=part.function_call.name,
|
||||
arguments=json.dumps({
|
||||
key: value
|
||||
for key, value in part.function_call.args.items()
|
||||
})
|
||||
arguments=json.dumps(dict(part.function_call.args.items()))
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
@@ -102,7 +102,7 @@ class Signer:
|
||||
body_hash = Util.sha256(request.body)
|
||||
request.headers['X-Content-Sha256'] = body_hash
|
||||
|
||||
signed_headers = dict()
|
||||
signed_headers = {}
|
||||
for key in request.headers:
|
||||
if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'):
|
||||
signed_headers[key.lower()] = request.headers[key]
|
||||
|
||||
@@ -150,7 +150,7 @@ class Request:
|
||||
self.headers = OrderedDict()
|
||||
self.query = OrderedDict()
|
||||
self.body = ''
|
||||
self.form = dict()
|
||||
self.form = {}
|
||||
self.connection_timeout = 0
|
||||
self.socket_timeout = 0
|
||||
|
||||
|
||||
@@ -147,7 +147,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if is_completion_model:
|
||||
return sum([tokens(str(message.content)) for message in messages])
|
||||
return sum(tokens(str(message.content)) for message in messages)
|
||||
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
@@ -18,7 +18,7 @@ class _CommonZhipuaiAI:
|
||||
"""
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['api_key'] if 'api_key' in credentials else
|
||||
credentials['zhipuai_api_key'] if 'zhipuai_api_key' in credentials else None,
|
||||
credentials.get("zhipuai_api_key"),
|
||||
}
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
Reference in New Issue
Block a user