feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,7 +1,7 @@
import json
import logging
import math
from typing import Any, Optional
from typing import Any, Optional, cast
from urllib.parse import urlparse
import requests
@@ -70,7 +70,7 @@ class ElasticSearchVector(BaseVector):
def _get_version(self) -> str:
info = self._client.info()
return info["version"]["number"]
return cast(str, info["version"]["number"])
def _check_version(self):
if self._version < "8.0.0":
@@ -135,7 +135,8 @@ class ElasticSearchVector(BaseVector):
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
doc.metadata["score"] = score
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
return docs
@@ -156,12 +157,15 @@ class ElasticSearchVector(BaseVector):
return docs
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata for d in texts]
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
@@ -208,10 +212,10 @@ class ElasticSearchVectorFactory(AbstractVectorFactory):
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get("ELASTICSEARCH_HOST"),
port=config.get("ELASTICSEARCH_PORT"),
username=config.get("ELASTICSEARCH_USERNAME"),
password=config.get("ELASTICSEARCH_PASSWORD"),
host=config.get("ELASTICSEARCH_HOST", "localhost"),
port=config.get("ELASTICSEARCH_PORT", 9200),
username=config.get("ELASTICSEARCH_USERNAME", ""),
password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)