feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -5,13 +5,13 @@ from typing import Any
import numpy as np
from pydantic import BaseModel, model_validator
from pymochow import MochowClient
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.configuration import Configuration
from pymochow.exception import ServerError
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
from pymochow import MochowClient # type: ignore
from pymochow.auth.bce_credentials import BceCredentials # type: ignore
from pymochow.configuration import Configuration # type: ignore
from pymochow.exception import ServerError # type: ignore
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -75,7 +75,7 @@ class BaiduVector(BaseVector):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
metadatas = [doc.metadata for doc in documents if doc.metadata is not None]
total_count = len(documents)
batch_size = 1000
@@ -84,6 +84,8 @@ class BaiduVector(BaseVector):
for start in range(0, total_count, batch_size):
end = min(start + batch_size, total_count)
rows = []
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
# FIXME do you need this assert?
for i in range(start, end, 1):
row = Row(
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
@@ -136,7 +138,7 @@ class BaiduVector(BaseVector):
# baidu vector database doesn't support bm25 search on current version
return []
def _get_search_res(self, res, score_threshold):
def _get_search_res(self, res, score_threshold) -> list[Document]:
docs = []
for row in res.rows:
row_data = row.get("row", {})
@@ -276,11 +278,11 @@ class BaiduVectorFactory(AbstractVectorFactory):
return BaiduVector(
collection_name=collection_name,
config=BaiduConfig(
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "",
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
database=dify_config.BAIDU_VECTOR_DB_DATABASE,
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "",
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "",
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
),