feat: Add OceanBase hybrid search features (#16652)

Co-authored-by: 李远军 <4842@9ji.com>
Co-authored-by: yourchanges <yourchanges@gmail.com>
This commit is contained in:
hsiong
2025-03-25 14:32:00 +08:00
committed by GitHub
parent c4bb07184d
commit 6157f57872
7 changed files with 121 additions and 22 deletions

View File

@@ -31,6 +31,7 @@ class OceanBaseVectorConfig(BaseModel):
user: str
password: str
database: str
enable_hybrid_search: bool = False
@model_validator(mode="before")
@classmethod
@@ -57,6 +58,7 @@ class OceanBaseVector(BaseVector):
password=self._config.password,
db_name=self._config.database,
)
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def get_type(self) -> str:
return VectorType.OCEANBASE
@@ -98,6 +100,16 @@ class OceanBaseVector(BaseVector):
columns=cols,
vidxs=vidx_params,
)
try:
if self._hybrid_search_enabled:
self._client.perform_raw_text_sql(f"""ALTER TABLE {self._collection_name}
ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER ik""")
except Exception as e:
raise Exception(
"Failed to add fulltext index to the target table, your OceanBase version must be 4.3.5.1 or above "
+ "to support fulltext index and vector index in the same table",
e,
)
vals = []
params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'")
for row in params:
@@ -116,6 +128,27 @@ class OceanBaseVector(BaseVector):
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _check_hybrid_search_support(self) -> bool:
"""
Check if the current OceanBase version supports hybrid search.
Returns True if the version is >= 4.3.5.1, otherwise False.
"""
if not self._config.enable_hybrid_search:
return False
try:
from packaging import version
# return OceanBase_CE 4.3.5.1 (r101000042025031818-bxxxx) (Built Mar 18 2025 18:13:36)
result = self._client.perform_raw_text_sql("SELECT @@version_comment AS version")
ob_full_version = result.fetchone()[0]
ob_version = ob_full_version.split()[1]
logger.debug("Current OceanBase version is %s", ob_version)
return version.parse(ob_version).base_version >= version.parse("4.3.5.1").base_version
except Exception as e:
logger.warning(f"Failed to check OceanBase version: {str(e)}. Disabling hybrid search.")
return False
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = self._get_uuids(documents)
for id, doc, emb in zip(ids, documents, embeddings):
@@ -130,7 +163,7 @@ class OceanBaseVector(BaseVector):
)
def text_exists(self, id: str) -> bool:
cur = self._client.get(table_name=self._collection_name, id=id)
cur = self._client.get(table_name=self._collection_name, ids=id)
return bool(cur.rowcount != 0)
def delete_by_ids(self, ids: list[str]) -> None:
@@ -139,9 +172,12 @@ class OceanBaseVector(BaseVector):
self._client.delete(table_name=self._collection_name, ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
from sqlalchemy import text
cur = self._client.get(
table_name=self._collection_name,
where_clause=f"metadata->>'$.{key}' = '{value}'",
ids=None,
where_clause=[text(f"metadata->>'$.{key}' = '{value}'")],
output_column_name=["id"],
)
return [row[0] for row in cur]
@@ -151,36 +187,84 @@ class OceanBaseVector(BaseVector):
self.delete_by_ids(ids)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
if not self._hybrid_search_enabled:
return []
try:
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'$.document_id' IN ({document_ids})"
full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score
FROM {self._collection_name}
WHERE MATCH (text) AGAINST (:query) > 0
{where_clause}
ORDER BY score DESC
LIMIT {top_k}"""
with self._client.engine.connect() as conn:
with conn.begin():
from sqlalchemy import text
result = conn.execute(text(full_sql), {"query": query})
rows = result.fetchall()
docs = []
for row in rows:
metadata_str, _text, score = row
try:
metadata = json.loads(metadata_str)
except json.JSONDecodeError:
print(f"Invalid JSON metadata: {metadata_str}")
metadata = {}
metadata["score"] = score
docs.append(Document(page_content=_text, metadata=metadata))
return docs
except Exception as e:
logger.warning(f"Failed to fulltext search: {str(e)}.")
return []
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = None
_where_clause = None
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
from sqlalchemy import text
_where_clause = [text(where_clause)]
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
if ef_search != self._hnsw_ef_search:
self._client.set_ob_hnsw_ef_search(ef_search)
self._hnsw_ef_search = ef_search
topk = kwargs.get("top_k", 10)
cur = self._client.ann_search(
table_name=self._collection_name,
vec_column_name="vector",
vec_data=query_vector,
topk=topk,
distance_func=func.l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
where_clause=where_clause,
)
try:
cur = self._client.ann_search(
table_name=self._collection_name,
vec_column_name="vector",
vec_data=query_vector,
topk=topk,
distance_func=func.l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
where_clause=_where_clause,
)
except Exception as e:
raise Exception("Failed to search by vector. ", e)
docs = []
for text, metadata, distance in cur:
for _text, metadata, distance in cur:
metadata = json.loads(metadata)
metadata["score"] = 1 - distance / math.sqrt(2)
docs.append(
Document(
page_content=text,
page_content=_text,
metadata=metadata,
)
)