mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
feat(vdb): add HNSW vector index for TiDB vector store with TiFlash (#12043)
This commit is contained in:
@@ -9,6 +9,7 @@ from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@@ -54,14 +55,13 @@ class TiDBVector(BaseVector):
|
||||
return Table(
|
||||
self._collection_name,
|
||||
self._orm_base.metadata,
|
||||
Column("id", String(36), primary_key=True, nullable=False),
|
||||
Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False),
|
||||
Column(
|
||||
"vector",
|
||||
Field.VECTOR.value,
|
||||
VectorType(dim),
|
||||
nullable=False,
|
||||
comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})",
|
||||
),
|
||||
Column("text", TEXT, nullable=False),
|
||||
Column(Field.TEXT_KEY.value, TEXT, nullable=False),
|
||||
Column("meta", JSON, nullable=False),
|
||||
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
|
||||
Column(
|
||||
@@ -96,6 +96,7 @@ class TiDBVector(BaseVector):
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
tidb_dist_func = self._get_distance_func()
|
||||
with Session(self._engine) as session:
|
||||
session.begin()
|
||||
create_statement = sql_text(f"""
|
||||
@@ -104,14 +105,14 @@ class TiDBVector(BaseVector):
|
||||
text TEXT NOT NULL,
|
||||
meta JSON NOT NULL,
|
||||
doc_id VARCHAR(64) AS (JSON_UNQUOTE(JSON_EXTRACT(meta, '$.doc_id'))) STORED,
|
||||
KEY (doc_id),
|
||||
vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
|
||||
vector VECTOR<FLOAT>({dimension}) NOT NULL,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
KEY (doc_id),
|
||||
VECTOR INDEX idx_vector (({tidb_dist_func}(vector))) USING HNSW
|
||||
);
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
# tidb vector not support 'CREATE/ADD INDEX' now
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
@@ -194,23 +195,30 @@ class TiDBVector(BaseVector):
|
||||
)
|
||||
|
||||
docs = []
|
||||
if self._distance_func == "l2":
|
||||
tidb_func = "Vec_l2_distance"
|
||||
elif self._distance_func == "cosine":
|
||||
tidb_func = "Vec_Cosine_distance"
|
||||
else:
|
||||
tidb_func = "Vec_Cosine_distance"
|
||||
tidb_dist_func = self._get_distance_func()
|
||||
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(
|
||||
f"""SELECT meta, text, distance FROM (
|
||||
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
||||
FROM {self._collection_name}
|
||||
ORDER BY distance
|
||||
LIMIT {top_k}
|
||||
) t WHERE distance < {distance};"""
|
||||
select_statement = sql_text(f"""
|
||||
SELECT meta, text, distance
|
||||
FROM (
|
||||
SELECT
|
||||
meta,
|
||||
text,
|
||||
{tidb_dist_func}(vector, :query_vector_str) AS distance
|
||||
FROM {self._collection_name}
|
||||
ORDER BY distance ASC
|
||||
LIMIT :top_k
|
||||
) t
|
||||
WHERE distance <= :distance
|
||||
""")
|
||||
res = session.execute(
|
||||
select_statement,
|
||||
params={
|
||||
"query_vector_str": query_vector_str,
|
||||
"distance": distance,
|
||||
"top_k": top_k,
|
||||
},
|
||||
)
|
||||
res = session.execute(select_statement)
|
||||
results = [(row[0], row[1], row[2]) for row in res]
|
||||
for meta, text, distance in results:
|
||||
metadata = json.loads(meta)
|
||||
@@ -227,6 +235,16 @@ class TiDBVector(BaseVector):
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
||||
session.commit()
|
||||
|
||||
def _get_distance_func(self) -> str:
|
||||
match self._distance_func:
|
||||
case "l2":
|
||||
tidb_dist_func = "VEC_L2_DISTANCE"
|
||||
case "cosine":
|
||||
tidb_dist_func = "VEC_COSINE_DISTANCE"
|
||||
case _:
|
||||
tidb_dist_func = "VEC_COSINE_DISTANCE"
|
||||
return tidb_dist_func
|
||||
|
||||
|
||||
class TiDBVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
|
||||
|
||||
Reference in New Issue
Block a user