feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -71,11 +71,13 @@ class ChromaVector(BaseVector):
metadatas = [d.metadata for d in documents]
collection = self._client.get_or_create_collection(self._collection_name)
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas)
# FIXME: chromadb using numpy array, fix the type error later
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
def delete_by_metadata_field(self, key: str, value: str):
collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(where={key: {"$eq": value}})
# FIXME: fix the type error later
collection.delete(where={key: {"$eq": value}}) # type: ignore
def delete(self):
self._client.delete_collection(self._collection_name)
@@ -94,15 +96,19 @@ class ChromaVector(BaseVector):
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
ids: list[str] = results["ids"][0]
documents: list[str] = results["documents"][0]
metadatas: dict[str, Any] = results["metadatas"][0]
distances: list[float] = results["distances"][0]
# Check if results contain data
if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]:
return []
ids = results["ids"][0]
documents = results["documents"][0]
metadatas = results["metadatas"][0]
distances = results["distances"][0]
docs = []
for index in range(len(ids)):
distance = distances[index]
metadata = metadatas[index]
metadata = dict(metadatas[index])
if distance >= score_threshold:
metadata["score"] = distance
doc = Document(
@@ -111,7 +117,7 @@ class ChromaVector(BaseVector):
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -133,7 +139,7 @@ class ChromaVectorFactory(AbstractVectorFactory):
return ChromaVector(
collection_name=collection_name,
config=ChromaConfig(
host=dify_config.CHROMA_HOST,
host=dify_config.CHROMA_HOST or "",
port=dify_config.CHROMA_PORT,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,