feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,7 +1,8 @@
import json
import logging
import sys
from typing import Optional
from collections.abc import Sequence
from typing import Optional, cast
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@@ -20,7 +21,7 @@ class LoggingCallback(Callback):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@@ -76,7 +77,7 @@ class LoggingCallback(Callback):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
@@ -94,7 +95,7 @@ class LoggingCallback(Callback):
:param stream: is stream response
:param user: unique user id
"""
sys.stdout.write(chunk.delta.message.content)
sys.stdout.write(cast(str, chunk.delta.message.content))
sys.stdout.flush()
def on_after_invoke(
@@ -106,7 +107,7 @@ class LoggingCallback(Callback):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@@ -147,7 +148,7 @@ class LoggingCallback(Callback):
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:

View File

@@ -3,7 +3,7 @@ from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Optional
from pydantic import BaseModel, Field, computed_field, field_validator
from pydantic import BaseModel, Field, field_validator
class PromptMessageRole(Enum):
@@ -89,7 +89,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
url: str = Field(default="", description="the url of multi-modal file")
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
@computed_field(return_type=str)
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"

View File

@@ -1,7 +1,6 @@
import decimal
import os
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Optional
from pydantic import ConfigDict
@@ -36,7 +35,7 @@ class AIModel(ABC):
model_config = ConfigDict(protected_namespaces=())
@abstractmethod
def validate_credentials(self, model: str, credentials: Mapping) -> None:
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
@@ -214,7 +213,7 @@ class AIModel(ABC):
return model_schemas
def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
"""
Get model schema by model name and credentials
@@ -236,9 +235,7 @@ class AIModel(ABC):
return None
def get_customizable_model_schema_from_credentials(
self, model: str, credentials: Mapping
) -> Optional[AIModelEntity]:
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema from credentials
@@ -248,7 +245,7 @@ class AIModel(ABC):
"""
return self._get_customizable_model_schema(model, credentials)
def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema and fill in the template
"""
@@ -301,7 +298,7 @@ class AIModel(ABC):
return schema
def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema

View File

@@ -2,7 +2,7 @@ import logging
import re
import time
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Generator, Sequence
from typing import Optional, Union
from pydantic import ConfigDict
@@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@@ -291,12 +291,12 @@ if you are not sure about the structure.
content = piece.delta.message.content
piece.delta.message.content = ""
yield piece
piece = content
content_piece = content
else:
yield piece
continue
new_piece: str = ""
for char in piece:
for char in content_piece:
char = str(char)
if state == "normal":
if char == "`":
@@ -350,7 +350,7 @@ if you are not sure about the structure.
piece.delta.message.content = ""
# Yield a piece with cleared content before processing it to maintain the generator structure
yield piece
piece = content
content_piece = content
else:
# Yield pieces without content directly
yield piece
@@ -360,7 +360,7 @@ if you are not sure about the structure.
continue
new_piece: str = ""
for char in piece:
for char in content_piece:
if state == "search_start":
if char == "`":
backtick_count += 1
@@ -535,7 +535,7 @@ if you are not sure about the structure.
return []
def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
"""
Get model mode

View File

@@ -104,9 +104,10 @@ class ModelProvider(ABC):
mod = import_module_from_source(
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
)
# FIXME "type" has no attribute "__abstractmethods__" ignore it for now fix it later
model_class = next(
filter(
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, # type: ignore
get_subclasses_from_module(mod, AIModel),
),
None,

View File

@@ -89,7 +89,8 @@ class TextEmbeddingModel(AIModel):
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
return content_size
return 1000
@@ -104,6 +105,7 @@ class TextEmbeddingModel(AIModel):
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
return max_chunks
return 1

View File

@@ -2,9 +2,9 @@ from os.path import abspath, dirname, join
from threading import Lock
from typing import Any
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
_tokenizer = None
_tokenizer: Any = None
_lock = Lock()

View File

@@ -127,7 +127,8 @@ class TTSModel(AIModel):
if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
raise ValueError("this model does not support audio type")
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
audio_type: str = model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
return audio_type
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
"""
@@ -138,8 +139,9 @@ class TTSModel(AIModel):
if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
raise ValueError("this model does not support word limit")
world_limit: int = model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
return world_limit
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
"""
@@ -150,8 +152,9 @@ class TTSModel(AIModel):
if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
raise ValueError("this model does not support max workers")
workers_limit: int = model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
return workers_limit
@staticmethod
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):

View File

@@ -64,10 +64,12 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
if not ai_model_entity:
return None
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
def _get_ai_model_entity(base_model_name: str, model: str) -> Optional[AzureBaseModel]:
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)

View File

@@ -114,6 +114,8 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
if not ai_model_entity:
return None
return ai_model_entity.entity
@staticmethod

View File

@@ -6,9 +6,9 @@ from collections.abc import Generator
from typing import Optional, Union, cast
# 3rd import
import boto3
from botocore.config import Config
from botocore.exceptions import (
import boto3 # type: ignore
from botocore.config import Config # type: ignore
from botocore.exceptions import ( # type: ignore
ClientError,
EndpointConnectionError,
NoRegionError,

View File

@@ -44,7 +44,7 @@ class CohereRerankModel(RerankModel):
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=docs)
return RerankResult(model=model, docs=[])
# initialize client
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
@@ -62,7 +62,7 @@ class CohereRerankModel(RerankModel):
# format document
rerank_document = RerankDocument(
index=result.index,
text=result.document.text,
text=result.document.text if result.document else "",
score=result.relevance_score,
)

View File

@@ -1,5 +1,3 @@
from collections.abc import Mapping
import openai
from core.model_runtime.errors.invoke import (
@@ -13,7 +11,7 @@ from core.model_runtime.errors.invoke import (
class _CommonFireworks:
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
def _to_credential_kwargs(self, credentials: dict) -> dict:
"""
Transform credentials to kwargs for model instance

View File

@@ -1,5 +1,4 @@
import time
from collections.abc import Mapping
from typing import Optional, Union
import numpy as np
@@ -93,7 +92,7 @@ class FireworksTextEmbeddingModel(_CommonFireworks, TextEmbeddingModel):
"""
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
def validate_credentials(self, model: str, credentials: Mapping) -> None:
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials

View File

@@ -1,4 +1,4 @@
from dashscope.common.error import (
from dashscope.common.error import ( # type: ignore
AuthenticationError,
InvalidParameter,
RequestFailure,

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional
import httpx
@@ -51,7 +51,7 @@ class GiteeAIRerankModel(RerankModel):
base_url = base_url.removesuffix("/")
try:
body = {"model": model, "query": query, "documents": docs}
body: dict[str, Any] = {"model": model, "query": query, "documents": docs}
if top_n is not None:
body["top_n"] = top_n
response = httpx.post(

View File

@@ -24,7 +24,7 @@ class GiteeAIEmbeddingModel(OAICompatEmbeddingModel):
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict, model: str) -> None:
def _add_custom_parameters(credentials: dict, model: Optional[str]) -> None:
if model is None:
model = "bge-m3"

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional
import requests
@@ -13,9 +13,10 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
Model class for OpenAI text2speech model.
"""
# FIXME this Any return will be better type
def _invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
) -> any:
) -> Any:
"""
_invoke text2speech model
@@ -47,7 +48,8 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
# FIXME this Any return will be better type
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any:
"""
_tts_invoke_streaming text2speech model
:param model: model name

View File

@@ -7,7 +7,7 @@ from collections.abc import Generator
from typing import Optional, Union
import google.ai.generativelanguage as glm
import google.generativeai as genai
import google.generativeai as genai # type: ignore
import requests
from google.api_core import exceptions
from google.generativeai.types import ContentType, File, GenerateContentResponse

View File

@@ -1,4 +1,4 @@
from huggingface_hub.utils import BadRequestError, HfHubHTTPError
from huggingface_hub.utils import BadRequestError, HfHubHTTPError # type: ignore
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError

View File

@@ -1,9 +1,9 @@
from collections.abc import Generator
from typing import Optional, Union
from huggingface_hub import InferenceClient
from huggingface_hub.hf_api import HfApi
from huggingface_hub.utils import BadRequestError
from huggingface_hub import InferenceClient # type: ignore
from huggingface_hub.hf_api import HfApi # type: ignore
from huggingface_hub.utils import BadRequestError # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE

View File

@@ -4,7 +4,7 @@ from typing import Optional
import numpy as np
import requests
from huggingface_hub import HfApi, InferenceClient
from huggingface_hub import HfApi, InferenceClient # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject

View File

@@ -3,11 +3,11 @@ import logging
from collections.abc import Generator
from typing import cast
from tencentcloud.common import credential
from tencentcloud.common.exception import TencentCloudSDKException
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
from tencentcloud.common import credential # type: ignore
from tencentcloud.common.exception import TencentCloudSDKException # type: ignore
from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore
from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
@@ -305,7 +305,7 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
elif isinstance(message, ToolPromptMessage):
message_text = f"{tool_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = content
message_text = content if isinstance(content, str) else ""
else:
raise ValueError(f"Got unknown type {message}")

View File

@@ -3,11 +3,11 @@ import logging
import time
from typing import Optional
from tencentcloud.common import credential
from tencentcloud.common.exception import TencentCloudSDKException
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
from tencentcloud.common import credential # type: ignore
from tencentcloud.common.exception import TencentCloudSDKException # type: ignore
from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore
from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType

View File

@@ -1,11 +1,11 @@
from os.path import abspath, dirname, join
from threading import Lock
from transformers import AutoTokenizer
from transformers import AutoTokenizer # type: ignore
class JinaTokenizer:
_tokenizer = None
_tokenizer: AutoTokenizer | None = None
_lock = Lock()
@classmethod

View File

@@ -40,7 +40,7 @@ class MinimaxChatCompletion:
url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}"
extra_kwargs = {}
extra_kwargs: dict[str, Any] = {}
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
@@ -117,19 +117,19 @@ class MinimaxChatCompletion:
"""
handle chat generate response
"""
response = response.json()
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
code = response["base_resp"]["status_code"]
msg = response["base_resp"]["status_msg"]
response_data = response.json()
if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0:
code = response_data["base_resp"]["status_code"]
msg = response_data["base_resp"]["status_msg"]
self._handle_error(code, msg)
message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
message.usage = {
"prompt_tokens": 0,
"completion_tokens": response["usage"]["total_tokens"],
"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response_data["usage"]["total_tokens"],
"total_tokens": response_data["usage"]["total_tokens"],
}
message.stop_reason = response["choices"][0]["finish_reason"]
message.stop_reason = response_data["choices"][0]["finish_reason"]
return message
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
@@ -139,10 +139,10 @@ class MinimaxChatCompletion:
for line in response.iter_lines():
if not line:
continue
line: str = line.decode("utf-8")
if line.startswith("data: "):
line = line[6:].strip()
data = loads(line)
line_str: str = line.decode("utf-8")
if line_str.startswith("data: "):
line_str = line_str[6:].strip()
data = loads(line_str)
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
code = data["base_resp"]["status_code"]
@@ -162,5 +162,5 @@ class MinimaxChatCompletion:
continue
for choice in choices:
message = choice["delta"]
yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value)
message_choice = choice["delta"]
yield MinimaxMessage(content=message_choice, role=MinimaxMessage.Role.ASSISTANT.value)

View File

@@ -41,7 +41,7 @@ class MinimaxChatCompletionPro:
url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}"
extra_kwargs = {}
extra_kwargs: dict[str, Any] = {}
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
@@ -122,19 +122,19 @@ class MinimaxChatCompletionPro:
"""
handle chat generate response
"""
response = response.json()
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
code = response["base_resp"]["status_code"]
msg = response["base_resp"]["status_msg"]
response_data = response.json()
if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0:
code = response_data["base_resp"]["status_code"]
msg = response_data["base_resp"]["status_msg"]
self._handle_error(code, msg)
message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
message.usage = {
"prompt_tokens": 0,
"completion_tokens": response["usage"]["total_tokens"],
"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response_data["usage"]["total_tokens"],
"total_tokens": response_data["usage"]["total_tokens"],
}
message.stop_reason = response["choices"][0]["finish_reason"]
message.stop_reason = response_data["choices"][0]["finish_reason"]
return message
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
@@ -144,10 +144,10 @@ class MinimaxChatCompletionPro:
for line in response.iter_lines():
if not line:
continue
line: str = line.decode("utf-8")
if line.startswith("data: "):
line = line[6:].strip()
data = loads(line)
line_str: str = line.decode("utf-8")
if line_str.startswith("data: "):
line_str = line_str[6:].strip()
data = loads(line_str)
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
code = data["base_resp"]["status_code"]

View File

@@ -11,9 +11,9 @@ class MinimaxMessage:
role: str = Role.USER.value
content: str
usage: dict[str, int] = None
usage: dict[str, int] | None = None
stop_reason: str = ""
function_call: dict[str, Any] = None
function_call: dict[str, Any] | None = None
def to_dict(self) -> dict[str, Any]:
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:

View File

@@ -2,8 +2,8 @@ import time
from functools import wraps
from typing import Optional
from nomic import embed
from nomic import login as nomic_login
from nomic import embed # type: ignore
from nomic import login as nomic_login # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType

View File

@@ -5,8 +5,8 @@ import logging
from collections.abc import Generator
from typing import Optional, Union
import oci
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse
import oci # type: ignore
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse # type: ignore
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (

View File

@@ -4,7 +4,7 @@ import time
from typing import Optional
import numpy as np
import oci
import oci # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType

View File

@@ -61,6 +61,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
headers = {"Content-Type": "application/json"}
endpoint_url = credentials.get("base_url")
assert endpoint_url is not None, "Base URL is required for Ollama API"
if not endpoint_url.endswith("/"):
endpoint_url += "/"

View File

@@ -1,5 +1,3 @@
from collections.abc import Mapping
import openai
from httpx import Timeout
@@ -14,7 +12,7 @@ from core.model_runtime.errors.invoke import (
class _CommonOpenAI:
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
def _to_credential_kwargs(self, credentials: dict) -> dict:
"""
Transform credentials to kwargs for model instance

View File

@@ -93,7 +93,8 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
max_characters_per_chunk: int = model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
return max_characters_per_chunk
return 2000
@@ -108,6 +109,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
return max_chunks
return 1

View File

@@ -1,5 +1,4 @@
import logging
from collections.abc import Mapping
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -9,7 +8,7 @@ logger = logging.getLogger(__name__)
class OpenAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: Mapping) -> None:
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception

View File

@@ -33,6 +33,7 @@ class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel):
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get("endpoint_url")
assert endpoint_url is not None, "endpoint_url is required in credentials"
if not endpoint_url.endswith("/"):
endpoint_url += "/"
endpoint_url = urljoin(endpoint_url, "audio/transcriptions")

View File

@@ -55,6 +55,7 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get("endpoint_url")
assert endpoint_url is not None, "endpoint_url is required in credentials"
if not endpoint_url.endswith("/"):
endpoint_url += "/"

View File

@@ -44,6 +44,7 @@ class OAICompatText2SpeechModel(_CommonOaiApiCompat, TTSModel):
# Construct endpoint URL
endpoint_url = credentials.get("endpoint_url")
assert endpoint_url is not None, "endpoint_url is required in credentials"
if not endpoint_url.endswith("/"):
endpoint_url += "/"
endpoint_url = urljoin(endpoint_url, "audio/speech")

View File

@@ -1,7 +1,7 @@
from collections.abc import Generator
from enum import Enum
from json import dumps, loads
from typing import Any, Union
from typing import Any, Optional, Union
from requests import Response, post
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
@@ -20,7 +20,7 @@ class OpenLLMGenerateMessage:
role: str = Role.USER.value
content: str
usage: dict[str, int] = None
usage: Optional[dict[str, int]] = None
stop_reason: str = ""
def to_dict(self) -> dict[str, Any]:
@@ -165,17 +165,17 @@ class OpenLLMGenerate:
if not line:
continue
line: str = line.decode("utf-8")
if line.startswith("data: "):
line = line[6:].strip()
line_str: str = line.decode("utf-8")
if line_str.startswith("data: "):
line_str = line_str[6:].strip()
if line == "[DONE]":
if line_str == "[DONE]":
return
try:
data = loads(line)
data = loads(line_str)
except Exception as e:
raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}")
raise InternalServerError(f"Failed to convert response to json: {e} with text: {line_str}")
output = data["outputs"]

View File

@@ -53,14 +53,16 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url: Optional[str]
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
endpoint_url = "https://cloud.perfxlab.cn/v1/"
else:
endpoint_url = credentials.get("endpoint_url")
assert endpoint_url is not None, "endpoint_url is required in credentials"
if not endpoint_url.endswith("/"):
endpoint_url += "/"
assert isinstance(endpoint_url, str)
endpoint_url = urljoin(endpoint_url, "embeddings")
extra_model_kwargs = {}
@@ -142,13 +144,16 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url: Optional[str]
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
endpoint_url = "https://cloud.perfxlab.cn/v1/"
else:
endpoint_url = credentials.get("endpoint_url")
assert endpoint_url is not None, "endpoint_url is required in credentials"
if not endpoint_url.endswith("/"):
endpoint_url += "/"
assert isinstance(endpoint_url, str)
endpoint_url = urljoin(endpoint_url, "embeddings")
payload = {"input": "ping", "model": model}

View File

@@ -1,4 +1,4 @@
from replicate.exceptions import ModelError, ReplicateError
from replicate.exceptions import ModelError, ReplicateError # type: ignore
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError

View File

@@ -1,9 +1,9 @@
from collections.abc import Generator
from typing import Optional, Union
from replicate import Client as ReplicateClient
from replicate.exceptions import ReplicateError
from replicate.prediction import Prediction
from replicate import Client as ReplicateClient # type: ignore
from replicate.exceptions import ReplicateError # type: ignore
from replicate.prediction import Prediction # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta

View File

@@ -2,11 +2,11 @@ import json
import time
from typing import Optional
from replicate import Client as ReplicateClient
from replicate import Client as ReplicateClient # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
@@ -86,7 +86,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={"context_size": 4096, "max_chunks": 1},
model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096, ModelPropertyKey.MAX_CHUNKS: 1},
)
return entity

View File

@@ -4,7 +4,7 @@ import re
from collections.abc import Generator, Iterator
from typing import Any, Optional, Union, cast
import boto3
import boto3 # type: ignore
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
@@ -83,7 +83,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
sagemaker_session: Any = None
predictor: Any = None
sagemaker_endpoint: str = None
sagemaker_endpoint: str | None = None
def _handle_chat_generate_response(
self,
@@ -209,8 +209,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
from sagemaker import Predictor, serializers
from sagemaker.session import Session
from sagemaker import Predictor, serializers # type: ignore
from sagemaker.session import Session # type: ignore
if not self.sagemaker_session:
access_key = credentials.get("aws_access_key_id")

View File

@@ -3,7 +3,7 @@ import logging
import operator
from typing import Any, Optional
import boto3
import boto3 # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
@@ -114,6 +114,7 @@ class SageMakerRerankModel(RerankModel):
except Exception as e:
logger.exception(f"Failed to invoke rerank model, model: {model}")
raise InvokeError(f"Failed to invoke rerank model, model: {model}, error: {str(e)}")
def validate_credentials(self, model: str, credentials: dict) -> None:
"""

View File

@@ -2,7 +2,7 @@ import json
import logging
from typing import IO, Any, Optional
import boto3
import boto3 # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
@@ -67,6 +67,7 @@ class SageMakerSpeech2TextModel(Speech2TextModel):
s3_prefix = "dify/speech2text/"
sagemaker_endpoint = credentials.get("sagemaker_endpoint")
bucket = credentials.get("audio_s3_cache_bucket")
assert bucket is not None, "audio_s3_cache_bucket is required in credentials"
s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix)
payload = {"audio_s3_presign_uri": s3_presign_url}

View File

@@ -4,7 +4,7 @@ import logging
import time
from typing import Any, Optional
import boto3
import boto3 # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
@@ -118,6 +118,7 @@ class SageMakerEmbeddingModel(TextEmbeddingModel):
except Exception as e:
logger.exception(f"Failed to invoke text embedding model, model: {model}, line: {line}")
raise InvokeError(str(e))
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""

View File

@@ -5,7 +5,7 @@ import logging
from enum import Enum
from typing import Any, Optional
import boto3
import boto3 # type: ignore
import requests
from core.model_runtime.entities.common_entities import I18nObject

View File

@@ -43,7 +43,7 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
credentials["mode"] = "chat"
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),

View File

@@ -1,6 +1,6 @@
import threading
from collections.abc import Generator
from typing import Optional, Union
from typing import Optional, Union, cast
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
@@ -270,7 +270,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = content
message_text = cast(str, content)
else:
raise ValueError(f"Got unknown type {message}")

View File

@@ -12,6 +12,7 @@ from core.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
@@ -67,7 +68,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
REPETITION_PENALTY = "repetition_penalty"
TOP_K = "top_k"
features = []
features: list[ModelFeature] = []
entity = AIModelEntity(
model=model,

View File

@@ -1,4 +1,4 @@
from dashscope.common.error import (
from dashscope.common.error import ( # type: ignore
AuthenticationError,
InvalidParameter,
RequestFailure,

View File

@@ -7,9 +7,9 @@ from http import HTTPStatus
from pathlib import Path
from typing import Optional, Union, cast
from dashscope import Generation, MultiModalConversation, get_tokenizer
from dashscope.api_entities.dashscope_response import GenerationResponse
from dashscope.common.error import (
from dashscope import Generation, MultiModalConversation, get_tokenizer # type: ignore
from dashscope.api_entities.dashscope_response import GenerationResponse # type: ignore
from dashscope.common.error import ( # type: ignore
AuthenticationError,
InvalidParameter,
RequestFailure,

View File

@@ -1,7 +1,7 @@
from typing import Optional
import dashscope
from dashscope.common.error import (
import dashscope # type: ignore
from dashscope.common.error import ( # type: ignore
AuthenticationError,
InvalidParameter,
RequestFailure,
@@ -51,7 +51,7 @@ class GTERerankModel(RerankModel):
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=docs)
return RerankResult(model=model, docs=[])
# initialize client
dashscope.api_key = credentials["dashscope_api_key"]
@@ -64,7 +64,7 @@ class GTERerankModel(RerankModel):
return_documents=True,
)
rerank_documents = []
rerank_documents: list[RerankDocument] = []
if not response.output:
return RerankResult(model=model, docs=rerank_documents)
for _, result in enumerate(response.output.results):

View File

@@ -1,7 +1,7 @@
import time
from typing import Optional
import dashscope
import dashscope # type: ignore
import numpy as np
from core.entities.embedding_type import EmbeddingInputType

View File

@@ -2,10 +2,10 @@ import threading
from queue import Queue
from typing import Any, Optional
import dashscope
from dashscope import SpeechSynthesizer
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult
import dashscope # type: ignore
from dashscope import SpeechSynthesizer # type: ignore
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse # type: ignore
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult # type: ignore
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError

View File

@@ -1,5 +1,3 @@
from collections.abc import Mapping
import openai
from httpx import Timeout
@@ -14,7 +12,7 @@ from core.model_runtime.errors.invoke import (
class _CommonUpstage:
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
def _to_credential_kwargs(self, credentials: dict) -> dict:
"""
Transform credentials to kwargs for model instance

View File

@@ -6,7 +6,7 @@ from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall
from tokenizers import Tokenizer
from tokenizers import Tokenizer # type: ignore
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta

View File

@@ -1,11 +1,10 @@
import base64
import time
from collections.abc import Mapping
from typing import Union
import numpy as np
from openai import OpenAI
from tokenizers import Tokenizer
from tokenizers import Tokenizer # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
@@ -132,7 +131,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
return total_num_tokens
def validate_credentials(self, model: str, credentials: Mapping) -> None:
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials

View File

@@ -12,4 +12,4 @@ class _CommonVertexAi:
:return: Invoke error mapping
"""
pass
raise NotImplementedError

View File

@@ -6,7 +6,7 @@ import time
from collections.abc import Generator
from typing import TYPE_CHECKING, Optional, Union, cast
import google.auth.transport.requests
import google.auth.transport.requests # type: ignore
import requests
from anthropic import AnthropicVertex, Stream
from anthropic.types import (

View File

@@ -17,14 +17,12 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
features = []
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
features=features,
features=[],
model_properties={
ModelPropertyKey.MODE: credentials.get("mode"),
},

View File

@@ -1,8 +1,8 @@
from collections.abc import Generator
from typing import Optional, cast
from volcenginesdkarkruntime import Ark
from volcenginesdkarkruntime.types.chat import (
from volcenginesdkarkruntime import Ark # type: ignore
from volcenginesdkarkruntime.types.chat import ( # type: ignore
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
@@ -15,10 +15,10 @@ from volcenginesdkarkruntime.types.chat import (
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
)
from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL
from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function
from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse
from volcenginesdkarkruntime.types.shared_params import FunctionDefinition
from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL # type: ignore
from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function # type: ignore
from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse # type: ignore
from volcenginesdkarkruntime.types.shared_params import FunctionDefinition # type: ignore
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,

View File

@@ -152,5 +152,6 @@ ErrorCodeMap = {
def wrap_error(e: MaasError) -> Exception:
if ErrorCodeMap.get(e.code):
return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
# FIXME: mypy type error, try to fix it instead of using type: ignore
return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) # type: ignore
return e

View File

@@ -2,7 +2,7 @@ import logging
from collections.abc import Generator
from typing import Optional
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMMode
@@ -102,7 +104,7 @@ def get_model_config(credentials: dict) -> ModelConfig:
def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None):
req_params = {}
req_params: dict[str, Any] = {}
# predefined properties
model_configs = get_model_config(credentials)
if model_configs:
@@ -130,7 +132,7 @@ def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str]
def get_v3_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None):
req_params = {}
req_params: dict[str, Any] = {}
# predefined properties
model_configs = get_model_config(credentials)
if model_configs:

View File

@@ -1,7 +1,7 @@
from collections.abc import Generator
from enum import Enum
from json import dumps, loads
from typing import Any, Union
from typing import Any, Optional, Union
from requests import Response, post
@@ -22,7 +22,7 @@ class ErnieMessage:
role: str = Role.USER.value
content: str
usage: dict[str, int] = None
usage: Optional[dict[str, int]] = None
stop_reason: str = ""
def to_dict(self) -> dict[str, Any]:
@@ -135,6 +135,7 @@ class ErnieBotModel(_CommonWenxin):
"""
TODO: implement function calling
"""
raise NotImplementedError("Function calling is not supported yet.")
def _build_chat_request_body(
self,

View File

@@ -1,6 +1,5 @@
import time
from abc import abstractmethod
from collections.abc import Mapping
from json import dumps
from typing import Any, Optional
@@ -23,12 +22,12 @@ from core.model_runtime.model_providers.wenxin.wenxin_errors import (
class TextEmbedding:
@abstractmethod
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]:
raise NotImplementedError
class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]:
access_token = self._get_access_token()
url = f"{self.api_bases[model]}?access_token={access_token}"
body = self._build_embed_request_body(model, texts, user)
@@ -50,7 +49,7 @@ class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
}
return body
def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
def _handle_embed_response(self, model: str, response: Response) -> tuple[list[list[float]], int, int]:
data = response.json()
if "error_code" in data:
code = data["error_code"]
@@ -147,7 +146,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel):
return total_num_tokens
def validate_credentials(self, model: str, credentials: Mapping) -> None:
def validate_credentials(self, model: str, credentials: dict) -> None:
api_key = credentials["api_key"]
secret_key = credentials["secret_key"]
try:

View File

@@ -17,7 +17,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.completion import Completion
from xinference_client.client.restful.restful_client import (
from xinference_client.client.restful.restful_client import ( # type: ignore
Client,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,

View File

@@ -1,6 +1,6 @@
from typing import Optional
from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle
from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType

View File

@@ -1,6 +1,6 @@
from typing import IO, Optional
from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle
from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType

View File

@@ -1,7 +1,7 @@
import time
from typing import Optional
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
@@ -134,7 +134,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
try:
handle = client.get_model(model_uid=model_uid)
except RuntimeError as e:
raise InvokeAuthorizationError(e)
raise InvokeAuthorizationError(str(e))
if not isinstance(handle, RESTfulEmbeddingModelHandle):
raise InvokeBadRequestError(

View File

@@ -1,7 +1,7 @@
import concurrent.futures
from typing import Any, Optional
from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle
from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
@@ -74,11 +74,14 @@ class XinferenceText2SpeechModel(TTSModel):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
credentials["server_url"] = credentials["server_url"].removesuffix("/")
api_key = credentials.get("api_key")
if api_key is None:
raise CredentialsValidateFailedError("api_key is required")
extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials["server_url"],
model_uid=credentials["model_uid"],
api_key=credentials.get("api_key"),
api_key=api_key,
)
if "text-to-audio" not in extra_param.model_ability:

View File

@@ -1,6 +1,6 @@
from threading import Lock
from time import time
from typing import Optional
from typing import Any, Optional
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, MissingSchema, Timeout
@@ -39,13 +39,15 @@ class XinferenceModelExtraParameter:
self.model_family = model_family
cache = {}
cache: dict[str, dict[str, Any]] = {}
cache_lock = Lock()
class XinferenceHelper:
@staticmethod
def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
def get_xinference_extra_parameter(
server_url: str, model_uid: str, api_key: str | None
) -> XinferenceModelExtraParameter:
XinferenceHelper._clean_cache()
with cache_lock:
if model_uid not in cache:
@@ -66,7 +68,9 @@ class XinferenceHelper:
pass
@staticmethod
def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
def _get_xinference_extra_parameter(
server_url: str, model_uid: str, api_key: str | None
) -> XinferenceModelExtraParameter:
"""
get xinference model extra parameter like model_format and model_handle_type
"""

View File

@@ -136,7 +136,7 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel):
parsed_url = urlparse(credentials["endpoint_url"])
credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}"
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),

View File

@@ -1,9 +1,9 @@
from collections.abc import Generator
from typing import Optional, Union
from zhipuai import ZhipuAI
from zhipuai.types.chat.chat_completion import Completion
from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk
from zhipuai import ZhipuAI # type: ignore
from zhipuai.types.chat.chat_completion import Completion # type: ignore
from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk # type: ignore
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (

View File

@@ -1,7 +1,7 @@
import time
from typing import Optional
from zhipuai import ZhipuAI
from zhipuai import ZhipuAI # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Union, cast
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType
@@ -38,7 +38,7 @@ class CommonValidator:
def _validate_credential_form_schema(
self, credential_form_schema: CredentialFormSchema, credentials: dict
) -> Optional[str]:
) -> Union[str, bool, None]:
"""
Validate credential form schema
@@ -47,6 +47,7 @@ class CommonValidator:
:return: validated credential form schema value
"""
# If the variable does not exist in credentials
value: Union[str, bool, None] = None
if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]:
# If required is True, an exception is thrown
if credential_form_schema.required:
@@ -61,7 +62,7 @@ class CommonValidator:
return None
# Get the value corresponding to the variable from credentials
value = credentials[credential_form_schema.variable]
value = cast(str, credentials[credential_form_schema.variable])
# If max_length=0, no validation is performed
if credential_form_schema.max_length:

View File

@@ -129,7 +129,8 @@ def jsonable_encoder(
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
obj_dict = dataclasses.asdict(obj)
# FIXME: mypy error, try to fix it instead of using type: ignore
obj_dict = dataclasses.asdict(obj) # type: ignore
return jsonable_encoder(
obj_dict,
by_alias=by_alias,

View File

@@ -4,6 +4,7 @@ from pydantic import BaseModel
def dump_model(model: BaseModel) -> dict:
if hasattr(pydantic, "model_dump"):
return pydantic.model_dump(model)
# FIXME mypy error, try to fix it instead of using type: ignore
return pydantic.model_dump(model) # type: ignore
else:
return model.model_dump()