feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -17,12 +17,19 @@ from models.dataset import Dataset
class AnalyticdbVector(BaseVector):
def __init__(
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
self,
collection_name: str,
api_config: AnalyticdbVectorOpenAPIConfig | None,
sql_config: AnalyticdbVectorBySqlConfig | None,
):
super().__init__(collection_name)
if api_config is not None:
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI(
collection_name, api_config
)
else:
if sql_config is None:
raise ValueError("Either api_config or sql_config must be provided")
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
def get_type(self) -> str:
@@ -33,8 +40,8 @@ class AnalyticdbVector(BaseVector):
self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
self.analyticdb_vector.add_texts(documents, embeddings)
def text_exists(self, id: str) -> bool:
return self.analyticdb_vector.text_exists(id)
@@ -68,13 +75,13 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
if dify_config.ANALYTICDB_HOST is None:
# implemented through OpenAPI
apiConfig = AnalyticdbVectorOpenAPIConfig(
access_key_id=dify_config.ANALYTICDB_KEY_ID,
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
region_id=dify_config.ANALYTICDB_REGION_ID,
instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
namespace=dify_config.ANALYTICDB_NAMESPACE,
access_key_id=dify_config.ANALYTICDB_KEY_ID or "",
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "",
region_id=dify_config.ANALYTICDB_REGION_ID or "",
instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "",
account=dify_config.ANALYTICDB_ACCOUNT or "",
account_password=dify_config.ANALYTICDB_PASSWORD or "",
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
)
sqlConfig = None
@@ -83,11 +90,11 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
sqlConfig = AnalyticdbVectorBySqlConfig(
host=dify_config.ANALYTICDB_HOST,
port=dify_config.ANALYTICDB_PORT,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
account=dify_config.ANALYTICDB_ACCOUNT or "",
account_password=dify_config.ANALYTICDB_PASSWORD or "",
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
namespace=dify_config.ANALYTICDB_NAMESPACE,
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
)
apiConfig = None
return AnalyticdbVector(

View File

@@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, Optional
from pydantic import BaseModel, model_validator
@@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
account: str
account_password: str
namespace: str = "dify"
namespace_password: str = (None,)
namespace_password: Optional[str] = None
metrics: str = "cosine"
read_timeout: int = 60000
@@ -55,8 +55,8 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
class AnalyticdbVectorOpenAPI:
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_gpdb20160503.client import Client # type: ignore
from alibabacloud_tea_openapi import models as open_api_models # type: ignore
except:
raise ImportError(_import_err_msg)
self._collection_name = collection_name.lower()
@@ -77,7 +77,7 @@ class AnalyticdbVectorOpenAPI:
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
@@ -89,7 +89,7 @@ class AnalyticdbVectorOpenAPI:
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
from Tea.exceptions import TeaException # type: ignore
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
@@ -159,17 +159,18 @@ class AnalyticdbVectorOpenAPI:
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
if doc.metadata is not None:
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@@ -258,7 +259,7 @@ class AnalyticdbVectorOpenAPI:
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -290,7 +291,7 @@ class AnalyticdbVectorOpenAPI:
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def delete(self) -> None:

View File

@@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document
@@ -75,6 +75,7 @@ class AnalyticdbVectorBySql:
@contextmanager
def _get_cursor(self):
assert self.pool is not None, "Connection pool is not initialized"
conn = self.pool.getconn()
cur = conn.cursor()
try:
@@ -156,16 +157,17 @@ class AnalyticdbVectorBySql:
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
"""
for i, doc in enumerate(documents):
values.append(
(
id_prefix + str(i),
doc.metadata.get("doc_id", str(uuid.uuid4())),
embeddings[i],
doc.page_content,
json.dumps(doc.metadata),
doc.page_content,
if doc.metadata is not None:
values.append(
(
id_prefix + str(i),
doc.metadata.get("doc_id", str(uuid.uuid4())),
embeddings[i],
doc.page_content,
json.dumps(doc.metadata),
doc.page_content,
)
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_batch(cur, sql, values)

View File

@@ -5,13 +5,13 @@ from typing import Any
import numpy as np
from pydantic import BaseModel, model_validator
from pymochow import MochowClient
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.configuration import Configuration
from pymochow.exception import ServerError
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
from pymochow import MochowClient # type: ignore
from pymochow.auth.bce_credentials import BceCredentials # type: ignore
from pymochow.configuration import Configuration # type: ignore
from pymochow.exception import ServerError # type: ignore
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -75,7 +75,7 @@ class BaiduVector(BaseVector):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
metadatas = [doc.metadata for doc in documents if doc.metadata is not None]
total_count = len(documents)
batch_size = 1000
@@ -84,6 +84,8 @@ class BaiduVector(BaseVector):
for start in range(0, total_count, batch_size):
end = min(start + batch_size, total_count)
rows = []
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
# FIXME do you need this assert?
for i in range(start, end, 1):
row = Row(
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
@@ -136,7 +138,7 @@ class BaiduVector(BaseVector):
# baidu vector database doesn't support bm25 search on current version
return []
def _get_search_res(self, res, score_threshold):
def _get_search_res(self, res, score_threshold) -> list[Document]:
docs = []
for row in res.rows:
row_data = row.get("row", {})
@@ -276,11 +278,11 @@ class BaiduVectorFactory(AbstractVectorFactory):
return BaiduVector(
collection_name=collection_name,
config=BaiduConfig(
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "",
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
database=dify_config.BAIDU_VECTOR_DB_DATABASE,
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "",
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "",
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
),

View File

@@ -71,11 +71,13 @@ class ChromaVector(BaseVector):
metadatas = [d.metadata for d in documents]
collection = self._client.get_or_create_collection(self._collection_name)
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas)
# FIXME: chromadb using numpy array, fix the type error later
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
def delete_by_metadata_field(self, key: str, value: str):
collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(where={key: {"$eq": value}})
# FIXME: fix the type error later
collection.delete(where={key: {"$eq": value}}) # type: ignore
def delete(self):
self._client.delete_collection(self._collection_name)
@@ -94,15 +96,19 @@ class ChromaVector(BaseVector):
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
ids: list[str] = results["ids"][0]
documents: list[str] = results["documents"][0]
metadatas: dict[str, Any] = results["metadatas"][0]
distances: list[float] = results["distances"][0]
# Check if results contain data
if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]:
return []
ids = results["ids"][0]
documents = results["documents"][0]
metadatas = results["metadatas"][0]
distances = results["distances"][0]
docs = []
for index in range(len(ids)):
distance = distances[index]
metadata = metadatas[index]
metadata = dict(metadatas[index])
if distance >= score_threshold:
metadata["score"] = distance
doc = Document(
@@ -111,7 +117,7 @@ class ChromaVector(BaseVector):
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -133,7 +139,7 @@ class ChromaVectorFactory(AbstractVectorFactory):
return ChromaVector(
collection_name=collection_name,
config=ChromaConfig(
host=dify_config.CHROMA_HOST,
host=dify_config.CHROMA_HOST or "",
port=dify_config.CHROMA_PORT,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,

View File

@@ -5,14 +5,14 @@ import uuid
from datetime import timedelta
from typing import Any
from couchbase import search
from couchbase.auth import PasswordAuthenticator
from couchbase.cluster import Cluster
from couchbase.management.search import SearchIndex
from couchbase import search # type: ignore
from couchbase.auth import PasswordAuthenticator # type: ignore
from couchbase.cluster import Cluster # type: ignore
from couchbase.management.search import SearchIndex # type: ignore
# needed for options -- cluster, timeout, SQL++ (N1QL) query, etc.
from couchbase.options import ClusterOptions, SearchOptions
from couchbase.vector_search import VectorQuery, VectorSearch
from couchbase.options import ClusterOptions, SearchOptions # type: ignore
from couchbase.vector_search import VectorQuery, VectorSearch # type: ignore
from flask import current_app
from pydantic import BaseModel, model_validator
@@ -231,7 +231,7 @@ class CouchbaseVector(BaseVector):
# Pass the id as a parameter to the query
result = self._cluster.query(query, named_parameters={"doc_id": id}).execute()
for row in result:
return row["count"] > 0
return bool(row["count"] > 0)
return False # Return False if no rows are returned
def delete_by_ids(self, ids: list[str]) -> None:
@@ -369,10 +369,10 @@ class CouchbaseVectorFactory(AbstractVectorFactory):
return CouchbaseVector(
collection_name=collection_name,
config=CouchbaseConfig(
connection_string=config.get("COUCHBASE_CONNECTION_STRING"),
user=config.get("COUCHBASE_USER"),
password=config.get("COUCHBASE_PASSWORD"),
bucket_name=config.get("COUCHBASE_BUCKET_NAME"),
scope_name=config.get("COUCHBASE_SCOPE_NAME"),
connection_string=config.get("COUCHBASE_CONNECTION_STRING", ""),
user=config.get("COUCHBASE_USER", ""),
password=config.get("COUCHBASE_PASSWORD", ""),
bucket_name=config.get("COUCHBASE_BUCKET_NAME", ""),
scope_name=config.get("COUCHBASE_SCOPE_NAME", ""),
),
)

View File

@@ -1,7 +1,7 @@
import json
import logging
import math
from typing import Any, Optional
from typing import Any, Optional, cast
from urllib.parse import urlparse
import requests
@@ -70,7 +70,7 @@ class ElasticSearchVector(BaseVector):
def _get_version(self) -> str:
info = self._client.info()
return info["version"]["number"]
return cast(str, info["version"]["number"])
def _check_version(self):
if self._version < "8.0.0":
@@ -135,7 +135,8 @@ class ElasticSearchVector(BaseVector):
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
doc.metadata["score"] = score
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
return docs
@@ -156,12 +157,15 @@ class ElasticSearchVector(BaseVector):
return docs
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata for d in texts]
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
@@ -208,10 +212,10 @@ class ElasticSearchVectorFactory(AbstractVectorFactory):
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get("ELASTICSEARCH_HOST"),
port=config.get("ELASTICSEARCH_PORT"),
username=config.get("ELASTICSEARCH_USERNAME"),
password=config.get("ELASTICSEARCH_PASSWORD"),
host=config.get("ELASTICSEARCH_HOST", "localhost"),
port=config.get("ELASTICSEARCH_PORT", 9200),
username=config.get("ELASTICSEARCH_USERNAME", ""),
password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)

View File

@@ -42,7 +42,7 @@ class LindormVectorStoreConfig(BaseModel):
return values
def to_opensearch_params(self) -> dict[str, Any]:
params = {"hosts": self.hosts}
params: dict[str, Any] = {"hosts": self.hosts}
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params
@@ -53,7 +53,7 @@ class LindormVectorStore(BaseVector):
self._routing = None
self._routing_field = None
if using_ugc:
routing_value: str = kwargs.get("routing_value")
routing_value: str | None = kwargs.get("routing_value")
if routing_value is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
self._routing = routing_value.lower()
@@ -87,14 +87,15 @@ class LindormVectorStore(BaseVector):
"_id": uuids[i],
}
}
action_values = {
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
action_values[self._routing_field] = self._routing
if self._routing_field is not None:
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
response = self._client.bulk(actions)
@@ -105,7 +106,9 @@ class LindormVectorStore(BaseVector):
self.refresh()
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
query: dict[str, Any] = {
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}
}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
response = self._client.search(index=self._collection_name, body=query)
@@ -191,7 +194,8 @@ class LindormVectorStore(BaseVector):
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
if score > score_threshold:
doc.metadata["score"] = score
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
return docs
@@ -366,6 +370,7 @@ def default_text_search_query(
routing_field: Optional[str] = None,
**kwargs,
) -> dict:
query_clause: dict[str, Any] = {}
if routing is not None:
query_clause = {
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
@@ -386,7 +391,7 @@ def default_text_search_query(
else:
must = [query_clause]
boolean_query = {"must": must}
boolean_query: dict[str, Any] = {"must": must}
if must_not:
if not isinstance(must_not, list):
@@ -426,7 +431,7 @@ def default_vector_search_query(
filter_type = "post_filter" if filter_type is None else filter_type
if not isinstance(filters, list):
raise RuntimeError(f"unexpected filter with {type(filters)}")
final_ext = {"lvector": {}}
final_ext: dict[str, Any] = {"lvector": {}}
if min_score != "0.0":
final_ext["lvector"]["min_score"] = min_score
if ef_search:
@@ -438,7 +443,7 @@ def default_vector_search_query(
if client_refactor:
final_ext["lvector"]["client_refactor"] = client_refactor
search_query = {
search_query: dict[str, Any] = {
"size": k,
"_source": True, # force return '_source'
"query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
@@ -446,8 +451,8 @@ def default_vector_search_query(
if filters is not None:
# when using filter, transform filter from List[Dict] to Dict as valid format
filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
search_query["query"]["knn"][vector_field]["filter"] = filter_dict # filter should be Dict
if filter_type:
final_ext["lvector"]["filter_type"] = filter_type
@@ -459,17 +464,19 @@ def default_vector_search_query(
class LindormVectorStoreFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
lindorm_config = LindormVectorStoreConfig(
hosts=dify_config.LINDORM_URL,
hosts=dify_config.LINDORM_URL or "",
username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD,
using_ugc=dify_config.USING_UGC_INDEX,
)
using_ugc = dify_config.USING_UGC_INDEX
if using_ugc is None:
raise ValueError("USING_UGC_INDEX is not set")
routing_value = None
if dataset.index_struct:
# if an existed record's index_struct_dict doesn't contain using_ugc field,
# it actually stores in the normal index format
stored_in_ugc = dataset.index_struct_dict.get("using_ugc", False)
stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False)
using_ugc = stored_in_ugc
if stored_in_ugc:
dimension = dataset.index_struct_dict["dimension"]

View File

@@ -3,8 +3,8 @@ import logging
from typing import Any, Optional
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException
from pymilvus.milvus_client import IndexParams
from pymilvus import MilvusClient, MilvusException # type: ignore
from pymilvus.milvus_client import IndexParams # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -54,14 +54,14 @@ class MilvusVector(BaseVector):
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = "Session"
self._fields = []
self._fields: list[str] = []
def get_type(self) -> str:
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata for d in texts]
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
@@ -161,8 +161,8 @@ class MilvusVector(BaseVector):
return
# Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema
from pymilvus.orm.types import infer_dtype_bydata
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
# Determine embedding dim
dim = len(embeddings[0])
@@ -217,10 +217,10 @@ class MilvusVectorFactory(AbstractVectorFactory):
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
uri=dify_config.MILVUS_URI,
token=dify_config.MILVUS_TOKEN,
user=dify_config.MILVUS_USER,
password=dify_config.MILVUS_PASSWORD,
database=dify_config.MILVUS_DATABASE,
uri=dify_config.MILVUS_URI or "",
token=dify_config.MILVUS_TOKEN or "",
user=dify_config.MILVUS_USER or "",
password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "",
),
)

View File

@@ -74,15 +74,16 @@ class MyScaleVector(BaseVector):
columns = ["id", "text", "vector", "metadata"]
values = []
for i, doc in enumerate(documents):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
row = (
doc_id,
self.escape_str(doc.page_content),
embeddings[i],
json.dumps(doc.metadata) if doc.metadata else {},
)
values.append(str(row))
ids.append(doc_id)
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
row = (
doc_id,
self.escape_str(doc.page_content),
embeddings[i],
json.dumps(doc.metadata) if doc.metadata else {},
)
values.append(str(row))
ids.append(doc_id)
sql = f"""
INSERT INTO {self._config.database}.{self._collection_name}
({",".join(columns)}) VALUES {",".join(values)}

View File

@@ -4,7 +4,7 @@ import math
from typing import Any
from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient
from pyobvector import VECTOR, ObVecClient # type: ignore
from sqlalchemy import JSON, Column, String, func
from sqlalchemy.dialects.mysql import LONGTEXT
@@ -131,7 +131,7 @@ class OceanBaseVector(BaseVector):
def text_exists(self, id: str) -> bool:
cur = self._client.get(table_name=self._collection_name, id=id)
return cur.rowcount != 0
return bool(cur.rowcount != 0)
def delete_by_ids(self, ids: list[str]) -> None:
self._client.delete(table_name=self._collection_name, ids=ids)

View File

@@ -66,7 +66,7 @@ class OpenSearchVector(BaseVector):
return VectorType.OPENSEARCH
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata for d in texts]
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings)
@@ -244,7 +244,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST,
host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,

View File

@@ -5,7 +5,7 @@ import uuid
from contextlib import contextmanager
from typing import Any
import jieba.posseg as pseg
import jieba.posseg as pseg # type: ignore
import numpy
import oracledb
from pydantic import BaseModel, model_validator
@@ -88,12 +88,11 @@ class OracleVector(BaseVector):
def numpy_converter_out(self, value):
if value.typecode == "b":
dtype = numpy.int8
return numpy.array(value, copy=False, dtype=numpy.int8)
elif value.typecode == "f":
dtype = numpy.float32
return numpy.array(value, copy=False, dtype=numpy.float32)
else:
dtype = numpy.float64
return numpy.array(value, copy=False, dtype=dtype)
return numpy.array(value, copy=False, dtype=numpy.float64)
def output_type_handler(self, cursor, metadata):
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
@@ -135,17 +134,18 @@ class OracleVector(BaseVector):
values = []
pks = []
for i, doc in enumerate(documents):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
# array.array("f", embeddings[i]),
numpy.array(embeddings[i]),
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
# array.array("f", embeddings[i]),
numpy.array(embeddings[i]),
)
)
)
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
with self._get_cursor() as cur:
cur.executemany(
@@ -201,8 +201,8 @@ class OracleVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# lazy import
import nltk
from nltk.corpus import stopwords
import nltk # type: ignore
from nltk.corpus import stopwords # type: ignore
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
@@ -285,10 +285,10 @@ class OracleVectorFactory(AbstractVectorFactory):
return OracleVector(
collection_name=collection_name,
config=OracleVectorConfig(
host=dify_config.ORACLE_HOST,
host=dify_config.ORACLE_HOST or "localhost",
port=dify_config.ORACLE_PORT,
user=dify_config.ORACLE_USER,
password=dify_config.ORACLE_PASSWORD,
database=dify_config.ORACLE_DATABASE,
user=dify_config.ORACLE_USER or "system",
password=dify_config.ORACLE_PASSWORD or "oracle",
database=dify_config.ORACLE_DATABASE or "orcl",
),
)

View File

@@ -4,7 +4,7 @@ from typing import Any
from uuid import UUID, uuid4
from numpy import ndarray
from pgvecto_rs.sqlalchemy import VECTOR
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
from pydantic import BaseModel, model_validator
from sqlalchemy import Float, String, create_engine, insert, select, text
from sqlalchemy import text as sql_text
@@ -58,7 +58,7 @@ class PGVectoRS(BaseVector):
with Session(self._client) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
session.commit()
self._fields = []
self._fields: list[str] = []
class _Table(CollectionORM):
__tablename__ = collection_name
@@ -222,11 +222,11 @@ class PGVectoRSFactory(AbstractVectorFactory):
return PGVectoRS(
collection_name=collection_name,
config=PgvectoRSConfig(
host=dify_config.PGVECTO_RS_HOST,
port=dify_config.PGVECTO_RS_PORT,
user=dify_config.PGVECTO_RS_USER,
password=dify_config.PGVECTO_RS_PASSWORD,
database=dify_config.PGVECTO_RS_DATABASE,
host=dify_config.PGVECTO_RS_HOST or "localhost",
port=dify_config.PGVECTO_RS_PORT or 5432,
user=dify_config.PGVECTO_RS_USER or "postgres",
password=dify_config.PGVECTO_RS_PASSWORD or "",
database=dify_config.PGVECTO_RS_DATABASE or "postgres",
),
dim=dim,
)

View File

@@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
@@ -98,16 +98,17 @@ class PGVector(BaseVector):
values = []
pks = []
for i, doc in enumerate(documents):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
embeddings[i],
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
embeddings[i],
)
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_values(
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
@@ -216,11 +217,11 @@ class PGVectorFactory(AbstractVectorFactory):
return PGVector(
collection_name=collection_name,
config=PGVectorConfig(
host=dify_config.PGVECTOR_HOST,
host=dify_config.PGVECTOR_HOST or "localhost",
port=dify_config.PGVECTOR_PORT,
user=dify_config.PGVECTOR_USER,
password=dify_config.PGVECTOR_PASSWORD,
database=dify_config.PGVECTOR_DATABASE,
user=dify_config.PGVECTOR_USER or "postgres",
password=dify_config.PGVECTOR_PASSWORD or "",
database=dify_config.PGVECTOR_DATABASE or "postgres",
min_connection=dify_config.PGVECTOR_MIN_CONNECTION,
max_connection=dify_config.PGVECTOR_MAX_CONNECTION,
),

View File

@@ -51,6 +51,8 @@ class QdrantConfig(BaseModel):
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
if not self.root_path:
raise ValueError("Root path is not set")
path = os.path.join(self.root_path, path)
return {"path": path}
@@ -149,9 +151,12 @@ class QdrantVector(BaseVector):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
added_ids = []
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
# Filter out None values from metadatas list to match expected type
filtered_metadatas = [m for m in metadatas if m is not None]
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
):
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
@@ -194,7 +199,7 @@ class QdrantVector(BaseVector):
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
group_id,
group_id or "", # Ensure group_id is never None
Field.GROUP_KEY.value,
),
)
@@ -337,18 +342,20 @@ class QdrantVector(BaseVector):
)
docs = []
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -432,9 +439,9 @@ class QdrantVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
group_id=dataset.id,
config=QdrantConfig(
endpoint=dify_config.QDRANT_URL,
endpoint=dify_config.QDRANT_URL or "",
api_key=dify_config.QDRANT_API_KEY,
root_path=current_app.config.root_path,
root_path=str(current_app.config.root_path),
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,

View File

@@ -3,7 +3,7 @@ import uuid
from typing import Any, Optional
from pydantic import BaseModel, model_validator
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
from sqlalchemy import Column, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session
@@ -58,14 +58,14 @@ class RelytVector(BaseVector):
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self.client = create_engine(self._url)
self._fields = []
self._fields: list[str] = []
self._group_id = group_id
def get_type(self) -> str:
return VectorType.RELYT
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None:
index_params: dict[str, Any] = {}
metadatas = [d.metadata for d in texts]
self.create_collection(len(embeddings[0]))
self.embedding_dimension = len(embeddings[0])
@@ -107,10 +107,10 @@ class RelytVector(BaseVector):
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from pgvecto_rs.sqlalchemy import VECTOR
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
ids = [str(uuid.uuid1()) for _ in documents]
metadatas = [d.metadata for d in documents]
metadatas = [d.metadata for d in documents if d.metadata is not None]
for metadata in metadatas:
metadata["group_id"] = self._group_id
texts = [d.page_content for d in documents]
@@ -242,10 +242,6 @@ class RelytVector(BaseVector):
filter: Optional[dict] = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided
try:
from sqlalchemy.engine import Row
except ImportError:
raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.")
filter_condition = ""
if filter is not None:
@@ -275,7 +271,7 @@ class RelytVector(BaseVector):
# Execute the query and fetch the results
with self.client.connect() as conn:
results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall()
results = conn.execute(sql_text(sql_query), params).fetchall()
documents_with_scores = [
(
@@ -307,11 +303,11 @@ class RelytVectorFactory(AbstractVectorFactory):
return RelytVector(
collection_name=collection_name,
config=RelytConfig(
host=dify_config.RELYT_HOST,
host=dify_config.RELYT_HOST or "localhost",
port=dify_config.RELYT_PORT,
user=dify_config.RELYT_USER,
password=dify_config.RELYT_PASSWORD,
database=dify_config.RELYT_DATABASE,
user=dify_config.RELYT_USER or "",
password=dify_config.RELYT_PASSWORD or "",
database=dify_config.RELYT_DATABASE or "default",
),
group_id=dataset.id,
)

View File

@@ -2,10 +2,10 @@ import json
from typing import Any, Optional
from pydantic import BaseModel
from tcvectordb import VectorDBClient
from tcvectordb.model import document, enum
from tcvectordb.model import index as vdb_index
from tcvectordb.model.document import Filter
from tcvectordb import VectorDBClient # type: ignore
from tcvectordb.model import document, enum # type: ignore
from tcvectordb.model import index as vdb_index # type: ignore
from tcvectordb.model.document import Filter # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -25,8 +25,8 @@ class TencentConfig(BaseModel):
database: Optional[str]
index_type: str = "HNSW"
metric_type: str = "L2"
shard: int = (1,)
replicas: int = (2,)
shard: int = 1
replicas: int = 2
def to_tencent_params(self):
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
@@ -120,15 +120,15 @@ class TencentVector(BaseVector):
metadatas = [doc.metadata for doc in documents]
total_count = len(embeddings)
docs = []
for id in range(0, total_count):
for i in range(0, total_count):
if metadatas is None:
continue
metadata = json.dumps(metadatas[id])
metadata = metadatas[i] or {}
doc = document.Document(
id=metadatas[id]["doc_id"],
vector=embeddings[id],
text=texts[id],
metadata=metadata,
id=metadata.get("doc_id"),
vector=embeddings[i],
text=texts[i],
metadata=json.dumps(metadata),
)
docs.append(doc)
self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout)
@@ -159,8 +159,8 @@ class TencentVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def _get_search_res(self, res, score_threshold):
docs = []
def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]:
docs: list[Document] = []
if res is None or len(res) == 0:
return docs
@@ -193,7 +193,7 @@ class TencentVectorFactory(AbstractVectorFactory):
return TencentVector(
collection_name=collection_name,
config=TencentConfig(
url=dify_config.TENCENT_VECTOR_DB_URL,
url=dify_config.TENCENT_VECTOR_DB_URL or "",
api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
username=dify_config.TENCENT_VECTOR_DB_USERNAME,

View File

@@ -54,7 +54,10 @@ class TidbOnQdrantConfig(BaseModel):
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
if self.root_path:
path = os.path.join(self.root_path, path)
else:
raise ValueError("root_path is required")
return {"path": path}
else:
@@ -157,7 +160,7 @@ class TidbOnQdrantVector(BaseVector):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
metadatas = [d.metadata for d in documents if d.metadata is not None]
added_ids = []
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
@@ -203,7 +206,7 @@ class TidbOnQdrantVector(BaseVector):
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
group_id,
group_id or "",
Field.GROUP_KEY.value,
),
)
@@ -334,18 +337,20 @@ class TidbOnQdrantVector(BaseVector):
)
docs = []
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -427,12 +432,12 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
else:
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID,
dify_config.TIDB_API_URL,
dify_config.TIDB_IAM_API_URL,
dify_config.TIDB_PUBLIC_KEY,
dify_config.TIDB_PRIVATE_KEY,
dify_config.TIDB_REGION,
dify_config.TIDB_PROJECT_ID or "",
dify_config.TIDB_API_URL or "",
dify_config.TIDB_IAM_API_URL or "",
dify_config.TIDB_PUBLIC_KEY or "",
dify_config.TIDB_PRIVATE_KEY or "",
dify_config.TIDB_REGION or "",
)
new_tidb_auth_binding = TidbAuthBinding(
cluster_id=new_cluster["cluster_id"],
@@ -464,9 +469,9 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
group_id=dataset.id,
config=TidbOnQdrantConfig(
endpoint=dify_config.TIDB_ON_QDRANT_URL,
endpoint=dify_config.TIDB_ON_QDRANT_URL or "",
api_key=TIDB_ON_QDRANT_API_KEY,
root_path=config.root_path,
root_path=str(config.root_path),
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,

View File

@@ -146,7 +146,7 @@ class TidbService:
iam_url: str,
public_key: str,
private_key: str,
) -> list[dict]:
):
"""
Update the status of a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
@@ -159,7 +159,6 @@ class TidbService:
:return: The response from the API.
"""
clusters = []
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "BASIC"}
@@ -169,7 +168,6 @@ class TidbService:
if response.status_code == 200:
response_data = response.json()
cluster_infos = []
for item in response_data["clusters"]:
state = item["state"]
userPrefix = item["userPrefix"]
@@ -236,16 +234,17 @@ class TidbService:
cluster_infos = []
for item in response_data["clusters"]:
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
password = redis_client.get(cache_key)
if not password:
cached_password = redis_client.get(cache_key)
if not cached_password:
continue
cluster_info = {
"cluster_id": item["clusterId"],
"cluster_name": item["displayName"],
"account": "root",
"password": password.decode("utf-8"),
"password": cached_password.decode("utf-8"),
}
cluster_infos.append(cluster_info)
return cluster_infos
else:
response.raise_for_status()
return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception

View File

@@ -49,7 +49,7 @@ class TiDBVector(BaseVector):
return VectorType.TIDB_VECTOR
def _table(self, dim: int) -> Table:
from tidb_vector.sqlalchemy import VectorType
from tidb_vector.sqlalchemy import VectorType # type: ignore
return Table(
self._collection_name,
@@ -241,11 +241,11 @@ class TiDBVectorFactory(AbstractVectorFactory):
return TiDBVector(
collection_name=collection_name,
config=TiDBVectorConfig(
host=dify_config.TIDB_VECTOR_HOST,
port=dify_config.TIDB_VECTOR_PORT,
user=dify_config.TIDB_VECTOR_USER,
password=dify_config.TIDB_VECTOR_PASSWORD,
database=dify_config.TIDB_VECTOR_DATABASE,
host=dify_config.TIDB_VECTOR_HOST or "",
port=dify_config.TIDB_VECTOR_PORT or 0,
user=dify_config.TIDB_VECTOR_USER or "",
password=dify_config.TIDB_VECTOR_PASSWORD or "",
database=dify_config.TIDB_VECTOR_DATABASE or "",
program_name=dify_config.APPLICATION_NAME,
),
)

View File

@@ -51,15 +51,16 @@ class BaseVector(ABC):
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
if text.metadata and "doc_id" in text.metadata:
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata["doc_id"] for text in texts]
return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata]
@property
def collection_name(self):

View File

@@ -193,10 +193,13 @@ class Vector:
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():
if text.metadata is None:
continue
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
if doc_id:
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts

View File

@@ -2,7 +2,7 @@ import json
from typing import Any
from pydantic import BaseModel
from volcengine.viking_db import (
from volcengine.viking_db import ( # type: ignore
Data,
DistanceType,
Field,
@@ -121,11 +121,12 @@ class VikingDBVector(BaseVector):
for i, page_content in enumerate(page_contents):
metadata = {}
if metadatas is not None:
for key, val in metadatas[i].items():
for key, val in (metadatas[i] or {}).items():
metadata[key] = val
# FIXME: fix the type of metadata later
doc = Data(
{
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"],
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
vdb_Field.CONTENT_KEY.value: page_content,
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
@@ -178,7 +179,7 @@ class VikingDBVector(BaseVector):
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(results, score_threshold)
def _get_search_res(self, results, score_threshold):
def _get_search_res(self, results, score_threshold) -> list[Document]:
if len(results) == 0:
return []
@@ -191,7 +192,7 @@ class VikingDBVector(BaseVector):
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:

View File

@@ -3,7 +3,7 @@ import json
from typing import Any, Optional
import requests
import weaviate
import weaviate # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
@@ -107,7 +107,8 @@ class WeaviateVector(BaseVector):
for i, text in enumerate(texts):
data_properties = {Field.TEXT_KEY.value: text}
if metadatas is not None:
for key, val in metadatas[i].items():
# metadata maybe None
for key, val in (metadatas[i] or {}).items():
data_properties[key] = self._json_serializable(val)
batch.add_data_object(
@@ -208,10 +209,11 @@ class WeaviateVector(BaseVector):
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
if score > score_threshold:
doc.metadata["score"] = score
docs.append(doc)
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -275,7 +277,7 @@ class WeaviateVectorFactory(AbstractVectorFactory):
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=dify_config.WEAVIATE_ENDPOINT,
endpoint=dify_config.WEAVIATE_ENDPOINT or "",
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),