mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-15 22:06:52 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -29,6 +29,7 @@ class AnalyticdbConfig(BaseModel):
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = ("cosine",)
|
||||
read_timeout: int = 60000
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
@@ -37,6 +38,7 @@ class AnalyticdbConfig(BaseModel):
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
_instance = None
|
||||
_init = False
|
||||
@@ -57,9 +59,7 @@ class AnalyticdbVector(BaseVector):
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(
|
||||
user_agent="dify", **config.to_analyticdb_client_params()
|
||||
)
|
||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
AnalyticdbVector._init = True
|
||||
@@ -77,6 +77,7 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -88,6 +89,7 @@ class AnalyticdbVector(BaseVector):
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
@@ -109,13 +111,12 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failed to create namespace {self.config.namespace}: {e}"
|
||||
)
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
@@ -149,9 +150,7 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failed to create collection {self._collection_name}: {e}"
|
||||
)
|
||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def get_type(self) -> str:
|
||||
@@ -162,10 +161,9 @@ class AnalyticdbVector(BaseVector):
|
||||
self._create_collection_if_not_exists(dimension)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(
|
||||
self, documents: list[Document], embeddings: list[list[float]], **kwargs
|
||||
):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
@@ -191,6 +189,7 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -202,13 +201,14 @@ class AnalyticdbVector(BaseVector):
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'"
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
@@ -224,6 +224,7 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -235,15 +236,10 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def search_by_vector(
|
||||
self, query_vector: list[float], **kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
kwargs.get("score_threshold", 0.0)
|
||||
if kwargs.get("score_threshold", 0.0)
|
||||
else 0.0
|
||||
)
|
||||
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -270,11 +266,8 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
kwargs.get("score_threshold", 0.0)
|
||||
if kwargs.get("score_threshold", 0.0)
|
||||
else 0.0
|
||||
)
|
||||
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -304,6 +297,7 @@ class AnalyticdbVector(BaseVector):
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
@@ -315,19 +309,16 @@ class AnalyticdbVector(BaseVector):
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"][
|
||||
"class_prefix"
|
||||
]
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
|
||||
)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
|
||||
|
||||
# handle optional params
|
||||
if dify_config.ANALYTICDB_KEY_ID is None:
|
||||
|
||||
@@ -27,21 +27,20 @@ class ChromaConfig(BaseModel):
|
||||
settings = Settings(
|
||||
# auth
|
||||
chroma_client_auth_provider=self.auth_provider,
|
||||
chroma_client_auth_credentials=self.auth_credentials
|
||||
chroma_client_auth_credentials=self.auth_credentials,
|
||||
)
|
||||
|
||||
return {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'ssl': False,
|
||||
'tenant': self.tenant,
|
||||
'database': self.database,
|
||||
'settings': settings,
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"ssl": False,
|
||||
"tenant": self.tenant,
|
||||
"database": self.database,
|
||||
"settings": settings,
|
||||
}
|
||||
|
||||
|
||||
class ChromaVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: ChromaConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
@@ -58,9 +57,9 @@ class ChromaVector(BaseVector):
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(self, collection_name: str):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
self._client.get_or_create_collection(collection_name)
|
||||
@@ -76,7 +75,7 @@ class ChromaVector(BaseVector):
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(where={key: {'$eq': value}})
|
||||
collection.delete(where={key: {"$eq": value}})
|
||||
|
||||
def delete(self):
|
||||
self._client.delete_collection(self._collection_name)
|
||||
@@ -93,26 +92,26 @@ class ChromaVector(BaseVector):
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
|
||||
ids: list[str] = results['ids'][0]
|
||||
documents: list[str] = results['documents'][0]
|
||||
metadatas: dict[str, Any] = results['metadatas'][0]
|
||||
distances: list[float] = results['distances'][0]
|
||||
ids: list[str] = results["ids"][0]
|
||||
documents: list[str] = results["documents"][0]
|
||||
metadatas: dict[str, Any] = results["metadatas"][0]
|
||||
distances: list[float] = results["distances"][0]
|
||||
|
||||
docs = []
|
||||
for index in range(len(ids)):
|
||||
distance = distances[index]
|
||||
metadata = metadatas[index]
|
||||
if distance >= score_threshold:
|
||||
metadata['score'] = distance
|
||||
metadata["score"] = distance
|
||||
doc = Document(
|
||||
page_content=documents[index],
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -123,15 +122,12 @@ class ChromaVector(BaseVector):
|
||||
class ChromaVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
index_struct_dict = {
|
||||
"type": VectorType.CHROMA,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
return ChromaVector(
|
||||
|
||||
@@ -26,15 +26,15 @@ class ElasticSearchConfig(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
if not values["host"]:
|
||||
raise ValueError("config HOST is required")
|
||||
if not values['port']:
|
||||
if not values["port"]:
|
||||
raise ValueError("config PORT is required")
|
||||
if not values['username']:
|
||||
if not values["username"]:
|
||||
raise ValueError("config USERNAME is required")
|
||||
if not values['password']:
|
||||
if not values["password"]:
|
||||
raise ValueError("config PASSWORD is required")
|
||||
return values
|
||||
|
||||
@@ -50,10 +50,10 @@ class ElasticSearchVector(BaseVector):
|
||||
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
||||
try:
|
||||
parsed_url = urlparse(config.host)
|
||||
if parsed_url.scheme in ['http', 'https']:
|
||||
hosts = f'{config.host}:{config.port}'
|
||||
if parsed_url.scheme in ["http", "https"]:
|
||||
hosts = f"{config.host}:{config.port}"
|
||||
else:
|
||||
hosts = f'http://{config.host}:{config.port}'
|
||||
hosts = f"http://{config.host}:{config.port}"
|
||||
client = Elasticsearch(
|
||||
hosts=hosts,
|
||||
basic_auth=(config.username, config.password),
|
||||
@@ -68,25 +68,27 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def _get_version(self) -> str:
|
||||
info = self._client.info()
|
||||
return info['version']['number']
|
||||
return info["version"]["number"]
|
||||
|
||||
def _check_version(self):
|
||||
if self._version < '8.0.0':
|
||||
if self._version < "8.0.0":
|
||||
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'elasticsearch'
|
||||
return "elasticsearch"
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
for i in range(len(documents)):
|
||||
self._client.index(index=self._collection_name,
|
||||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
|
||||
})
|
||||
self._client.index(
|
||||
index=self._collection_name,
|
||||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {},
|
||||
},
|
||||
)
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
return uuids
|
||||
|
||||
@@ -98,15 +100,9 @@ class ElasticSearchVector(BaseVector):
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
query_str = {
|
||||
'query': {
|
||||
'match': {
|
||||
f'metadata.{key}': f'{value}'
|
||||
}
|
||||
}
|
||||
}
|
||||
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
|
||||
results = self._client.search(index=self._collection_name, body=query_str)
|
||||
ids = [hit['_id'] for hit in results['hits']['hits']]
|
||||
ids = [hit["_id"] for hit in results["hits"]["hits"]]
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
@@ -115,44 +111,44 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
knn = {
|
||||
"field": Field.VECTOR.value,
|
||||
"query_vector": query_vector,
|
||||
"k": top_k
|
||||
}
|
||||
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k}
|
||||
|
||||
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
|
||||
|
||||
docs_and_scores = []
|
||||
for hit in results['hits']['hits']:
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs_and_scores.append(
|
||||
(Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
|
||||
vector=hit['_source'][Field.VECTOR.value],
|
||||
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
)
|
||||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
if score > score_threshold:
|
||||
doc.metadata['score'] = score
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {
|
||||
"match": {
|
||||
Field.CONTENT_KEY.value: query
|
||||
}
|
||||
}
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
results = self._client.search(index=self._collection_name, query=query_str)
|
||||
docs = []
|
||||
for hit in results['hits']['hits']:
|
||||
docs.append(Document(
|
||||
page_content=hit['_source'][Field.CONTENT_KEY.value],
|
||||
vector=hit['_source'][Field.VECTOR.value],
|
||||
metadata=hit['_source'][Field.METADATA_KEY.value],
|
||||
))
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
)
|
||||
)
|
||||
|
||||
return docs
|
||||
|
||||
@@ -162,11 +158,11 @@ class ElasticSearchVector(BaseVector):
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
):
|
||||
lock_name = f'vector_indexing_lock_{self._collection_name}'
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name} already exists.")
|
||||
return
|
||||
@@ -179,14 +175,14 @@ class ElasticSearchVector(BaseVector):
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
"type": "dense_vector",
|
||||
"dims": dim,
|
||||
"similarity": "cosine"
|
||||
"similarity": "cosine",
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
self._client.indices.create(index=self._collection_name, mappings=mappings)
|
||||
@@ -197,22 +193,21 @@ class ElasticSearchVector(BaseVector):
|
||||
class ElasticSearchVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return ElasticSearchVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(
|
||||
host=config.get('ELASTICSEARCH_HOST'),
|
||||
port=config.get('ELASTICSEARCH_PORT'),
|
||||
username=config.get('ELASTICSEARCH_USERNAME'),
|
||||
password=config.get('ELASTICSEARCH_PASSWORD'),
|
||||
host=config.get("ELASTICSEARCH_HOST"),
|
||||
port=config.get("ELASTICSEARCH_PORT"),
|
||||
username=config.get("ELASTICSEARCH_USERNAME"),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD"),
|
||||
),
|
||||
attributes=[]
|
||||
attributes=[],
|
||||
)
|
||||
|
||||
@@ -27,44 +27,39 @@ class MilvusConfig(BaseModel):
|
||||
batch_size: int = 100
|
||||
database: str = "default"
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values.get('uri'):
|
||||
if not values.get("uri"):
|
||||
raise ValueError("config MILVUS_URI is required")
|
||||
if not values.get('user'):
|
||||
if not values.get("user"):
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values.get('password'):
|
||||
if not values.get("password"):
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
return {
|
||||
'uri': self.uri,
|
||||
'token': self.token,
|
||||
'user': self.user,
|
||||
'password': self.password,
|
||||
'db_name': self.database,
|
||||
"uri": self.uri,
|
||||
"token": self.token,
|
||||
"user": self.user,
|
||||
"password": self.password,
|
||||
"db_name": self.database,
|
||||
}
|
||||
|
||||
|
||||
class MilvusVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: MilvusConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = self._init_client(config)
|
||||
self._consistency_level = 'Session'
|
||||
self._consistency_level = "Session"
|
||||
self._fields = []
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MILVUS
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_params = {
|
||||
'metric_type': 'IP',
|
||||
'index_type': "HNSW",
|
||||
'params': {"M": 8, "efConstruction": 64}
|
||||
}
|
||||
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
|
||||
metadatas = [d.metadata for d in texts]
|
||||
self.create_collection(embeddings, metadatas, index_params)
|
||||
self.add_texts(texts, embeddings)
|
||||
@@ -75,7 +70,7 @@ class MilvusVector(BaseVector):
|
||||
insert_dict = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
}
|
||||
insert_dict_list.append(insert_dict)
|
||||
# Total insert count
|
||||
@@ -84,22 +79,20 @@ class MilvusVector(BaseVector):
|
||||
pks: list[str] = []
|
||||
|
||||
for i in range(0, total_count, 1000):
|
||||
batch_insert_list = insert_dict_list[i:i + 1000]
|
||||
batch_insert_list = insert_dict_list[i : i + 1000]
|
||||
# Insert into the collection.
|
||||
try:
|
||||
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
||||
pks.extend(ids)
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to insert batch starting at entity: %s/%s", i, total_count
|
||||
)
|
||||
logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count)
|
||||
raise e
|
||||
return pks
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["{key}"] == "{value}"',
|
||||
output_fields=["id"])
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
|
||||
)
|
||||
if result:
|
||||
return [item["id"] for item in result]
|
||||
else:
|
||||
@@ -107,17 +100,15 @@ class MilvusVector(BaseVector):
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
if self._client.has_collection(self._collection_name):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if self._client.has_collection(self._collection_name):
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] in {ids}',
|
||||
output_fields=["id"])
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
|
||||
)
|
||||
if result:
|
||||
ids = [item["id"] for item in result]
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
@@ -130,29 +121,28 @@ class MilvusVector(BaseVector):
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
return False
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] == "{id}"',
|
||||
output_fields=["id"])
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"]
|
||||
)
|
||||
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
|
||||
# Set search parameters.
|
||||
results = self._client.search(collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
limit=kwargs.get('top_k', 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results[0]:
|
||||
metadata = result['entity'].get(Field.METADATA_KEY.value)
|
||||
metadata['score'] = result['distance']
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
if result['distance'] > score_threshold:
|
||||
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
|
||||
metadata=metadata)
|
||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
||||
metadata["score"] = result["distance"]
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@@ -161,11 +151,11 @@ class MilvusVector(BaseVector):
|
||||
return []
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
# Grab the existing collection if it exists
|
||||
@@ -180,19 +170,11 @@ class MilvusVector(BaseVector):
|
||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field
|
||||
fields.append(
|
||||
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
|
||||
)
|
||||
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
|
||||
# Create the primary key field
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
|
||||
)
|
||||
)
|
||||
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(
|
||||
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
|
||||
)
|
||||
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
|
||||
|
||||
# Create the schema for the collection
|
||||
schema = CollectionSchema(fields)
|
||||
@@ -208,9 +190,12 @@ class MilvusVector(BaseVector):
|
||||
|
||||
# Create the collection
|
||||
collection_name = self._collection_name
|
||||
self._client.create_collection(collection_name=collection_name,
|
||||
schema=schema, index_params=index_params_obj,
|
||||
consistency_level=self._consistency_level)
|
||||
self._client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params_obj,
|
||||
consistency_level=self._consistency_level,
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _init_client(self, config) -> MilvusClient:
|
||||
@@ -221,13 +206,12 @@ class MilvusVector(BaseVector):
|
||||
class MilvusVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
|
||||
|
||||
return MilvusVector(
|
||||
collection_name=collection_name,
|
||||
@@ -237,5 +221,5 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
||||
user=dify_config.MILVUS_USER,
|
||||
password=dify_config.MILVUS_PASSWORD,
|
||||
database=dify_config.MILVUS_DATABASE,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -31,7 +31,6 @@ class SortOrder(Enum):
|
||||
|
||||
|
||||
class MyScaleVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
@@ -80,7 +79,7 @@ class MyScaleVector(BaseVector):
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {}
|
||||
json.dumps(doc.metadata) if doc.metadata else {},
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
@@ -101,7 +100,8 @@ class MyScaleVector(BaseVector):
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}")
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
rows = self._client.query(
|
||||
@@ -122,9 +122,12 @@ class MyScaleVector(BaseVector):
|
||||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get('score_threshold') or 0.0
|
||||
where_str = f"WHERE dist < {1 - score_threshold}" if \
|
||||
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
where_str = (
|
||||
f"WHERE dist < {1 - score_threshold}"
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
else ""
|
||||
)
|
||||
sql = f"""
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
@@ -133,7 +136,7 @@ class MyScaleVector(BaseVector):
|
||||
return [
|
||||
Document(
|
||||
page_content=r["text"],
|
||||
vector=r['vector'],
|
||||
vector=r["vector"],
|
||||
metadata=r["metadata"],
|
||||
)
|
||||
for r in self._client.query(sql).named_results()
|
||||
@@ -149,13 +152,12 @@ class MyScaleVector(BaseVector):
|
||||
class MyScaleVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
|
||||
|
||||
return MyScaleVector(
|
||||
collection_name=collection_name,
|
||||
|
||||
@@ -28,11 +28,11 @@ class OpenSearchConfig(BaseModel):
|
||||
password: Optional[str] = None
|
||||
secure: bool = False
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values.get('host'):
|
||||
if not values.get("host"):
|
||||
raise ValueError("config OPENSEARCH_HOST is required")
|
||||
if not values.get('port'):
|
||||
if not values.get("port"):
|
||||
raise ValueError("config OPENSEARCH_PORT is required")
|
||||
return values
|
||||
|
||||
@@ -44,19 +44,18 @@ class OpenSearchConfig(BaseModel):
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
'hosts': [{'host': self.host, 'port': self.port}],
|
||||
'use_ssl': self.secure,
|
||||
'verify_certs': self.secure,
|
||||
"hosts": [{"host": self.host, "port": self.port}],
|
||||
"use_ssl": self.secure,
|
||||
"verify_certs": self.secure,
|
||||
}
|
||||
if self.user and self.password:
|
||||
params['http_auth'] = (self.user, self.password)
|
||||
params["http_auth"] = (self.user, self.password)
|
||||
if self.secure:
|
||||
params['ssl_context'] = self.create_ssl_context()
|
||||
params["ssl_context"] = self.create_ssl_context()
|
||||
return params
|
||||
|
||||
|
||||
class OpenSearchVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: OpenSearchConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
@@ -81,7 +80,7 @@ class OpenSearchVector(BaseVector):
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
}
|
||||
},
|
||||
}
|
||||
actions.append(action)
|
||||
|
||||
@@ -90,8 +89,8 @@ class OpenSearchVector(BaseVector):
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
if response['hits']['hits']:
|
||||
return [hit['_id'] for hit in response['hits']['hits']]
|
||||
if response["hits"]["hits"]:
|
||||
return [hit["_id"] for hit in response["hits"]["hits"]]
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -110,7 +109,7 @@ class OpenSearchVector(BaseVector):
|
||||
actual_ids = []
|
||||
|
||||
for doc_id in ids:
|
||||
es_ids = self.get_ids_by_metadata_field('doc_id', doc_id)
|
||||
es_ids = self.get_ids_by_metadata_field("doc_id", doc_id)
|
||||
if es_ids:
|
||||
actual_ids.extend(es_ids)
|
||||
else:
|
||||
@@ -122,9 +121,9 @@ class OpenSearchVector(BaseVector):
|
||||
helpers.bulk(self._client, actions)
|
||||
except BulkIndexError as e:
|
||||
for error in e.errors:
|
||||
delete_error = error.get('delete', {})
|
||||
status = delete_error.get('status')
|
||||
doc_id = delete_error.get('_id')
|
||||
delete_error = error.get("delete", {})
|
||||
status = delete_error.get("status")
|
||||
doc_id = delete_error.get("_id")
|
||||
|
||||
if status == 404:
|
||||
logger.warning(f"Document not found for deletion: {doc_id}")
|
||||
@@ -151,15 +150,8 @@ class OpenSearchVector(BaseVector):
|
||||
raise ValueError("All elements in query_vector should be floats")
|
||||
|
||||
query = {
|
||||
"size": kwargs.get('top_k', 4),
|
||||
"query": {
|
||||
"knn": {
|
||||
Field.VECTOR.value: {
|
||||
Field.VECTOR.value: query_vector,
|
||||
"k": kwargs.get('top_k', 4)
|
||||
}
|
||||
}
|
||||
}
|
||||
"size": kwargs.get("top_k", 4),
|
||||
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -169,17 +161,17 @@ class OpenSearchVector(BaseVector):
|
||||
raise
|
||||
|
||||
docs = []
|
||||
for hit in response['hits']['hits']:
|
||||
metadata = hit['_source'].get(Field.METADATA_KEY.value, {})
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value, {})
|
||||
|
||||
# Make sure metadata is a dictionary
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
metadata['score'] = hit['_score']
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
if hit['_score'] > score_threshold:
|
||||
doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
metadata["score"] = hit["_score"]
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
if hit["_score"] > score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
@@ -190,32 +182,28 @@ class OpenSearchVector(BaseVector):
|
||||
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
||||
|
||||
docs = []
|
||||
for hit in response['hits']['hits']:
|
||||
metadata = hit['_source'].get(Field.METADATA_KEY.value)
|
||||
vector = hit['_source'].get(Field.VECTOR.value)
|
||||
page_content = hit['_source'].get(Field.CONTENT_KEY.value)
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value)
|
||||
vector = hit["_source"].get(Field.VECTOR.value)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
|
||||
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
):
|
||||
lock_name = f'vector_indexing_lock_{self._collection_name.lower()}'
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name.lower()}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}'
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name.lower()} already exists.")
|
||||
return
|
||||
|
||||
if not self._client.indices.exists(index=self._collection_name.lower()):
|
||||
index_body = {
|
||||
"settings": {
|
||||
"index": {
|
||||
"knn": True
|
||||
}
|
||||
},
|
||||
"settings": {"index": {"knn": True}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
@@ -226,20 +214,17 @@ class OpenSearchVector(BaseVector):
|
||||
"name": "hnsw",
|
||||
"space_type": "l2",
|
||||
"engine": "faiss",
|
||||
"parameters": {
|
||||
"ef_construction": 64,
|
||||
"m": 8
|
||||
}
|
||||
}
|
||||
"parameters": {"ef_construction": 64, "m": 8},
|
||||
},
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
|
||||
@@ -248,17 +233,14 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
|
||||
class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
|
||||
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
|
||||
|
||||
open_search_config = OpenSearchConfig(
|
||||
host=dify_config.OPENSEARCH_HOST,
|
||||
@@ -268,7 +250,4 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
)
|
||||
|
||||
return OpenSearchVector(
|
||||
collection_name=collection_name,
|
||||
config=open_search_config
|
||||
)
|
||||
return OpenSearchVector(collection_name=collection_name, config=open_search_config)
|
||||
|
||||
@@ -31,7 +31,7 @@ class OracleVectorConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config ORACLE_HOST is required")
|
||||
@@ -103,9 +103,16 @@ class OracleVector(BaseVector):
|
||||
arraysize=cursor.arraysize,
|
||||
outconverter=self.numpy_converter_out,
|
||||
)
|
||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||
return oracledb.create_pool(user=config.user, password=config.password, dsn="{}:{}/{}".format(config.host, config.port, config.database), min=1, max=50, increment=1)
|
||||
|
||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||
return oracledb.create_pool(
|
||||
user=config.user,
|
||||
password=config.password,
|
||||
dsn="{}:{}/{}".format(config.host, config.port, config.database),
|
||||
min=1,
|
||||
max=50,
|
||||
increment=1,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
@@ -136,13 +143,15 @@ class OracleVector(BaseVector):
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
#array.array("f", embeddings[i]),
|
||||
# array.array("f", embeddings[i]),
|
||||
numpy.array(embeddings[i]),
|
||||
)
|
||||
)
|
||||
#print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
|
||||
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
|
||||
with self._get_cursor() as cur:
|
||||
cur.executemany(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values)
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
|
||||
)
|
||||
return pks
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
@@ -157,7 +166,8 @@ class OracleVector(BaseVector):
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
return docs
|
||||
#def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
# def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
# with self._get_cursor() as cur:
|
||||
# cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" )
|
||||
# idss = []
|
||||
@@ -184,7 +194,8 @@ class OracleVector(BaseVector):
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only" ,[numpy.array(query_vector)]
|
||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
@@ -202,7 +213,7 @@ class OracleVector(BaseVector):
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
if len(query) > 0:
|
||||
# Check which language the query is in
|
||||
zh_pattern = re.compile('[\u4e00-\u9fa5]+')
|
||||
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
||||
match = zh_pattern.search(query)
|
||||
entities = []
|
||||
# match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split
|
||||
@@ -210,7 +221,15 @@ class OracleVector(BaseVector):
|
||||
words = pseg.cut(query)
|
||||
current_entity = ""
|
||||
for word, pos in words:
|
||||
if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v': # nr: 人名, ns: 地名, nt: 机构名
|
||||
if (
|
||||
pos == "nr"
|
||||
or pos == "Ng"
|
||||
or pos == "eng"
|
||||
or pos == "nz"
|
||||
or pos == "n"
|
||||
or pos == "ORG"
|
||||
or pos == "v"
|
||||
): # nr: 人名, ns: 地名, nt: 机构名
|
||||
current_entity += word
|
||||
else:
|
||||
if current_entity:
|
||||
@@ -220,22 +239,22 @@ class OracleVector(BaseVector):
|
||||
entities.append(current_entity)
|
||||
else:
|
||||
try:
|
||||
nltk.data.find('tokenizers/punkt')
|
||||
nltk.data.find('corpora/stopwords')
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
nltk.data.find("corpora/stopwords")
|
||||
except LookupError:
|
||||
nltk.download('punkt')
|
||||
nltk.download('stopwords')
|
||||
nltk.download("punkt")
|
||||
nltk.download("stopwords")
|
||||
print("run download")
|
||||
e_str = re.sub(r'[^\w ]', '', query)
|
||||
e_str = re.sub(r"[^\w ]", "", query)
|
||||
all_tokens = nltk.word_tokenize(e_str)
|
||||
stop_words = stopwords.words('english')
|
||||
stop_words = stopwords.words("english")
|
||||
for token in all_tokens:
|
||||
if token not in stop_words:
|
||||
entities.append(token)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
||||
[" ACCUM ".join(entities)]
|
||||
[" ACCUM ".join(entities)],
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
@@ -273,8 +292,7 @@ class OracleVectorFactory(AbstractVectorFactory):
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
|
||||
|
||||
return OracleVector(
|
||||
collection_name=collection_name,
|
||||
|
||||
@@ -31,27 +31,28 @@ class PgvectoRSConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTO_RS_HOST is required")
|
||||
if not values['port']:
|
||||
if not values["port"]:
|
||||
raise ValueError("config PGVECTO_RS_PORT is required")
|
||||
if not values['user']:
|
||||
if not values["user"]:
|
||||
raise ValueError("config PGVECTO_RS_USER is required")
|
||||
if not values['password']:
|
||||
if not values["password"]:
|
||||
raise ValueError("config PGVECTO_RS_PASSWORD is required")
|
||||
if not values['database']:
|
||||
if not values["database"]:
|
||||
raise ValueError("config PGVECTO_RS_DATABASE is required")
|
||||
return values
|
||||
|
||||
|
||||
class PGVectoRS(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
self._url = (
|
||||
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
)
|
||||
self._client = create_engine(self._url)
|
||||
with Session(self._client) as session:
|
||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
|
||||
@@ -80,9 +81,9 @@ class PGVectoRS(BaseVector):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def create_collection(self, dimension: int):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
@@ -133,9 +134,7 @@ class PGVectoRS(BaseVector):
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = None
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(
|
||||
f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; "
|
||||
)
|
||||
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ")
|
||||
result = session.execute(select_statement).fetchall()
|
||||
if result:
|
||||
return [item[0] for item in result]
|
||||
@@ -143,12 +142,11 @@ class PGVectoRS(BaseVector):
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {'ids': ids})
|
||||
session.execute(select_statement, {"ids": ids})
|
||||
session.commit()
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
@@ -156,13 +154,13 @@ class PGVectoRS(BaseVector):
|
||||
select_statement = sql_text(
|
||||
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); "
|
||||
)
|
||||
result = session.execute(select_statement, {'doc_ids': ids}).fetchall()
|
||||
result = session.execute(select_statement, {"doc_ids": ids}).fetchall()
|
||||
if result:
|
||||
ids = [item[0] for item in result]
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {'ids': ids})
|
||||
session.execute(select_statement, {"ids": ids})
|
||||
session.commit()
|
||||
|
||||
def delete(self) -> None:
|
||||
@@ -187,7 +185,7 @@ class PGVectoRS(BaseVector):
|
||||
query_vector,
|
||||
).label("distance"),
|
||||
)
|
||||
.limit(kwargs.get('top_k', 2))
|
||||
.limit(kwargs.get("top_k", 2))
|
||||
.order_by("distance")
|
||||
)
|
||||
res = session.execute(stmt)
|
||||
@@ -198,11 +196,10 @@ class PGVectoRS(BaseVector):
|
||||
for record, dis in results:
|
||||
metadata = record.meta
|
||||
score = 1 - dis
|
||||
metadata['score'] = score
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
metadata["score"] = score
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
if score > score_threshold:
|
||||
doc = Document(page_content=record.text,
|
||||
metadata=metadata)
|
||||
doc = Document(page_content=record.text, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@@ -225,13 +222,12 @@ class PGVectoRS(BaseVector):
|
||||
class PGVectoRSFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
dim = len(embeddings.embed_query("pgvecto_rs"))
|
||||
|
||||
return PGVectoRS(
|
||||
@@ -243,5 +239,5 @@ class PGVectoRSFactory(AbstractVectorFactory):
|
||||
password=dify_config.PGVECTO_RS_PASSWORD,
|
||||
database=dify_config.PGVECTO_RS_DATABASE,
|
||||
),
|
||||
dim=dim
|
||||
dim=dim,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ class PGVectorConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTOR_HOST is required")
|
||||
@@ -201,8 +201,7 @@ class PGVectorFactory(AbstractVectorFactory):
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
|
||||
|
||||
return PGVector(
|
||||
collection_name=collection_name,
|
||||
|
||||
@@ -48,28 +48,25 @@ class QdrantConfig(BaseModel):
|
||||
prefer_grpc: bool = False
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith('path:'):
|
||||
path = self.endpoint.replace('path:', '')
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
path = self.endpoint.replace("path:", "")
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {
|
||||
'path': path
|
||||
}
|
||||
return {"path": path}
|
||||
else:
|
||||
return {
|
||||
'url': self.endpoint,
|
||||
'api_key': self.api_key,
|
||||
'timeout': self.timeout,
|
||||
'verify': self.endpoint.startswith('https'),
|
||||
'grpc_port': self.grpc_port,
|
||||
'prefer_grpc': self.prefer_grpc
|
||||
"url": self.endpoint,
|
||||
"api_key": self.api_key,
|
||||
"timeout": self.timeout,
|
||||
"verify": self.endpoint.startswith("https"),
|
||||
"grpc_port": self.grpc_port,
|
||||
"prefer_grpc": self.prefer_grpc,
|
||||
}
|
||||
|
||||
|
||||
class QdrantVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'):
|
||||
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
|
||||
@@ -80,10 +77,7 @@ class QdrantVector(BaseVector):
|
||||
return VectorType.QDRANT
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if texts:
|
||||
@@ -97,9 +91,9 @@ class QdrantVector(BaseVector):
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(self, collection_name: str, vector_size: int):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
@@ -110,12 +104,19 @@ class QdrantVector(BaseVector):
|
||||
all_collection_name.append(collection.name)
|
||||
if collection_name not in all_collection_name:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
vectors_config = rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[self._distance_func],
|
||||
)
|
||||
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False)
|
||||
hnsw_config = HnswConfigDiff(
|
||||
m=0,
|
||||
payload_m=16,
|
||||
ef_construct=100,
|
||||
full_scan_threshold=10000,
|
||||
max_indexing_threads=0,
|
||||
on_disk=False,
|
||||
)
|
||||
self._client.recreate_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vectors_config,
|
||||
@@ -124,21 +125,24 @@ class QdrantVector(BaseVector):
|
||||
)
|
||||
|
||||
# create group_id payload index
|
||||
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
|
||||
field_schema=PayloadSchemaType.KEYWORD)
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create doc_id payload index
|
||||
self._client.create_payload_index(collection_name, Field.DOC_ID.value,
|
||||
field_schema=PayloadSchemaType.KEYWORD)
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create full text index
|
||||
text_index_params = TextIndexParams(
|
||||
type=TextIndexType.TEXT,
|
||||
tokenizer=TokenizerType.MULTILINGUAL,
|
||||
min_token_len=2,
|
||||
max_token_len=20,
|
||||
lowercase=True
|
||||
lowercase=True,
|
||||
)
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
|
||||
)
|
||||
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
|
||||
field_schema=text_index_params)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@@ -147,26 +151,23 @@ class QdrantVector(BaseVector):
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
added_ids = []
|
||||
for batch_ids, points in self._generate_rest_batches(
|
||||
texts, embeddings, metadatas, uuids, 64, self._group_id
|
||||
):
|
||||
self._client.upsert(
|
||||
collection_name=self._collection_name, points=points
|
||||
)
|
||||
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
|
||||
self._client.upsert(collection_name=self._collection_name, points=points)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
return added_ids
|
||||
|
||||
def _generate_rest_batches(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
texts_iterator = iter(texts)
|
||||
embeddings_iterator = iter(embeddings)
|
||||
metadatas_iterator = iter(metadatas or [])
|
||||
@@ -203,13 +204,13 @@ class QdrantVector(BaseVector):
|
||||
|
||||
@classmethod
|
||||
def _build_payloads(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]],
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
group_payload_key: str
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]],
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
group_payload_key: str,
|
||||
) -> list[dict]:
|
||||
payloads = []
|
||||
for i, text in enumerate(texts):
|
||||
@@ -219,18 +220,11 @@ class QdrantVector(BaseVector):
|
||||
"calling .from_texts or .add_texts on Qdrant instance."
|
||||
)
|
||||
metadata = metadatas[i] if metadatas is not None else None
|
||||
payloads.append(
|
||||
{
|
||||
content_payload_key: text,
|
||||
metadata_payload_key: metadata,
|
||||
group_payload_key: group_id
|
||||
}
|
||||
)
|
||||
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
|
||||
|
||||
return payloads
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
@@ -248,9 +242,7 @@ class QdrantVector(BaseVector):
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
@@ -275,9 +267,7 @@ class QdrantVector(BaseVector):
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
@@ -288,7 +278,6 @@ class QdrantVector(BaseVector):
|
||||
raise e
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
@@ -304,9 +293,7 @@ class QdrantVector(BaseVector):
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
@@ -324,15 +311,13 @@ class QdrantVector(BaseVector):
|
||||
all_collection_name.append(collection.name)
|
||||
if self._collection_name not in all_collection_name:
|
||||
return False
|
||||
response = self._client.retrieve(
|
||||
collection_name=self._collection_name,
|
||||
ids=[id]
|
||||
)
|
||||
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
|
||||
|
||||
return len(response) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from qdrant_client.http import models
|
||||
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
@@ -348,22 +333,22 @@ class QdrantVector(BaseVector):
|
||||
limit=kwargs.get("top_k", 4),
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
score_threshold=kwargs.get("score_threshold", .0)
|
||||
score_threshold=kwargs.get("score_threshold", 0.0),
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
if result.score > score_threshold:
|
||||
metadata['score'] = result.score
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -372,6 +357,7 @@ class QdrantVector(BaseVector):
|
||||
List of documents most similar to the query text and distance for each.
|
||||
"""
|
||||
from qdrant_client.http import models
|
||||
|
||||
scroll_filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
@@ -381,24 +367,21 @@ class QdrantVector(BaseVector):
|
||||
models.FieldCondition(
|
||||
key="page_content",
|
||||
match=models.MatchText(text=query),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=kwargs.get('top_k', 2),
|
||||
limit=kwargs.get("top_k", 2),
|
||||
with_payload=True,
|
||||
with_vectors=True
|
||||
|
||||
with_vectors=True,
|
||||
)
|
||||
results = response[0]
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
document = self._document_from_scored_point(
|
||||
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
|
||||
)
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
|
||||
documents.append(document)
|
||||
|
||||
return documents
|
||||
@@ -410,10 +393,10 @@ class QdrantVector(BaseVector):
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
cls,
|
||||
scored_point: Any,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
cls,
|
||||
scored_point: Any,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
) -> Document:
|
||||
return Document(
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
@@ -425,24 +408,25 @@ class QdrantVector(BaseVector):
|
||||
class QdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
raise ValueError("Dataset Collection Bindings is not exist!")
|
||||
else:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
if not dataset.index_struct_dict:
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return QdrantVector(
|
||||
@@ -454,6 +438,6 @@ class QdrantVectorFactory(AbstractVectorFactory):
|
||||
root_path=config.root_path,
|
||||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
|
||||
)
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -33,28 +33,29 @@ class RelytConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
if not values["host"]:
|
||||
raise ValueError("config RELYT_HOST is required")
|
||||
if not values['port']:
|
||||
if not values["port"]:
|
||||
raise ValueError("config RELYT_PORT is required")
|
||||
if not values['user']:
|
||||
if not values["user"]:
|
||||
raise ValueError("config RELYT_USER is required")
|
||||
if not values['password']:
|
||||
if not values["password"]:
|
||||
raise ValueError("config RELYT_PASSWORD is required")
|
||||
if not values['database']:
|
||||
if not values["database"]:
|
||||
raise ValueError("config RELYT_DATABASE is required")
|
||||
return values
|
||||
|
||||
|
||||
class RelytVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
|
||||
super().__init__(collection_name)
|
||||
self.embedding_dimension = 1536
|
||||
self._client_config = config
|
||||
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
self._url = (
|
||||
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
)
|
||||
self.client = create_engine(self._url)
|
||||
self._fields = []
|
||||
self._group_id = group_id
|
||||
@@ -70,9 +71,9 @@ class RelytVector(BaseVector):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def create_collection(self, dimension: int):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
@@ -110,7 +111,7 @@ class RelytVector(BaseVector):
|
||||
ids = [str(uuid.uuid1()) for _ in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
for metadata in metadatas:
|
||||
metadata['group_id'] = self._group_id
|
||||
metadata["group_id"] = self._group_id
|
||||
texts = [d.page_content for d in documents]
|
||||
|
||||
# Define the table schema
|
||||
@@ -127,9 +128,7 @@ class RelytVector(BaseVector):
|
||||
chunks_table_data = []
|
||||
with self.client.connect() as conn:
|
||||
with conn.begin():
|
||||
for document, metadata, chunk_id, embedding in zip(
|
||||
texts, metadatas, ids, embeddings
|
||||
):
|
||||
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
|
||||
chunks_table_data.append(
|
||||
{
|
||||
"id": chunk_id,
|
||||
@@ -196,15 +195,13 @@ class RelytVector(BaseVector):
|
||||
return False
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self.delete_by_uuids(ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
with Session(self.client) as session:
|
||||
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
|
||||
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
|
||||
)
|
||||
@@ -228,38 +225,34 @@ class RelytVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
results = self.similarity_search_with_score_by_vector(
|
||||
k=int(kwargs.get('top_k')),
|
||||
embedding=query_vector,
|
||||
filter=kwargs.get('filter')
|
||||
k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter")
|
||||
)
|
||||
|
||||
# Organize results.
|
||||
docs = []
|
||||
for document, score in results:
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
if 1 - score > score_threshold:
|
||||
docs.append(document)
|
||||
return docs
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
) -> list[tuple[Document, float]]:
|
||||
# Add the filter if provided
|
||||
try:
|
||||
from sqlalchemy.engine import Row
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Row from sqlalchemy.engine. "
|
||||
"Please 'pip install sqlalchemy>=1.4'."
|
||||
)
|
||||
raise ImportError("Could not import Row from sqlalchemy.engine. " "Please 'pip install sqlalchemy>=1.4'.")
|
||||
|
||||
filter_condition = ""
|
||||
if filter is not None:
|
||||
conditions = [
|
||||
f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1
|
||||
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
|
||||
if len(value) > 1
|
||||
else f"metadata->>{key!r} = {value[0]!r}"
|
||||
for key, value in filter.items()
|
||||
]
|
||||
@@ -305,13 +298,12 @@ class RelytVector(BaseVector):
|
||||
class RelytVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name))
|
||||
|
||||
return RelytVector(
|
||||
collection_name=collection_name,
|
||||
@@ -322,5 +314,5 @@ class RelytVectorFactory(AbstractVectorFactory):
|
||||
password=dify_config.RELYT_PASSWORD,
|
||||
database=dify_config.RELYT_DATABASE,
|
||||
),
|
||||
group_id=dataset.id
|
||||
group_id=dataset.id,
|
||||
)
|
||||
|
||||
@@ -25,16 +25,11 @@ class TencentConfig(BaseModel):
|
||||
database: Optional[str]
|
||||
index_type: str = "HNSW"
|
||||
metric_type: str = "L2"
|
||||
shard: int = 1,
|
||||
replicas: int = 2,
|
||||
shard: int = (1,)
|
||||
replicas: int = (2,)
|
||||
|
||||
def to_tencent_params(self):
|
||||
return {
|
||||
'url': self.url,
|
||||
'username': self.username,
|
||||
'key': self.api_key,
|
||||
'timeout': self.timeout
|
||||
}
|
||||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
||||
|
||||
|
||||
class TencentVector(BaseVector):
|
||||
@@ -61,13 +56,10 @@ class TencentVector(BaseVector):
|
||||
return self._client.create_database(database_name=self._client_config.database)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'tencent'
|
||||
return "tencent"
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
||||
def _has_collection(self) -> bool:
|
||||
collections = self._db.list_collections()
|
||||
@@ -77,9 +69,9 @@ class TencentVector(BaseVector):
|
||||
return False
|
||||
|
||||
def _create_collection(self, dimension: int) -> None:
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
@@ -101,9 +93,7 @@ class TencentVector(BaseVector):
|
||||
raise ValueError("unsupported metric_type")
|
||||
params = vdb_index.HNSWParams(m=16, efconstruction=200)
|
||||
index = vdb_index.Index(
|
||||
vdb_index.FilterIndex(
|
||||
self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
|
||||
),
|
||||
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
||||
vdb_index.VectorIndex(
|
||||
self.field_vector,
|
||||
dimension,
|
||||
@@ -111,12 +101,8 @@ class TencentVector(BaseVector):
|
||||
metric_type,
|
||||
params,
|
||||
),
|
||||
vdb_index.FilterIndex(
|
||||
self.field_text, enum.FieldType.String, enum.IndexType.FILTER
|
||||
),
|
||||
vdb_index.FilterIndex(
|
||||
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
|
||||
),
|
||||
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
)
|
||||
|
||||
self._db.create_collection(
|
||||
@@ -163,15 +149,14 @@ class TencentVector(BaseVector):
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
|
||||
res = self._db.collection(self._collection_name).search(vectors=[query_vector],
|
||||
params=document.HNSWSearchParams(
|
||||
ef=kwargs.get("ef", 10)),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get('top_k', 4),
|
||||
timeout=self._client_config.timeout,
|
||||
)
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
res = self._db.collection(self._collection_name).search(
|
||||
vectors=[query_vector],
|
||||
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
timeout=self._client_config.timeout,
|
||||
)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -200,15 +185,13 @@ class TencentVector(BaseVector):
|
||||
|
||||
class TencentVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector:
|
||||
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
|
||||
|
||||
return TencentVector(
|
||||
collection_name=collection_name,
|
||||
@@ -220,5 +203,5 @@ class TencentVectorFactory(AbstractVectorFactory):
|
||||
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
|
||||
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -28,47 +28,56 @@ class TiDBVectorConfig(BaseModel):
|
||||
database: str
|
||||
program_name: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
if not values["host"]:
|
||||
raise ValueError("config TIDB_VECTOR_HOST is required")
|
||||
if not values['port']:
|
||||
if not values["port"]:
|
||||
raise ValueError("config TIDB_VECTOR_PORT is required")
|
||||
if not values['user']:
|
||||
if not values["user"]:
|
||||
raise ValueError("config TIDB_VECTOR_USER is required")
|
||||
if not values['password']:
|
||||
if not values["password"]:
|
||||
raise ValueError("config TIDB_VECTOR_PASSWORD is required")
|
||||
if not values['database']:
|
||||
if not values["database"]:
|
||||
raise ValueError("config TIDB_VECTOR_DATABASE is required")
|
||||
if not values['program_name']:
|
||||
if not values["program_name"]:
|
||||
raise ValueError("config APPLICATION_NAME is required")
|
||||
return values
|
||||
|
||||
|
||||
class TiDBVector(BaseVector):
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.TIDB_VECTOR
|
||||
|
||||
def _table(self, dim: int) -> Table:
|
||||
from tidb_vector.sqlalchemy import VectorType
|
||||
|
||||
return Table(
|
||||
self._collection_name,
|
||||
self._orm_base.metadata,
|
||||
Column('id', String(36), primary_key=True, nullable=False),
|
||||
Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"),
|
||||
Column("id", String(36), primary_key=True, nullable=False),
|
||||
Column(
|
||||
"vector",
|
||||
VectorType(dim),
|
||||
nullable=False,
|
||||
comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})",
|
||||
),
|
||||
Column("text", TEXT, nullable=False),
|
||||
Column("meta", JSON, nullable=False),
|
||||
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
|
||||
Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
|
||||
extend_existing=True
|
||||
Column(
|
||||
"update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
|
||||
),
|
||||
extend_existing=True,
|
||||
)
|
||||
|
||||
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'):
|
||||
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = "cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
|
||||
f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}")
|
||||
self._url = (
|
||||
f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
|
||||
f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}"
|
||||
)
|
||||
self._distance_func = distance_func.lower()
|
||||
self._engine = create_engine(self._url)
|
||||
self._orm_base = declarative_base()
|
||||
@@ -83,9 +92,9 @@ class TiDBVector(BaseVector):
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
logger.info("_create_collection, collection_name " + self._collection_name)
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
with Session(self._engine) as session:
|
||||
@@ -116,9 +125,7 @@ class TiDBVector(BaseVector):
|
||||
chunks_table_data = []
|
||||
with self._engine.connect() as conn:
|
||||
with conn.begin():
|
||||
for id, text, meta, embedding in zip(
|
||||
ids, texts, metas, embeddings
|
||||
):
|
||||
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
|
||||
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
||||
|
||||
# Execute the batch insert when the batch size is reached
|
||||
@@ -133,12 +140,12 @@ class TiDBVector(BaseVector):
|
||||
return ids
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
result = self.get_ids_by_metadata_field('doc_id', id)
|
||||
result = self.get_ids_by_metadata_field("doc_id", id)
|
||||
return bool(result)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
with Session(self._engine) as session:
|
||||
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
|
||||
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """
|
||||
)
|
||||
@@ -180,20 +187,22 @@ class TiDBVector(BaseVector):
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
filter = kwargs.get('filter')
|
||||
filter = kwargs.get("filter")
|
||||
distance = 1 - score_threshold
|
||||
|
||||
query_vector_str = ", ".join(format(x) for x in query_vector)
|
||||
query_vector_str = "[" + query_vector_str + "]"
|
||||
logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}")
|
||||
logger.debug(
|
||||
f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}"
|
||||
)
|
||||
|
||||
docs = []
|
||||
if self._distance_func == 'l2':
|
||||
tidb_func = 'Vec_l2_distance'
|
||||
elif self._distance_func == 'cosine':
|
||||
tidb_func = 'Vec_Cosine_distance'
|
||||
if self._distance_func == "l2":
|
||||
tidb_func = "Vec_l2_distance"
|
||||
elif self._distance_func == "cosine":
|
||||
tidb_func = "Vec_Cosine_distance"
|
||||
else:
|
||||
tidb_func = 'Vec_Cosine_distance'
|
||||
tidb_func = "Vec_Cosine_distance"
|
||||
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(
|
||||
@@ -208,7 +217,7 @@ class TiDBVector(BaseVector):
|
||||
results = [(row[0], row[1], row[2]) for row in res]
|
||||
for meta, text, distance in results:
|
||||
metadata = json.loads(meta)
|
||||
metadata['score'] = 1 - distance
|
||||
metadata["score"] = 1 - distance
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
@@ -224,15 +233,13 @@ class TiDBVector(BaseVector):
|
||||
|
||||
class TiDBVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
|
||||
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
|
||||
|
||||
return TiDBVector(
|
||||
collection_name=collection_name,
|
||||
|
||||
@@ -7,7 +7,6 @@ from core.rag.models.document import Document
|
||||
|
||||
|
||||
class BaseVector(ABC):
|
||||
|
||||
def __init__(self, collection_name: str):
|
||||
self._collection_name = collection_name
|
||||
|
||||
@@ -39,18 +38,11 @@ class BaseVector(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_vector(
|
||||
self,
|
||||
query_vector: list[float],
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_full_text(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self) -> None:
|
||||
@@ -58,7 +50,7 @@ class BaseVector(ABC):
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts[:]:
|
||||
doc_id = text.metadata['doc_id']
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
@@ -66,7 +58,7 @@ class BaseVector(ABC):
|
||||
return texts
|
||||
|
||||
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||
return [text.metadata['doc_id'] for text in texts]
|
||||
return [text.metadata["doc_id"] for text in texts]
|
||||
|
||||
@property
|
||||
def collection_name(self):
|
||||
|
||||
@@ -20,17 +20,14 @@ class AbstractVectorFactory(ABC):
|
||||
|
||||
@staticmethod
|
||||
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict:
|
||||
index_struct_dict = {
|
||||
"type": vector_type,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
|
||||
return index_struct_dict
|
||||
|
||||
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list = None):
|
||||
if attributes is None:
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash', 'page']
|
||||
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "page"]
|
||||
self._dataset = dataset
|
||||
self._embeddings = self._get_embeddings()
|
||||
self._attributes = attributes
|
||||
@@ -39,7 +36,7 @@ class Vector:
|
||||
def _init_vector(self) -> BaseVector:
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict['type']
|
||||
vector_type = self._dataset.index_struct_dict["type"]
|
||||
|
||||
if not vector_type:
|
||||
raise ValueError("Vector store must be specified.")
|
||||
@@ -52,45 +49,59 @@ class Vector:
|
||||
match vector_type:
|
||||
case VectorType.CHROMA:
|
||||
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
|
||||
|
||||
return ChromaVectorFactory
|
||||
case VectorType.MILVUS:
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
||||
|
||||
return MilvusVectorFactory
|
||||
case VectorType.MYSCALE:
|
||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
|
||||
|
||||
return MyScaleVectorFactory
|
||||
case VectorType.PGVECTOR:
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
|
||||
|
||||
return PGVectorFactory
|
||||
case VectorType.PGVECTO_RS:
|
||||
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
|
||||
|
||||
return PGVectoRSFactory
|
||||
case VectorType.QDRANT:
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
|
||||
|
||||
return QdrantVectorFactory
|
||||
case VectorType.RELYT:
|
||||
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
|
||||
|
||||
return RelytVectorFactory
|
||||
case VectorType.ELASTICSEARCH:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
return ElasticSearchVectorFactory
|
||||
case VectorType.TIDB_VECTOR:
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||
|
||||
return TiDBVectorFactory
|
||||
case VectorType.WEAVIATE:
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
|
||||
|
||||
return WeaviateVectorFactory
|
||||
case VectorType.TENCENT:
|
||||
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
|
||||
|
||||
return TencentVectorFactory
|
||||
case VectorType.ORACLE:
|
||||
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
|
||||
|
||||
return OracleVectorFactory
|
||||
case VectorType.OPENSEARCH:
|
||||
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
|
||||
|
||||
return OpenSearchVectorFactory
|
||||
case VectorType.ANALYTICDB:
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
|
||||
|
||||
return AnalyticdbVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
@@ -98,22 +109,14 @@ class Vector:
|
||||
def create(self, texts: list = None, **kwargs):
|
||||
if texts:
|
||||
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
|
||||
self._vector_processor.create(
|
||||
texts=texts,
|
||||
embeddings=embeddings,
|
||||
**kwargs
|
||||
)
|
||||
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
|
||||
|
||||
def add_texts(self, documents: list[Document], **kwargs):
|
||||
if kwargs.get('duplicate_check', False):
|
||||
if kwargs.get("duplicate_check", False):
|
||||
documents = self._filter_duplicate_texts(documents)
|
||||
|
||||
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
|
||||
self._vector_processor.create(
|
||||
texts=documents,
|
||||
embeddings=embeddings,
|
||||
**kwargs
|
||||
)
|
||||
self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self._vector_processor.text_exists(id)
|
||||
@@ -124,24 +127,18 @@ class Vector:
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._vector_processor.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_vector = self._embeddings.embed_query(query)
|
||||
return self._vector_processor.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_full_text(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._vector_processor.search_by_full_text(query, **kwargs)
|
||||
|
||||
def delete(self) -> None:
|
||||
self._vector_processor.delete()
|
||||
# delete collection redis cache
|
||||
if self._vector_processor.collection_name:
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._vector_processor.collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name)
|
||||
redis_client.delete(collection_exist_cache_key)
|
||||
|
||||
def _get_embeddings(self) -> Embeddings:
|
||||
@@ -151,14 +148,13 @@ class Vector:
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=self._dataset.embedding_model
|
||||
|
||||
model=self._dataset.embedding_model,
|
||||
)
|
||||
return CacheEmbedding(embedding_model)
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts[:]:
|
||||
doc_id = text.metadata['doc_id']
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
|
||||
@@ -2,17 +2,17 @@ from enum import Enum
|
||||
|
||||
|
||||
class VectorType(str, Enum):
|
||||
ANALYTICDB = 'analyticdb'
|
||||
CHROMA = 'chroma'
|
||||
MILVUS = 'milvus'
|
||||
MYSCALE = 'myscale'
|
||||
PGVECTOR = 'pgvector'
|
||||
PGVECTO_RS = 'pgvecto-rs'
|
||||
QDRANT = 'qdrant'
|
||||
RELYT = 'relyt'
|
||||
TIDB_VECTOR = 'tidb_vector'
|
||||
WEAVIATE = 'weaviate'
|
||||
OPENSEARCH = 'opensearch'
|
||||
TENCENT = 'tencent'
|
||||
ORACLE = 'oracle'
|
||||
ELASTICSEARCH = 'elasticsearch'
|
||||
ANALYTICDB = "analyticdb"
|
||||
CHROMA = "chroma"
|
||||
MILVUS = "milvus"
|
||||
MYSCALE = "myscale"
|
||||
PGVECTOR = "pgvector"
|
||||
PGVECTO_RS = "pgvecto-rs"
|
||||
QDRANT = "qdrant"
|
||||
RELYT = "relyt"
|
||||
TIDB_VECTOR = "tidb_vector"
|
||||
WEAVIATE = "weaviate"
|
||||
OPENSEARCH = "opensearch"
|
||||
TENCENT = "tencent"
|
||||
ORACLE = "oracle"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
|
||||
@@ -22,15 +22,14 @@ class WeaviateConfig(BaseModel):
|
||||
api_key: Optional[str] = None
|
||||
batch_size: int = 100
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
if not values["endpoint"]:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
return values
|
||||
|
||||
|
||||
class WeaviateVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
|
||||
super().__init__(collection_name)
|
||||
self._client = self._init_client(config)
|
||||
@@ -43,10 +42,7 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
try:
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint,
|
||||
auth_client_secret=auth_config,
|
||||
timeout_config=(5, 60),
|
||||
startup_period=None
|
||||
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
@@ -68,10 +64,10 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
def get_collection_name(self, dataset: Dataset) -> str:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
if not class_prefix.endswith("_Node"):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
class_prefix += "_Node"
|
||||
|
||||
return class_prefix
|
||||
|
||||
@@ -79,10 +75,7 @@ class WeaviateVector(BaseVector):
|
||||
return Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
# create collection
|
||||
@@ -91,9 +84,9 @@ class WeaviateVector(BaseVector):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def _create_collection(self):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
schema = self._default_schema(self._collection_name)
|
||||
@@ -129,17 +122,9 @@ class WeaviateVector(BaseVector):
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
where_filter = {
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
}
|
||||
where_filter = {"operator": "Equal", "path": [key], "valueText": value}
|
||||
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal")
|
||||
|
||||
def delete(self):
|
||||
# check whether the index already exists
|
||||
@@ -154,11 +139,19 @@ class WeaviateVector(BaseVector):
|
||||
# check whether the index already exists
|
||||
if not self._client.schema.contains(schema):
|
||||
return False
|
||||
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": id,
|
||||
}).with_limit(1).do()
|
||||
result = (
|
||||
self._client.query.get(collection_name)
|
||||
.with_additional(["id"])
|
||||
.with_where(
|
||||
{
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": id,
|
||||
}
|
||||
)
|
||||
.with_limit(1)
|
||||
.do()
|
||||
)
|
||||
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
@@ -211,13 +204,13 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
doc.metadata['score'] = score
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -240,15 +233,15 @@ class WeaviateVector(BaseVector):
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
query_obj = query_obj.with_additional(["vector"])
|
||||
properties = ['text']
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
|
||||
properties = ["text"]
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][collection_name]:
|
||||
text = res.pop(Field.TEXT_KEY.value)
|
||||
additional = res.pop('_additional')
|
||||
docs.append(Document(page_content=text, vector=additional['vector'], metadata=res))
|
||||
additional = res.pop("_additional")
|
||||
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
|
||||
return docs
|
||||
|
||||
def _default_schema(self, index_name: str) -> dict:
|
||||
@@ -271,20 +264,19 @@ class WeaviateVector(BaseVector):
|
||||
class WeaviateVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
|
||||
return WeaviateVector(
|
||||
collection_name=collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint=dify_config.WEAVIATE_ENDPOINT,
|
||||
api_key=dify_config.WEAVIATE_API_KEY,
|
||||
batch_size=dify_config.WEAVIATE_BATCH_SIZE
|
||||
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
|
||||
),
|
||||
attributes=attributes
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user