mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-25 02:33:00 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -5,13 +5,13 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymochow import MochowClient
|
||||
from pymochow.auth.bce_credentials import BceCredentials
|
||||
from pymochow.configuration import Configuration
|
||||
from pymochow.exception import ServerError
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
|
||||
from pymochow import MochowClient # type: ignore
|
||||
from pymochow.auth.bce_credentials import BceCredentials # type: ignore
|
||||
from pymochow.configuration import Configuration # type: ignore
|
||||
from pymochow.exception import ServerError # type: ignore
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -75,7 +75,7 @@ class BaiduVector(BaseVector):
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents if doc.metadata is not None]
|
||||
total_count = len(documents)
|
||||
batch_size = 1000
|
||||
|
||||
@@ -84,6 +84,8 @@ class BaiduVector(BaseVector):
|
||||
for start in range(0, total_count, batch_size):
|
||||
end = min(start + batch_size, total_count)
|
||||
rows = []
|
||||
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
|
||||
# FIXME do you need this assert?
|
||||
for i in range(start, end, 1):
|
||||
row = Row(
|
||||
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
||||
@@ -136,7 +138,7 @@ class BaiduVector(BaseVector):
|
||||
# baidu vector database doesn't support bm25 search on current version
|
||||
return []
|
||||
|
||||
def _get_search_res(self, res, score_threshold):
|
||||
def _get_search_res(self, res, score_threshold) -> list[Document]:
|
||||
docs = []
|
||||
for row in res.rows:
|
||||
row_data = row.get("row", {})
|
||||
@@ -276,11 +278,11 @@ class BaiduVectorFactory(AbstractVectorFactory):
|
||||
return BaiduVector(
|
||||
collection_name=collection_name,
|
||||
config=BaiduConfig(
|
||||
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
|
||||
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "",
|
||||
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
|
||||
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
|
||||
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
|
||||
database=dify_config.BAIDU_VECTOR_DB_DATABASE,
|
||||
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "",
|
||||
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "",
|
||||
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
|
||||
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user