feat(vdb): add HNSW vector index for TiDB vector store with TiFlash (#12043)

This commit is contained in:
Bowen Liang
2025-02-12 13:53:51 +08:00
committed by GitHub
parent 786550bdc9
commit 0751ad1eeb
11 changed files with 211 additions and 45 deletions

View File

@@ -9,6 +9,7 @@ from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -54,14 +55,13 @@ class TiDBVector(BaseVector):
return Table(
self._collection_name,
self._orm_base.metadata,
Column("id", String(36), primary_key=True, nullable=False),
Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False),
Column(
"vector",
Field.VECTOR.value,
VectorType(dim),
nullable=False,
comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})",
),
Column("text", TEXT, nullable=False),
Column(Field.TEXT_KEY.value, TEXT, nullable=False),
Column("meta", JSON, nullable=False),
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
Column(
@@ -96,6 +96,7 @@ class TiDBVector(BaseVector):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
tidb_dist_func = self._get_distance_func()
with Session(self._engine) as session:
session.begin()
create_statement = sql_text(f"""
@@ -104,14 +105,14 @@ class TiDBVector(BaseVector):
text TEXT NOT NULL,
meta JSON NOT NULL,
doc_id VARCHAR(64) AS (JSON_UNQUOTE(JSON_EXTRACT(meta, '$.doc_id'))) STORED,
KEY (doc_id),
vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
vector VECTOR<FLOAT>({dimension}) NOT NULL,
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
KEY (doc_id),
VECTOR INDEX idx_vector (({tidb_dist_func}(vector))) USING HNSW
);
""")
session.execute(create_statement)
# tidb vector not support 'CREATE/ADD INDEX' now
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@@ -194,23 +195,30 @@ class TiDBVector(BaseVector):
)
docs = []
if self._distance_func == "l2":
tidb_func = "Vec_l2_distance"
elif self._distance_func == "cosine":
tidb_func = "Vec_Cosine_distance"
else:
tidb_func = "Vec_Cosine_distance"
tidb_dist_func = self._get_distance_func()
with Session(self._engine) as session:
select_statement = sql_text(
f"""SELECT meta, text, distance FROM (
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
FROM {self._collection_name}
ORDER BY distance
LIMIT {top_k}
) t WHERE distance < {distance};"""
select_statement = sql_text(f"""
SELECT meta, text, distance
FROM (
SELECT
meta,
text,
{tidb_dist_func}(vector, :query_vector_str) AS distance
FROM {self._collection_name}
ORDER BY distance ASC
LIMIT :top_k
) t
WHERE distance <= :distance
""")
res = session.execute(
select_statement,
params={
"query_vector_str": query_vector_str,
"distance": distance,
"top_k": top_k,
},
)
res = session.execute(select_statement)
results = [(row[0], row[1], row[2]) for row in res]
for meta, text, distance in results:
metadata = json.loads(meta)
@@ -227,6 +235,16 @@ class TiDBVector(BaseVector):
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
session.commit()
def _get_distance_func(self) -> str:
match self._distance_func:
case "l2":
tidb_dist_func = "VEC_L2_DISTANCE"
case "cosine":
tidb_dist_func = "VEC_COSINE_DISTANCE"
case _:
tidb_dist_func = "VEC_COSINE_DISTANCE"
return tidb_dist_func
class TiDBVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:

View File

@@ -0,0 +1,59 @@
import time
import pymysql
def check_tiflash_ready() -> bool:
try:
connection = pymysql.connect(
host="localhost",
port=4000,
user="root",
password="",
)
with connection.cursor() as cursor:
# Doc reference:
# https://docs.pingcap.com/zh/tidb/stable/information-schema-cluster-hardware
select_tiflash_query = """
SELECT * FROM information_schema.cluster_hardware
WHERE TYPE='tiflash'
LIMIT 1;
"""
cursor.execute(select_tiflash_query)
result = cursor.fetchall()
return result is not None and len(result) > 0
except Exception as e:
print(f"TiFlash is not ready. Exception: {e}")
return False
finally:
if connection:
connection.close()
def main():
max_attempts = 30
retry_interval_seconds = 2
is_tiflash_ready = False
for attempt in range(max_attempts):
try:
is_tiflash_ready = check_tiflash_ready()
except Exception as e:
print(f"TiFlash is not ready. Exception: {e}")
is_tiflash_ready = False
if is_tiflash_ready:
break
else:
print(f"Attempt {attempt + 1} failedretry in {retry_interval_seconds} seconds...")
time.sleep(retry_interval_seconds)
if is_tiflash_ready:
print("TiFlash is ready in TiDB.")
else:
print(f"TiFlash is not ready in TiDB after {max_attempts} attempting checks.")
exit(1)
if __name__ == "__main__":
main()