chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -29,6 +29,7 @@ class AnalyticdbConfig(BaseModel):
namespace_password: str = (None,)
metrics: str = ("cosine",)
read_timeout: int = 60000
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
@@ -37,6 +38,7 @@ class AnalyticdbConfig(BaseModel):
"read_timeout": self.read_timeout,
}
class AnalyticdbVector(BaseVector):
_instance = None
_init = False
@@ -57,9 +59,7 @@ class AnalyticdbVector(BaseVector):
except:
raise ImportError(_import_err_msg)
self.config = config
self._client_config = open_api_models.Config(
user_agent="dify", **config.to_analyticdb_client_params()
)
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
AnalyticdbVector._init = True
@@ -77,6 +77,7 @@ class AnalyticdbVector(BaseVector):
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@@ -88,6 +89,7 @@ class AnalyticdbVector(BaseVector):
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
@@ -109,13 +111,12 @@ class AnalyticdbVector(BaseVector):
)
self._client.create_namespace(request)
else:
raise ValueError(
f"failed to create namespace {self.config.namespace}: {e}"
)
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
@@ -149,9 +150,7 @@ class AnalyticdbVector(BaseVector):
)
self._client.create_collection(request)
else:
raise ValueError(
f"failed to create collection {self._collection_name}: {e}"
)
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def get_type(self) -> str:
@@ -162,10 +161,9 @@ class AnalyticdbVector(BaseVector):
self._create_collection_if_not_exists(dimension)
self.add_texts(texts, embeddings)
def add_texts(
self, documents: list[Document], embeddings: list[list[float]], **kwargs
):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
@@ -191,6 +189,7 @@ class AnalyticdbVector(BaseVector):
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@@ -202,13 +201,14 @@ class AnalyticdbVector(BaseVector):
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'"
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
@@ -224,6 +224,7 @@ class AnalyticdbVector(BaseVector):
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@@ -235,15 +236,10 @@ class AnalyticdbVector(BaseVector):
)
self._client.delete_collection_data(request)
def search_by_vector(
self, query_vector: list[float], **kwargs: Any
) -> list[Document]:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = (
kwargs.get("score_threshold", 0.0)
if kwargs.get("score_threshold", 0.0)
else 0.0
)
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@@ -270,11 +266,8 @@ class AnalyticdbVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = (
kwargs.get("score_threshold", 0.0)
if kwargs.get("score_threshold", 0.0)
else 0.0
)
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@@ -304,6 +297,7 @@ class AnalyticdbVector(BaseVector):
def delete(self) -> None:
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
@@ -315,19 +309,16 @@ class AnalyticdbVector(BaseVector):
except Exception as e:
raise e
class AnalyticdbVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"][
"class_prefix"
]
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
# handle optional params
if dify_config.ANALYTICDB_KEY_ID is None:

View File

@@ -27,21 +27,20 @@ class ChromaConfig(BaseModel):
settings = Settings(
# auth
chroma_client_auth_provider=self.auth_provider,
chroma_client_auth_credentials=self.auth_credentials
chroma_client_auth_credentials=self.auth_credentials,
)
return {
'host': self.host,
'port': self.port,
'ssl': False,
'tenant': self.tenant,
'database': self.database,
'settings': settings,
"host": self.host,
"port": self.port,
"ssl": False,
"tenant": self.tenant,
"database": self.database,
"settings": settings,
}
class ChromaVector(BaseVector):
def __init__(self, collection_name: str, config: ChromaConfig):
super().__init__(collection_name)
self._client_config = config
@@ -58,9 +57,9 @@ class ChromaVector(BaseVector):
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str):
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
lock_name = "vector_indexing_lock_{}".format(collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
self._client.get_or_create_collection(collection_name)
@@ -76,7 +75,7 @@ class ChromaVector(BaseVector):
def delete_by_metadata_field(self, key: str, value: str):
collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(where={key: {'$eq': value}})
collection.delete(where={key: {"$eq": value}})
def delete(self):
self._client.delete_collection(self._collection_name)
@@ -93,26 +92,26 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
ids: list[str] = results['ids'][0]
documents: list[str] = results['documents'][0]
metadatas: dict[str, Any] = results['metadatas'][0]
distances: list[float] = results['distances'][0]
ids: list[str] = results["ids"][0]
documents: list[str] = results["documents"][0]
metadatas: dict[str, Any] = results["metadatas"][0]
distances: list[float] = results["distances"][0]
docs = []
for index in range(len(ids)):
distance = distances[index]
metadata = metadatas[index]
if distance >= score_threshold:
metadata['score'] = distance
metadata["score"] = distance
doc = Document(
page_content=documents[index],
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -123,15 +122,12 @@ class ChromaVector(BaseVector):
class ChromaVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict = {
"type": VectorType.CHROMA,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
return ChromaVector(

View File

@@ -26,15 +26,15 @@ class ElasticSearchConfig(BaseModel):
username: str
password: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config PORT is required")
if not values['username']:
if not values["username"]:
raise ValueError("config USERNAME is required")
if not values['password']:
if not values["password"]:
raise ValueError("config PASSWORD is required")
return values
@@ -50,10 +50,10 @@ class ElasticSearchVector(BaseVector):
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try:
parsed_url = urlparse(config.host)
if parsed_url.scheme in ['http', 'https']:
hosts = f'{config.host}:{config.port}'
if parsed_url.scheme in ["http", "https"]:
hosts = f"{config.host}:{config.port}"
else:
hosts = f'http://{config.host}:{config.port}'
hosts = f"http://{config.host}:{config.port}"
client = Elasticsearch(
hosts=hosts,
basic_auth=(config.username, config.password),
@@ -68,25 +68,27 @@ class ElasticSearchVector(BaseVector):
def _get_version(self) -> str:
info = self._client.info()
return info['version']['number']
return info["version"]["number"]
def _check_version(self):
if self._version < '8.0.0':
if self._version < "8.0.0":
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
def get_type(self) -> str:
return 'elasticsearch'
return "elasticsearch"
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
for i in range(len(documents)):
self._client.index(index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
})
self._client.index(
index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {},
},
)
self._client.indices.refresh(index=self._collection_name)
return uuids
@@ -98,15 +100,9 @@ class ElasticSearchVector(BaseVector):
self._client.delete(index=self._collection_name, id=id)
def delete_by_metadata_field(self, key: str, value: str) -> None:
query_str = {
'query': {
'match': {
f'metadata.{key}': f'{value}'
}
}
}
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit['_id'] for hit in results['hits']['hits']]
ids = [hit["_id"] for hit in results["hits"]["hits"]]
if ids:
self.delete_by_ids(ids)
@@ -115,44 +111,44 @@ class ElasticSearchVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 10)
knn = {
"field": Field.VECTOR.value,
"query_vector": query_vector,
"k": top_k
}
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k}
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
docs_and_scores = []
for hit in results['hits']['hits']:
for hit in results["hits"]["hits"]:
docs_and_scores.append(
(Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
),
hit["_score"],
)
)
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
if score > score_threshold:
doc.metadata['score'] = score
doc.metadata["score"] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {
"match": {
Field.CONTENT_KEY.value: query
}
}
query_str = {"match": {Field.CONTENT_KEY.value: query}}
results = self._client.search(index=self._collection_name, query=query_str)
docs = []
for hit in results['hits']['hits']:
docs.append(Document(
page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value],
))
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
)
)
return docs
@@ -162,11 +158,11 @@ class ElasticSearchVector(BaseVector):
self.add_texts(texts, embeddings, **kwargs)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = f'vector_indexing_lock_{self._collection_name}'
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
@@ -179,14 +175,14 @@ class ElasticSearchVector(BaseVector):
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"similarity": "cosine"
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
}
}
},
},
}
}
self._client.indices.create(index=self._collection_name, mappings=mappings)
@@ -197,22 +193,21 @@ class ElasticSearchVector(BaseVector):
class ElasticSearchVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get('ELASTICSEARCH_HOST'),
port=config.get('ELASTICSEARCH_PORT'),
username=config.get('ELASTICSEARCH_USERNAME'),
password=config.get('ELASTICSEARCH_PASSWORD'),
host=config.get("ELASTICSEARCH_HOST"),
port=config.get("ELASTICSEARCH_PORT"),
username=config.get("ELASTICSEARCH_USERNAME"),
password=config.get("ELASTICSEARCH_PASSWORD"),
),
attributes=[]
attributes=[],
)

View File

@@ -27,44 +27,39 @@ class MilvusConfig(BaseModel):
batch_size: int = 100
database: str = "default"
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values.get('uri'):
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get('user'):
if not values.get("user"):
raise ValueError("config MILVUS_USER is required")
if not values.get('password'):
if not values.get("password"):
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'uri': self.uri,
'token': self.token,
'user': self.user,
'password': self.password,
'db_name': self.database,
"uri": self.uri,
"token": self.token,
"user": self.user,
"password": self.password,
"db_name": self.database,
}
class MilvusVector(BaseVector):
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = 'Session'
self._consistency_level = "Session"
self._fields = []
def get_type(self) -> str:
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
@@ -75,7 +70,7 @@ class MilvusVector(BaseVector):
insert_dict = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata
Field.METADATA_KEY.value: documents[i].metadata,
}
insert_dict_list.append(insert_dict)
# Total insert count
@@ -84,22 +79,20 @@ class MilvusVector(BaseVector):
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i:i + 1000]
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection.
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count)
raise e
return pks
def get_ids_by_metadata_field(self, key: str, value: str):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["{key}"] == "{value}"',
output_fields=["id"])
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
@@ -107,17 +100,15 @@ class MilvusVector(BaseVector):
def delete_by_metadata_field(self, key: str, value: str):
if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]) -> None:
if self._client.has_collection(self._collection_name):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {ids}',
output_fields=["id"])
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
)
if result:
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
@@ -130,29 +121,28 @@ class MilvusVector(BaseVector):
if not self._client.has_collection(self._collection_name):
return False
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] == "{id}"',
output_fields=["id"])
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"]
)
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
results = self._client.search(collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get('top_k', 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result['entity'].get(Field.METADATA_KEY.value)
metadata['score'] = result['distance']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if result['distance'] > score_threshold:
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
metadata=metadata)
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
@@ -161,11 +151,11 @@ class MilvusVector(BaseVector):
return []
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
# Grab the existing collection if it exists
@@ -180,19 +170,11 @@ class MilvusVector(BaseVector):
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
)
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
# Create the primary key field
fields.append(
FieldSchema(
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
)
)
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
)
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create the schema for the collection
schema = CollectionSchema(fields)
@@ -208,9 +190,12 @@ class MilvusVector(BaseVector):
# Create the collection
collection_name = self._collection_name
self._client.create_collection(collection_name=collection_name,
schema=schema, index_params=index_params_obj,
consistency_level=self._consistency_level)
self._client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params_obj,
consistency_level=self._consistency_level,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
@@ -221,13 +206,12 @@ class MilvusVector(BaseVector):
class MilvusVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
return MilvusVector(
collection_name=collection_name,
@@ -237,5 +221,5 @@ class MilvusVectorFactory(AbstractVectorFactory):
user=dify_config.MILVUS_USER,
password=dify_config.MILVUS_PASSWORD,
database=dify_config.MILVUS_DATABASE,
)
),
)

View File

@@ -31,7 +31,6 @@ class SortOrder(Enum):
class MyScaleVector(BaseVector):
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
super().__init__(collection_name)
self._config = config
@@ -80,7 +79,7 @@ class MyScaleVector(BaseVector):
doc_id,
self.escape_str(doc.page_content),
embeddings[i],
json.dumps(doc.metadata) if doc.metadata else {}
json.dumps(doc.metadata) if doc.metadata else {},
)
values.append(str(row))
ids.append(doc_id)
@@ -101,7 +100,8 @@ class MyScaleVector(BaseVector):
def delete_by_ids(self, ids: list[str]) -> None:
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}")
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
)
def get_ids_by_metadata_field(self, key: str, value: str):
rows = self._client.query(
@@ -122,9 +122,12 @@ class MyScaleVector(BaseVector):
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get('score_threshold') or 0.0
where_str = f"WHERE dist < {1 - score_threshold}" if \
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
score_threshold = kwargs.get("score_threshold") or 0.0
where_str = (
f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
else ""
)
sql = f"""
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
@@ -133,7 +136,7 @@ class MyScaleVector(BaseVector):
return [
Document(
page_content=r["text"],
vector=r['vector'],
vector=r["vector"],
metadata=r["metadata"],
)
for r in self._client.query(sql).named_results()
@@ -149,13 +152,12 @@ class MyScaleVector(BaseVector):
class MyScaleVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
return MyScaleVector(
collection_name=collection_name,

View File

@@ -28,11 +28,11 @@ class OpenSearchConfig(BaseModel):
password: Optional[str] = None
secure: bool = False
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values.get('host'):
if not values.get("host"):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get('port'):
if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required")
return values
@@ -44,19 +44,18 @@ class OpenSearchConfig(BaseModel):
def to_opensearch_params(self) -> dict[str, Any]:
params = {
'hosts': [{'host': self.host, 'port': self.port}],
'use_ssl': self.secure,
'verify_certs': self.secure,
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.secure,
}
if self.user and self.password:
params['http_auth'] = (self.user, self.password)
params["http_auth"] = (self.user, self.password)
if self.secure:
params['ssl_context'] = self.create_ssl_context()
params["ssl_context"] = self.create_ssl_context()
return params
class OpenSearchVector(BaseVector):
def __init__(self, collection_name: str, config: OpenSearchConfig):
super().__init__(collection_name)
self._client_config = config
@@ -81,7 +80,7 @@ class OpenSearchVector(BaseVector):
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
},
}
actions.append(action)
@@ -90,8 +89,8 @@ class OpenSearchVector(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
response = self._client.search(index=self._collection_name.lower(), body=query)
if response['hits']['hits']:
return [hit['_id'] for hit in response['hits']['hits']]
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
else:
return None
@@ -110,7 +109,7 @@ class OpenSearchVector(BaseVector):
actual_ids = []
for doc_id in ids:
es_ids = self.get_ids_by_metadata_field('doc_id', doc_id)
es_ids = self.get_ids_by_metadata_field("doc_id", doc_id)
if es_ids:
actual_ids.extend(es_ids)
else:
@@ -122,9 +121,9 @@ class OpenSearchVector(BaseVector):
helpers.bulk(self._client, actions)
except BulkIndexError as e:
for error in e.errors:
delete_error = error.get('delete', {})
status = delete_error.get('status')
doc_id = delete_error.get('_id')
delete_error = error.get("delete", {})
status = delete_error.get("status")
doc_id = delete_error.get("_id")
if status == 404:
logger.warning(f"Document not found for deletion: {doc_id}")
@@ -151,15 +150,8 @@ class OpenSearchVector(BaseVector):
raise ValueError("All elements in query_vector should be floats")
query = {
"size": kwargs.get('top_k', 4),
"query": {
"knn": {
Field.VECTOR.value: {
Field.VECTOR.value: query_vector,
"k": kwargs.get('top_k', 4)
}
}
}
"size": kwargs.get("top_k", 4),
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
}
try:
@@ -169,17 +161,17 @@ class OpenSearchVector(BaseVector):
raise
docs = []
for hit in response['hits']['hits']:
metadata = hit['_source'].get(Field.METADATA_KEY.value, {})
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value, {})
# Make sure metadata is a dictionary
if metadata is None:
metadata = {}
metadata['score'] = hit['_score']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if hit['_score'] > score_threshold:
doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
metadata["score"] = hit["_score"]
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
if hit["_score"] > score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
@@ -190,32 +182,28 @@ class OpenSearchVector(BaseVector):
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
docs = []
for hit in response['hits']['hits']:
metadata = hit['_source'].get(Field.METADATA_KEY.value)
vector = hit['_source'].get(Field.VECTOR.value)
page_content = hit['_source'].get(Field.CONTENT_KEY.value)
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value)
vector = hit["_source"].get(Field.VECTOR.value)
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
return docs
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = f'vector_indexing_lock_{self._collection_name.lower()}'
lock_name = f"vector_indexing_lock_{self._collection_name.lower()}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}'
collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name.lower()} already exists.")
return
if not self._client.indices.exists(index=self._collection_name.lower()):
index_body = {
"settings": {
"index": {
"knn": True
}
},
"settings": {"index": {"knn": True}},
"mappings": {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
@@ -226,20 +214,17 @@ class OpenSearchVector(BaseVector):
"name": "hnsw",
"space_type": "l2",
"engine": "faiss",
"parameters": {
"ef_construction": 64,
"m": 8
}
}
"parameters": {"ef_construction": 64, "m": 8},
},
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
}
}
},
},
}
}
},
}
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
@@ -248,17 +233,14 @@ class OpenSearchVector(BaseVector):
class OpenSearchVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST,
@@ -268,7 +250,4 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
secure=dify_config.OPENSEARCH_SECURE,
)
return OpenSearchVector(
collection_name=collection_name,
config=open_search_config
)
return OpenSearchVector(collection_name=collection_name, config=open_search_config)

View File

@@ -31,7 +31,7 @@ class OracleVectorConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config ORACLE_HOST is required")
@@ -103,9 +103,16 @@ class OracleVector(BaseVector):
arraysize=cursor.arraysize,
outconverter=self.numpy_converter_out,
)
def _create_connection_pool(self, config: OracleVectorConfig):
return oracledb.create_pool(user=config.user, password=config.password, dsn="{}:{}/{}".format(config.host, config.port, config.database), min=1, max=50, increment=1)
def _create_connection_pool(self, config: OracleVectorConfig):
return oracledb.create_pool(
user=config.user,
password=config.password,
dsn="{}:{}/{}".format(config.host, config.port, config.database),
min=1,
max=50,
increment=1,
)
@contextmanager
def _get_cursor(self):
@@ -136,13 +143,15 @@ class OracleVector(BaseVector):
doc_id,
doc.page_content,
json.dumps(doc.metadata),
#array.array("f", embeddings[i]),
# array.array("f", embeddings[i]),
numpy.array(embeddings[i]),
)
)
#print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
with self._get_cursor() as cur:
cur.executemany(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values)
cur.executemany(
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
)
return pks
def text_exists(self, id: str) -> bool:
@@ -157,7 +166,8 @@ class OracleVector(BaseVector):
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs
#def get_ids_by_metadata_field(self, key: str, value: str):
# def get_ids_by_metadata_field(self, key: str, value: str):
# with self._get_cursor() as cur:
# cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" )
# idss = []
@@ -184,7 +194,8 @@ class OracleVector(BaseVector):
top_k = kwargs.get("top_k", 5)
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only" ,[numpy.array(query_vector)]
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only",
[numpy.array(query_vector)],
)
docs = []
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
@@ -202,7 +213,7 @@ class OracleVector(BaseVector):
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
if len(query) > 0:
# Check which language the query is in
zh_pattern = re.compile('[\u4e00-\u9fa5]+')
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
match = zh_pattern.search(query)
entities = []
# match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split
@@ -210,7 +221,15 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v': # nr: 人名, ns: 地名, nt: 机构名
if (
pos == "nr"
or pos == "Ng"
or pos == "eng"
or pos == "nz"
or pos == "n"
or pos == "ORG"
or pos == "v"
): # nr: 人名, ns: 地名, nt: 机构名
current_entity += word
else:
if current_entity:
@@ -220,22 +239,22 @@ class OracleVector(BaseVector):
entities.append(current_entity)
else:
try:
nltk.data.find('tokenizers/punkt')
nltk.data.find('corpora/stopwords')
nltk.data.find("tokenizers/punkt")
nltk.data.find("corpora/stopwords")
except LookupError:
nltk.download('punkt')
nltk.download('stopwords')
nltk.download("punkt")
nltk.download("stopwords")
print("run download")
e_str = re.sub(r'[^\w ]', '', query)
e_str = re.sub(r"[^\w ]", "", query)
all_tokens = nltk.word_tokenize(e_str)
stop_words = stopwords.words('english')
stop_words = stopwords.words("english")
for token in all_tokens:
if token not in stop_words:
entities.append(token)
with self._get_cursor() as cur:
cur.execute(
f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
[" ACCUM ".join(entities)]
[" ACCUM ".join(entities)],
)
docs = []
for record in cur:
@@ -273,8 +292,7 @@ class OracleVectorFactory(AbstractVectorFactory):
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
return OracleVector(
collection_name=collection_name,

View File

@@ -31,27 +31,28 @@ class PgvectoRSConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config PGVECTO_RS_HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config PGVECTO_RS_PORT is required")
if not values['user']:
if not values["user"]:
raise ValueError("config PGVECTO_RS_USER is required")
if not values['password']:
if not values["password"]:
raise ValueError("config PGVECTO_RS_PASSWORD is required")
if not values['database']:
if not values["database"]:
raise ValueError("config PGVECTO_RS_DATABASE is required")
return values
class PGVectoRS(BaseVector):
def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int):
super().__init__(collection_name)
self._client_config = config
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
self._url = (
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self._client = create_engine(self._url)
with Session(self._client) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
@@ -80,9 +81,9 @@ class PGVectoRS(BaseVector):
self.add_texts(texts, embeddings)
def create_collection(self, dimension: int):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
@@ -133,9 +134,7 @@ class PGVectoRS(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
result = None
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; "
)
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ")
result = session.execute(select_statement).fetchall()
if result:
return [item[0] for item in result]
@@ -143,12 +142,11 @@ class PGVectoRS(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {'ids': ids})
session.execute(select_statement, {"ids": ids})
session.commit()
def delete_by_ids(self, ids: list[str]) -> None:
@@ -156,13 +154,13 @@ class PGVectoRS(BaseVector):
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); "
)
result = session.execute(select_statement, {'doc_ids': ids}).fetchall()
result = session.execute(select_statement, {"doc_ids": ids}).fetchall()
if result:
ids = [item[0] for item in result]
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {'ids': ids})
session.execute(select_statement, {"ids": ids})
session.commit()
def delete(self) -> None:
@@ -187,7 +185,7 @@ class PGVectoRS(BaseVector):
query_vector,
).label("distance"),
)
.limit(kwargs.get('top_k', 2))
.limit(kwargs.get("top_k", 2))
.order_by("distance")
)
res = session.execute(stmt)
@@ -198,11 +196,10 @@ class PGVectoRS(BaseVector):
for record, dis in results:
metadata = record.meta
score = 1 - dis
metadata['score'] = score
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
metadata["score"] = score
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
if score > score_threshold:
doc = Document(page_content=record.text,
metadata=metadata)
doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc)
return docs
@@ -225,13 +222,12 @@ class PGVectoRS(BaseVector):
class PGVectoRSFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dim = len(embeddings.embed_query("pgvecto_rs"))
return PGVectoRS(
@@ -243,5 +239,5 @@ class PGVectoRSFactory(AbstractVectorFactory):
password=dify_config.PGVECTO_RS_PASSWORD,
database=dify_config.PGVECTO_RS_DATABASE,
),
dim=dim
dim=dim,
)

View File

@@ -24,7 +24,7 @@ class PGVectorConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config PGVECTOR_HOST is required")
@@ -201,8 +201,7 @@ class PGVectorFactory(AbstractVectorFactory):
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
return PGVector(
collection_name=collection_name,

View File

@@ -48,28 +48,25 @@ class QdrantConfig(BaseModel):
prefer_grpc: bool = False
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
return {"path": path}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout,
'verify': self.endpoint.startswith('https'),
'grpc_port': self.grpc_port,
'prefer_grpc': self.prefer_grpc
"url": self.endpoint,
"api_key": self.api_key,
"timeout": self.timeout,
"verify": self.endpoint.startswith("https"),
"grpc_port": self.grpc_port,
"prefer_grpc": self.prefer_grpc,
}
class QdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'):
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"):
super().__init__(collection_name)
self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
@@ -80,10 +77,7 @@ class QdrantVector(BaseVector):
return VectorType.QDRANT
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
@@ -97,9 +91,9 @@ class QdrantVector(BaseVector):
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
lock_name = "vector_indexing_lock_{}".format(collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
collection_name = collection_name or uuid.uuid4().hex
@@ -110,12 +104,19 @@ class QdrantVector(BaseVector):
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False)
hnsw_config = HnswConfigDiff(
m=0,
payload_m=16,
ef_construct=100,
full_scan_threshold=10000,
max_indexing_threads=0,
on_disk=False,
)
self._client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
@@ -124,21 +125,24 @@ class QdrantVector(BaseVector):
)
# create group_id payload index
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
field_schema=PayloadSchemaType.KEYWORD)
self._client.create_payload_index(
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(collection_name, Field.DOC_ID.value,
field_schema=PayloadSchemaType.KEYWORD)
self._client.create_payload_index(
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
lowercase=True,
)
self._client.create_payload_index(
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@@ -147,26 +151,23 @@ class QdrantVector(BaseVector):
metadatas = [d.metadata for d in documents]
added_ids = []
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, metadatas, uuids, 64, self._group_id
):
self._client.upsert(
collection_name=self._collection_name, points=points
)
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
embeddings_iterator = iter(embeddings)
metadatas_iterator = iter(metadatas or [])
@@ -203,13 +204,13 @@ class QdrantVector(BaseVector):
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str,
) -> list[dict]:
payloads = []
for i, text in enumerate(texts):
@@ -219,18 +220,11 @@ class QdrantVector(BaseVector):
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append(
{
content_payload_key: text,
metadata_payload_key: metadata,
group_payload_key: group_id
}
)
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
return payloads
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
@@ -248,9 +242,7 @@ class QdrantVector(BaseVector):
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
@@ -275,9 +267,7 @@ class QdrantVector(BaseVector):
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
@@ -288,7 +278,6 @@ class QdrantVector(BaseVector):
raise e
def delete_by_ids(self, ids: list[str]) -> None:
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
@@ -304,9 +293,7 @@ class QdrantVector(BaseVector):
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
@@ -324,15 +311,13 @@ class QdrantVector(BaseVector):
all_collection_name.append(collection.name)
if self._collection_name not in all_collection_name:
return False
response = self._client.retrieve(
collection_name=self._collection_name,
ids=[id]
)
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
return len(response) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
@@ -348,22 +333,22 @@ class QdrantVector(BaseVector):
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=kwargs.get("score_threshold", .0)
score_threshold=kwargs.get("score_threshold", 0.0),
)
docs = []
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
if result.score > score_threshold:
metadata['score'] = result.score
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -372,6 +357,7 @@ class QdrantVector(BaseVector):
List of documents most similar to the query text and distance for each.
"""
from qdrant_client.http import models
scroll_filter = models.Filter(
must=[
models.FieldCondition(
@@ -381,24 +367,21 @@ class QdrantVector(BaseVector):
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
),
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=kwargs.get('top_k', 2),
limit=kwargs.get("top_k", 2),
with_payload=True,
with_vectors=True
with_vectors=True,
)
results = response[0]
documents = []
for result in results:
if result:
document = self._document_from_scored_point(
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
)
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
documents.append(document)
return documents
@@ -410,10 +393,10 @@ class QdrantVector(BaseVector):
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
@@ -425,24 +408,25 @@ class QdrantVector(BaseVector):
class QdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
raise ValueError("Dataset Collection Bindings is not exist!")
else:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
if not dataset.index_struct_dict:
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
config = current_app.config
return QdrantVector(
@@ -454,6 +438,6 @@ class QdrantVectorFactory(AbstractVectorFactory):
root_path=config.root_path,
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
)
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
),
)

View File

@@ -33,28 +33,29 @@ class RelytConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config RELYT_HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config RELYT_PORT is required")
if not values['user']:
if not values["user"]:
raise ValueError("config RELYT_USER is required")
if not values['password']:
if not values["password"]:
raise ValueError("config RELYT_PASSWORD is required")
if not values['database']:
if not values["database"]:
raise ValueError("config RELYT_DATABASE is required")
return values
class RelytVector(BaseVector):
def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
super().__init__(collection_name)
self.embedding_dimension = 1536
self._client_config = config
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
self._url = (
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self.client = create_engine(self._url)
self._fields = []
self._group_id = group_id
@@ -70,9 +71,9 @@ class RelytVector(BaseVector):
self.add_texts(texts, embeddings)
def create_collection(self, dimension: int):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
@@ -110,7 +111,7 @@ class RelytVector(BaseVector):
ids = [str(uuid.uuid1()) for _ in documents]
metadatas = [d.metadata for d in documents]
for metadata in metadatas:
metadata['group_id'] = self._group_id
metadata["group_id"] = self._group_id
texts = [d.page_content for d in documents]
# Define the table schema
@@ -127,9 +128,7 @@ class RelytVector(BaseVector):
chunks_table_data = []
with self.client.connect() as conn:
with conn.begin():
for document, metadata, chunk_id, embedding in zip(
texts, metadatas, ids, embeddings
):
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
chunks_table_data.append(
{
"id": chunk_id,
@@ -196,15 +195,13 @@ class RelytVector(BaseVector):
return False
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_uuids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
with Session(self.client) as session:
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
)
@@ -228,38 +225,34 @@ class RelytVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
results = self.similarity_search_with_score_by_vector(
k=int(kwargs.get('top_k')),
embedding=query_vector,
filter=kwargs.get('filter')
k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter")
)
# Organize results.
docs = []
for document, score in results:
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
if 1 - score > score_threshold:
docs.append(document)
return docs
def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided
try:
from sqlalchemy.engine import Row
except ImportError:
raise ImportError(
"Could not import Row from sqlalchemy.engine. "
"Please 'pip install sqlalchemy>=1.4'."
)
raise ImportError("Could not import Row from sqlalchemy.engine. " "Please 'pip install sqlalchemy>=1.4'.")
filter_condition = ""
if filter is not None:
conditions = [
f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
if len(value) > 1
else f"metadata->>{key!r} = {value[0]!r}"
for key, value in filter.items()
]
@@ -305,13 +298,12 @@ class RelytVector(BaseVector):
class RelytVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name))
return RelytVector(
collection_name=collection_name,
@@ -322,5 +314,5 @@ class RelytVectorFactory(AbstractVectorFactory):
password=dify_config.RELYT_PASSWORD,
database=dify_config.RELYT_DATABASE,
),
group_id=dataset.id
group_id=dataset.id,
)

View File

@@ -25,16 +25,11 @@ class TencentConfig(BaseModel):
database: Optional[str]
index_type: str = "HNSW"
metric_type: str = "L2"
shard: int = 1,
replicas: int = 2,
shard: int = (1,)
replicas: int = (2,)
def to_tencent_params(self):
return {
'url': self.url,
'username': self.username,
'key': self.api_key,
'timeout': self.timeout
}
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
class TencentVector(BaseVector):
@@ -61,13 +56,10 @@ class TencentVector(BaseVector):
return self._client.create_database(database_name=self._client_config.database)
def get_type(self) -> str:
return 'tencent'
return "tencent"
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def _has_collection(self) -> bool:
collections = self._db.list_collections()
@@ -77,9 +69,9 @@ class TencentVector(BaseVector):
return False
def _create_collection(self, dimension: int) -> None:
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
@@ -101,9 +93,7 @@ class TencentVector(BaseVector):
raise ValueError("unsupported metric_type")
params = vdb_index.HNSWParams(m=16, efconstruction=200)
index = vdb_index.Index(
vdb_index.FilterIndex(
self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
),
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
vdb_index.VectorIndex(
self.field_vector,
dimension,
@@ -111,12 +101,8 @@ class TencentVector(BaseVector):
metric_type,
params,
),
vdb_index.FilterIndex(
self.field_text, enum.FieldType.String, enum.IndexType.FILTER
),
vdb_index.FilterIndex(
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
),
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER),
)
self._db.create_collection(
@@ -163,15 +149,14 @@ class TencentVector(BaseVector):
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
res = self._db.collection(self._collection_name).search(vectors=[query_vector],
params=document.HNSWSearchParams(
ef=kwargs.get("ef", 10)),
retrieve_vector=False,
limit=kwargs.get('top_k', 4),
timeout=self._client_config.timeout,
)
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
res = self._db.collection(self._collection_name).search(
vectors=[query_vector],
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
retrieve_vector=False,
limit=kwargs.get("top_k", 4),
timeout=self._client_config.timeout,
)
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -200,15 +185,13 @@ class TencentVector(BaseVector):
class TencentVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
return TencentVector(
collection_name=collection_name,
@@ -220,5 +203,5 @@ class TencentVectorFactory(AbstractVectorFactory):
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
)
),
)

View File

@@ -28,47 +28,56 @@ class TiDBVectorConfig(BaseModel):
database: str
program_name: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config TIDB_VECTOR_HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config TIDB_VECTOR_PORT is required")
if not values['user']:
if not values["user"]:
raise ValueError("config TIDB_VECTOR_USER is required")
if not values['password']:
if not values["password"]:
raise ValueError("config TIDB_VECTOR_PASSWORD is required")
if not values['database']:
if not values["database"]:
raise ValueError("config TIDB_VECTOR_DATABASE is required")
if not values['program_name']:
if not values["program_name"]:
raise ValueError("config APPLICATION_NAME is required")
return values
class TiDBVector(BaseVector):
def get_type(self) -> str:
return VectorType.TIDB_VECTOR
def _table(self, dim: int) -> Table:
from tidb_vector.sqlalchemy import VectorType
return Table(
self._collection_name,
self._orm_base.metadata,
Column('id', String(36), primary_key=True, nullable=False),
Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"),
Column("id", String(36), primary_key=True, nullable=False),
Column(
"vector",
VectorType(dim),
nullable=False,
comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})",
),
Column("text", TEXT, nullable=False),
Column("meta", JSON, nullable=False),
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
extend_existing=True
Column(
"update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
),
extend_existing=True,
)
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'):
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = "cosine"):
super().__init__(collection_name)
self._client_config = config
self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}")
self._url = (
f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}"
)
self._distance_func = distance_func.lower()
self._engine = create_engine(self._url)
self._orm_base = declarative_base()
@@ -83,9 +92,9 @@ class TiDBVector(BaseVector):
def _create_collection(self, dimension: int):
logger.info("_create_collection, collection_name " + self._collection_name)
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
with Session(self._engine) as session:
@@ -116,9 +125,7 @@ class TiDBVector(BaseVector):
chunks_table_data = []
with self._engine.connect() as conn:
with conn.begin():
for id, text, meta, embedding in zip(
ids, texts, metas, embeddings
):
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
# Execute the batch insert when the batch size is reached
@@ -133,12 +140,12 @@ class TiDBVector(BaseVector):
return ids
def text_exists(self, id: str) -> bool:
result = self.get_ids_by_metadata_field('doc_id', id)
result = self.get_ids_by_metadata_field("doc_id", id)
return bool(result)
def delete_by_ids(self, ids: list[str]) -> None:
with Session(self._engine) as session:
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """
)
@@ -180,20 +187,22 @@ class TiDBVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
filter = kwargs.get('filter')
filter = kwargs.get("filter")
distance = 1 - score_threshold
query_vector_str = ", ".join(format(x) for x in query_vector)
query_vector_str = "[" + query_vector_str + "]"
logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}")
logger.debug(
f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}"
)
docs = []
if self._distance_func == 'l2':
tidb_func = 'Vec_l2_distance'
elif self._distance_func == 'cosine':
tidb_func = 'Vec_Cosine_distance'
if self._distance_func == "l2":
tidb_func = "Vec_l2_distance"
elif self._distance_func == "cosine":
tidb_func = "Vec_Cosine_distance"
else:
tidb_func = 'Vec_Cosine_distance'
tidb_func = "Vec_Cosine_distance"
with Session(self._engine) as session:
select_statement = sql_text(
@@ -208,7 +217,7 @@ class TiDBVector(BaseVector):
results = [(row[0], row[1], row[2]) for row in res]
for meta, text, distance in results:
metadata = json.loads(meta)
metadata['score'] = 1 - distance
metadata["score"] = 1 - distance
docs.append(Document(page_content=text, metadata=metadata))
return docs
@@ -224,15 +233,13 @@ class TiDBVector(BaseVector):
class TiDBVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
return TiDBVector(
collection_name=collection_name,

View File

@@ -7,7 +7,6 @@ from core.rag.models.document import Document
class BaseVector(ABC):
def __init__(self, collection_name: str):
self._collection_name = collection_name
@@ -39,18 +38,11 @@ class BaseVector(ABC):
raise NotImplementedError
@abstractmethod
def search_by_vector(
self,
query_vector: list[float],
**kwargs: Any
) -> list[Document]:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
raise NotImplementedError
@abstractmethod
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
raise NotImplementedError
def delete(self) -> None:
@@ -58,7 +50,7 @@ class BaseVector(ABC):
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
doc_id = text.metadata['doc_id']
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
@@ -66,7 +58,7 @@ class BaseVector(ABC):
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
return [text.metadata["doc_id"] for text in texts]
@property
def collection_name(self):

View File

@@ -20,17 +20,14 @@ class AbstractVectorFactory(ABC):
@staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict:
index_struct_dict = {
"type": vector_type,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
return index_struct_dict
class Vector:
def __init__(self, dataset: Dataset, attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash', 'page']
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "page"]
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes
@@ -39,7 +36,7 @@ class Vector:
def _init_vector(self) -> BaseVector:
vector_type = dify_config.VECTOR_STORE
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
vector_type = self._dataset.index_struct_dict["type"]
if not vector_type:
raise ValueError("Vector store must be specified.")
@@ -52,45 +49,59 @@ class Vector:
match vector_type:
case VectorType.CHROMA:
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
return ChromaVectorFactory
case VectorType.MILVUS:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
return MyScaleVectorFactory
case VectorType.PGVECTOR:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory
case VectorType.PGVECTO_RS:
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
return PGVectoRSFactory
case VectorType.QDRANT:
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
return QdrantVectorFactory
case VectorType.RELYT:
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
return RelytVectorFactory
case VectorType.ELASTICSEARCH:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
return TiDBVectorFactory
case VectorType.WEAVIATE:
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
return WeaviateVectorFactory
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case VectorType.ORACLE:
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
return OracleVectorFactory
case VectorType.OPENSEARCH:
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
return OpenSearchVectorFactory
case VectorType.ANALYTICDB:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
@@ -98,22 +109,14 @@ class Vector:
def create(self, texts: list = None, **kwargs):
if texts:
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
self._vector_processor.create(
texts=texts,
embeddings=embeddings,
**kwargs
)
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get('duplicate_check', False):
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
self._vector_processor.create(
texts=documents,
embeddings=embeddings,
**kwargs
)
self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs)
def text_exists(self, id: str) -> bool:
return self._vector_processor.text_exists(id)
@@ -124,24 +127,18 @@ class Vector:
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._vector_processor.delete_by_metadata_field(key, value)
def search_by_vector(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
def delete(self) -> None:
self._vector_processor.delete()
# delete collection redis cache
if self._vector_processor.collection_name:
collection_exist_cache_key = 'vector_indexing_{}'.format(self._vector_processor.collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name)
redis_client.delete(collection_exist_cache_key)
def _get_embeddings(self) -> Embeddings:
@@ -151,14 +148,13 @@ class Vector:
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model
model=self._dataset.embedding_model,
)
return CacheEmbedding(embedding_model)
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
doc_id = text.metadata['doc_id']
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)

View File

@@ -2,17 +2,17 @@ from enum import Enum
class VectorType(str, Enum):
ANALYTICDB = 'analyticdb'
CHROMA = 'chroma'
MILVUS = 'milvus'
MYSCALE = 'myscale'
PGVECTOR = 'pgvector'
PGVECTO_RS = 'pgvecto-rs'
QDRANT = 'qdrant'
RELYT = 'relyt'
TIDB_VECTOR = 'tidb_vector'
WEAVIATE = 'weaviate'
OPENSEARCH = 'opensearch'
TENCENT = 'tencent'
ORACLE = 'oracle'
ELASTICSEARCH = 'elasticsearch'
ANALYTICDB = "analyticdb"
CHROMA = "chroma"
MILVUS = "milvus"
MYSCALE = "myscale"
PGVECTOR = "pgvector"
PGVECTO_RS = "pgvecto-rs"
QDRANT = "qdrant"
RELYT = "relyt"
TIDB_VECTOR = "tidb_vector"
WEAVIATE = "weaviate"
OPENSEARCH = "opensearch"
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"

View File

@@ -22,15 +22,14 @@ class WeaviateConfig(BaseModel):
api_key: Optional[str] = None
batch_size: int = 100
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVector(BaseVector):
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
super().__init__(collection_name)
self._client = self._init_client(config)
@@ -43,10 +42,7 @@ class WeaviateVector(BaseVector):
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
@@ -68,10 +64,10 @@ class WeaviateVector(BaseVector):
def get_collection_name(self, dataset: Dataset) -> str:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
if not class_prefix.endswith("_Node"):
# original class_prefix
class_prefix += '_Node'
class_prefix += "_Node"
return class_prefix
@@ -79,10 +75,7 @@ class WeaviateVector(BaseVector):
return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
# create collection
@@ -91,9 +84,9 @@ class WeaviateVector(BaseVector):
self.add_texts(texts, embeddings)
def _create_collection(self):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
schema = self._default_schema(self._collection_name)
@@ -129,17 +122,9 @@ class WeaviateVector(BaseVector):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
where_filter = {
"operator": "Equal",
"path": [key],
"valueText": value
}
where_filter = {"operator": "Equal", "path": [key], "valueText": value}
self._client.batch.delete_objects(
class_name=self._collection_name,
where=where_filter,
output='minimal'
)
self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal")
def delete(self):
# check whether the index already exists
@@ -154,11 +139,19 @@ class WeaviateVector(BaseVector):
# check whether the index already exists
if not self._client.schema.contains(schema):
return False
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}).with_limit(1).do()
result = (
self._client.query.get(collection_name)
.with_additional(["id"])
.with_where(
{
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}
)
.with_limit(1)
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
@@ -211,13 +204,13 @@ class WeaviateVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
# check score threshold
if score > score_threshold:
doc.metadata['score'] = score
doc.metadata["score"] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -240,15 +233,15 @@ class WeaviateVector(BaseVector):
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
query_obj = query_obj.with_additional(["vector"])
properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
properties = ["text"]
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
additional = res.pop('_additional')
docs.append(Document(page_content=text, vector=additional['vector'], metadata=res))
additional = res.pop("_additional")
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
return docs
def _default_schema(self, index_name: str) -> dict:
@@ -271,20 +264,19 @@ class WeaviateVector(BaseVector):
class WeaviateVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=dify_config.WEAVIATE_ENDPOINT,
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),
attributes=attributes
attributes=attributes,
)