diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 1ba37dfb3..b2fbfced0 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -1,10 +1,12 @@ import copy import json import logging +import time from typing import Any, Optional from opensearchpy import OpenSearch from pydantic import BaseModel, model_validator +from tenacity import retry, stop_after_attempt, wait_exponential from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -77,31 +79,74 @@ class LindormVectorStore(BaseVector): def refresh(self): self._client.indices.refresh(index=self._collection_name) - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): - actions = [] + def add_texts( + self, + documents: list[Document], + embeddings: list[list[float]], + batch_size: int = 64, + timeout: int = 60, + **kwargs, + ): + logger.info(f"Total documents to add: {len(documents)}") uuids = self._get_uuids(documents) - for i in range(len(documents)): - action_header = { - "index": { - "_index": self.collection_name.lower(), - "_id": uuids[i], + + total_docs = len(documents) + num_batches = (total_docs + batch_size - 1) // batch_size + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + ) + def _bulk_with_retry(actions): + try: + response = self._client.bulk(actions, timeout=timeout) + if response["errors"]: + error_items = [item for item in response["items"] if "error" in item["index"]] + error_msg = f"Bulk indexing had {len(error_items)} errors" + logger.exception(error_msg) + raise Exception(error_msg) + return response + except Exception: + logger.exception("Bulk indexing error") + raise + + for batch_num in range(num_batches): + start_idx = batch_num * batch_size + end_idx = min((batch_num + 1) * batch_size, total_docs) + + actions = [] + for i in range(start_idx, end_idx): + action_header = { + "index": { + "_index": self.collection_name.lower(), + "_id": uuids[i], + } } - } - action_values: dict[str, Any] = { - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i], # Make sure you pass an array here - Field.METADATA_KEY.value: documents[i].metadata, - } - if self._using_ugc: - action_header["index"]["routing"] = self._routing - if self._routing_field is not None: - action_values[self._routing_field] = self._routing - actions.append(action_header) - actions.append(action_values) - response = self._client.bulk(actions) - if response["errors"]: - for item in response["items"]: - print(f"{item['index']['status']}: {item['index']['error']['type']}") + action_values: dict[str, Any] = { + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i], + Field.METADATA_KEY.value: documents[i].metadata, + } + if self._using_ugc: + action_header["index"]["routing"] = self._routing + if self._routing_field is not None: + action_values[self._routing_field] = self._routing + + actions.append(action_header) + actions.append(action_values) + + logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})") + + try: + _bulk_with_retry(actions) + logger.info(f"Successfully processed batch {batch_num + 1}") + # simple latency to avoid too many requests in a short time + if batch_num < num_batches - 1: + time.sleep(1) + + except Exception: + logger.exception(f"Failed to process batch {batch_num + 1}") + raise def get_ids_by_metadata_field(self, key: str, value: str): query: dict[str, Any] = { @@ -130,7 +175,6 @@ class LindormVectorStore(BaseVector): if self._using_ugc: params["routing"] = self._routing self._client.delete(index=self._collection_name, id=id, params=params) - self.refresh() else: logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")