mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 18:23:07 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -17,12 +17,19 @@ from models.dataset import Dataset
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
def __init__(
|
||||
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
|
||||
self,
|
||||
collection_name: str,
|
||||
api_config: AnalyticdbVectorOpenAPIConfig | None,
|
||||
sql_config: AnalyticdbVectorBySqlConfig | None,
|
||||
):
|
||||
super().__init__(collection_name)
|
||||
if api_config is not None:
|
||||
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
|
||||
self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI(
|
||||
collection_name, api_config
|
||||
)
|
||||
else:
|
||||
if sql_config is None:
|
||||
raise ValueError("Either api_config or sql_config must be provided")
|
||||
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
|
||||
|
||||
def get_type(self) -> str:
|
||||
@@ -33,8 +40,8 @@ class AnalyticdbVector(BaseVector):
|
||||
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.analyticdb_vector.add_texts(documents, embeddings)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self.analyticdb_vector.text_exists(id)
|
||||
@@ -68,13 +75,13 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
if dify_config.ANALYTICDB_HOST is None:
|
||||
# implemented through OpenAPI
|
||||
apiConfig = AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID,
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID,
|
||||
instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT,
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID or "",
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "",
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID or "",
|
||||
instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "",
|
||||
account=dify_config.ANALYTICDB_ACCOUNT or "",
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD or "",
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
|
||||
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
|
||||
)
|
||||
sqlConfig = None
|
||||
@@ -83,11 +90,11 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
sqlConfig = AnalyticdbVectorBySqlConfig(
|
||||
host=dify_config.ANALYTICDB_HOST,
|
||||
port=dify_config.ANALYTICDB_PORT,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT,
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT or "",
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD or "",
|
||||
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
|
||||
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
|
||||
)
|
||||
apiConfig = None
|
||||
return AnalyticdbVector(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = "dify"
|
||||
namespace_password: str = (None,)
|
||||
namespace_password: Optional[str] = None
|
||||
metrics: str = "cosine"
|
||||
read_timeout: int = 60000
|
||||
|
||||
@@ -55,8 +55,8 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
class AnalyticdbVectorOpenAPI:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
from alibabacloud_gpdb20160503.client import Client # type: ignore
|
||||
from alibabacloud_tea_openapi import models as open_api_models # type: ignore
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self._collection_name = collection_name.lower()
|
||||
@@ -77,7 +77,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
@@ -89,7 +89,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
from Tea.exceptions import TeaException # type: ignore
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
@@ -159,17 +159,18 @@ class AnalyticdbVectorOpenAPI:
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
if doc.metadata is not None:
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -258,7 +259,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -290,7 +291,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
|
||||
@@ -3,8 +3,8 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.models.document import Document
|
||||
@@ -75,6 +75,7 @@ class AnalyticdbVectorBySql:
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
assert self.pool is not None, "Connection pool is not initialized"
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
@@ -156,16 +157,17 @@ class AnalyticdbVectorBySql:
|
||||
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
|
||||
"""
|
||||
for i, doc in enumerate(documents):
|
||||
values.append(
|
||||
(
|
||||
id_prefix + str(i),
|
||||
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||
embeddings[i],
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
doc.page_content,
|
||||
if doc.metadata is not None:
|
||||
values.append(
|
||||
(
|
||||
id_prefix + str(i),
|
||||
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||
embeddings[i],
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
doc.page_content,
|
||||
)
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_batch(cur, sql, values)
|
||||
|
||||
|
||||
@@ -5,13 +5,13 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymochow import MochowClient
|
||||
from pymochow.auth.bce_credentials import BceCredentials
|
||||
from pymochow.configuration import Configuration
|
||||
from pymochow.exception import ServerError
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
|
||||
from pymochow import MochowClient # type: ignore
|
||||
from pymochow.auth.bce_credentials import BceCredentials # type: ignore
|
||||
from pymochow.configuration import Configuration # type: ignore
|
||||
from pymochow.exception import ServerError # type: ignore
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -75,7 +75,7 @@ class BaiduVector(BaseVector):
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents if doc.metadata is not None]
|
||||
total_count = len(documents)
|
||||
batch_size = 1000
|
||||
|
||||
@@ -84,6 +84,8 @@ class BaiduVector(BaseVector):
|
||||
for start in range(0, total_count, batch_size):
|
||||
end = min(start + batch_size, total_count)
|
||||
rows = []
|
||||
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
|
||||
# FIXME do you need this assert?
|
||||
for i in range(start, end, 1):
|
||||
row = Row(
|
||||
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
||||
@@ -136,7 +138,7 @@ class BaiduVector(BaseVector):
|
||||
# baidu vector database doesn't support bm25 search on current version
|
||||
return []
|
||||
|
||||
def _get_search_res(self, res, score_threshold):
|
||||
def _get_search_res(self, res, score_threshold) -> list[Document]:
|
||||
docs = []
|
||||
for row in res.rows:
|
||||
row_data = row.get("row", {})
|
||||
@@ -276,11 +278,11 @@ class BaiduVectorFactory(AbstractVectorFactory):
|
||||
return BaiduVector(
|
||||
collection_name=collection_name,
|
||||
config=BaiduConfig(
|
||||
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
|
||||
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "",
|
||||
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
|
||||
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
|
||||
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
|
||||
database=dify_config.BAIDU_VECTOR_DB_DATABASE,
|
||||
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "",
|
||||
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "",
|
||||
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
|
||||
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
|
||||
),
|
||||
|
||||
@@ -71,11 +71,13 @@ class ChromaVector(BaseVector):
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas)
|
||||
# FIXME: chromadb using numpy array, fix the type error later
|
||||
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(where={key: {"$eq": value}})
|
||||
# FIXME: fix the type error later
|
||||
collection.delete(where={key: {"$eq": value}}) # type: ignore
|
||||
|
||||
def delete(self):
|
||||
self._client.delete_collection(self._collection_name)
|
||||
@@ -94,15 +96,19 @@ class ChromaVector(BaseVector):
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
ids: list[str] = results["ids"][0]
|
||||
documents: list[str] = results["documents"][0]
|
||||
metadatas: dict[str, Any] = results["metadatas"][0]
|
||||
distances: list[float] = results["distances"][0]
|
||||
# Check if results contain data
|
||||
if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]:
|
||||
return []
|
||||
|
||||
ids = results["ids"][0]
|
||||
documents = results["documents"][0]
|
||||
metadatas = results["metadatas"][0]
|
||||
distances = results["distances"][0]
|
||||
|
||||
docs = []
|
||||
for index in range(len(ids)):
|
||||
distance = distances[index]
|
||||
metadata = metadatas[index]
|
||||
metadata = dict(metadatas[index])
|
||||
if distance >= score_threshold:
|
||||
metadata["score"] = distance
|
||||
doc = Document(
|
||||
@@ -111,7 +117,7 @@ class ChromaVector(BaseVector):
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -133,7 +139,7 @@ class ChromaVectorFactory(AbstractVectorFactory):
|
||||
return ChromaVector(
|
||||
collection_name=collection_name,
|
||||
config=ChromaConfig(
|
||||
host=dify_config.CHROMA_HOST,
|
||||
host=dify_config.CHROMA_HOST or "",
|
||||
port=dify_config.CHROMA_PORT,
|
||||
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
|
||||
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
|
||||
|
||||
@@ -5,14 +5,14 @@ import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from couchbase import search
|
||||
from couchbase.auth import PasswordAuthenticator
|
||||
from couchbase.cluster import Cluster
|
||||
from couchbase.management.search import SearchIndex
|
||||
from couchbase import search # type: ignore
|
||||
from couchbase.auth import PasswordAuthenticator # type: ignore
|
||||
from couchbase.cluster import Cluster # type: ignore
|
||||
from couchbase.management.search import SearchIndex # type: ignore
|
||||
|
||||
# needed for options -- cluster, timeout, SQL++ (N1QL) query, etc.
|
||||
from couchbase.options import ClusterOptions, SearchOptions
|
||||
from couchbase.vector_search import VectorQuery, VectorSearch
|
||||
from couchbase.options import ClusterOptions, SearchOptions # type: ignore
|
||||
from couchbase.vector_search import VectorQuery, VectorSearch # type: ignore
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -231,7 +231,7 @@ class CouchbaseVector(BaseVector):
|
||||
# Pass the id as a parameter to the query
|
||||
result = self._cluster.query(query, named_parameters={"doc_id": id}).execute()
|
||||
for row in result:
|
||||
return row["count"] > 0
|
||||
return bool(row["count"] > 0)
|
||||
return False # Return False if no rows are returned
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
@@ -369,10 +369,10 @@ class CouchbaseVectorFactory(AbstractVectorFactory):
|
||||
return CouchbaseVector(
|
||||
collection_name=collection_name,
|
||||
config=CouchbaseConfig(
|
||||
connection_string=config.get("COUCHBASE_CONNECTION_STRING"),
|
||||
user=config.get("COUCHBASE_USER"),
|
||||
password=config.get("COUCHBASE_PASSWORD"),
|
||||
bucket_name=config.get("COUCHBASE_BUCKET_NAME"),
|
||||
scope_name=config.get("COUCHBASE_SCOPE_NAME"),
|
||||
connection_string=config.get("COUCHBASE_CONNECTION_STRING", ""),
|
||||
user=config.get("COUCHBASE_USER", ""),
|
||||
password=config.get("COUCHBASE_PASSWORD", ""),
|
||||
bucket_name=config.get("COUCHBASE_BUCKET_NAME", ""),
|
||||
scope_name=config.get("COUCHBASE_SCOPE_NAME", ""),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
@@ -70,7 +70,7 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def _get_version(self) -> str:
|
||||
info = self._client.info()
|
||||
return info["version"]["number"]
|
||||
return cast(str, info["version"]["number"])
|
||||
|
||||
def _check_version(self):
|
||||
if self._version < "8.0.0":
|
||||
@@ -135,7 +135,8 @@ class ElasticSearchVector(BaseVector):
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
@@ -156,12 +157,15 @@ class ElasticSearchVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
metadatas = [d.metadata for d in texts]
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas)
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||
index_params: Optional[dict] = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
@@ -208,10 +212,10 @@ class ElasticSearchVectorFactory(AbstractVectorFactory):
|
||||
return ElasticSearchVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(
|
||||
host=config.get("ELASTICSEARCH_HOST"),
|
||||
port=config.get("ELASTICSEARCH_PORT"),
|
||||
username=config.get("ELASTICSEARCH_USERNAME"),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD"),
|
||||
host=config.get("ELASTICSEARCH_HOST", "localhost"),
|
||||
port=config.get("ELASTICSEARCH_PORT", 9200),
|
||||
username=config.get("ELASTICSEARCH_USERNAME", ""),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD", ""),
|
||||
),
|
||||
attributes=[],
|
||||
)
|
||||
|
||||
@@ -42,7 +42,7 @@ class LindormVectorStoreConfig(BaseModel):
|
||||
return values
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {"hosts": self.hosts}
|
||||
params: dict[str, Any] = {"hosts": self.hosts}
|
||||
if self.username and self.password:
|
||||
params["http_auth"] = (self.username, self.password)
|
||||
return params
|
||||
@@ -53,7 +53,7 @@ class LindormVectorStore(BaseVector):
|
||||
self._routing = None
|
||||
self._routing_field = None
|
||||
if using_ugc:
|
||||
routing_value: str = kwargs.get("routing_value")
|
||||
routing_value: str | None = kwargs.get("routing_value")
|
||||
if routing_value is None:
|
||||
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
|
||||
self._routing = routing_value.lower()
|
||||
@@ -87,14 +87,15 @@ class LindormVectorStore(BaseVector):
|
||||
"_id": uuids[i],
|
||||
}
|
||||
}
|
||||
action_values = {
|
||||
action_values: dict[str, Any] = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
}
|
||||
if self._using_ugc:
|
||||
action_header["index"]["routing"] = self._routing
|
||||
action_values[self._routing_field] = self._routing
|
||||
if self._routing_field is not None:
|
||||
action_values[self._routing_field] = self._routing
|
||||
actions.append(action_header)
|
||||
actions.append(action_values)
|
||||
response = self._client.bulk(actions)
|
||||
@@ -105,7 +106,9 @@ class LindormVectorStore(BaseVector):
|
||||
self.refresh()
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
|
||||
query: dict[str, Any] = {
|
||||
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}
|
||||
}
|
||||
if self._using_ugc:
|
||||
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
|
||||
response = self._client.search(index=self._collection_name, body=query)
|
||||
@@ -191,7 +194,8 @@ class LindormVectorStore(BaseVector):
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
@@ -366,6 +370,7 @@ def default_text_search_query(
|
||||
routing_field: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
query_clause: dict[str, Any] = {}
|
||||
if routing is not None:
|
||||
query_clause = {
|
||||
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
|
||||
@@ -386,7 +391,7 @@ def default_text_search_query(
|
||||
else:
|
||||
must = [query_clause]
|
||||
|
||||
boolean_query = {"must": must}
|
||||
boolean_query: dict[str, Any] = {"must": must}
|
||||
|
||||
if must_not:
|
||||
if not isinstance(must_not, list):
|
||||
@@ -426,7 +431,7 @@ def default_vector_search_query(
|
||||
filter_type = "post_filter" if filter_type is None else filter_type
|
||||
if not isinstance(filters, list):
|
||||
raise RuntimeError(f"unexpected filter with {type(filters)}")
|
||||
final_ext = {"lvector": {}}
|
||||
final_ext: dict[str, Any] = {"lvector": {}}
|
||||
if min_score != "0.0":
|
||||
final_ext["lvector"]["min_score"] = min_score
|
||||
if ef_search:
|
||||
@@ -438,7 +443,7 @@ def default_vector_search_query(
|
||||
if client_refactor:
|
||||
final_ext["lvector"]["client_refactor"] = client_refactor
|
||||
|
||||
search_query = {
|
||||
search_query: dict[str, Any] = {
|
||||
"size": k,
|
||||
"_source": True, # force return '_source'
|
||||
"query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
|
||||
@@ -446,8 +451,8 @@ def default_vector_search_query(
|
||||
|
||||
if filters is not None:
|
||||
# when using filter, transform filter from List[Dict] to Dict as valid format
|
||||
filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
|
||||
search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict
|
||||
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
|
||||
search_query["query"]["knn"][vector_field]["filter"] = filter_dict # filter should be Dict
|
||||
if filter_type:
|
||||
final_ext["lvector"]["filter_type"] = filter_type
|
||||
|
||||
@@ -459,17 +464,19 @@ def default_vector_search_query(
|
||||
class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
|
||||
lindorm_config = LindormVectorStoreConfig(
|
||||
hosts=dify_config.LINDORM_URL,
|
||||
hosts=dify_config.LINDORM_URL or "",
|
||||
username=dify_config.LINDORM_USERNAME,
|
||||
password=dify_config.LINDORM_PASSWORD,
|
||||
using_ugc=dify_config.USING_UGC_INDEX,
|
||||
)
|
||||
using_ugc = dify_config.USING_UGC_INDEX
|
||||
if using_ugc is None:
|
||||
raise ValueError("USING_UGC_INDEX is not set")
|
||||
routing_value = None
|
||||
if dataset.index_struct:
|
||||
# if an existed record's index_struct_dict doesn't contain using_ugc field,
|
||||
# it actually stores in the normal index format
|
||||
stored_in_ugc = dataset.index_struct_dict.get("using_ugc", False)
|
||||
stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False)
|
||||
using_ugc = stored_in_ugc
|
||||
if stored_in_ugc:
|
||||
dimension = dataset.index_struct_dict["dimension"]
|
||||
|
||||
@@ -3,8 +3,8 @@ import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymilvus import MilvusClient, MilvusException
|
||||
from pymilvus.milvus_client import IndexParams
|
||||
from pymilvus import MilvusClient, MilvusException # type: ignore
|
||||
from pymilvus.milvus_client import IndexParams # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@@ -54,14 +54,14 @@ class MilvusVector(BaseVector):
|
||||
self._client_config = config
|
||||
self._client = self._init_client(config)
|
||||
self._consistency_level = "Session"
|
||||
self._fields = []
|
||||
self._fields: list[str] = []
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MILVUS
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
|
||||
metadatas = [d.metadata for d in texts]
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas, index_params)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
@@ -161,8 +161,8 @@ class MilvusVector(BaseVector):
|
||||
return
|
||||
# Grab the existing collection if it exists
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema
|
||||
from pymilvus.orm.types import infer_dtype_bydata
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
|
||||
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
|
||||
|
||||
# Determine embedding dim
|
||||
dim = len(embeddings[0])
|
||||
@@ -217,10 +217,10 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
||||
return MilvusVector(
|
||||
collection_name=collection_name,
|
||||
config=MilvusConfig(
|
||||
uri=dify_config.MILVUS_URI,
|
||||
token=dify_config.MILVUS_TOKEN,
|
||||
user=dify_config.MILVUS_USER,
|
||||
password=dify_config.MILVUS_PASSWORD,
|
||||
database=dify_config.MILVUS_DATABASE,
|
||||
uri=dify_config.MILVUS_URI or "",
|
||||
token=dify_config.MILVUS_TOKEN or "",
|
||||
user=dify_config.MILVUS_USER or "",
|
||||
password=dify_config.MILVUS_PASSWORD or "",
|
||||
database=dify_config.MILVUS_DATABASE or "",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -74,15 +74,16 @@ class MyScaleVector(BaseVector):
|
||||
columns = ["id", "text", "vector", "metadata"]
|
||||
values = []
|
||||
for i, doc in enumerate(documents):
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
row = (
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {},
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
row = (
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {},
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
sql = f"""
|
||||
INSERT INTO {self._config.database}.{self._collection_name}
|
||||
({",".join(columns)}) VALUES {",".join(values)}
|
||||
|
||||
@@ -4,7 +4,7 @@ import math
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pyobvector import VECTOR, ObVecClient
|
||||
from pyobvector import VECTOR, ObVecClient # type: ignore
|
||||
from sqlalchemy import JSON, Column, String, func
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
|
||||
@@ -131,7 +131,7 @@ class OceanBaseVector(BaseVector):
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
cur = self._client.get(table_name=self._collection_name, id=id)
|
||||
return cur.rowcount != 0
|
||||
return bool(cur.rowcount != 0)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
|
||||
@@ -66,7 +66,7 @@ class OpenSearchVector(BaseVector):
|
||||
return VectorType.OPENSEARCH
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
metadatas = [d.metadata for d in texts]
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
@@ -244,7 +244,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
|
||||
|
||||
open_search_config = OpenSearchConfig(
|
||||
host=dify_config.OPENSEARCH_HOST,
|
||||
host=dify_config.OPENSEARCH_HOST or "localhost",
|
||||
port=dify_config.OPENSEARCH_PORT,
|
||||
user=dify_config.OPENSEARCH_USER,
|
||||
password=dify_config.OPENSEARCH_PASSWORD,
|
||||
|
||||
@@ -5,7 +5,7 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import jieba.posseg as pseg
|
||||
import jieba.posseg as pseg # type: ignore
|
||||
import numpy
|
||||
import oracledb
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -88,12 +88,11 @@ class OracleVector(BaseVector):
|
||||
|
||||
def numpy_converter_out(self, value):
|
||||
if value.typecode == "b":
|
||||
dtype = numpy.int8
|
||||
return numpy.array(value, copy=False, dtype=numpy.int8)
|
||||
elif value.typecode == "f":
|
||||
dtype = numpy.float32
|
||||
return numpy.array(value, copy=False, dtype=numpy.float32)
|
||||
else:
|
||||
dtype = numpy.float64
|
||||
return numpy.array(value, copy=False, dtype=dtype)
|
||||
return numpy.array(value, copy=False, dtype=numpy.float64)
|
||||
|
||||
def output_type_handler(self, cursor, metadata):
|
||||
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
|
||||
@@ -135,17 +134,18 @@ class OracleVector(BaseVector):
|
||||
values = []
|
||||
pks = []
|
||||
for i, doc in enumerate(documents):
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
# array.array("f", embeddings[i]),
|
||||
numpy.array(embeddings[i]),
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
# array.array("f", embeddings[i]),
|
||||
numpy.array(embeddings[i]),
|
||||
)
|
||||
)
|
||||
)
|
||||
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
|
||||
with self._get_cursor() as cur:
|
||||
cur.executemany(
|
||||
@@ -201,8 +201,8 @@ class OracleVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# lazy import
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
import nltk # type: ignore
|
||||
from nltk.corpus import stopwords # type: ignore
|
||||
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
# just not implement fetch by score_threshold now, may be later
|
||||
@@ -285,10 +285,10 @@ class OracleVectorFactory(AbstractVectorFactory):
|
||||
return OracleVector(
|
||||
collection_name=collection_name,
|
||||
config=OracleVectorConfig(
|
||||
host=dify_config.ORACLE_HOST,
|
||||
host=dify_config.ORACLE_HOST or "localhost",
|
||||
port=dify_config.ORACLE_PORT,
|
||||
user=dify_config.ORACLE_USER,
|
||||
password=dify_config.ORACLE_PASSWORD,
|
||||
database=dify_config.ORACLE_DATABASE,
|
||||
user=dify_config.ORACLE_USER or "system",
|
||||
password=dify_config.ORACLE_PASSWORD or "oracle",
|
||||
database=dify_config.ORACLE_DATABASE or "orcl",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from numpy import ndarray
|
||||
from pgvecto_rs.sqlalchemy import VECTOR
|
||||
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Float, String, create_engine, insert, select, text
|
||||
from sqlalchemy import text as sql_text
|
||||
@@ -58,7 +58,7 @@ class PGVectoRS(BaseVector):
|
||||
with Session(self._client) as session:
|
||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
|
||||
session.commit()
|
||||
self._fields = []
|
||||
self._fields: list[str] = []
|
||||
|
||||
class _Table(CollectionORM):
|
||||
__tablename__ = collection_name
|
||||
@@ -222,11 +222,11 @@ class PGVectoRSFactory(AbstractVectorFactory):
|
||||
return PGVectoRS(
|
||||
collection_name=collection_name,
|
||||
config=PgvectoRSConfig(
|
||||
host=dify_config.PGVECTO_RS_HOST,
|
||||
port=dify_config.PGVECTO_RS_PORT,
|
||||
user=dify_config.PGVECTO_RS_USER,
|
||||
password=dify_config.PGVECTO_RS_PASSWORD,
|
||||
database=dify_config.PGVECTO_RS_DATABASE,
|
||||
host=dify_config.PGVECTO_RS_HOST or "localhost",
|
||||
port=dify_config.PGVECTO_RS_PORT or 5432,
|
||||
user=dify_config.PGVECTO_RS_USER or "postgres",
|
||||
password=dify_config.PGVECTO_RS_PASSWORD or "",
|
||||
database=dify_config.PGVECTO_RS_DATABASE or "postgres",
|
||||
),
|
||||
dim=dim,
|
||||
)
|
||||
|
||||
@@ -3,8 +3,8 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@@ -98,16 +98,17 @@ class PGVector(BaseVector):
|
||||
values = []
|
||||
pks = []
|
||||
for i, doc in enumerate(documents):
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
embeddings[i],
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
embeddings[i],
|
||||
)
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_values(
|
||||
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
|
||||
@@ -216,11 +217,11 @@ class PGVectorFactory(AbstractVectorFactory):
|
||||
return PGVector(
|
||||
collection_name=collection_name,
|
||||
config=PGVectorConfig(
|
||||
host=dify_config.PGVECTOR_HOST,
|
||||
host=dify_config.PGVECTOR_HOST or "localhost",
|
||||
port=dify_config.PGVECTOR_PORT,
|
||||
user=dify_config.PGVECTOR_USER,
|
||||
password=dify_config.PGVECTOR_PASSWORD,
|
||||
database=dify_config.PGVECTOR_DATABASE,
|
||||
user=dify_config.PGVECTOR_USER or "postgres",
|
||||
password=dify_config.PGVECTOR_PASSWORD or "",
|
||||
database=dify_config.PGVECTOR_DATABASE or "postgres",
|
||||
min_connection=dify_config.PGVECTOR_MIN_CONNECTION,
|
||||
max_connection=dify_config.PGVECTOR_MAX_CONNECTION,
|
||||
),
|
||||
|
||||
@@ -51,6 +51,8 @@ class QdrantConfig(BaseModel):
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
path = self.endpoint.replace("path:", "")
|
||||
if not os.path.isabs(path):
|
||||
if not self.root_path:
|
||||
raise ValueError("Root path is not set")
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {"path": path}
|
||||
@@ -149,9 +151,12 @@ class QdrantVector(BaseVector):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
added_ids = []
|
||||
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
|
||||
# Filter out None values from metadatas list to match expected type
|
||||
filtered_metadatas = [m for m in metadatas if m is not None]
|
||||
for batch_ids, points in self._generate_rest_batches(
|
||||
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
|
||||
):
|
||||
self._client.upsert(collection_name=self._collection_name, points=points)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
@@ -194,7 +199,7 @@ class QdrantVector(BaseVector):
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
group_id,
|
||||
group_id or "", # Ensure group_id is never None
|
||||
Field.GROUP_KEY.value,
|
||||
),
|
||||
)
|
||||
@@ -337,18 +342,20 @@ class QdrantVector(BaseVector):
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
if result.payload is None:
|
||||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if result.score > score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value),
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -432,9 +439,9 @@ class QdrantVectorFactory(AbstractVectorFactory):
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=QdrantConfig(
|
||||
endpoint=dify_config.QDRANT_URL,
|
||||
endpoint=dify_config.QDRANT_URL or "",
|
||||
api_key=dify_config.QDRANT_API_KEY,
|
||||
root_path=current_app.config.root_path,
|
||||
root_path=str(current_app.config.root_path),
|
||||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
|
||||
from sqlalchemy import Column, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -58,14 +58,14 @@ class RelytVector(BaseVector):
|
||||
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
)
|
||||
self.client = create_engine(self._url)
|
||||
self._fields = []
|
||||
self._fields: list[str] = []
|
||||
self._group_id = group_id
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.RELYT
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_params = {}
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None:
|
||||
index_params: dict[str, Any] = {}
|
||||
metadatas = [d.metadata for d in texts]
|
||||
self.create_collection(len(embeddings[0]))
|
||||
self.embedding_dimension = len(embeddings[0])
|
||||
@@ -107,10 +107,10 @@ class RelytVector(BaseVector):
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from pgvecto_rs.sqlalchemy import VECTOR
|
||||
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
|
||||
|
||||
ids = [str(uuid.uuid1()) for _ in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
metadatas = [d.metadata for d in documents if d.metadata is not None]
|
||||
for metadata in metadatas:
|
||||
metadata["group_id"] = self._group_id
|
||||
texts = [d.page_content for d in documents]
|
||||
@@ -242,10 +242,6 @@ class RelytVector(BaseVector):
|
||||
filter: Optional[dict] = None,
|
||||
) -> list[tuple[Document, float]]:
|
||||
# Add the filter if provided
|
||||
try:
|
||||
from sqlalchemy.engine import Row
|
||||
except ImportError:
|
||||
raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.")
|
||||
|
||||
filter_condition = ""
|
||||
if filter is not None:
|
||||
@@ -275,7 +271,7 @@ class RelytVector(BaseVector):
|
||||
|
||||
# Execute the query and fetch the results
|
||||
with self.client.connect() as conn:
|
||||
results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall()
|
||||
results = conn.execute(sql_text(sql_query), params).fetchall()
|
||||
|
||||
documents_with_scores = [
|
||||
(
|
||||
@@ -307,11 +303,11 @@ class RelytVectorFactory(AbstractVectorFactory):
|
||||
return RelytVector(
|
||||
collection_name=collection_name,
|
||||
config=RelytConfig(
|
||||
host=dify_config.RELYT_HOST,
|
||||
host=dify_config.RELYT_HOST or "localhost",
|
||||
port=dify_config.RELYT_PORT,
|
||||
user=dify_config.RELYT_USER,
|
||||
password=dify_config.RELYT_PASSWORD,
|
||||
database=dify_config.RELYT_DATABASE,
|
||||
user=dify_config.RELYT_USER or "",
|
||||
password=dify_config.RELYT_PASSWORD or "",
|
||||
database=dify_config.RELYT_DATABASE or "default",
|
||||
),
|
||||
group_id=dataset.id,
|
||||
)
|
||||
|
||||
@@ -2,10 +2,10 @@ import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tcvectordb import VectorDBClient
|
||||
from tcvectordb.model import document, enum
|
||||
from tcvectordb.model import index as vdb_index
|
||||
from tcvectordb.model.document import Filter
|
||||
from tcvectordb import VectorDBClient # type: ignore
|
||||
from tcvectordb.model import document, enum # type: ignore
|
||||
from tcvectordb.model import index as vdb_index # type: ignore
|
||||
from tcvectordb.model.document import Filter # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -25,8 +25,8 @@ class TencentConfig(BaseModel):
|
||||
database: Optional[str]
|
||||
index_type: str = "HNSW"
|
||||
metric_type: str = "L2"
|
||||
shard: int = (1,)
|
||||
replicas: int = (2,)
|
||||
shard: int = 1
|
||||
replicas: int = 2
|
||||
|
||||
def to_tencent_params(self):
|
||||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
||||
@@ -120,15 +120,15 @@ class TencentVector(BaseVector):
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
total_count = len(embeddings)
|
||||
docs = []
|
||||
for id in range(0, total_count):
|
||||
for i in range(0, total_count):
|
||||
if metadatas is None:
|
||||
continue
|
||||
metadata = json.dumps(metadatas[id])
|
||||
metadata = metadatas[i] or {}
|
||||
doc = document.Document(
|
||||
id=metadatas[id]["doc_id"],
|
||||
vector=embeddings[id],
|
||||
text=texts[id],
|
||||
metadata=metadata,
|
||||
id=metadata.get("doc_id"),
|
||||
vector=embeddings[i],
|
||||
text=texts[i],
|
||||
metadata=json.dumps(metadata),
|
||||
)
|
||||
docs.append(doc)
|
||||
self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout)
|
||||
@@ -159,8 +159,8 @@ class TencentVector(BaseVector):
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
|
||||
def _get_search_res(self, res, score_threshold):
|
||||
docs = []
|
||||
def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]:
|
||||
docs: list[Document] = []
|
||||
if res is None or len(res) == 0:
|
||||
return docs
|
||||
|
||||
@@ -193,7 +193,7 @@ class TencentVectorFactory(AbstractVectorFactory):
|
||||
return TencentVector(
|
||||
collection_name=collection_name,
|
||||
config=TencentConfig(
|
||||
url=dify_config.TENCENT_VECTOR_DB_URL,
|
||||
url=dify_config.TENCENT_VECTOR_DB_URL or "",
|
||||
api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
|
||||
timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
|
||||
username=dify_config.TENCENT_VECTOR_DB_USERNAME,
|
||||
|
||||
@@ -54,7 +54,10 @@ class TidbOnQdrantConfig(BaseModel):
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
path = self.endpoint.replace("path:", "")
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(self.root_path, path)
|
||||
if self.root_path:
|
||||
path = os.path.join(self.root_path, path)
|
||||
else:
|
||||
raise ValueError("root_path is required")
|
||||
|
||||
return {"path": path}
|
||||
else:
|
||||
@@ -157,7 +160,7 @@ class TidbOnQdrantVector(BaseVector):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
metadatas = [d.metadata for d in documents if d.metadata is not None]
|
||||
|
||||
added_ids = []
|
||||
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
|
||||
@@ -203,7 +206,7 @@ class TidbOnQdrantVector(BaseVector):
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
group_id,
|
||||
group_id or "",
|
||||
Field.GROUP_KEY.value,
|
||||
),
|
||||
)
|
||||
@@ -334,18 +337,20 @@ class TidbOnQdrantVector(BaseVector):
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
if result.payload is None:
|
||||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
if result.score > score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value),
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -427,12 +432,12 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
|
||||
else:
|
||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||
dify_config.TIDB_PROJECT_ID,
|
||||
dify_config.TIDB_API_URL,
|
||||
dify_config.TIDB_IAM_API_URL,
|
||||
dify_config.TIDB_PUBLIC_KEY,
|
||||
dify_config.TIDB_PRIVATE_KEY,
|
||||
dify_config.TIDB_REGION,
|
||||
dify_config.TIDB_PROJECT_ID or "",
|
||||
dify_config.TIDB_API_URL or "",
|
||||
dify_config.TIDB_IAM_API_URL or "",
|
||||
dify_config.TIDB_PUBLIC_KEY or "",
|
||||
dify_config.TIDB_PRIVATE_KEY or "",
|
||||
dify_config.TIDB_REGION or "",
|
||||
)
|
||||
new_tidb_auth_binding = TidbAuthBinding(
|
||||
cluster_id=new_cluster["cluster_id"],
|
||||
@@ -464,9 +469,9 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=TidbOnQdrantConfig(
|
||||
endpoint=dify_config.TIDB_ON_QDRANT_URL,
|
||||
endpoint=dify_config.TIDB_ON_QDRANT_URL or "",
|
||||
api_key=TIDB_ON_QDRANT_API_KEY,
|
||||
root_path=config.root_path,
|
||||
root_path=str(config.root_path),
|
||||
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,
|
||||
|
||||
@@ -146,7 +146,7 @@ class TidbService:
|
||||
iam_url: str,
|
||||
public_key: str,
|
||||
private_key: str,
|
||||
) -> list[dict]:
|
||||
):
|
||||
"""
|
||||
Update the status of a new TiDB Serverless cluster.
|
||||
:param project_id: The project ID of the TiDB Cloud project (required).
|
||||
@@ -159,7 +159,6 @@ class TidbService:
|
||||
|
||||
:return: The response from the API.
|
||||
"""
|
||||
clusters = []
|
||||
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
|
||||
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
|
||||
params = {"clusterIds": cluster_ids, "view": "BASIC"}
|
||||
@@ -169,7 +168,6 @@ class TidbService:
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_infos = []
|
||||
for item in response_data["clusters"]:
|
||||
state = item["state"]
|
||||
userPrefix = item["userPrefix"]
|
||||
@@ -236,16 +234,17 @@ class TidbService:
|
||||
cluster_infos = []
|
||||
for item in response_data["clusters"]:
|
||||
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
|
||||
password = redis_client.get(cache_key)
|
||||
if not password:
|
||||
cached_password = redis_client.get(cache_key)
|
||||
if not cached_password:
|
||||
continue
|
||||
cluster_info = {
|
||||
"cluster_id": item["clusterId"],
|
||||
"cluster_name": item["displayName"],
|
||||
"account": "root",
|
||||
"password": password.decode("utf-8"),
|
||||
"password": cached_password.decode("utf-8"),
|
||||
}
|
||||
cluster_infos.append(cluster_info)
|
||||
return cluster_infos
|
||||
else:
|
||||
response.raise_for_status()
|
||||
return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception
|
||||
|
||||
@@ -49,7 +49,7 @@ class TiDBVector(BaseVector):
|
||||
return VectorType.TIDB_VECTOR
|
||||
|
||||
def _table(self, dim: int) -> Table:
|
||||
from tidb_vector.sqlalchemy import VectorType
|
||||
from tidb_vector.sqlalchemy import VectorType # type: ignore
|
||||
|
||||
return Table(
|
||||
self._collection_name,
|
||||
@@ -241,11 +241,11 @@ class TiDBVectorFactory(AbstractVectorFactory):
|
||||
return TiDBVector(
|
||||
collection_name=collection_name,
|
||||
config=TiDBVectorConfig(
|
||||
host=dify_config.TIDB_VECTOR_HOST,
|
||||
port=dify_config.TIDB_VECTOR_PORT,
|
||||
user=dify_config.TIDB_VECTOR_USER,
|
||||
password=dify_config.TIDB_VECTOR_PASSWORD,
|
||||
database=dify_config.TIDB_VECTOR_DATABASE,
|
||||
host=dify_config.TIDB_VECTOR_HOST or "",
|
||||
port=dify_config.TIDB_VECTOR_PORT or 0,
|
||||
user=dify_config.TIDB_VECTOR_USER or "",
|
||||
password=dify_config.TIDB_VECTOR_PASSWORD or "",
|
||||
database=dify_config.TIDB_VECTOR_DATABASE or "",
|
||||
program_name=dify_config.APPLICATION_NAME,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -51,15 +51,16 @@ class BaseVector(ABC):
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts.copy():
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
if text.metadata and "doc_id" in text.metadata:
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||
return [text.metadata["doc_id"] for text in texts]
|
||||
return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata]
|
||||
|
||||
@property
|
||||
def collection_name(self):
|
||||
|
||||
@@ -193,10 +193,13 @@ class Vector:
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts.copy():
|
||||
if text.metadata is None:
|
||||
continue
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
if doc_id:
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from volcengine.viking_db import (
|
||||
from volcengine.viking_db import ( # type: ignore
|
||||
Data,
|
||||
DistanceType,
|
||||
Field,
|
||||
@@ -121,11 +121,12 @@ class VikingDBVector(BaseVector):
|
||||
for i, page_content in enumerate(page_contents):
|
||||
metadata = {}
|
||||
if metadatas is not None:
|
||||
for key, val in metadatas[i].items():
|
||||
for key, val in (metadatas[i] or {}).items():
|
||||
metadata[key] = val
|
||||
# FIXME: fix the type of metadata later
|
||||
doc = Data(
|
||||
{
|
||||
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"],
|
||||
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore
|
||||
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
|
||||
vdb_Field.CONTENT_KEY.value: page_content,
|
||||
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
|
||||
@@ -178,7 +179,7 @@ class VikingDBVector(BaseVector):
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(results, score_threshold)
|
||||
|
||||
def _get_search_res(self, results, score_threshold):
|
||||
def _get_search_res(self, results, score_threshold) -> list[Document]:
|
||||
if len(results) == 0:
|
||||
return []
|
||||
|
||||
@@ -191,7 +192,7 @@ class VikingDBVector(BaseVector):
|
||||
metadata["score"] = result.score
|
||||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
import weaviate
|
||||
import weaviate # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@@ -107,7 +107,8 @@ class WeaviateVector(BaseVector):
|
||||
for i, text in enumerate(texts):
|
||||
data_properties = {Field.TEXT_KEY.value: text}
|
||||
if metadatas is not None:
|
||||
for key, val in metadatas[i].items():
|
||||
# metadata maybe None
|
||||
for key, val in (metadatas[i] or {}).items():
|
||||
data_properties[key] = self._json_serializable(val)
|
||||
|
||||
batch.add_data_object(
|
||||
@@ -208,10 +209,11 @@ class WeaviateVector(BaseVector):
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -275,7 +277,7 @@ class WeaviateVectorFactory(AbstractVectorFactory):
|
||||
return WeaviateVector(
|
||||
collection_name=collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint=dify_config.WEAVIATE_ENDPOINT,
|
||||
endpoint=dify_config.WEAVIATE_ENDPOINT or "",
|
||||
api_key=dify_config.WEAVIATE_API_KEY,
|
||||
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user