mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-16 06:16:53 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -71,11 +71,13 @@ class ChromaVector(BaseVector):
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas)
|
||||
# FIXME: chromadb using numpy array, fix the type error later
|
||||
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(where={key: {"$eq": value}})
|
||||
# FIXME: fix the type error later
|
||||
collection.delete(where={key: {"$eq": value}}) # type: ignore
|
||||
|
||||
def delete(self):
|
||||
self._client.delete_collection(self._collection_name)
|
||||
@@ -94,15 +96,19 @@ class ChromaVector(BaseVector):
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
ids: list[str] = results["ids"][0]
|
||||
documents: list[str] = results["documents"][0]
|
||||
metadatas: dict[str, Any] = results["metadatas"][0]
|
||||
distances: list[float] = results["distances"][0]
|
||||
# Check if results contain data
|
||||
if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]:
|
||||
return []
|
||||
|
||||
ids = results["ids"][0]
|
||||
documents = results["documents"][0]
|
||||
metadatas = results["metadatas"][0]
|
||||
distances = results["distances"][0]
|
||||
|
||||
docs = []
|
||||
for index in range(len(ids)):
|
||||
distance = distances[index]
|
||||
metadata = metadatas[index]
|
||||
metadata = dict(metadatas[index])
|
||||
if distance >= score_threshold:
|
||||
metadata["score"] = distance
|
||||
doc = Document(
|
||||
@@ -111,7 +117,7 @@ class ChromaVector(BaseVector):
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -133,7 +139,7 @@ class ChromaVectorFactory(AbstractVectorFactory):
|
||||
return ChromaVector(
|
||||
collection_name=collection_name,
|
||||
config=ChromaConfig(
|
||||
host=dify_config.CHROMA_HOST,
|
||||
host=dify_config.CHROMA_HOST or "",
|
||||
port=dify_config.CHROMA_PORT,
|
||||
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
|
||||
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
|
||||
|
||||
Reference in New Issue
Block a user