mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-15 22:06:52 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
@@ -20,7 +21,7 @@ class LoggingCallback(Callback):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
@@ -76,7 +77,7 @@ class LoggingCallback(Callback):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
):
|
||||
@@ -94,7 +95,7 @@ class LoggingCallback(Callback):
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
sys.stdout.write(chunk.delta.message.content)
|
||||
sys.stdout.write(cast(str, chunk.delta.message.content))
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_after_invoke(
|
||||
@@ -106,7 +107,7 @@ class LoggingCallback(Callback):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
@@ -147,7 +148,7 @@ class LoggingCallback(Callback):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
@@ -89,7 +89,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
|
||||
@computed_field(return_type=str)
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import decimal
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@@ -36,7 +35,7 @@ class AIModel(ABC):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
@@ -214,7 +213,7 @@ class AIModel(ABC):
|
||||
|
||||
return model_schemas
|
||||
|
||||
def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
|
||||
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get model schema by model name and credentials
|
||||
|
||||
@@ -236,9 +235,7 @@ class AIModel(ABC):
|
||||
|
||||
return None
|
||||
|
||||
def get_customizable_model_schema_from_credentials(
|
||||
self, model: str, credentials: Mapping
|
||||
) -> Optional[AIModelEntity]:
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
@@ -248,7 +245,7 @@ class AIModel(ABC):
|
||||
"""
|
||||
return self._get_customizable_model_schema(model, credentials)
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema and fill in the template
|
||||
"""
|
||||
@@ -301,7 +298,7 @@ class AIModel(ABC):
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
@@ -291,12 +291,12 @@ if you are not sure about the structure.
|
||||
content = piece.delta.message.content
|
||||
piece.delta.message.content = ""
|
||||
yield piece
|
||||
piece = content
|
||||
content_piece = content
|
||||
else:
|
||||
yield piece
|
||||
continue
|
||||
new_piece: str = ""
|
||||
for char in piece:
|
||||
for char in content_piece:
|
||||
char = str(char)
|
||||
if state == "normal":
|
||||
if char == "`":
|
||||
@@ -350,7 +350,7 @@ if you are not sure about the structure.
|
||||
piece.delta.message.content = ""
|
||||
# Yield a piece with cleared content before processing it to maintain the generator structure
|
||||
yield piece
|
||||
piece = content
|
||||
content_piece = content
|
||||
else:
|
||||
# Yield pieces without content directly
|
||||
yield piece
|
||||
@@ -360,7 +360,7 @@ if you are not sure about the structure.
|
||||
continue
|
||||
|
||||
new_piece: str = ""
|
||||
for char in piece:
|
||||
for char in content_piece:
|
||||
if state == "search_start":
|
||||
if char == "`":
|
||||
backtick_count += 1
|
||||
@@ -535,7 +535,7 @@ if you are not sure about the structure.
|
||||
|
||||
return []
|
||||
|
||||
def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
|
||||
def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
|
||||
"""
|
||||
Get model mode
|
||||
|
||||
|
||||
@@ -104,9 +104,10 @@ class ModelProvider(ABC):
|
||||
mod = import_module_from_source(
|
||||
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
|
||||
)
|
||||
# FIXME "type" has no attribute "__abstractmethods__" ignore it for now fix it later
|
||||
model_class = next(
|
||||
filter(
|
||||
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
|
||||
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, # type: ignore
|
||||
get_subclasses_from_module(mod, AIModel),
|
||||
),
|
||||
None,
|
||||
|
||||
@@ -89,7 +89,8 @@ class TextEmbeddingModel(AIModel):
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
return content_size
|
||||
|
||||
return 1000
|
||||
|
||||
@@ -104,6 +105,7 @@ class TextEmbeddingModel(AIModel):
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
return max_chunks
|
||||
|
||||
return 1
|
||||
|
||||
@@ -2,9 +2,9 @@ from os.path import abspath, dirname, join
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||
|
||||
_tokenizer = None
|
||||
_tokenizer: Any = None
|
||||
_lock = Lock()
|
||||
|
||||
|
||||
|
||||
@@ -127,7 +127,8 @@ class TTSModel(AIModel):
|
||||
if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support audio type")
|
||||
|
||||
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
||||
audio_type: str = model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
||||
return audio_type
|
||||
|
||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
@@ -138,8 +139,9 @@ class TTSModel(AIModel):
|
||||
|
||||
if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support word limit")
|
||||
world_limit: int = model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||
|
||||
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||
return world_limit
|
||||
|
||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
@@ -150,8 +152,9 @@ class TTSModel(AIModel):
|
||||
|
||||
if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support max workers")
|
||||
workers_limit: int = model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
return workers_limit
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
|
||||
|
||||
@@ -64,10 +64,12 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
|
||||
if not ai_model_entity:
|
||||
return None
|
||||
return ai_model_entity.entity
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> Optional[AzureBaseModel]:
|
||||
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
|
||||
@@ -114,6 +114,8 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
|
||||
if not ai_model_entity:
|
||||
return None
|
||||
return ai_model_entity.entity
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -6,9 +6,9 @@ from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
# 3rd import
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import (
|
||||
import boto3 # type: ignore
|
||||
from botocore.config import Config # type: ignore
|
||||
from botocore.exceptions import ( # type: ignore
|
||||
ClientError,
|
||||
EndpointConnectionError,
|
||||
NoRegionError,
|
||||
|
||||
@@ -44,7 +44,7 @@ class CohereRerankModel(RerankModel):
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=docs)
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
|
||||
@@ -62,7 +62,7 @@ class CohereRerankModel(RerankModel):
|
||||
# format document
|
||||
rerank_document = RerankDocument(
|
||||
index=result.index,
|
||||
text=result.document.text,
|
||||
text=result.document.text if result.document else "",
|
||||
score=result.relevance_score,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import openai
|
||||
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@@ -13,7 +11,7 @@ from core.model_runtime.errors.invoke import (
|
||||
|
||||
|
||||
class _CommonFireworks:
|
||||
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
|
||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Transform credentials to kwargs for model instance
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -93,7 +92,7 @@ class FireworksTextEmbeddingModel(_CommonFireworks, TextEmbeddingModel):
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dashscope.common.error import (
|
||||
from dashscope.common.error import ( # type: ignore
|
||||
AuthenticationError,
|
||||
InvalidParameter,
|
||||
RequestFailure,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -51,7 +51,7 @@ class GiteeAIRerankModel(RerankModel):
|
||||
base_url = base_url.removesuffix("/")
|
||||
|
||||
try:
|
||||
body = {"model": model, "query": query, "documents": docs}
|
||||
body: dict[str, Any] = {"model": model, "query": query, "documents": docs}
|
||||
if top_n is not None:
|
||||
body["top_n"] = top_n
|
||||
response = httpx.post(
|
||||
|
||||
@@ -24,7 +24,7 @@ class GiteeAIEmbeddingModel(OAICompatEmbeddingModel):
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict, model: str) -> None:
|
||||
def _add_custom_parameters(credentials: dict, model: Optional[str]) -> None:
|
||||
if model is None:
|
||||
model = "bge-m3"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
@@ -13,9 +13,10 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
|
||||
Model class for OpenAI text2speech model.
|
||||
"""
|
||||
|
||||
# FIXME this Any return will be better type
|
||||
def _invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
) -> any:
|
||||
) -> Any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
@@ -47,7 +48,8 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
|
||||
# FIXME this Any return will be better type
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
:param model: model name
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.generativeai as genai
|
||||
import google.generativeai as genai # type: ignore
|
||||
import requests
|
||||
from google.api_core import exceptions
|
||||
from google.generativeai.types import ContentType, File, GenerateContentResponse
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from huggingface_hub.utils import BadRequestError, HfHubHTTPError
|
||||
from huggingface_hub.utils import BadRequestError, HfHubHTTPError # type: ignore
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
from huggingface_hub.utils import BadRequestError
|
||||
from huggingface_hub import InferenceClient # type: ignore
|
||||
from huggingface_hub.hf_api import HfApi # type: ignore
|
||||
from huggingface_hub.utils import BadRequestError # type: ignore
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import HfApi, InferenceClient
|
||||
from huggingface_hub import HfApi, InferenceClient # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
@@ -3,11 +3,11 @@ import logging
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.exception import TencentCloudSDKException
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
from tencentcloud.common import credential # type: ignore
|
||||
from tencentcloud.common.exception import TencentCloudSDKException # type: ignore
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
@@ -305,7 +305,7 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message_text = f"{tool_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
message_text = content if isinstance(content, str) else ""
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
@@ -3,11 +3,11 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.exception import TencentCloudSDKException
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
from tencentcloud.common import credential # type: ignore
|
||||
from tencentcloud.common.exception import TencentCloudSDKException # type: ignore
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from os.path import abspath, dirname, join
|
||||
from threading import Lock
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
class JinaTokenizer:
|
||||
_tokenizer = None
|
||||
_tokenizer: AutoTokenizer | None = None
|
||||
_lock = Lock()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -40,7 +40,7 @@ class MinimaxChatCompletion:
|
||||
|
||||
url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}"
|
||||
|
||||
extra_kwargs = {}
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
|
||||
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
|
||||
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
|
||||
@@ -117,19 +117,19 @@ class MinimaxChatCompletion:
|
||||
"""
|
||||
handle chat generate response
|
||||
"""
|
||||
response = response.json()
|
||||
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
|
||||
code = response["base_resp"]["status_code"]
|
||||
msg = response["base_resp"]["status_msg"]
|
||||
response_data = response.json()
|
||||
if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0:
|
||||
code = response_data["base_resp"]["status_code"]
|
||||
msg = response_data["base_resp"]["status_msg"]
|
||||
self._handle_error(code, msg)
|
||||
|
||||
message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
message.usage = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": response["usage"]["total_tokens"],
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response_data["usage"]["total_tokens"],
|
||||
"total_tokens": response_data["usage"]["total_tokens"],
|
||||
}
|
||||
message.stop_reason = response["choices"][0]["finish_reason"]
|
||||
message.stop_reason = response_data["choices"][0]["finish_reason"]
|
||||
return message
|
||||
|
||||
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
|
||||
@@ -139,10 +139,10 @@ class MinimaxChatCompletion:
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line: str = line.decode("utf-8")
|
||||
if line.startswith("data: "):
|
||||
line = line[6:].strip()
|
||||
data = loads(line)
|
||||
line_str: str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
line_str = line_str[6:].strip()
|
||||
data = loads(line_str)
|
||||
|
||||
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
|
||||
code = data["base_resp"]["status_code"]
|
||||
@@ -162,5 +162,5 @@ class MinimaxChatCompletion:
|
||||
continue
|
||||
|
||||
for choice in choices:
|
||||
message = choice["delta"]
|
||||
yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
message_choice = choice["delta"]
|
||||
yield MinimaxMessage(content=message_choice, role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
|
||||
@@ -41,7 +41,7 @@ class MinimaxChatCompletionPro:
|
||||
|
||||
url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}"
|
||||
|
||||
extra_kwargs = {}
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
|
||||
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
|
||||
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
|
||||
@@ -122,19 +122,19 @@ class MinimaxChatCompletionPro:
|
||||
"""
|
||||
handle chat generate response
|
||||
"""
|
||||
response = response.json()
|
||||
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
|
||||
code = response["base_resp"]["status_code"]
|
||||
msg = response["base_resp"]["status_msg"]
|
||||
response_data = response.json()
|
||||
if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0:
|
||||
code = response_data["base_resp"]["status_code"]
|
||||
msg = response_data["base_resp"]["status_msg"]
|
||||
self._handle_error(code, msg)
|
||||
|
||||
message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
message.usage = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": response["usage"]["total_tokens"],
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response_data["usage"]["total_tokens"],
|
||||
"total_tokens": response_data["usage"]["total_tokens"],
|
||||
}
|
||||
message.stop_reason = response["choices"][0]["finish_reason"]
|
||||
message.stop_reason = response_data["choices"][0]["finish_reason"]
|
||||
return message
|
||||
|
||||
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
|
||||
@@ -144,10 +144,10 @@ class MinimaxChatCompletionPro:
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line: str = line.decode("utf-8")
|
||||
if line.startswith("data: "):
|
||||
line = line[6:].strip()
|
||||
data = loads(line)
|
||||
line_str: str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
line_str = line_str[6:].strip()
|
||||
data = loads(line_str)
|
||||
|
||||
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
|
||||
code = data["base_resp"]["status_code"]
|
||||
|
||||
@@ -11,9 +11,9 @@ class MinimaxMessage:
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: dict[str, int] = None
|
||||
usage: dict[str, int] | None = None
|
||||
stop_reason: str = ""
|
||||
function_call: dict[str, Any] = None
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
|
||||
|
||||
@@ -2,8 +2,8 @@ import time
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from nomic import embed
|
||||
from nomic import login as nomic_login
|
||||
from nomic import embed # type: ignore
|
||||
from nomic import login as nomic_login # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
|
||||
@@ -5,8 +5,8 @@ import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import oci
|
||||
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse
|
||||
import oci # type: ignore
|
||||
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse # type: ignore
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
|
||||
@@ -4,7 +4,7 @@ import time
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import oci
|
||||
import oci # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
|
||||
@@ -61,6 +61,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
endpoint_url = credentials.get("base_url")
|
||||
assert endpoint_url is not None, "Base URL is required for Ollama API"
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import openai
|
||||
from httpx import Timeout
|
||||
|
||||
@@ -14,7 +12,7 @@ from core.model_runtime.errors.invoke import (
|
||||
|
||||
|
||||
class _CommonOpenAI:
|
||||
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
|
||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Transform credentials to kwargs for model instance
|
||||
|
||||
|
||||
@@ -93,7 +93,8 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
|
||||
max_characters_per_chunk: int = model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
|
||||
return max_characters_per_chunk
|
||||
|
||||
return 2000
|
||||
|
||||
@@ -108,6 +109,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
return max_chunks
|
||||
|
||||
return 1
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@@ -9,7 +8,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: Mapping) -> None:
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
@@ -33,6 +33,7 @@ class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel):
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
assert endpoint_url is not None, "endpoint_url is required in credentials"
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
endpoint_url = urljoin(endpoint_url, "audio/transcriptions")
|
||||
|
||||
@@ -55,6 +55,7 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
assert endpoint_url is not None, "endpoint_url is required in credentials"
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ class OAICompatText2SpeechModel(_CommonOaiApiCompat, TTSModel):
|
||||
|
||||
# Construct endpoint URL
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
assert endpoint_url is not None, "endpoint_url is required in credentials"
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
endpoint_url = urljoin(endpoint_url, "audio/speech")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from json import dumps, loads
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from requests import Response, post
|
||||
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
|
||||
@@ -20,7 +20,7 @@ class OpenLLMGenerateMessage:
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: dict[str, int] = None
|
||||
usage: Optional[dict[str, int]] = None
|
||||
stop_reason: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
@@ -165,17 +165,17 @@ class OpenLLMGenerate:
|
||||
if not line:
|
||||
continue
|
||||
|
||||
line: str = line.decode("utf-8")
|
||||
if line.startswith("data: "):
|
||||
line = line[6:].strip()
|
||||
line_str: str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
line_str = line_str[6:].strip()
|
||||
|
||||
if line == "[DONE]":
|
||||
if line_str == "[DONE]":
|
||||
return
|
||||
|
||||
try:
|
||||
data = loads(line)
|
||||
data = loads(line_str)
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}")
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {line_str}")
|
||||
|
||||
output = data["outputs"]
|
||||
|
||||
|
||||
@@ -53,14 +53,16 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url: Optional[str]
|
||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
||||
endpoint_url = "https://cloud.perfxlab.cn/v1/"
|
||||
else:
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
assert endpoint_url is not None, "endpoint_url is required in credentials"
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
assert isinstance(endpoint_url, str)
|
||||
endpoint_url = urljoin(endpoint_url, "embeddings")
|
||||
|
||||
extra_model_kwargs = {}
|
||||
@@ -142,13 +144,16 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url: Optional[str]
|
||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
||||
endpoint_url = "https://cloud.perfxlab.cn/v1/"
|
||||
else:
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
assert endpoint_url is not None, "endpoint_url is required in credentials"
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
assert isinstance(endpoint_url, str)
|
||||
endpoint_url = urljoin(endpoint_url, "embeddings")
|
||||
|
||||
payload = {"input": "ping", "model": model}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
from replicate.exceptions import ModelError, ReplicateError # type: ignore
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from replicate import Client as ReplicateClient
|
||||
from replicate.exceptions import ReplicateError
|
||||
from replicate.prediction import Prediction
|
||||
from replicate import Client as ReplicateClient # type: ignore
|
||||
from replicate.exceptions import ReplicateError # type: ignore
|
||||
from replicate.prediction import Prediction # type: ignore
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
|
||||
@@ -2,11 +2,11 @@ import json
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from replicate import Client as ReplicateClient
|
||||
from replicate import Client as ReplicateClient # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
@@ -86,7 +86,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={"context_size": 4096, "max_chunks": 1},
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096, ModelPropertyKey.MAX_CHUNKS: 1},
|
||||
)
|
||||
return entity
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
@@ -83,7 +83,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
sagemaker_session: Any = None
|
||||
predictor: Any = None
|
||||
sagemaker_endpoint: str = None
|
||||
sagemaker_endpoint: str | None = None
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
@@ -209,8 +209,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
from sagemaker import Predictor, serializers
|
||||
from sagemaker.session import Session
|
||||
from sagemaker import Predictor, serializers # type: ignore
|
||||
from sagemaker.session import Session # type: ignore
|
||||
|
||||
if not self.sagemaker_session:
|
||||
access_key = credentials.get("aws_access_key_id")
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import operator
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
@@ -114,6 +114,7 @@ class SageMakerRerankModel(RerankModel):
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to invoke rerank model, model: {model}")
|
||||
raise InvokeError(f"Failed to invoke rerank model, model: {model}, error: {str(e)}")
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from typing import IO, Any, Optional
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
@@ -67,6 +67,7 @@ class SageMakerSpeech2TextModel(Speech2TextModel):
|
||||
s3_prefix = "dify/speech2text/"
|
||||
sagemaker_endpoint = credentials.get("sagemaker_endpoint")
|
||||
bucket = credentials.get("audio_s3_cache_bucket")
|
||||
assert bucket is not None, "audio_s3_cache_bucket is required in credentials"
|
||||
|
||||
s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix)
|
||||
payload = {"audio_s3_presign_uri": s3_presign_url}
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
@@ -118,6 +118,7 @@ class SageMakerEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to invoke text embedding model, model: {model}, line: {line}")
|
||||
raise InvokeError(str(e))
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
@@ -43,7 +43,7 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
credentials["mode"] = "chat"
|
||||
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
@@ -270,7 +270,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
message_text = cast(str, content)
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
@@ -67,7 +68,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
REPETITION_PENALTY = "repetition_penalty"
|
||||
TOP_K = "top_k"
|
||||
features = []
|
||||
features: list[ModelFeature] = []
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dashscope.common.error import (
|
||||
from dashscope.common.error import ( # type: ignore
|
||||
AuthenticationError,
|
||||
InvalidParameter,
|
||||
RequestFailure,
|
||||
|
||||
@@ -7,9 +7,9 @@ from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from dashscope import Generation, MultiModalConversation, get_tokenizer
|
||||
from dashscope.api_entities.dashscope_response import GenerationResponse
|
||||
from dashscope.common.error import (
|
||||
from dashscope import Generation, MultiModalConversation, get_tokenizer # type: ignore
|
||||
from dashscope.api_entities.dashscope_response import GenerationResponse # type: ignore
|
||||
from dashscope.common.error import ( # type: ignore
|
||||
AuthenticationError,
|
||||
InvalidParameter,
|
||||
RequestFailure,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import dashscope
|
||||
from dashscope.common.error import (
|
||||
import dashscope # type: ignore
|
||||
from dashscope.common.error import ( # type: ignore
|
||||
AuthenticationError,
|
||||
InvalidParameter,
|
||||
RequestFailure,
|
||||
@@ -51,7 +51,7 @@ class GTERerankModel(RerankModel):
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=docs)
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
# initialize client
|
||||
dashscope.api_key = credentials["dashscope_api_key"]
|
||||
@@ -64,7 +64,7 @@ class GTERerankModel(RerankModel):
|
||||
return_documents=True,
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
rerank_documents: list[RerankDocument] = []
|
||||
if not response.output:
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
for _, result in enumerate(response.output.results):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import dashscope
|
||||
import dashscope # type: ignore
|
||||
import numpy as np
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
|
||||
@@ -2,10 +2,10 @@ import threading
|
||||
from queue import Queue
|
||||
from typing import Any, Optional
|
||||
|
||||
import dashscope
|
||||
from dashscope import SpeechSynthesizer
|
||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult
|
||||
import dashscope # type: ignore
|
||||
from dashscope import SpeechSynthesizer # type: ignore
|
||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse # type: ignore
|
||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult # type: ignore
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import openai
|
||||
from httpx import Timeout
|
||||
|
||||
@@ -14,7 +12,7 @@ from core.model_runtime.errors.invoke import (
|
||||
|
||||
|
||||
class _CommonUpstage:
|
||||
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
|
||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Transform credentials to kwargs for model instance
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from openai import OpenAI, Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import base64
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
@@ -132,7 +131,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
|
||||
@@ -12,4 +12,4 @@ class _CommonVertexAi:
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -6,7 +6,7 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Optional, Union, cast
|
||||
|
||||
import google.auth.transport.requests
|
||||
import google.auth.transport.requests # type: ignore
|
||||
import requests
|
||||
from anthropic import AnthropicVertex, Stream
|
||||
from anthropic.types import (
|
||||
|
||||
@@ -17,14 +17,12 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
|
||||
|
||||
class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
features = []
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
features=features,
|
||||
features=[],
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||
},
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, cast
|
||||
|
||||
from volcenginesdkarkruntime import Ark
|
||||
from volcenginesdkarkruntime.types.chat import (
|
||||
from volcenginesdkarkruntime import Ark # type: ignore
|
||||
from volcenginesdkarkruntime.types.chat import ( # type: ignore
|
||||
ChatCompletion,
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionChunk,
|
||||
@@ -15,10 +15,10 @@ from volcenginesdkarkruntime.types.chat import (
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL
|
||||
from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function
|
||||
from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse
|
||||
from volcenginesdkarkruntime.types.shared_params import FunctionDefinition
|
||||
from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL # type: ignore
|
||||
from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function # type: ignore
|
||||
from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse # type: ignore
|
||||
from volcenginesdkarkruntime.types.shared_params import FunctionDefinition # type: ignore
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
|
||||
@@ -152,5 +152,6 @@ ErrorCodeMap = {
|
||||
|
||||
def wrap_error(e: MaasError) -> Exception:
|
||||
if ErrorCodeMap.get(e.code):
|
||||
return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
|
||||
# FIXME: mypy type error, try to fix it instead of using type: ignore
|
||||
return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) # type: ignore
|
||||
return e
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk # type: ignore
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
@@ -102,7 +104,7 @@ def get_model_config(credentials: dict) -> ModelConfig:
|
||||
|
||||
|
||||
def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None):
|
||||
req_params = {}
|
||||
req_params: dict[str, Any] = {}
|
||||
# predefined properties
|
||||
model_configs = get_model_config(credentials)
|
||||
if model_configs:
|
||||
@@ -130,7 +132,7 @@ def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str]
|
||||
|
||||
|
||||
def get_v3_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None):
|
||||
req_params = {}
|
||||
req_params: dict[str, Any] = {}
|
||||
# predefined properties
|
||||
model_configs = get_model_config(credentials)
|
||||
if model_configs:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from json import dumps, loads
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from requests import Response, post
|
||||
|
||||
@@ -22,7 +22,7 @@ class ErnieMessage:
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: dict[str, int] = None
|
||||
usage: Optional[dict[str, int]] = None
|
||||
stop_reason: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
@@ -135,6 +135,7 @@ class ErnieBotModel(_CommonWenxin):
|
||||
"""
|
||||
TODO: implement function calling
|
||||
"""
|
||||
raise NotImplementedError("Function calling is not supported yet.")
|
||||
|
||||
def _build_chat_request_body(
|
||||
self,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from json import dumps
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -23,12 +22,12 @@ from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||
|
||||
class TextEmbedding:
|
||||
@abstractmethod
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]:
|
||||
access_token = self._get_access_token()
|
||||
url = f"{self.api_bases[model]}?access_token={access_token}"
|
||||
body = self._build_embed_request_body(model, texts, user)
|
||||
@@ -50,7 +49,7 @@ class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
|
||||
}
|
||||
return body
|
||||
|
||||
def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
|
||||
def _handle_embed_response(self, model: str, response: Response) -> tuple[list[list[float]], int, int]:
|
||||
data = response.json()
|
||||
if "error_code" in data:
|
||||
code = data["error_code"]
|
||||
@@ -147,7 +146,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
api_key = credentials["api_key"]
|
||||
secret_key = credentials["secret_key"]
|
||||
try:
|
||||
|
||||
@@ -17,7 +17,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
from openai.types.completion import Completion
|
||||
from xinference_client.client.restful.restful_client import (
|
||||
from xinference_client.client.restful.restful_client import ( # type: ignore
|
||||
Client,
|
||||
RESTfulChatModelHandle,
|
||||
RESTfulGenerateModelHandle,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle
|
||||
from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle # type: ignore
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import IO, Optional
|
||||
|
||||
from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle
|
||||
from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle # type: ignore
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
|
||||
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
@@ -134,7 +134,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
try:
|
||||
handle = client.get_model(model_uid=model_uid)
|
||||
except RuntimeError as e:
|
||||
raise InvokeAuthorizationError(e)
|
||||
raise InvokeAuthorizationError(str(e))
|
||||
|
||||
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
||||
raise InvokeBadRequestError(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import concurrent.futures
|
||||
from typing import Any, Optional
|
||||
|
||||
from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle
|
||||
from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle # type: ignore
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
@@ -74,11 +74,14 @@ class XinferenceText2SpeechModel(TTSModel):
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
credentials["server_url"] = credentials["server_url"].removesuffix("/")
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key is None:
|
||||
raise CredentialsValidateFailedError("api_key is required")
|
||||
|
||||
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials["server_url"],
|
||||
model_uid=credentials["model_uid"],
|
||||
api_key=credentials.get("api_key"),
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
if "text-to-audio" not in extra_param.model_ability:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from threading import Lock
|
||||
from time import time
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.exceptions import ConnectionError, MissingSchema, Timeout
|
||||
@@ -39,13 +39,15 @@ class XinferenceModelExtraParameter:
|
||||
self.model_family = model_family
|
||||
|
||||
|
||||
cache = {}
|
||||
cache: dict[str, dict[str, Any]] = {}
|
||||
cache_lock = Lock()
|
||||
|
||||
|
||||
class XinferenceHelper:
|
||||
@staticmethod
|
||||
def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
|
||||
def get_xinference_extra_parameter(
|
||||
server_url: str, model_uid: str, api_key: str | None
|
||||
) -> XinferenceModelExtraParameter:
|
||||
XinferenceHelper._clean_cache()
|
||||
with cache_lock:
|
||||
if model_uid not in cache:
|
||||
@@ -66,7 +68,9 @@ class XinferenceHelper:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
|
||||
def _get_xinference_extra_parameter(
|
||||
server_url: str, model_uid: str, api_key: str | None
|
||||
) -> XinferenceModelExtraParameter:
|
||||
"""
|
||||
get xinference model extra parameter like model_format and model_handle_type
|
||||
"""
|
||||
|
||||
@@ -136,7 +136,7 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
parsed_url = urlparse(credentials["endpoint_url"])
|
||||
credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai.types.chat.chat_completion import Completion
|
||||
from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from zhipuai import ZhipuAI # type: ignore
|
||||
from zhipuai.types.chat.chat_completion import Completion # type: ignore
|
||||
from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk # type: ignore
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai import ZhipuAI # type: ignore
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Union, cast
|
||||
|
||||
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType
|
||||
|
||||
@@ -38,7 +38,7 @@ class CommonValidator:
|
||||
|
||||
def _validate_credential_form_schema(
|
||||
self, credential_form_schema: CredentialFormSchema, credentials: dict
|
||||
) -> Optional[str]:
|
||||
) -> Union[str, bool, None]:
|
||||
"""
|
||||
Validate credential form schema
|
||||
|
||||
@@ -47,6 +47,7 @@ class CommonValidator:
|
||||
:return: validated credential form schema value
|
||||
"""
|
||||
# If the variable does not exist in credentials
|
||||
value: Union[str, bool, None] = None
|
||||
if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]:
|
||||
# If required is True, an exception is thrown
|
||||
if credential_form_schema.required:
|
||||
@@ -61,7 +62,7 @@ class CommonValidator:
|
||||
return None
|
||||
|
||||
# Get the value corresponding to the variable from credentials
|
||||
value = credentials[credential_form_schema.variable]
|
||||
value = cast(str, credentials[credential_form_schema.variable])
|
||||
|
||||
# If max_length=0, no validation is performed
|
||||
if credential_form_schema.max_length:
|
||||
|
||||
@@ -129,7 +129,8 @@ def jsonable_encoder(
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if dataclasses.is_dataclass(obj):
|
||||
obj_dict = dataclasses.asdict(obj)
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
obj_dict = dataclasses.asdict(obj) # type: ignore
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
by_alias=by_alias,
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
def dump_model(model: BaseModel) -> dict:
|
||||
if hasattr(pydantic, "model_dump"):
|
||||
return pydantic.model_dump(model)
|
||||
# FIXME mypy error, try to fix it instead of using type: ignore
|
||||
return pydantic.model_dump(model) # type: ignore
|
||||
else:
|
||||
return model.model_dump()
|
||||
|
||||
Reference in New Issue
Block a user