From 6c4e6bf1d6ca92716ed05b47c5b46035047c3873 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Thu, 22 Feb 2024 23:31:57 +0800 Subject: [PATCH] Feat/dify rag (#2528) Co-authored-by: jyong --- api/celerybeat-schedule.db | Bin 0 -> 16384 bytes api/config.py | 3 +- .../console/datasets/data_source.py | 32 +- api/controllers/console/datasets/datasets.py | 74 +- .../console/datasets/datasets_document.py | 74 +- api/core/data_loader/file_extractor.py | 107 - api/core/data_loader/loader/html.py | 34 - api/core/data_loader/loader/pdf.py | 55 - api/core/docstore/dataset_docstore.py | 2 +- api/core/features/annotation_reply.py | 38 +- api/core/index/index.py | 51 - api/core/index/vector_index/base.py | 305 --- .../index/vector_index/milvus_vector_index.py | 165 -- .../index/vector_index/qdrant_vector_index.py | 229 --- api/core/index/vector_index/vector_index.py | 90 - .../vector_index/weaviate_vector_index.py | 179 -- api/core/indexing_runner.py | 380 ++-- api/core/rag/__init__.py | 0 api/core/rag/cleaner/clean_processor.py | 38 + api/core/rag/cleaner/cleaner_base.py | 12 + .../unstructured_extra_whitespace_cleaner.py | 12 + ...uctured_group_broken_paragraphs_cleaner.py | 15 + .../unstructured_non_ascii_chars_cleaner.py | 12 + ...ructured_replace_unicode_quotes_cleaner.py | 11 + .../unstructured_translate_text_cleaner.py | 11 + api/core/rag/data_post_processor/__init__.py | 0 .../data_post_processor.py | 49 + api/core/rag/data_post_processor/reorder.py | 19 + api/core/rag/datasource/__init__.py | 0 api/core/rag/datasource/entity/embedding.py | 21 + api/core/rag/datasource/keyword/__init__.py | 0 .../rag/datasource/keyword/jieba/__init__.py | 0 .../datasource/keyword/jieba/jieba.py} | 108 +- .../jieba}/jieba_keyword_table_handler.py | 2 +- .../datasource/keyword/jieba}/stopwords.py | 0 .../datasource/keyword/keyword_base.py} | 28 +- .../rag/datasource/keyword/keyword_factory.py | 60 + api/core/rag/datasource/retrieval_service.py | 165 ++ api/core/rag/datasource/vdb/__init__.py | 0 api/core/rag/datasource/vdb/field.py | 10 + .../rag/datasource/vdb/milvus/__init__.py | 0 .../datasource/vdb/milvus/milvus_vector.py | 214 ++ .../rag/datasource/vdb/qdrant/__init__.py | 0 .../datasource/vdb/qdrant/qdrant_vector.py | 360 ++++ api/core/rag/datasource/vdb/vector_base.py | 62 + api/core/rag/datasource/vdb/vector_factory.py | 171 ++ .../rag/datasource/vdb/weaviate/__init__.py | 0 .../vdb/weaviate/weaviate_vector.py | 235 +++ api/core/rag/extractor/blod/blod.py | 166 ++ api/core/rag/extractor/csv_extractor.py | 71 + .../rag/extractor/entity/datasource_type.py | 6 + .../rag/extractor/entity/extract_setting.py | 36 + .../extractor/excel_extractor.py} | 23 +- api/core/rag/extractor/extract_processor.py | 139 ++ api/core/rag/extractor/extractor_base.py | 12 + api/core/rag/extractor/helpers.py | 46 + .../extractor/html_extractor.py} | 46 +- .../extractor/markdown_extractor.py} | 40 +- .../extractor/notion_extractor.py} | 56 +- api/core/rag/extractor/pdf_extractor.py | 72 + api/core/rag/extractor/text_extractor.py | 50 + .../unstructured_doc_extractor.py | 61 + .../unstructured_eml_extractor.py} | 9 +- .../unstructured_markdown_extractor.py} | 8 +- .../unstructured_msg_extractor.py} | 8 +- .../unstructured_ppt_extractor.py} | 15 +- .../unstructured_pptx_extractor.py} | 16 +- .../unstructured_text_extractor.py} | 8 +- .../unstructured_xml_extractor.py} | 8 +- api/core/rag/extractor/word_extractor.py | 62 + api/core/rag/index_processor/__init__.py | 0 .../rag/index_processor/constant/__init__.py | 0 .../index_processor/constant/index_type.py | 8 + .../index_processor/index_processor_base.py | 70 + .../index_processor_factory.py | 28 + .../rag/index_processor/processor/__init__.py | 0 .../processor/paragraph_index_processor.py | 92 + .../processor/qa_index_processor.py | 161 ++ api/core/rag/models/__init__.py | 0 api/core/rag/models/document.py | 16 + api/core/tool/web_reader_tool.py | 10 +- .../dataset_multi_retriever_tool.py | 87 +- .../dataset_retriever_tool.py | 112 +- api/core/tools/utils/web_reader_tool.py | 10 +- api/core/vector_store/milvus_vector_store.py | 56 - api/core/vector_store/qdrant_vector_store.py | 76 - api/core/vector_store/vector/milvus.py | 852 -------- api/core/vector_store/vector/qdrant.py | 1759 ----------------- api/core/vector_store/vector/weaviate.py | 506 ----- .../vector_store/weaviate_vector_store.py | 38 - .../clean_when_dataset_deleted.py | 2 +- .../clean_when_document_deleted.py | 3 +- api/models/dataset.py | 8 + api/schedule/clean_unused_datasets_task.py | 17 +- api/services/dataset_service.py | 35 +- api/services/file_service.py | 9 +- api/services/hit_testing_service.py | 75 +- api/services/retrieval_service.py | 119 -- api/services/vector_service.py | 85 +- api/tasks/add_document_to_index_task.py | 16 +- .../add_annotation_to_index_task.py | 8 +- .../batch_import_annotations_task.py | 7 +- .../delete_annotation_index_task.py | 13 +- .../disable_annotation_reply_task.py | 13 +- .../enable_annotation_reply_task.py | 20 +- .../update_annotation_to_index_task.py | 9 +- api/tasks/clean_dataset_task.py | 25 +- api/tasks/clean_document_task.py | 18 +- api/tasks/clean_notion_document_task.py | 15 +- api/tasks/create_segment_to_index_task.py | 19 +- api/tasks/deal_dataset_vector_index_task.py | 14 +- api/tasks/delete_segment_from_index_task.py | 14 +- api/tasks/disable_segment_from_index_task.py | 14 +- api/tasks/document_indexing_sync_task.py | 22 +- api/tasks/document_indexing_update_task.py | 13 +- api/tasks/enable_segment_to_index_task.py | 14 +- api/tasks/remove_document_from_index_task.py | 12 +- api/tasks/update_segment_index_task.py | 114 -- .../update_segment_keyword_index_task.py | 68 - 119 files changed, 3181 insertions(+), 5892 deletions(-) create mode 100644 api/celerybeat-schedule.db delete mode 100644 api/core/data_loader/loader/html.py delete mode 100644 api/core/data_loader/loader/pdf.py delete mode 100644 api/core/index/index.py delete mode 100644 api/core/index/vector_index/base.py delete mode 100644 api/core/index/vector_index/milvus_vector_index.py delete mode 100644 api/core/index/vector_index/qdrant_vector_index.py delete mode 100644 api/core/index/vector_index/vector_index.py delete mode 100644 api/core/index/vector_index/weaviate_vector_index.py create mode 100644 api/core/rag/__init__.py create mode 100644 api/core/rag/cleaner/clean_processor.py create mode 100644 api/core/rag/cleaner/cleaner_base.py create mode 100644 api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py create mode 100644 api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py create mode 100644 api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py create mode 100644 api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py create mode 100644 api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py create mode 100644 api/core/rag/data_post_processor/__init__.py create mode 100644 api/core/rag/data_post_processor/data_post_processor.py create mode 100644 api/core/rag/data_post_processor/reorder.py create mode 100644 api/core/rag/datasource/__init__.py create mode 100644 api/core/rag/datasource/entity/embedding.py create mode 100644 api/core/rag/datasource/keyword/__init__.py create mode 100644 api/core/rag/datasource/keyword/jieba/__init__.py rename api/core/{index/keyword_table_index/keyword_table_index.py => rag/datasource/keyword/jieba/jieba.py} (71%) rename api/core/{index/keyword_table_index => rag/datasource/keyword/jieba}/jieba_keyword_table_handler.py (93%) rename api/core/{index/keyword_table_index => rag/datasource/keyword/jieba}/stopwords.py (100%) rename api/core/{index/base.py => rag/datasource/keyword/keyword_base.py} (63%) create mode 100644 api/core/rag/datasource/keyword/keyword_factory.py create mode 100644 api/core/rag/datasource/retrieval_service.py create mode 100644 api/core/rag/datasource/vdb/__init__.py create mode 100644 api/core/rag/datasource/vdb/field.py create mode 100644 api/core/rag/datasource/vdb/milvus/__init__.py create mode 100644 api/core/rag/datasource/vdb/milvus/milvus_vector.py create mode 100644 api/core/rag/datasource/vdb/qdrant/__init__.py create mode 100644 api/core/rag/datasource/vdb/qdrant/qdrant_vector.py create mode 100644 api/core/rag/datasource/vdb/vector_base.py create mode 100644 api/core/rag/datasource/vdb/vector_factory.py create mode 100644 api/core/rag/datasource/vdb/weaviate/__init__.py create mode 100644 api/core/rag/datasource/vdb/weaviate/weaviate_vector.py create mode 100644 api/core/rag/extractor/blod/blod.py create mode 100644 api/core/rag/extractor/csv_extractor.py create mode 100644 api/core/rag/extractor/entity/datasource_type.py create mode 100644 api/core/rag/extractor/entity/extract_setting.py rename api/core/{data_loader/loader/excel.py => rag/extractor/excel_extractor.py} (66%) create mode 100644 api/core/rag/extractor/extract_processor.py create mode 100644 api/core/rag/extractor/extractor_base.py create mode 100644 api/core/rag/extractor/helpers.py rename api/core/{data_loader/loader/csv_loader.py => rag/extractor/html_extractor.py} (61%) rename api/core/{data_loader/loader/markdown.py => rag/extractor/markdown_extractor.py} (79%) rename api/core/{data_loader/loader/notion.py => rag/extractor/notion_extractor.py} (89%) create mode 100644 api/core/rag/extractor/pdf_extractor.py create mode 100644 api/core/rag/extractor/text_extractor.py create mode 100644 api/core/rag/extractor/unstructured/unstructured_doc_extractor.py rename api/core/{data_loader/loader/unstructured/unstructured_eml.py => rag/extractor/unstructured/unstructured_eml_extractor.py} (87%) rename api/core/{data_loader/loader/unstructured/unstructured_markdown.py => rag/extractor/unstructured/unstructured_markdown_extractor.py} (85%) rename api/core/{data_loader/loader/unstructured/unstructured_msg.py => rag/extractor/unstructured/unstructured_msg_extractor.py} (80%) rename api/core/{data_loader/loader/unstructured/unstructured_ppt.py => rag/extractor/unstructured/unstructured_ppt_extractor.py} (78%) rename api/core/{data_loader/loader/unstructured/unstructured_pptx.py => rag/extractor/unstructured/unstructured_pptx_extractor.py} (78%) rename api/core/{data_loader/loader/unstructured/unstructured_text.py => rag/extractor/unstructured/unstructured_text_extractor.py} (80%) rename api/core/{data_loader/loader/unstructured/unstructured_xml.py => rag/extractor/unstructured/unstructured_xml_extractor.py} (81%) create mode 100644 api/core/rag/extractor/word_extractor.py create mode 100644 api/core/rag/index_processor/__init__.py create mode 100644 api/core/rag/index_processor/constant/__init__.py create mode 100644 api/core/rag/index_processor/constant/index_type.py create mode 100644 api/core/rag/index_processor/index_processor_base.py create mode 100644 api/core/rag/index_processor/index_processor_factory.py create mode 100644 api/core/rag/index_processor/processor/__init__.py create mode 100644 api/core/rag/index_processor/processor/paragraph_index_processor.py create mode 100644 api/core/rag/index_processor/processor/qa_index_processor.py create mode 100644 api/core/rag/models/__init__.py create mode 100644 api/core/rag/models/document.py delete mode 100644 api/core/vector_store/milvus_vector_store.py delete mode 100644 api/core/vector_store/qdrant_vector_store.py delete mode 100644 api/core/vector_store/vector/milvus.py delete mode 100644 api/core/vector_store/vector/qdrant.py delete mode 100644 api/core/vector_store/vector/weaviate.py delete mode 100644 api/core/vector_store/weaviate_vector_store.py delete mode 100644 api/services/retrieval_service.py delete mode 100644 api/tasks/update_segment_index_task.py delete mode 100644 api/tasks/update_segment_keyword_index_task.py diff --git a/api/celerybeat-schedule.db b/api/celerybeat-schedule.db new file mode 100644 index 0000000000000000000000000000000000000000..b8c01de27bfe7ea04f1dd868cec4935ef336f2b5 GIT binary patch literal 16384 zcmeI%J#W)M7zglk>f!`QTojc7@c|m78dNo6>rydA=}>~icDmRP*lOxV@fkIhEM?+F zx~ENo1C0EEPX#K4Zk%h}inp%k`MMgNm7_ax8R_xN|`DeS_kV2srmv)?kd zVnTMAG0O~jXZ12L`QnGAa$GZ`oyR9}_Q8yK%je{M;jLbjy6|POAOs))0SG_<0uX=z z1Rwwb2teT62-Mj(_lx`4{p7xP-?*>cL367)Xr7z$Q78l;009U<00Izz00bZafio4D z*(VRmKSDFTrmp!T5;3R!Aq7DcKjgmfL*h~-ds|{FA(gYPz8m9>b+|(oz zlF3h^(Edd@H?RIgm^Z6Ln3vLFqkP=;mUX?Y!&dQh8(}+K?Xmv-+WeBQRvRb$J&FTf zY(P5JdAXSdB=Qjg4Oi6}AWu2CL*wcPbKyxZF2{1Hu(=pg2NW2r$3a74k(-tpwo znZ77k90Cx400bZa0SG_<0uX=z1Rwx`g$R88js^h;KmY;|fB*y_009U<00Izz!2d4r E0}denXaE2J literal 0 HcmV?d00001 diff --git a/api/config.py b/api/config.py index 95eabe204..83336e6c4 100644 --- a/api/config.py +++ b/api/config.py @@ -56,6 +56,7 @@ DEFAULTS = { 'BILLING_ENABLED': 'False', 'CAN_REPLACE_LOGO': 'False', 'ETL_TYPE': 'dify', + 'KEYWORD_STORE': 'jieba', 'BATCH_UPLOAD_LIMIT': 20 } @@ -183,7 +184,7 @@ class Config: # Currently, only support: qdrant, milvus, zilliz, weaviate # ------------------------ self.VECTOR_STORE = get_env('VECTOR_STORE') - + self.KEYWORD_STORE = get_env('KEYWORD_STORE') # qdrant settings self.QDRANT_URL = get_env('QDRANT_URL') self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 86fcf704c..c0c345bae 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.data_loader.loader.notion import NotionLoader from core.indexing_runner import IndexingRunner +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from libs.login import login_required @@ -173,14 +174,14 @@ class DataSourceNotionApi(Resource): if not data_source_binding: raise NotFound('Data source binding not found.') - loader = NotionLoader( - notion_access_token=data_source_binding.access_token, + extractor = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, - notion_page_type=page_type + notion_page_type=page_type, + notion_access_token=data_source_binding.access_token ) - text_docs = loader.load() + text_docs = extractor.extract() return { 'content': "\n".join([doc.page_content for doc in text_docs]) }, 200 @@ -192,11 +193,30 @@ class DataSourceNotionApi(Resource): parser = reqparse.RequestParser() parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') + parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') + parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) + notion_info_list = args['notion_info_list'] + extract_settings = [] + for notion_info in notion_info_list: + workspace_id = notion_info['workspace_id'] + for page in notion_info['pages']: + extract_setting = ExtractSetting( + datasource_type="notion_import", + notion_info={ + "notion_workspace_id": workspace_id, + "notion_obj_id": page['page_id'], + "notion_page_type": page['type'] + }, + document_model=args['doc_form'] + ) + extract_settings.append(extract_setting) indexing_runner = IndexingRunner() - response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule']) + response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, + args['process_rule'], args['doc_form'], + args['doc_language']) return response, 200 diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 2d26d0ecf..f80b4de48 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager +from core.rag.extractor.entity.extract_setting import ExtractSetting from extensions.ext_database import db from fields.app_fields import related_app_list from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields @@ -178,9 +179,9 @@ class DatasetApi(Resource): location='json', store_missing=False, type=_validate_description_length) parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help='Invalid indexing technique.') + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help='Invalid indexing technique.') parser.add_argument('permission', type=str, location='json', choices=( 'only_me', 'all_team_members'), help='Invalid permission.') parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') @@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource): parser = reqparse.RequestParser() parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('indexing_technique', type=str, required=True, + parser.add_argument('indexing_technique', type=str, required=True, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=True, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') @@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource): args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) + extract_settings = [] if args['info_list']['data_source_type'] == 'upload_file': file_ids = args['info_list']['file_info_list']['file_ids'] file_details = db.session.query(UploadFile).filter( @@ -278,37 +280,44 @@ class DatasetIndexingEstimateApi(Resource): if file_details is None: raise NotFound("File not found.") - indexing_runner = IndexingRunner() - - try: - response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, - args['process_rule'], args['doc_form'], - args['doc_language'], args['dataset_id'], - args['indexing_technique']) - except LLMBadRequestError: - raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) + if file_details: + for file_detail in file_details: + extract_setting = ExtractSetting( + datasource_type="upload_file", + upload_file=file_detail, + document_model=args['doc_form'] + ) + extract_settings.append(extract_setting) elif args['info_list']['data_source_type'] == 'notion_import': - - indexing_runner = IndexingRunner() - - try: - response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, - args['info_list']['notion_info_list'], - args['process_rule'], args['doc_form'], - args['doc_language'], args['dataset_id'], - args['indexing_technique']) - except LLMBadRequestError: - raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) + notion_info_list = args['info_list']['notion_info_list'] + for notion_info in notion_info_list: + workspace_id = notion_info['workspace_id'] + for page in notion_info['pages']: + extract_setting = ExtractSetting( + datasource_type="notion_import", + notion_info={ + "notion_workspace_id": workspace_id, + "notion_obj_id": page['page_id'], + "notion_page_type": page['type'] + }, + document_model=args['doc_form'] + ) + extract_settings.append(extract_setting) else: raise ValueError('Data source type not support') + indexing_runner = IndexingRunner() + try: + response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, + args['process_rule'], args['doc_form'], + args['doc_language'], args['dataset_id'], + args['indexing_technique']) + except LLMBadRequestError: + raise ProviderNotInitializeError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + return response, 200 @@ -508,4 +517,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/') api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/') - diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 3fb6f16cd..a990ef96e 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError +from core.rag.extractor.entity.extract_setting import ExtractSetting from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.document_fields import ( @@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource): req_data = request.args document_id = req_data.get('document_id') - + # get default rules mode = DocumentService.DEFAULT_RULES['mode'] rules = DocumentService.DEFAULT_RULES['rules'] @@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource): if not file: raise NotFound('File not found.') + extract_setting = ExtractSetting( + datasource_type="upload_file", + upload_file=file, + document_model=document.doc_form + ) + indexing_runner = IndexingRunner() try: - response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file], - data_process_rule_dict, None, - 'English', dataset_id) + response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting], + data_process_rule_dict, document.doc_form, + 'English', dataset_id) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " @@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): data_process_rule = documents[0].dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() info_list = [] + extract_settings = [] for document in documents: if document.indexing_status in ['completed', 'error']: raise DocumentAlreadyFinishedError() @@ -424,42 +432,48 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): } info_list.append(notion_info) - if dataset.data_source_type == 'upload_file': - file_details = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id.in_(info_list) - ).all() + if document.data_source_type == 'upload_file': + file_id = data_source_info['upload_file_id'] + file_detail = db.session.query(UploadFile).filter( + UploadFile.tenant_id == current_user.current_tenant_id, + UploadFile.id == file_id + ).first() - if file_details is None: - raise NotFound("File not found.") + if file_detail is None: + raise NotFound("File not found.") + extract_setting = ExtractSetting( + datasource_type="upload_file", + upload_file=file_detail, + document_model=document.doc_form + ) + extract_settings.append(extract_setting) + + elif document.data_source_type == 'notion_import': + extract_setting = ExtractSetting( + datasource_type="notion_import", + notion_info={ + "notion_workspace_id": data_source_info['notion_workspace_id'], + "notion_obj_id": data_source_info['notion_page_id'], + "notion_page_type": data_source_info['type'] + }, + document_model=document.doc_form + ) + extract_settings.append(extract_setting) + + else: + raise ValueError('Data source type not support') indexing_runner = IndexingRunner() try: - response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, - data_process_rule_dict, None, - 'English', dataset_id) + response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, + data_process_rule_dict, document.doc_form, + 'English', dataset_id) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) - elif dataset.data_source_type == 'notion_import': - - indexing_runner = IndexingRunner() - try: - response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, - info_list, - data_process_rule_dict, - None, 'English', dataset_id) - except LLMBadRequestError: - raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - else: - raise ValueError('Data source type not support') return response diff --git a/api/core/data_loader/file_extractor.py b/api/core/data_loader/file_extractor.py index 4741014c9..e69de29bb 100644 --- a/api/core/data_loader/file_extractor.py +++ b/api/core/data_loader/file_extractor.py @@ -1,107 +0,0 @@ -import tempfile -from pathlib import Path -from typing import Optional, Union - -import requests -from flask import current_app -from langchain.document_loaders import Docx2txtLoader, TextLoader -from langchain.schema import Document - -from core.data_loader.loader.csv_loader import CSVLoader -from core.data_loader.loader.excel import ExcelLoader -from core.data_loader.loader.html import HTMLLoader -from core.data_loader.loader.markdown import MarkdownLoader -from core.data_loader.loader.pdf import PdfLoader -from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader -from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader -from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader -from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader -from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader -from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader -from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader -from extensions.ext_storage import storage -from models.model import UploadFile - -SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain'] -USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - - -class FileExtractor: - @classmethod - def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]: - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" - storage.download(upload_file.key, file_path) - - return cls.load_from_file(file_path, return_text, upload_file, is_automatic) - - @classmethod - def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: - response = requests.get(url, headers={ - "User-Agent": USER_AGENT - }) - - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(url).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" - with open(file_path, 'wb') as file: - file.write(response.content) - - return cls.load_from_file(file_path, return_text) - - @classmethod - def load_from_file(cls, file_path: str, return_text: bool = False, - upload_file: Optional[UploadFile] = None, - is_automatic: bool = False) -> Union[list[Document], str]: - input_file = Path(file_path) - delimiter = '\n' - file_extension = input_file.suffix.lower() - etl_type = current_app.config['ETL_TYPE'] - unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] - if etl_type == 'Unstructured': - if file_extension == '.xlsx': - loader = ExcelLoader(file_path) - elif file_extension == '.pdf': - loader = PdfLoader(file_path, upload_file=upload_file) - elif file_extension in ['.md', '.markdown']: - loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \ - else MarkdownLoader(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: - loader = HTMLLoader(file_path) - elif file_extension in ['.docx']: - loader = Docx2txtLoader(file_path) - elif file_extension == '.csv': - loader = CSVLoader(file_path, autodetect_encoding=True) - elif file_extension == '.msg': - loader = UnstructuredMsgLoader(file_path, unstructured_api_url) - elif file_extension == '.eml': - loader = UnstructuredEmailLoader(file_path, unstructured_api_url) - elif file_extension == '.ppt': - loader = UnstructuredPPTLoader(file_path, unstructured_api_url) - elif file_extension == '.pptx': - loader = UnstructuredPPTXLoader(file_path, unstructured_api_url) - elif file_extension == '.xml': - loader = UnstructuredXmlLoader(file_path, unstructured_api_url) - else: - # txt - loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \ - else TextLoader(file_path, autodetect_encoding=True) - else: - if file_extension == '.xlsx': - loader = ExcelLoader(file_path) - elif file_extension == '.pdf': - loader = PdfLoader(file_path, upload_file=upload_file) - elif file_extension in ['.md', '.markdown']: - loader = MarkdownLoader(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: - loader = HTMLLoader(file_path) - elif file_extension in ['.docx']: - loader = Docx2txtLoader(file_path) - elif file_extension == '.csv': - loader = CSVLoader(file_path, autodetect_encoding=True) - else: - # txt - loader = TextLoader(file_path, autodetect_encoding=True) - - return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load() diff --git a/api/core/data_loader/loader/html.py b/api/core/data_loader/loader/html.py deleted file mode 100644 index 6a9b48a5b..000000000 --- a/api/core/data_loader/loader/html.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging - -from bs4 import BeautifulSoup -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document - -logger = logging.getLogger(__name__) - - -class HTMLLoader(BaseLoader): - """Load html files. - - - Args: - file_path: Path to the file to load. - """ - - def __init__( - self, - file_path: str - ): - """Initialize with file path.""" - self._file_path = file_path - - def load(self) -> list[Document]: - return [Document(page_content=self._load_as_text())] - - def _load_as_text(self) -> str: - with open(self._file_path, "rb") as fp: - soup = BeautifulSoup(fp, 'html.parser') - text = soup.get_text() - text = text.strip() if text else '' - - return text diff --git a/api/core/data_loader/loader/pdf.py b/api/core/data_loader/loader/pdf.py deleted file mode 100644 index a3452b367..000000000 --- a/api/core/data_loader/loader/pdf.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -from typing import Optional - -from langchain.document_loaders import PyPDFium2Loader -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document - -from extensions.ext_storage import storage -from models.model import UploadFile - -logger = logging.getLogger(__name__) - - -class PdfLoader(BaseLoader): - """Load pdf files. - - - Args: - file_path: Path to the file to load. - """ - - def __init__( - self, - file_path: str, - upload_file: Optional[UploadFile] = None - ): - """Initialize with file path.""" - self._file_path = file_path - self._upload_file = upload_file - - def load(self) -> list[Document]: - plaintext_file_key = '' - plaintext_file_exists = False - if self._upload_file: - if self._upload_file.hash: - plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \ - + self._upload_file.hash + '.0625.plaintext' - try: - text = storage.load(plaintext_file_key).decode('utf-8') - plaintext_file_exists = True - return [Document(page_content=text)] - except FileNotFoundError: - pass - documents = PyPDFium2Loader(file_path=self._file_path).load() - text_list = [] - for document in documents: - text_list.append(document.page_content) - text = "\n\n".join(text_list) - - # save plaintext file for caching - if not plaintext_file_exists and plaintext_file_key: - storage.save(plaintext_file_key, text.encode('utf-8')) - - return documents - diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index 556b3aced..9a051fd4c 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from typing import Any, Optional, cast -from langchain.schema import Document from sqlalchemy import func from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment diff --git a/api/core/features/annotation_reply.py b/api/core/features/annotation_reply.py index bdc5467e6..e1b64cf73 100644 --- a/api/core/features/annotation_reply.py +++ b/api/core/features/annotation_reply.py @@ -1,13 +1,8 @@ import logging from typing import Optional -from flask import current_app - -from core.embedding.cached_embedding import CacheEmbedding from core.entities.application_entities import InvokeFrom -from core.index.vector_index.vector_index import VectorIndex -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db from models.dataset import Dataset from models.model import App, AppAnnotationSetting, Message, MessageAnnotation @@ -45,17 +40,6 @@ class AnnotationReplyFeature: embedding_provider_name = collection_binding_detail.provider_name embedding_model_name = collection_binding_detail.model_name - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=app_record.tenant_id, - provider=embedding_provider_name, - model_type=ModelType.TEXT_EMBEDDING, - model=embedding_model_name - ) - - # get embedding model - embeddings = CacheEmbedding(model_instance) - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_provider_name, embedding_model_name, @@ -71,22 +55,14 @@ class AnnotationReplyFeature: collection_binding_id=dataset_collection_binding.id ) - vector_index = VectorIndex( - dataset=dataset, - config=current_app.config, - embeddings=embeddings, - attributes=['doc_id', 'annotation_id', 'app_id'] - ) + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - documents = vector_index.search( + documents = vector.search_by_vector( query=query, - search_type='similarity_score_threshold', - search_kwargs={ - 'k': 1, - 'score_threshold': score_threshold, - 'filter': { - 'group_id': [dataset.id] - } + k=1, + score_threshold=score_threshold, + filter={ + 'group_id': [dataset.id] } ) diff --git a/api/core/index/index.py b/api/core/index/index.py deleted file mode 100644 index 42971c895..000000000 --- a/api/core/index/index.py +++ /dev/null @@ -1,51 +0,0 @@ -from flask import current_app -from langchain.embeddings import OpenAIEmbeddings - -from core.embedding.cached_embedding import CacheEmbedding -from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex -from core.index.vector_index.vector_index import VectorIndex -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from models.dataset import Dataset - - -class IndexBuilder: - @classmethod - def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False): - if indexing_technique == "high_quality": - if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': - return None - - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - provider=dataset.embedding_model_provider, - model=dataset.embedding_model - ) - - embeddings = CacheEmbedding(embedding_model) - - return VectorIndex( - dataset=dataset, - config=current_app.config, - embeddings=embeddings - ) - elif indexing_technique == "economy": - return KeywordTableIndex( - dataset=dataset, - config=KeywordTableConfig( - max_keywords_per_chunk=10 - ) - ) - else: - raise ValueError('Unknown indexing technique') - - @classmethod - def get_default_high_quality_index(cls, dataset: Dataset): - embeddings = OpenAIEmbeddings(openai_api_key=' ') - return VectorIndex( - dataset=dataset, - config=current_app.config, - embeddings=embeddings - ) diff --git a/api/core/index/vector_index/base.py b/api/core/index/vector_index/base.py deleted file mode 100644 index 36aa1917a..000000000 --- a/api/core/index/vector_index/base.py +++ /dev/null @@ -1,305 +0,0 @@ -import json -import logging -from abc import abstractmethod -from typing import Any, cast - -from langchain.embeddings.base import Embeddings -from langchain.schema import BaseRetriever, Document -from langchain.vectorstores import VectorStore - -from core.index.base import BaseIndex -from extensions.ext_database import db -from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment -from models.dataset import Document as DatasetDocument - - -class BaseVectorIndex(BaseIndex): - - def __init__(self, dataset: Dataset, embeddings: Embeddings): - super().__init__(dataset) - self._embeddings = embeddings - self._vector_store = None - - def get_type(self) -> str: - raise NotImplementedError - - @abstractmethod - def get_index_name(self, dataset: Dataset) -> str: - raise NotImplementedError - - @abstractmethod - def to_index_struct(self) -> dict: - raise NotImplementedError - - @abstractmethod - def _get_vector_store(self) -> VectorStore: - raise NotImplementedError - - @abstractmethod - def _get_vector_store_class(self) -> type: - raise NotImplementedError - - @abstractmethod - def search_by_full_text_index( - self, query: str, - **kwargs: Any - ) -> list[Document]: - raise NotImplementedError - - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity' - search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} - - if search_type == 'similarity_score_threshold': - score_threshold = search_kwargs.get("score_threshold") - if (score_threshold is None) or (not isinstance(score_threshold, float)): - search_kwargs['score_threshold'] = .0 - - docs_with_similarity = vector_store.similarity_search_with_relevance_scores( - query, **search_kwargs - ) - - docs = [] - for doc, similarity in docs_with_similarity: - doc.metadata['score'] = similarity - docs.append(doc) - - return docs - - # similarity k - # mmr k, fetch_k, lambda_mult - # similarity_score_threshold k - return vector_store.as_retriever( - search_type=search_type, - search_kwargs=search_kwargs - ).get_relevant_documents(query) - - def get_retriever(self, **kwargs: Any) -> BaseRetriever: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - return vector_store.as_retriever(**kwargs) - - def add_texts(self, texts: list[Document], **kwargs): - if self._is_origin(): - self.recreate_dataset(self.dataset) - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - if kwargs.get('duplicate_check', False): - texts = self._filter_duplicate_texts(texts) - - uuids = self._get_uuids(texts) - vector_store.add_documents(texts, uuids=uuids) - - def text_exists(self, id: str) -> bool: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - return vector_store.text_exists(id) - - def delete_by_ids(self, ids: list[str]) -> None: - if self._is_origin(): - self.recreate_dataset(self.dataset) - return - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - for node_id in ids: - vector_store.del_text(node_id) - - def delete_by_group_id(self, group_id: str) -> None: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - if self.dataset.collection_binding_id: - vector_store.delete_by_group_id(group_id) - else: - vector_store.delete() - - def delete(self) -> None: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - vector_store.delete() - - def _is_origin(self): - return False - - def recreate_dataset(self, dataset: Dataset): - logging.info(f"Recreating dataset {dataset.id}") - - try: - self.delete() - except Exception as e: - raise e - - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() - - documents = [] - for dataset_document in dataset_documents: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).all() - - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - ) - - documents.append(document) - - origin_index_struct = self.dataset.index_struct[:] - self.dataset.index_struct = None - - if documents: - try: - self.create(documents) - except Exception as e: - self.dataset.index_struct = origin_index_struct - raise e - - dataset.index_struct = json.dumps(self.to_index_struct()) - - db.session.commit() - - self.dataset = dataset - logging.info(f"Dataset {dataset.id} recreate successfully.") - - def create_qdrant_dataset(self, dataset: Dataset): - logging.info(f"create_qdrant_dataset {dataset.id}") - - try: - self.delete() - except Exception as e: - raise e - - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() - - documents = [] - for dataset_document in dataset_documents: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).all() - - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - ) - - documents.append(document) - - if documents: - try: - self.create(documents) - except Exception as e: - raise e - - logging.info(f"Dataset {dataset.id} recreate successfully.") - - def update_qdrant_dataset(self, dataset: Dataset): - logging.info(f"update_qdrant_dataset {dataset.id}") - - segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).first() - - if segment: - try: - exist = self.text_exists(segment.index_node_id) - if exist: - index_struct = { - "type": 'qdrant', - "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} - } - dataset.index_struct = json.dumps(index_struct) - db.session.commit() - except Exception as e: - raise e - - logging.info(f"Dataset {dataset.id} recreate successfully.") - - def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): - logging.info(f"restore dataset in_one,_dataset {dataset.id}") - - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() - - documents = [] - for dataset_document in dataset_documents: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).all() - - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - ) - - documents.append(document) - - if documents: - try: - self.add_texts(documents) - except Exception as e: - raise e - - logging.info(f"Dataset {dataset.id} recreate successfully.") - - def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): - logging.info(f"delete original collection: {dataset.id}") - - self.delete() - - dataset.collection_binding_id = dataset_collection_binding.id - db.session.add(dataset) - db.session.commit() - - logging.info(f"Dataset {dataset.id} recreate successfully.") diff --git a/api/core/index/vector_index/milvus_vector_index.py b/api/core/index/vector_index/milvus_vector_index.py deleted file mode 100644 index a18cf35a2..000000000 --- a/api/core/index/vector_index/milvus_vector_index.py +++ /dev/null @@ -1,165 +0,0 @@ -from typing import Any, cast - -from langchain.embeddings.base import Embeddings -from langchain.schema import Document -from langchain.vectorstores import VectorStore -from pydantic import BaseModel, root_validator - -from core.index.base import BaseIndex -from core.index.vector_index.base import BaseVectorIndex -from core.vector_store.milvus_vector_store import MilvusVectorStore -from models.dataset import Dataset - - -class MilvusConfig(BaseModel): - host: str - port: int - user: str - password: str - secure: bool = False - batch_size: int = 100 - - @root_validator() - def validate_config(cls, values: dict) -> dict: - if not values['host']: - raise ValueError("config MILVUS_HOST is required") - if not values['port']: - raise ValueError("config MILVUS_PORT is required") - if not values['user']: - raise ValueError("config MILVUS_USER is required") - if not values['password']: - raise ValueError("config MILVUS_PASSWORD is required") - return values - - def to_milvus_params(self): - return { - 'host': self.host, - 'port': self.port, - 'user': self.user, - 'password': self.password, - 'secure': self.secure - } - - -class MilvusVectorIndex(BaseVectorIndex): - def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings): - super().__init__(dataset, embeddings) - self._client_config = config - - def get_type(self) -> str: - return 'milvus' - - def get_index_name(self, dataset: Dataset) -> str: - if self.dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): - # original class_prefix - class_prefix += '_Node' - - return class_prefix - - dataset_id = dataset.id - return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' - - def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self.get_index_name(self.dataset)} - } - - def create(self, texts: list[Document], **kwargs) -> BaseIndex: - uuids = self._get_uuids(texts) - index_params = { - 'metric_type': 'IP', - 'index_type': "HNSW", - 'params': {"M": 8, "efConstruction": 64} - } - self._vector_store = MilvusVectorStore.from_documents( - texts, - self._embeddings, - collection_name=self.get_index_name(self.dataset), - connection_args=self._client_config.to_milvus_params(), - index_params=index_params - ) - - return self - - def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: - uuids = self._get_uuids(texts) - self._vector_store = MilvusVectorStore.from_documents( - texts, - self._embeddings, - collection_name=collection_name, - ids=uuids, - content_payload_key='page_content' - ) - - return self - - def _get_vector_store(self) -> VectorStore: - """Only for created index.""" - if self._vector_store: - return self._vector_store - - return MilvusVectorStore( - collection_name=self.get_index_name(self.dataset), - embedding_function=self._embeddings, - connection_args=self._client_config.to_milvus_params() - ) - - def _get_vector_store_class(self) -> type: - return MilvusVectorStore - - def delete_by_document_id(self, document_id: str): - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - ids = vector_store.get_ids_by_document_id(document_id) - if ids: - vector_store.del_texts({ - 'filter': f'id in {ids}' - }) - - def delete_by_metadata_field(self, key: str, value: str): - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - ids = vector_store.get_ids_by_metadata_field(key, value) - if ids: - vector_store.del_texts({ - 'filter': f'id in {ids}' - }) - - def delete_by_ids(self, doc_ids: list[str]) -> None: - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - ids = vector_store.get_ids_by_doc_ids(doc_ids) - vector_store.del_texts({ - 'filter': f' id in {ids}' - }) - - def delete_by_group_id(self, group_id: str) -> None: - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - vector_store.delete() - - def delete(self) -> None: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - vector_store.del_texts(models.Filter( - must=[ - models.FieldCondition( - key="group_id", - match=models.MatchValue(value=self.dataset.id), - ), - ], - )) - - def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: - # milvus/zilliz doesn't support bm25 search - return [] diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py deleted file mode 100644 index 046260d2f..000000000 --- a/api/core/index/vector_index/qdrant_vector_index.py +++ /dev/null @@ -1,229 +0,0 @@ -import os -from typing import Any, Optional, cast - -import qdrant_client -from langchain.embeddings.base import Embeddings -from langchain.schema import Document -from langchain.vectorstores import VectorStore -from pydantic import BaseModel -from qdrant_client.http.models import HnswConfigDiff - -from core.index.base import BaseIndex -from core.index.vector_index.base import BaseVectorIndex -from core.vector_store.qdrant_vector_store import QdrantVectorStore -from extensions.ext_database import db -from models.dataset import Dataset, DatasetCollectionBinding - - -class QdrantConfig(BaseModel): - endpoint: str - api_key: Optional[str] - timeout: float = 20 - root_path: Optional[str] - - def to_qdrant_params(self): - if self.endpoint and self.endpoint.startswith('path:'): - path = self.endpoint.replace('path:', '') - if not os.path.isabs(path): - path = os.path.join(self.root_path, path) - - return { - 'path': path - } - else: - return { - 'url': self.endpoint, - 'api_key': self.api_key, - 'timeout': self.timeout - } - - -class QdrantVectorIndex(BaseVectorIndex): - def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings): - super().__init__(dataset, embeddings) - self._client_config = config - - def get_type(self) -> str: - return 'qdrant' - - def get_index_name(self, dataset: Dataset) -> str: - if dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ - one_or_none() - if dataset_collection_binding: - return dataset_collection_binding.collection_name - else: - raise ValueError('Dataset Collection Bindings is not exist!') - else: - if self.dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] - return class_prefix - - dataset_id = dataset.id - return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' - - def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self.get_index_name(self.dataset)} - } - - def create(self, texts: list[Document], **kwargs) -> BaseIndex: - uuids = self._get_uuids(texts) - self._vector_store = QdrantVectorStore.from_documents( - texts, - self._embeddings, - collection_name=self.get_index_name(self.dataset), - ids=uuids, - content_payload_key='page_content', - group_id=self.dataset.id, - group_payload_key='group_id', - hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, - max_indexing_threads=0, on_disk=False), - **self._client_config.to_qdrant_params() - ) - - return self - - def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: - uuids = self._get_uuids(texts) - self._vector_store = QdrantVectorStore.from_documents( - texts, - self._embeddings, - collection_name=collection_name, - ids=uuids, - content_payload_key='page_content', - group_id=self.dataset.id, - group_payload_key='group_id', - hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, - max_indexing_threads=0, on_disk=False), - **self._client_config.to_qdrant_params() - ) - - return self - - def _get_vector_store(self) -> VectorStore: - """Only for created index.""" - if self._vector_store: - return self._vector_store - attributes = ['doc_id', 'dataset_id', 'document_id'] - client = qdrant_client.QdrantClient( - **self._client_config.to_qdrant_params() - ) - - return QdrantVectorStore( - client=client, - collection_name=self.get_index_name(self.dataset), - embeddings=self._embeddings, - content_payload_key='page_content', - group_id=self.dataset.id, - group_payload_key='group_id' - ) - - def _get_vector_store_class(self) -> type: - return QdrantVectorStore - - def delete_by_document_id(self, document_id: str): - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - - vector_store.del_texts(models.Filter( - must=[ - models.FieldCondition( - key="metadata.document_id", - match=models.MatchValue(value=document_id), - ), - ], - )) - - def delete_by_metadata_field(self, key: str, value: str): - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - - vector_store.del_texts(models.Filter( - must=[ - models.FieldCondition( - key=f"metadata.{key}", - match=models.MatchValue(value=value), - ), - ], - )) - - def delete_by_ids(self, ids: list[str]) -> None: - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - for node_id in ids: - vector_store.del_texts(models.Filter( - must=[ - models.FieldCondition( - key="metadata.doc_id", - match=models.MatchValue(value=node_id), - ), - ], - )) - - def delete_by_group_id(self, group_id: str) -> None: - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - vector_store.del_texts(models.Filter( - must=[ - models.FieldCondition( - key="group_id", - match=models.MatchValue(value=group_id), - ), - ], - )) - - def delete(self) -> None: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - vector_store.del_texts(models.Filter( - must=[ - models.FieldCondition( - key="group_id", - match=models.MatchValue(value=self.dataset.id), - ), - ], - )) - - def _is_origin(self): - if self.dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): - # original class_prefix - return True - - return False - - def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - from qdrant_client.http import models - return vector_store.similarity_search_by_bm25(models.Filter( - must=[ - models.FieldCondition( - key="group_id", - match=models.MatchValue(value=self.dataset.id), - ), - models.FieldCondition( - key="page_content", - match=models.MatchText(text=query), - ) - ], - ), kwargs.get('top_k', 2)) diff --git a/api/core/index/vector_index/vector_index.py b/api/core/index/vector_index/vector_index.py deleted file mode 100644 index ed6e2699d..000000000 --- a/api/core/index/vector_index/vector_index.py +++ /dev/null @@ -1,90 +0,0 @@ -import json - -from flask import current_app -from langchain.embeddings.base import Embeddings - -from core.index.vector_index.base import BaseVectorIndex -from extensions.ext_database import db -from models.dataset import Dataset, Document - - -class VectorIndex: - def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings, - attributes: list = None): - if attributes is None: - attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] - self._dataset = dataset - self._embeddings = embeddings - self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes) - self._attributes = attributes - - def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings, - attributes: list) -> BaseVectorIndex: - vector_type = config.get('VECTOR_STORE') - - if self._dataset.index_struct_dict: - vector_type = self._dataset.index_struct_dict['type'] - - if not vector_type: - raise ValueError("Vector store must be specified.") - - if vector_type == "weaviate": - from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex - - return WeaviateVectorIndex( - dataset=dataset, - config=WeaviateConfig( - endpoint=config.get('WEAVIATE_ENDPOINT'), - api_key=config.get('WEAVIATE_API_KEY'), - batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) - ), - embeddings=embeddings, - attributes=attributes - ) - elif vector_type == "qdrant": - from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex - - return QdrantVectorIndex( - dataset=dataset, - config=QdrantConfig( - endpoint=config.get('QDRANT_URL'), - api_key=config.get('QDRANT_API_KEY'), - root_path=current_app.root_path, - timeout=config.get('QDRANT_CLIENT_TIMEOUT') - ), - embeddings=embeddings - ) - elif vector_type == "milvus": - from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex - - return MilvusVectorIndex( - dataset=dataset, - config=MilvusConfig( - host=config.get('MILVUS_HOST'), - port=config.get('MILVUS_PORT'), - user=config.get('MILVUS_USER'), - password=config.get('MILVUS_PASSWORD'), - secure=config.get('MILVUS_SECURE'), - ), - embeddings=embeddings - ) - else: - raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") - - def add_texts(self, texts: list[Document], **kwargs): - if not self._dataset.index_struct_dict: - self._vector_index.create(texts, **kwargs) - self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct()) - db.session.commit() - return - - self._vector_index.add_texts(texts, **kwargs) - - def __getattr__(self, name): - if self._vector_index is not None: - method = getattr(self._vector_index, name) - if callable(method): - return method - - raise AttributeError(f"'VectorIndex' object has no attribute '{name}'") - diff --git a/api/core/index/vector_index/weaviate_vector_index.py b/api/core/index/vector_index/weaviate_vector_index.py deleted file mode 100644 index 72a74a039..000000000 --- a/api/core/index/vector_index/weaviate_vector_index.py +++ /dev/null @@ -1,179 +0,0 @@ -from typing import Any, Optional, cast - -import requests -import weaviate -from langchain.embeddings.base import Embeddings -from langchain.schema import Document -from langchain.vectorstores import VectorStore -from pydantic import BaseModel, root_validator - -from core.index.base import BaseIndex -from core.index.vector_index.base import BaseVectorIndex -from core.vector_store.weaviate_vector_store import WeaviateVectorStore -from models.dataset import Dataset - - -class WeaviateConfig(BaseModel): - endpoint: str - api_key: Optional[str] - batch_size: int = 100 - - @root_validator() - def validate_config(cls, values: dict) -> dict: - if not values['endpoint']: - raise ValueError("config WEAVIATE_ENDPOINT is required") - return values - - -class WeaviateVectorIndex(BaseVectorIndex): - - def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list): - super().__init__(dataset, embeddings) - self._client = self._init_client(config) - self._attributes = attributes - - def _init_client(self, config: WeaviateConfig) -> weaviate.Client: - auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) - - weaviate.connect.connection.has_grpc = False - - try: - client = weaviate.Client( - url=config.endpoint, - auth_client_secret=auth_config, - timeout_config=(5, 60), - startup_period=None - ) - except requests.exceptions.ConnectionError: - raise ConnectionError("Vector database connection error") - - client.batch.configure( - # `batch_size` takes an `int` value to enable auto-batching - # (`None` is used for manual batching) - batch_size=config.batch_size, - # dynamically update the `batch_size` based on import speed - dynamic=True, - # `timeout_retries` takes an `int` value to retry on time outs - timeout_retries=3, - ) - - return client - - def get_type(self) -> str: - return 'weaviate' - - def get_index_name(self, dataset: Dataset) -> str: - if self.dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): - # original class_prefix - class_prefix += '_Node' - - return class_prefix - - dataset_id = dataset.id - return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' - - def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self.get_index_name(self.dataset)} - } - - def create(self, texts: list[Document], **kwargs) -> BaseIndex: - uuids = self._get_uuids(texts) - self._vector_store = WeaviateVectorStore.from_documents( - texts, - self._embeddings, - client=self._client, - index_name=self.get_index_name(self.dataset), - uuids=uuids, - by_text=False - ) - - return self - - def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: - uuids = self._get_uuids(texts) - self._vector_store = WeaviateVectorStore.from_documents( - texts, - self._embeddings, - client=self._client, - index_name=self.get_index_name(self.dataset), - uuids=uuids, - by_text=False - ) - - return self - - - def _get_vector_store(self) -> VectorStore: - """Only for created index.""" - if self._vector_store: - return self._vector_store - - attributes = self._attributes - if self._is_origin(): - attributes = ['doc_id'] - - return WeaviateVectorStore( - client=self._client, - index_name=self.get_index_name(self.dataset), - text_key='text', - embedding=self._embeddings, - attributes=attributes, - by_text=False - ) - - def _get_vector_store_class(self) -> type: - return WeaviateVectorStore - - def delete_by_document_id(self, document_id: str): - if self._is_origin(): - self.recreate_dataset(self.dataset) - return - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - vector_store.del_texts({ - "operator": "Equal", - "path": ["document_id"], - "valueText": document_id - }) - - def delete_by_metadata_field(self, key: str, value: str): - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - vector_store.del_texts({ - "operator": "Equal", - "path": [key], - "valueText": value - }) - - def delete_by_group_id(self, group_id: str): - if self._is_origin(): - self.recreate_dataset(self.dataset) - return - - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - - vector_store.delete() - - def _is_origin(self): - if self.dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): - # original class_prefix - return True - - return False - - def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: - vector_store = self._get_vector_store() - vector_store = cast(self._get_vector_store_class(), vector_store) - return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs) - diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 1f80726d5..f11170d88 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -9,20 +9,20 @@ from typing import Optional, cast from flask import Flask, current_app from flask_login import current_user -from langchain.schema import Document from langchain.text_splitter import TextSplitter from sqlalchemy.orm.exc import ObjectDeletedError -from core.data_loader.file_extractor import FileExtractor -from core.data_loader.loader.notion import NotionLoader from core.docstore.dataset_docstore import DatasetDocumentStore from core.errors.error import ProviderTokenNotInitError from core.generator.llm_generator import LLMGenerator -from core.index.index import IndexBuilder from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import Document from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -31,7 +31,6 @@ from libs import helper from models.dataset import Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile -from models.source import DataSourceBinding from services.feature_service import FeatureService @@ -57,38 +56,19 @@ class IndexingRunner: processing_rule = db.session.query(DatasetProcessRule). \ filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ first() + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + # extract + text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) - # load file - text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic') + # transform + documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict()) + # save segment + self._load_segments(dataset, dataset_document, documents) - # get embedding model instance - embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': - if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) - else: - embedding_model_instance = self.model_manager.get_default_model_instance( - tenant_id=dataset.tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - ) - - # get splitter - splitter = self._get_splitter(processing_rule, embedding_model_instance) - - # split to documents - documents = self._step_split( - text_docs=text_docs, - splitter=splitter, - dataset=dataset, - dataset_document=dataset_document, - processing_rule=processing_rule - ) - self._build_index( + # load + self._load( + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents @@ -134,39 +114,19 @@ class IndexingRunner: filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ first() - # load file - text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic') + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + # extract + text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) - # get embedding model instance - embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': - if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) - else: - embedding_model_instance = self.model_manager.get_default_model_instance( - tenant_id=dataset.tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - ) + # transform + documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict()) + # save segment + self._load_segments(dataset, dataset_document, documents) - # get splitter - splitter = self._get_splitter(processing_rule, embedding_model_instance) - - # split to documents - documents = self._step_split( - text_docs=text_docs, - splitter=splitter, - dataset=dataset, - dataset_document=dataset_document, - processing_rule=processing_rule - ) - - # build index - self._build_index( + # load + self._load( + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents @@ -220,7 +180,15 @@ class IndexingRunner: documents.append(document) # build index - self._build_index( + # get the process rule + processing_rule = db.session.query(DatasetProcessRule). \ + filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ + first() + + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor() + self._load( + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents @@ -239,16 +207,16 @@ class IndexingRunner: dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() - def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict, - doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, - indexing_technique: str = 'economy') -> dict: + def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, + doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, + indexing_technique: str = 'economy') -> dict: """ Estimate the indexing for the document. """ # check document limit features = FeatureService.get_features(tenant_id) if features.billing.enabled: - count = len(file_details) + count = len(extract_settings) batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -284,16 +252,18 @@ class IndexingRunner: total_segments = 0 total_price = 0 currency = 'USD' - for file_detail in file_details: - + index_type = doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + all_text_docs = [] + for extract_setting in extract_settings: + # extract + text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) + all_text_docs.extend(text_docs) processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) - # load data from file - text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic') - # get splitter splitter = self._get_splitter(processing_rule, embedding_model_instance) @@ -305,7 +275,6 @@ class IndexingRunner: ) total_segments += len(documents) - for document in documents: if len(preview_texts) < 5: preview_texts.append(document.page_content) @@ -364,154 +333,8 @@ class IndexingRunner: "preview": preview_texts } - def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, - doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, - indexing_technique: str = 'economy') -> dict: - """ - Estimate the indexing for the document. - """ - # check document limit - features = FeatureService.get_features(tenant_id) - if features.billing.enabled: - count = len(notion_info_list) - batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - - embedding_model_instance = None - if dataset_id: - dataset = Dataset.query.filter_by( - id=dataset_id - ).first() - if not dataset: - raise ValueError('Dataset not found.') - if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': - if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( - tenant_id=tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) - else: - embedding_model_instance = self.model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - ) - else: - if indexing_technique == 'high_quality': - embedding_model_instance = self.model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TEXT_EMBEDDING - ) - # load data from notion - tokens = 0 - preview_texts = [] - total_segments = 0 - total_price = 0 - currency = 'USD' - for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - data_source_binding = DataSourceBinding.query.filter( - db.and_( - DataSourceBinding.tenant_id == current_user.current_tenant_id, - DataSourceBinding.provider == 'notion', - DataSourceBinding.disabled == False, - DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' - ) - ).first() - if not data_source_binding: - raise ValueError('Data source binding not found.') - - for page in notion_info['pages']: - loader = NotionLoader( - notion_access_token=data_source_binding.access_token, - notion_workspace_id=workspace_id, - notion_obj_id=page['page_id'], - notion_page_type=page['type'] - ) - documents = loader.load() - - processing_rule = DatasetProcessRule( - mode=tmp_processing_rule["mode"], - rules=json.dumps(tmp_processing_rule["rules"]) - ) - - # get splitter - splitter = self._get_splitter(processing_rule, embedding_model_instance) - - # split to documents - documents = self._split_to_documents_for_estimate( - text_docs=documents, - splitter=splitter, - processing_rule=processing_rule - ) - total_segments += len(documents) - - embedding_model_type_instance = None - if embedding_model_instance: - embedding_model_type_instance = embedding_model_instance.model_type_instance - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) - - for document in documents: - if len(preview_texts) < 5: - preview_texts.append(document.page_content) - if indexing_technique == 'high_quality' and embedding_model_type_instance: - tokens += embedding_model_type_instance.get_num_tokens( - model=embedding_model_instance.model, - credentials=embedding_model_instance.credentials, - texts=[document.page_content] - ) - - if doc_form and doc_form == 'qa_model': - model_instance = self.model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM - ) - - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - if len(preview_texts) > 0: - # qa model document - response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], - doc_language) - document_qa_list = self.format_split_text(response) - - price_info = model_type_instance.get_price( - model=model_instance.model, - credentials=model_instance.credentials, - price_type=PriceType.INPUT, - tokens=total_segments * 2000, - ) - - return { - "total_segments": total_segments * 20, - "tokens": total_segments * 2000, - "total_price": '{:f}'.format(price_info.total_amount), - "currency": price_info.currency, - "qa_preview": document_qa_list, - "preview": preview_texts - } - if embedding_model_instance: - embedding_model_type_instance = embedding_model_instance.model_type_instance - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) - embedding_price_info = embedding_model_type_instance.get_price( - model=embedding_model_instance.model, - credentials=embedding_model_instance.credentials, - price_type=PriceType.INPUT, - tokens=tokens - ) - total_price = '{:f}'.format(embedding_price_info.total_amount) - currency = embedding_price_info.currency - return { - "total_segments": total_segments, - "tokens": tokens, - "total_price": total_price, - "currency": currency, - "preview": preview_texts - } - - def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]: + def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \ + -> list[Document]: # load file if dataset_document.data_source_type not in ["upload_file", "notion_import"]: return [] @@ -527,11 +350,27 @@ class IndexingRunner: one_or_none() if file_detail: - text_docs = FileExtractor.load(file_detail, is_automatic=automatic) + extract_setting = ExtractSetting( + datasource_type="upload_file", + upload_file=file_detail, + document_model=dataset_document.doc_form + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) elif dataset_document.data_source_type == 'notion_import': - loader = NotionLoader.from_document(dataset_document) - text_docs = loader.load() - + if (not data_source_info or 'notion_workspace_id' not in data_source_info + or 'notion_page_id' not in data_source_info): + raise ValueError("no notion import info found") + extract_setting = ExtractSetting( + datasource_type="notion_import", + notion_info={ + "notion_workspace_id": data_source_info['notion_workspace_id'], + "notion_obj_id": data_source_info['notion_page_id'], + "notion_page_type": data_source_info['notion_page_type'], + "document": dataset_document + }, + document_model=dataset_document.doc_form + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, @@ -545,8 +384,6 @@ class IndexingRunner: # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: - # remove invalid symbol - text_doc.page_content = self.filter_string(text_doc.page_content) text_doc.metadata['document_id'] = dataset_document.id text_doc.metadata['dataset_id'] = dataset_document.dataset_id @@ -787,12 +624,12 @@ class IndexingRunner: for q, a in matches if q and a ] - def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None: + def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, + dataset_document: DatasetDocument, documents: list[Document]) -> None: """ - Build the index for the document. + insert index and update document/segment status to completed """ - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - keyword_table_index = IndexBuilder.get_index(dataset, 'economy') + embedding_model_instance = None if dataset.indexing_technique == 'high_quality': embedding_model_instance = self.model_manager.get_model_instance( @@ -825,13 +662,8 @@ class IndexingRunner: ) for document in chunk_documents ) - - # save vector index - if vector_index: - vector_index.add_texts(chunk_documents) - - # save keyword index - keyword_table_index.add_texts(chunk_documents) + # load index + index_processor.load(dataset, chunk_documents) document_ids = [document.metadata['doc_id'] for document in chunk_documents] db.session.query(DocumentSegment).filter( @@ -911,14 +743,64 @@ class IndexingRunner: ) documents.append(document) # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts(documents, duplicate_check=True) + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.load(dataset, documents) - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - index.add_texts(documents) + def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset, + text_docs: list[Document], process_rule: dict) -> list[Document]: + # get embedding model instance + embedding_model_instance = None + if dataset.indexing_technique == 'high_quality': + if dataset.embedding_model_provider: + embedding_model_instance = self.model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model + ) + else: + embedding_model_instance = self.model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + + documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance, + process_rule=process_rule) + + return documents + + def _load_segments(self, dataset, dataset_document, documents): + # save node to document segment + doc_store = DatasetDocumentStore( + dataset=dataset, + user_id=dataset_document.created_by, + document_id=dataset_document.id + ) + + # add document segments + doc_store.add_documents(documents) + + # update document status to indexing + cur_time = datetime.datetime.utcnow() + self._update_document_index_status( + document_id=dataset_document.id, + after_indexing_status="indexing", + extra_update_params={ + DatasetDocument.cleaning_completed_at: cur_time, + DatasetDocument.splitting_completed_at: cur_time, + } + ) + + # update segment status to indexing + self._update_segments_by_document( + dataset_document_id=dataset_document.id, + update_params={ + DocumentSegment.status: "indexing", + DocumentSegment.indexing_at: datetime.datetime.utcnow() + } + ) + pass class DocumentIsPausedException(Exception): diff --git a/api/core/rag/__init__.py b/api/core/rag/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py new file mode 100644 index 000000000..eaad0e0f4 --- /dev/null +++ b/api/core/rag/cleaner/clean_processor.py @@ -0,0 +1,38 @@ +import re + + +class CleanProcessor: + + @classmethod + def clean(cls, text: str, process_rule: dict) -> str: + # default clean + # remove invalid symbol + text = re.sub(r'<\|', '<', text) + text = re.sub(r'\|>', '>', text) + text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) + # Unicode U+FFFE + text = re.sub('\uFFFE', '', text) + + rules = process_rule['rules'] if process_rule else None + if 'pre_processing_rules' in rules: + pre_processing_rules = rules["pre_processing_rules"] + for pre_processing_rule in pre_processing_rules: + if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: + # Remove extra spaces + pattern = r'\n{3,}' + text = re.sub(pattern, '\n\n', text) + pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' + text = re.sub(pattern, ' ', text) + elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: + # Remove email + pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' + text = re.sub(pattern, '', text) + + # Remove URL + pattern = r'https?://[^\s]+' + text = re.sub(pattern, '', text) + return text + + def filter_string(self, text): + + return text diff --git a/api/core/rag/cleaner/cleaner_base.py b/api/core/rag/cleaner/cleaner_base.py new file mode 100644 index 000000000..523bd904f --- /dev/null +++ b/api/core/rag/cleaner/cleaner_base.py @@ -0,0 +1,12 @@ +"""Abstract interface for document cleaner implementations.""" +from abc import ABC, abstractmethod + + +class BaseCleaner(ABC): + """Interface for clean chunk content. + """ + + @abstractmethod + def clean(self, content: str): + raise NotImplementedError + diff --git a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py new file mode 100644 index 000000000..6a0b8c904 --- /dev/null +++ b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py @@ -0,0 +1,12 @@ +"""Abstract interface for document clean implementations.""" +from core.rag.cleaner.cleaner_base import BaseCleaner + + +class UnstructuredNonAsciiCharsCleaner(BaseCleaner): + + def clean(self, content) -> str: + """clean document content.""" + from unstructured.cleaners.core import clean_extra_whitespace + + # Returns "ITEM 1A: RISK FACTORS" + return clean_extra_whitespace(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py new file mode 100644 index 000000000..6fc3a408d --- /dev/null +++ b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py @@ -0,0 +1,15 @@ +"""Abstract interface for document clean implementations.""" +from core.rag.cleaner.cleaner_base import BaseCleaner + + +class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): + + def clean(self, content) -> str: + """clean document content.""" + import re + + from unstructured.cleaners.core import group_broken_paragraphs + + para_split_re = re.compile(r"(\s*\n\s*){3}") + + return group_broken_paragraphs(content, paragraph_split=para_split_re) diff --git a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py new file mode 100644 index 000000000..ca1ae8dfd --- /dev/null +++ b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py @@ -0,0 +1,12 @@ +"""Abstract interface for document clean implementations.""" +from core.rag.cleaner.cleaner_base import BaseCleaner + + +class UnstructuredNonAsciiCharsCleaner(BaseCleaner): + + def clean(self, content) -> str: + """clean document content.""" + from unstructured.cleaners.core import clean_non_ascii_chars + + # Returns "This text containsnon-ascii characters!" + return clean_non_ascii_chars(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py new file mode 100644 index 000000000..974a28fef --- /dev/null +++ b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py @@ -0,0 +1,11 @@ +"""Abstract interface for document clean implementations.""" +from core.rag.cleaner.cleaner_base import BaseCleaner + + +class UnstructuredNonAsciiCharsCleaner(BaseCleaner): + + def clean(self, content) -> str: + """Replaces unicode quote characters, such as the \x91 character in a string.""" + + from unstructured.cleaners.core import replace_unicode_quotes + return replace_unicode_quotes(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py new file mode 100644 index 000000000..dfaf3a278 --- /dev/null +++ b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py @@ -0,0 +1,11 @@ +"""Abstract interface for document clean implementations.""" +from core.rag.cleaner.cleaner_base import BaseCleaner + + +class UnstructuredTranslateTextCleaner(BaseCleaner): + + def clean(self, content) -> str: + """clean document content.""" + from unstructured.cleaners.translate import translate_text + + return translate_text(content) diff --git a/api/core/rag/data_post_processor/__init__.py b/api/core/rag/data_post_processor/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py new file mode 100644 index 000000000..bdd69c27b --- /dev/null +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -0,0 +1,49 @@ +from typing import Optional + +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.invoke import InvokeAuthorizationError +from core.rag.data_post_processor.reorder import ReorderRunner +from core.rag.models.document import Document +from core.rerank.rerank import RerankRunner + + +class DataPostProcessor: + """Interface for data post-processing document. + """ + + def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False): + self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id) + self.reorder_runner = self._get_reorder_runner(reorder_enabled) + + def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, + top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + if self.rerank_runner: + documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) + + if self.reorder_runner: + documents = self.reorder_runner.run(documents) + + return documents + + def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]: + if reranking_model: + try: + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_model['reranking_provider_name'], + model_type=ModelType.RERANK, + model=reranking_model['reranking_model_name'] + ) + except InvokeAuthorizationError: + return None + return RerankRunner(rerank_model_instance) + return None + + def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: + if reorder_enabled: + return ReorderRunner() + return None + + diff --git a/api/core/rag/data_post_processor/reorder.py b/api/core/rag/data_post_processor/reorder.py new file mode 100644 index 000000000..284acaf1b --- /dev/null +++ b/api/core/rag/data_post_processor/reorder.py @@ -0,0 +1,19 @@ + +from langchain.schema import Document + + +class ReorderRunner: + + def run(self, documents: list[Document]) -> list[Document]: + # Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list + odd_elements = documents[::2] + + # Retrieve elements from even indices (1, 3, 5, etc.) of the documents list + even_elements = documents[1::2] + + # Reverse the list of elements from even indices + even_elements_reversed = even_elements[::-1] + + new_documents = odd_elements + even_elements_reversed + + return new_documents diff --git a/api/core/rag/datasource/__init__.py b/api/core/rag/datasource/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/datasource/entity/embedding.py b/api/core/rag/datasource/entity/embedding.py new file mode 100644 index 000000000..126c1a372 --- /dev/null +++ b/api/core/rag/datasource/entity/embedding.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod + + +class Embeddings(ABC): + """Interface for embedding models.""" + + @abstractmethod + def embed_documents(self, texts: list[str]) -> list[list[float]]: + """Embed search docs.""" + + @abstractmethod + def embed_query(self, text: str) -> list[float]: + """Embed query text.""" + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + """Asynchronous Embed search docs.""" + raise NotImplementedError + + async def aembed_query(self, text: str) -> list[float]: + """Asynchronous Embed query text.""" + raise NotImplementedError diff --git a/api/core/rag/datasource/keyword/__init__.py b/api/core/rag/datasource/keyword/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/datasource/keyword/jieba/__init__.py b/api/core/rag/datasource/keyword/jieba/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/rag/datasource/keyword/jieba/jieba.py similarity index 71% rename from api/core/index/keyword_table_index/keyword_table_index.py rename to api/core/rag/datasource/keyword/jieba/jieba.py index 8bf0b1334..94a692637 100644 --- a/api/core/index/keyword_table_index/keyword_table_index.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -2,11 +2,11 @@ import json from collections import defaultdict from typing import Any, Optional -from langchain.schema import BaseRetriever, Document -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel -from core.index.base import BaseIndex -from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment @@ -15,59 +15,19 @@ class KeywordTableConfig(BaseModel): max_keywords_per_chunk: int = 10 -class KeywordTableIndex(BaseIndex): - def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()): +class Jieba(BaseKeyword): + def __init__(self, dataset: Dataset): super().__init__(dataset) - self._config = config + self._config = KeywordTableConfig() - def create(self, texts: list[Document], **kwargs) -> BaseIndex: + def create(self, texts: list[Document], **kwargs) -> BaseKeyword: keyword_table_handler = JiebaKeywordTableHandler() - keyword_table = {} + keyword_table = self._get_dataset_keyword_table() for text in texts: keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) - dataset_keyword_table = DatasetKeywordTable( - dataset_id=self.dataset.id, - keyword_table=json.dumps({ - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": {} - } - }, cls=SetEncoder) - ) - db.session.add(dataset_keyword_table) - db.session.commit() - - self._save_dataset_keyword_table(keyword_table) - - return self - - def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: - keyword_table_handler = JiebaKeywordTableHandler() - keyword_table = {} - for text in texts: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) - - dataset_keyword_table = DatasetKeywordTable( - dataset_id=self.dataset.id, - keyword_table=json.dumps({ - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": {} - } - }, cls=SetEncoder) - ) - db.session.add(dataset_keyword_table) - db.session.commit() - self._save_dataset_keyword_table(keyword_table) return self @@ -76,8 +36,13 @@ class KeywordTableIndex(BaseIndex): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() - for text in texts: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) + keywords_list = kwargs.get('keywords_list', None) + for i in range(len(texts)): + text = texts[i] + if keywords_list: + keywords = keywords_list[i] + else: + keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) @@ -107,20 +72,13 @@ class KeywordTableIndex(BaseIndex): self._save_dataset_keyword_table(keyword_table) - def delete_by_metadata_field(self, key: str, value: str): - pass - - def get_retriever(self, **kwargs: Any) -> BaseRetriever: - return KeywordTableRetriever(index=self, **kwargs) - def search( self, query: str, **kwargs: Any ) -> list[Document]: keyword_table = self._get_dataset_keyword_table() - search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} - k = search_kwargs.get('k') if search_kwargs.get('k') else 4 + k = kwargs.get('top_k', 4) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) @@ -150,12 +108,6 @@ class KeywordTableIndex(BaseIndex): db.session.delete(dataset_keyword_table) db.session.commit() - def delete_by_group_id(self, group_id: str) -> None: - dataset_keyword_table = self.dataset.dataset_keyword_table - if dataset_keyword_table: - db.session.delete(dataset_keyword_table) - db.session.commit() - def _save_dataset_keyword_table(self, keyword_table): keyword_table_dict = { '__type__': 'keyword_table', @@ -242,6 +194,7 @@ class KeywordTableIndex(BaseIndex): ).first() if document_segment: document_segment.keywords = keywords + db.session.add(document_segment) db.session.commit() def create_segment_keywords(self, node_id: str, keywords: list[str]): @@ -272,31 +225,6 @@ class KeywordTableIndex(BaseIndex): self._save_dataset_keyword_table(keyword_table) -class KeywordTableRetriever(BaseRetriever, BaseModel): - index: KeywordTableIndex - search_kwargs: dict = Field(default_factory=dict) - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - def get_relevant_documents(self, query: str) -> list[Document]: - """Get documents relevant for a query. - - Args: - query: string to find relevant documents for - - Returns: - List of relevant documents - """ - return self.index.search(query, **self.search_kwargs) - - async def aget_relevant_documents(self, query: str) -> list[Document]: - raise NotImplementedError("KeywordTableRetriever does not support async") - - class SetEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, set): diff --git a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py similarity index 93% rename from api/core/index/keyword_table_index/jieba_keyword_table_handler.py rename to api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index df93a1903..5f862b8d1 100644 --- a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -3,7 +3,7 @@ import re import jieba from jieba.analyse import default_tfidf -from core.index.keyword_table_index.stopwords import STOPWORDS +from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS class JiebaKeywordTableHandler: diff --git a/api/core/index/keyword_table_index/stopwords.py b/api/core/rag/datasource/keyword/jieba/stopwords.py similarity index 100% rename from api/core/index/keyword_table_index/stopwords.py rename to api/core/rag/datasource/keyword/jieba/stopwords.py diff --git a/api/core/index/base.py b/api/core/rag/datasource/keyword/keyword_base.py similarity index 63% rename from api/core/index/base.py rename to api/core/rag/datasource/keyword/keyword_base.py index f8eb1a134..84a580085 100644 --- a/api/core/index/base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -3,22 +3,17 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Any -from langchain.schema import BaseRetriever, Document - +from core.rag.models.document import Document from models.dataset import Dataset -class BaseIndex(ABC): +class BaseKeyword(ABC): def __init__(self, dataset: Dataset): self.dataset = dataset @abstractmethod - def create(self, texts: list[Document], **kwargs) -> BaseIndex: - raise NotImplementedError - - @abstractmethod - def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: + def create(self, texts: list[Document], **kwargs) -> BaseKeyword: raise NotImplementedError @abstractmethod @@ -34,31 +29,18 @@ class BaseIndex(ABC): raise NotImplementedError @abstractmethod - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_document_id(self, document_id: str) -> None: raise NotImplementedError - @abstractmethod - def delete_by_group_id(self, group_id: str) -> None: + def delete(self) -> None: raise NotImplementedError - @abstractmethod - def delete_by_document_id(self, document_id: str): - raise NotImplementedError - - @abstractmethod - def get_retriever(self, **kwargs: Any) -> BaseRetriever: - raise NotImplementedError - - @abstractmethod def search( self, query: str, **kwargs: Any ) -> list[Document]: raise NotImplementedError - def delete(self) -> None: - raise NotImplementedError - def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts: doc_id = text.metadata['doc_id'] diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py new file mode 100644 index 000000000..bccec2071 --- /dev/null +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -0,0 +1,60 @@ +from typing import Any, cast + +from flask import current_app + +from core.rag.datasource.keyword.jieba.jieba import Jieba +from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.models.document import Document +from models.dataset import Dataset + + +class Keyword: + def __init__(self, dataset: Dataset): + self._dataset = dataset + self._keyword_processor = self._init_keyword() + + def _init_keyword(self) -> BaseKeyword: + config = cast(dict, current_app.config) + keyword_type = config.get('KEYWORD_STORE') + + if not keyword_type: + raise ValueError("Keyword store must be specified.") + + if keyword_type == "jieba": + return Jieba( + dataset=self._dataset + ) + else: + raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") + + def create(self, texts: list[Document], **kwargs): + self._keyword_processor.create(texts, **kwargs) + + def add_texts(self, texts: list[Document], **kwargs): + self._keyword_processor.add_texts(texts, **kwargs) + + def text_exists(self, id: str) -> bool: + return self._keyword_processor.text_exists(id) + + def delete_by_ids(self, ids: list[str]) -> None: + self._keyword_processor.delete_by_ids(ids) + + def delete_by_document_id(self, document_id: str) -> None: + self._keyword_processor.delete_by_document_id(document_id) + + def delete(self) -> None: + self._keyword_processor.delete() + + def search( + self, query: str, + **kwargs: Any + ) -> list[Document]: + return self._keyword_processor.search(query, **kwargs) + + def __getattr__(self, name): + if self._keyword_processor is not None: + method = getattr(self._keyword_processor, name) + if callable(method): + return method + + raise AttributeError(f"'Keyword' object has no attribute '{name}'") diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py new file mode 100644 index 000000000..79673ffa8 --- /dev/null +++ b/api/core/rag/datasource/retrieval_service.py @@ -0,0 +1,165 @@ +import threading +from typing import Optional + +from flask import Flask, current_app +from flask_login import current_user + +from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.vdb.vector_factory import Vector +from extensions.ext_database import db +from models.dataset import Dataset + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} + + +class RetrievalService: + + @classmethod + def retrieve(cls, retrival_method: str, dataset_id: str, query: str, + top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None): + all_documents = [] + threads = [] + # retrieval_model source with keyword + if retrival_method == 'keyword_search': + keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'top_k': top_k + }) + threads.append(keyword_thread) + keyword_thread.start() + # retrieval_model source with semantic + if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search': + embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'top_k': top_k, + 'score_threshold': score_threshold, + 'reranking_model': reranking_model, + 'all_documents': all_documents, + 'retrival_method': retrival_method + }) + threads.append(embedding_thread) + embedding_thread.start() + + # retrieval source with full text + if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search': + full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'retrival_method': retrival_method, + 'score_threshold': score_threshold, + 'top_k': top_k, + 'reranking_model': reranking_model, + 'all_documents': all_documents + }) + threads.append(full_text_index_thread) + full_text_index_thread.start() + + for thread in threads: + thread.join() + + if retrival_method == 'hybrid_search': + data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False) + all_documents = data_post_processor.invoke( + query=query, + documents=all_documents, + score_threshold=score_threshold, + top_n=top_k + ) + return all_documents + + @classmethod + def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, + top_k: int, all_documents: list): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + + keyword = Keyword( + dataset=dataset + ) + + documents = keyword.search( + query, + k=top_k + ) + all_documents.extend(documents) + + @classmethod + def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, + top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], + all_documents: list, retrival_method: str): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + + vector = Vector( + dataset=dataset + ) + + documents = vector.search_by_vector( + query, + search_type='similarity_score_threshold', + k=top_k, + score_threshold=score_threshold, + filter={ + 'group_id': [dataset.id] + } + ) + + if documents: + if reranking_model and retrival_method == 'semantic_search': + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) + all_documents.extend(data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents) + )) + else: + all_documents.extend(documents) + + @classmethod + def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, + top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], + all_documents: list, retrival_method: str): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + + vector_processor = Vector( + dataset=dataset, + ) + + documents = vector_processor.search_by_full_text( + query, + top_k=top_k + ) + if documents: + if reranking_model and retrival_method == 'full_text_search': + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) + all_documents.extend(data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents) + )) + else: + all_documents.extend(documents) diff --git a/api/core/rag/datasource/vdb/__init__.py b/api/core/rag/datasource/vdb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py new file mode 100644 index 000000000..6a594a83c --- /dev/null +++ b/api/core/rag/datasource/vdb/field.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class Field(Enum): + CONTENT_KEY = "page_content" + METADATA_KEY = "metadata" + GROUP_KEY = "group_id" + VECTOR = "vector" + TEXT_KEY = "text" + PRIMARY_KEY = " id" diff --git a/api/core/rag/datasource/vdb/milvus/__init__.py b/api/core/rag/datasource/vdb/milvus/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py new file mode 100644 index 000000000..9a251ede9 --- /dev/null +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -0,0 +1,214 @@ +import logging +from typing import Any, Optional +from uuid import uuid4 + +from pydantic import BaseModel, root_validator +from pymilvus import MilvusClient, MilvusException, connections + +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document + +logger = logging.getLogger(__name__) + + +class MilvusConfig(BaseModel): + host: str + port: int + user: str + password: str + secure: bool = False + batch_size: int = 100 + + @root_validator() + def validate_config(cls, values: dict) -> dict: + if not values['host']: + raise ValueError("config MILVUS_HOST is required") + if not values['port']: + raise ValueError("config MILVUS_PORT is required") + if not values['user']: + raise ValueError("config MILVUS_USER is required") + if not values['password']: + raise ValueError("config MILVUS_PASSWORD is required") + return values + + def to_milvus_params(self): + return { + 'host': self.host, + 'port': self.port, + 'user': self.user, + 'password': self.password, + 'secure': self.secure + } + + +class MilvusVector(BaseVector): + + def __init__(self, collection_name: str, config: MilvusConfig): + super().__init__(collection_name) + self._client_config = config + self._client = self._init_client(config) + self._consistency_level = 'Session' + self._fields = [] + + def get_type(self) -> str: + return 'milvus' + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + index_params = { + 'metric_type': 'IP', + 'index_type': "HNSW", + 'params': {"M": 8, "efConstruction": 64} + } + metadatas = [d.metadata for d in texts] + + # Grab the existing collection if it exists + from pymilvus import utility + alias = uuid4().hex + if self._client_config.secure: + uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) + else: + uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) + connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password) + if not utility.has_collection(self._collection_name, using=alias): + self.create_collection(embeddings, metadatas, index_params) + self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + insert_dict_list = [] + for i in range(len(documents)): + insert_dict = { + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i], + Field.METADATA_KEY.value: documents[i].metadata + } + insert_dict_list.append(insert_dict) + # Total insert count + total_count = len(insert_dict_list) + + pks: list[str] = [] + + for i in range(0, total_count, 1000): + batch_insert_list = insert_dict_list[i:i + 1000] + # Insert into the collection. + try: + ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) + pks.extend(ids) + except MilvusException as e: + logger.error( + "Failed to insert batch starting at entity: %s/%s", i, total_count + ) + raise e + return pks + + def delete_by_document_id(self, document_id: str): + + ids = self.get_ids_by_metadata_field('document_id', document_id) + if ids: + self._client.delete(collection_name=self._collection_name, pks=ids) + + def get_ids_by_metadata_field(self, key: str, value: str): + result = self._client.query(collection_name=self._collection_name, + filter=f'metadata["{key}"] == "{value}"', + output_fields=["id"]) + if result: + return [item["id"] for item in result] + else: + return None + + def delete_by_metadata_field(self, key: str, value: str): + + ids = self.get_ids_by_metadata_field(key, value) + if ids: + self._client.delete(collection_name=self._collection_name, pks=ids) + + def delete_by_ids(self, doc_ids: list[str]) -> None: + + self._client.delete(collection_name=self._collection_name, pks=doc_ids) + + def delete(self) -> None: + + from pymilvus import utility + utility.drop_collection(self._collection_name, None) + + def text_exists(self, id: str) -> bool: + + result = self._client.query(collection_name=self._collection_name, + filter=f'metadata["doc_id"] == "{id}"', + output_fields=["id"]) + + return len(result) > 0 + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + + # Set search parameters. + results = self._client.search(collection_name=self._collection_name, + data=[query_vector], + limit=kwargs.get('top_k', 4), + output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + ) + # Organize results. + docs = [] + for result in results[0]: + metadata = result['entity'].get(Field.METADATA_KEY.value) + metadata['score'] = result['distance'] + score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + if result['distance'] > score_threshold: + doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value), + metadata=metadata) + docs.append(doc) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # milvus/zilliz doesn't support bm25 search + return [] + + def create_collection( + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + ) -> str: + from pymilvus import CollectionSchema, DataType, FieldSchema + from pymilvus.orm.types import infer_dtype_bydata + + # Determine embedding dim + dim = len(embeddings[0]) + fields = [] + if metadatas: + fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) + + # Create the text field + fields.append( + FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) + ) + # Create the primary key field + fields.append( + FieldSchema( + Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True + ) + ) + # Create the vector field, supports binary or float vectors + fields.append( + FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) + ) + + # Create the schema for the collection + schema = CollectionSchema(fields) + + for x in schema.fields: + self._fields.append(x.name) + # Since primary field is auto-id, no need to track it + self._fields.remove(Field.PRIMARY_KEY.value) + + # Create the collection + collection_name = self._collection_name + self._client.create_collection_with_schema(collection_name=collection_name, + schema=schema, index_param=index_params, + consistency_level=self._consistency_level) + return collection_name + + def _init_client(self, config) -> MilvusClient: + if config.secure: + uri = "https://" + str(config.host) + ":" + str(config.port) + else: + uri = "http://" + str(config.host) + ":" + str(config.port) + client = MilvusClient(uri=uri, user=config.user, password=config.password) + return client diff --git a/api/core/rag/datasource/vdb/qdrant/__init__.py b/api/core/rag/datasource/vdb/qdrant/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py new file mode 100644 index 000000000..243293122 --- /dev/null +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -0,0 +1,360 @@ +import os +import uuid +from collections.abc import Generator, Iterable, Sequence +from itertools import islice +from typing import TYPE_CHECKING, Any, Optional, Union, cast + +import qdrant_client +from pydantic import BaseModel +from qdrant_client.http import models as rest +from qdrant_client.http.models import ( + FilterSelector, + HnswConfigDiff, + PayloadSchemaType, + TextIndexParams, + TextIndexType, + TokenizerType, +) +from qdrant_client.local.qdrant_local import QdrantLocal + +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document + +if TYPE_CHECKING: + from qdrant_client import grpc # noqa + from qdrant_client.conversions import common_types + from qdrant_client.http import models as rest + + DictFilter = dict[str, Union[str, int, bool, dict, list]] + MetadataFilter = Union[DictFilter, common_types.Filter] + + +class QdrantConfig(BaseModel): + endpoint: str + api_key: Optional[str] + timeout: float = 20 + root_path: Optional[str] + + def to_qdrant_params(self): + if self.endpoint and self.endpoint.startswith('path:'): + path = self.endpoint.replace('path:', '') + if not os.path.isabs(path): + path = os.path.join(self.root_path, path) + + return { + 'path': path + } + else: + return { + 'url': self.endpoint, + 'api_key': self.api_key, + 'timeout': self.timeout + } + + +class QdrantVector(BaseVector): + + def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'): + super().__init__(collection_name) + self._client_config = config + self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) + self._distance_func = distance_func.upper() + self._group_id = group_id + + def get_type(self) -> str: + return 'qdrant' + + def to_index_struct(self) -> dict: + return { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name} + } + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + if texts: + # get embedding vector size + vector_size = len(embeddings[0]) + # get collection name + collection_name = self._collection_name + collection_name = collection_name or uuid.uuid4().hex + all_collection_name = [] + collections_response = self._client.get_collections() + collection_list = collections_response.collections + for collection in collection_list: + all_collection_name.append(collection.name) + if collection_name not in all_collection_name: + # create collection + self.create_collection(collection_name, vector_size) + + self.add_texts(texts, embeddings, **kwargs) + + def create_collection(self, collection_name: str, vector_size: int): + from qdrant_client.http import models as rest + vectors_config = rest.VectorParams( + size=vector_size, + distance=rest.Distance[self._distance_func], + ) + hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, + max_indexing_threads=0, on_disk=False) + self._client.recreate_collection( + collection_name=collection_name, + vectors_config=vectors_config, + hnsw_config=hnsw_config, + timeout=int(self._client_config.timeout), + ) + + # create payload index + self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, + field_schema=PayloadSchemaType.KEYWORD, + field_type=PayloadSchemaType.KEYWORD) + # creat full text index + text_index_params = TextIndexParams( + type=TextIndexType.TEXT, + tokenizer=TokenizerType.MULTILINGUAL, + min_token_len=2, + max_token_len=20, + lowercase=True + ) + self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, + field_schema=text_index_params) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + uuids = self._get_uuids(documents) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + + added_ids = [] + for batch_ids, points in self._generate_rest_batches( + texts, embeddings, metadatas, uuids, 64, self._group_id + ): + self._client.upsert( + collection_name=self._collection_name, points=points + ) + added_ids.extend(batch_ids) + + return added_ids + + def _generate_rest_batches( + self, + texts: Iterable[str], + embeddings: list[list[float]], + metadatas: Optional[list[dict]] = None, + ids: Optional[Sequence[str]] = None, + batch_size: int = 64, + group_id: Optional[str] = None, + ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: + from qdrant_client.http import models as rest + texts_iterator = iter(texts) + embeddings_iterator = iter(embeddings) + metadatas_iterator = iter(metadatas or []) + ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) + while batch_texts := list(islice(texts_iterator, batch_size)): + # Take the corresponding metadata and id for each text in a batch + batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None + batch_ids = list(islice(ids_iterator, batch_size)) + + # Generate the embeddings for all the texts in a batch + batch_embeddings = list(islice(embeddings_iterator, batch_size)) + + points = [ + rest.PointStruct( + id=point_id, + vector=vector, + payload=payload, + ) + for point_id, vector, payload in zip( + batch_ids, + batch_embeddings, + self._build_payloads( + batch_texts, + batch_metadatas, + Field.CONTENT_KEY.value, + Field.METADATA_KEY.value, + group_id, + Field.GROUP_KEY.value, + ), + ) + ] + + yield batch_ids, points + + @classmethod + def _build_payloads( + cls, + texts: Iterable[str], + metadatas: Optional[list[dict]], + content_payload_key: str, + metadata_payload_key: str, + group_id: str, + group_payload_key: str + ) -> list[dict]: + payloads = [] + for i, text in enumerate(texts): + if text is None: + raise ValueError( + "At least one of the texts is None. Please remove it before " + "calling .from_texts or .add_texts on Qdrant instance." + ) + metadata = metadatas[i] if metadatas is not None else None + payloads.append( + { + content_payload_key: text, + metadata_payload_key: metadata, + group_payload_key: group_id + } + ) + + return payloads + + def delete_by_metadata_field(self, key: str, value: str): + + from qdrant_client.http import models + + filter = models.Filter( + must=[ + models.FieldCondition( + key=f"metadata.{key}", + match=models.MatchValue(value=value), + ), + ], + ) + + self._reload_if_needed() + + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector( + filter=filter + ), + ) + + def delete(self): + from qdrant_client.http import models + filter = models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector( + filter=filter + ), + ) + + def delete_by_ids(self, ids: list[str]) -> None: + + from qdrant_client.http import models + for node_id in ids: + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchValue(value=node_id), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector( + filter=filter + ), + ) + + def text_exists(self, id: str) -> bool: + response = self._client.retrieve( + collection_name=self._collection_name, + ids=[id] + ) + + return len(response) > 0 + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + from qdrant_client.http import models + filter = models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + ], + ) + results = self._client.search( + collection_name=self._collection_name, + query_vector=query_vector, + query_filter=filter, + limit=kwargs.get("top_k", 4), + with_payload=True, + with_vectors=True, + score_threshold=kwargs.get("score_threshold", .0) + ) + docs = [] + for result in results: + metadata = result.payload.get(Field.METADATA_KEY.value) or {} + # duplicate check score threshold + score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + if result.score > score_threshold: + metadata['score'] = result.score + doc = Document( + page_content=result.payload.get(Field.CONTENT_KEY.value), + metadata=metadata, + ) + docs.append(doc) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Return docs most similar by bm25. + Returns: + List of documents most similar to the query text and distance for each. + """ + from qdrant_client.http import models + scroll_filter = models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + models.FieldCondition( + key="page_content", + match=models.MatchText(text=query), + ) + ] + ) + response = self._client.scroll( + collection_name=self._collection_name, + scroll_filter=scroll_filter, + limit=kwargs.get('top_k', 2), + with_payload=True, + with_vectors=True + + ) + results = response[0] + documents = [] + for result in results: + if result: + documents.append(self._document_from_scored_point( + result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value + )) + + return documents + + def _reload_if_needed(self): + if isinstance(self._client, QdrantLocal): + self._client = cast(QdrantLocal, self._client) + self._client._load() + + @classmethod + def _document_from_scored_point( + cls, + scored_point: Any, + content_payload_key: str, + metadata_payload_key: str, + ) -> Document: + return Document( + page_content=scored_point.payload.get(content_payload_key), + metadata=scored_point.payload.get(metadata_payload_key) or {}, + ) diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py new file mode 100644 index 000000000..69ed4ed51 --- /dev/null +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +from core.rag.models.document import Document + + +class BaseVector(ABC): + + def __init__(self, collection_name: str): + self._collection_name = collection_name + + @abstractmethod + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + raise NotImplementedError + + @abstractmethod + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + raise NotImplementedError + + @abstractmethod + def text_exists(self, id: str) -> bool: + raise NotImplementedError + + @abstractmethod + def delete_by_ids(self, ids: list[str]) -> None: + raise NotImplementedError + + @abstractmethod + def delete_by_metadata_field(self, key: str, value: str) -> None: + raise NotImplementedError + + @abstractmethod + def search_by_vector( + self, + query_vector: list[float], + **kwargs: Any + ) -> list[Document]: + raise NotImplementedError + + @abstractmethod + def search_by_full_text( + self, query: str, + **kwargs: Any + ) -> list[Document]: + raise NotImplementedError + + def delete(self) -> None: + raise NotImplementedError + + def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: + for text in texts: + doc_id = text.metadata['doc_id'] + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) + + return texts + + def _get_uuids(self, texts: list[Document]) -> list[str]: + return [text.metadata['doc_id'] for text in texts] diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py new file mode 100644 index 000000000..dd8fc9304 --- /dev/null +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -0,0 +1,171 @@ +from typing import Any, cast + +from flask import current_app + +from core.embedding.cached_embedding import CacheEmbedding +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document +from extensions.ext_database import db +from models.dataset import Dataset, DatasetCollectionBinding + + +class Vector: + def __init__(self, dataset: Dataset, attributes: list = None): + if attributes is None: + attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self._dataset = dataset + self._embeddings = self._get_embeddings() + self._attributes = attributes + self._vector_processor = self._init_vector() + + def _init_vector(self) -> BaseVector: + config = cast(dict, current_app.config) + vector_type = config.get('VECTOR_STORE') + + if self._dataset.index_struct_dict: + vector_type = self._dataset.index_struct_dict['type'] + + if not vector_type: + raise ValueError("Vector store must be specified.") + + if vector_type == "weaviate": + from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector + if self._dataset.index_struct_dict: + class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = self._dataset.id + collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + return WeaviateVector( + collection_name=collection_name, + config=WeaviateConfig( + endpoint=config.get('WEAVIATE_ENDPOINT'), + api_key=config.get('WEAVIATE_API_KEY'), + batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) + ), + attributes=self._attributes + ) + elif vector_type == "qdrant": + from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector + if self._dataset.collection_binding_id: + dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ + filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \ + one_or_none() + if dataset_collection_binding: + collection_name = dataset_collection_binding.collection_name + else: + raise ValueError('Dataset Collection Bindings is not exist!') + else: + if self._dataset.index_struct_dict: + class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = self._dataset.id + collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + + return QdrantVector( + collection_name=collection_name, + group_id=self._dataset.id, + config=QdrantConfig( + endpoint=config.get('QDRANT_URL'), + api_key=config.get('QDRANT_API_KEY'), + root_path=current_app.root_path, + timeout=config.get('QDRANT_CLIENT_TIMEOUT') + ) + ) + elif vector_type == "milvus": + from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector + if self._dataset.index_struct_dict: + class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = self._dataset.id + collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + return MilvusVector( + collection_name=collection_name, + config=MilvusConfig( + host=config.get('MILVUS_HOST'), + port=config.get('MILVUS_PORT'), + user=config.get('MILVUS_USER'), + password=config.get('MILVUS_PASSWORD'), + secure=config.get('MILVUS_SECURE'), + ) + ) + else: + raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") + + def create(self, texts: list = None, **kwargs): + if texts: + embeddings = self._embeddings.embed_documents([document.page_content for document in texts]) + self._vector_processor.create( + texts=texts, + embeddings=embeddings, + **kwargs + ) + + def add_texts(self, documents: list[Document], **kwargs): + if kwargs.get('duplicate_check', False): + documents = self._filter_duplicate_texts(documents) + embeddings = self._embeddings.embed_documents([document.page_content for document in documents]) + self._vector_processor.add_texts( + documents=documents, + embeddings=embeddings, + **kwargs + ) + + def text_exists(self, id: str) -> bool: + return self._vector_processor.text_exists(id) + + def delete_by_ids(self, ids: list[str]) -> None: + self._vector_processor.delete_by_ids(ids) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + self._vector_processor.delete_by_metadata_field(key, value) + + def search_by_vector( + self, query: str, + **kwargs: Any + ) -> list[Document]: + query_vector = self._embeddings.embed_query(query) + return self._vector_processor.search_by_vector(query_vector, **kwargs) + + def search_by_full_text( + self, query: str, + **kwargs: Any + ) -> list[Document]: + return self._vector_processor.search_by_full_text(query, **kwargs) + + def delete(self) -> None: + self._vector_processor.delete() + + def _get_embeddings(self) -> Embeddings: + model_manager = ModelManager() + + embedding_model = model_manager.get_model_instance( + tenant_id=self._dataset.tenant_id, + provider=self._dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=self._dataset.embedding_model + + ) + return CacheEmbedding(embedding_model) + + def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: + for text in texts: + doc_id = text.metadata['doc_id'] + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) + + return texts + + def __getattr__(self, name): + if self._vector_processor is not None: + method = getattr(self._vector_processor, name) + if callable(method): + return method + + raise AttributeError(f"'vector_processor' object has no attribute '{name}'") diff --git a/api/core/rag/datasource/vdb/weaviate/__init__.py b/api/core/rag/datasource/vdb/weaviate/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py new file mode 100644 index 000000000..5c3a810fb --- /dev/null +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -0,0 +1,235 @@ +import datetime +from typing import Any, Optional + +import requests +import weaviate +from pydantic import BaseModel, root_validator + +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document +from models.dataset import Dataset + + +class WeaviateConfig(BaseModel): + endpoint: str + api_key: Optional[str] + batch_size: int = 100 + + @root_validator() + def validate_config(cls, values: dict) -> dict: + if not values['endpoint']: + raise ValueError("config WEAVIATE_ENDPOINT is required") + return values + + +class WeaviateVector(BaseVector): + + def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): + super().__init__(collection_name) + self._client = self._init_client(config) + self._attributes = attributes + + def _init_client(self, config: WeaviateConfig) -> weaviate.Client: + auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) + + weaviate.connect.connection.has_grpc = False + + try: + client = weaviate.Client( + url=config.endpoint, + auth_client_secret=auth_config, + timeout_config=(5, 60), + startup_period=None + ) + except requests.exceptions.ConnectionError: + raise ConnectionError("Vector database connection error") + + client.batch.configure( + # `batch_size` takes an `int` value to enable auto-batching + # (`None` is used for manual batching) + batch_size=config.batch_size, + # dynamically update the `batch_size` based on import speed + dynamic=True, + # `timeout_retries` takes an `int` value to retry on time outs + timeout_retries=3, + ) + + return client + + def get_type(self) -> str: + return 'weaviate' + + def get_collection_name(self, dataset: Dataset) -> str: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + if not class_prefix.endswith('_Node'): + # original class_prefix + class_prefix += '_Node' + + return class_prefix + + dataset_id = dataset.id + return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + + def to_index_struct(self) -> dict: + return { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name} + } + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + + schema = self._default_schema(self._collection_name) + + # check whether the index already exists + if not self._client.schema.contains(schema): + # create collection + self._client.schema.create_class(schema) + # create vector + self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + uuids = self._get_uuids(documents) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + + ids = [] + + with self._client.batch as batch: + for i, text in enumerate(texts): + data_properties = {Field.TEXT_KEY.value: text} + if metadatas is not None: + for key, val in metadatas[i].items(): + data_properties[key] = self._json_serializable(val) + + batch.add_data_object( + data_object=data_properties, + class_name=self._collection_name, + uuid=uuids[i], + vector=embeddings[i] if embeddings else None, + ) + ids.append(uuids[i]) + return ids + + def delete_by_metadata_field(self, key: str, value: str): + + where_filter = { + "operator": "Equal", + "path": [key], + "valueText": value + } + + self._client.batch.delete_objects( + class_name=self._collection_name, + where=where_filter, + output='minimal' + ) + + def delete(self): + self._client.schema.delete_class(self._collection_name) + + def text_exists(self, id: str) -> bool: + collection_name = self._collection_name + result = self._client.query.get(collection_name).with_additional(["id"]).with_where({ + "path": ["doc_id"], + "operator": "Equal", + "valueText": id, + }).with_limit(1).do() + + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") + + entries = result["data"]["Get"][collection_name] + if len(entries) == 0: + return False + + return True + + def delete_by_ids(self, ids: list[str]) -> None: + self._client.data_object.delete( + ids, + class_name=self._collection_name + ) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """Look up similar documents by embedding vector in Weaviate.""" + collection_name = self._collection_name + properties = self._attributes + properties.append(Field.TEXT_KEY.value) + query_obj = self._client.query.get(collection_name, properties) + + vector = {"vector": query_vector} + if kwargs.get("where_filter"): + query_obj = query_obj.with_where(kwargs.get("where_filter")) + result = ( + query_obj.with_near_vector(vector) + .with_limit(kwargs.get("top_k", 4)) + .with_additional(["vector", "distance"]) + .do() + ) + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") + + docs_and_scores = [] + for res in result["data"]["Get"][collection_name]: + text = res.pop(Field.TEXT_KEY.value) + score = 1 - res["_additional"]["distance"] + docs_and_scores.append((Document(page_content=text, metadata=res), score)) + + docs = [] + for doc, score in docs_and_scores: + score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + # check score threshold + if score > score_threshold: + doc.metadata['score'] = score + docs.append(doc) + + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Return docs using BM25F. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the query. + """ + collection_name = self._collection_name + content: dict[str, Any] = {"concepts": [query]} + properties = self._attributes + properties.append(Field.TEXT_KEY.value) + if kwargs.get("search_distance"): + content["certainty"] = kwargs.get("search_distance") + query_obj = self._client.query.get(collection_name, properties) + if kwargs.get("where_filter"): + query_obj = query_obj.with_where(kwargs.get("where_filter")) + if kwargs.get("additional"): + query_obj = query_obj.with_additional(kwargs.get("additional")) + properties = ['text'] + result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do() + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") + docs = [] + for res in result["data"]["Get"][collection_name]: + text = res.pop(Field.TEXT_KEY.value) + docs.append(Document(page_content=text, metadata=res)) + return docs + + def _default_schema(self, index_name: str) -> dict: + return { + "class": index_name, + "properties": [ + { + "name": "text", + "dataType": ["text"], + } + ], + } + + def _json_serializable(self, value: Any) -> Any: + if isinstance(value, datetime.datetime): + return value.isoformat() + return value diff --git a/api/core/rag/extractor/blod/blod.py b/api/core/rag/extractor/blod/blod.py new file mode 100644 index 000000000..368946b5e --- /dev/null +++ b/api/core/rag/extractor/blod/blod.py @@ -0,0 +1,166 @@ +"""Schema for Blobs and Blob Loaders. + +The goal is to facilitate decoupling of content loading from content parsing code. + +In addition, content loading code should provide a lazy loading interface by default. +""" +from __future__ import annotations + +import contextlib +import mimetypes +from abc import ABC, abstractmethod +from collections.abc import Generator, Iterable, Mapping +from io import BufferedReader, BytesIO +from pathlib import PurePath +from typing import Any, Optional, Union + +from pydantic import BaseModel, root_validator + +PathLike = Union[str, PurePath] + + +class Blob(BaseModel): + """A blob is used to represent raw data by either reference or value. + + Provides an interface to materialize the blob in different representations, and + help to decouple the development of data loaders from the downstream parsing of + the raw data. + + Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob + """ + + data: Union[bytes, str, None] # Raw data + mimetype: Optional[str] = None # Not to be confused with a file extension + encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string + # Location where the original content was found + # Represent location on the local file system + # Useful for situations where downstream code assumes it must work with file paths + # rather than in-memory content. + path: Optional[PathLike] = None + + class Config: + arbitrary_types_allowed = True + frozen = True + + @property + def source(self) -> Optional[str]: + """The source location of the blob as string if known otherwise none.""" + return str(self.path) if self.path else None + + @root_validator(pre=True) + def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]: + """Verify that either data or path is provided.""" + if "data" not in values and "path" not in values: + raise ValueError("Either data or path must be provided") + return values + + def as_string(self) -> str: + """Read data as a string.""" + if self.data is None and self.path: + with open(str(self.path), encoding=self.encoding) as f: + return f.read() + elif isinstance(self.data, bytes): + return self.data.decode(self.encoding) + elif isinstance(self.data, str): + return self.data + else: + raise ValueError(f"Unable to get string for blob {self}") + + def as_bytes(self) -> bytes: + """Read data as bytes.""" + if isinstance(self.data, bytes): + return self.data + elif isinstance(self.data, str): + return self.data.encode(self.encoding) + elif self.data is None and self.path: + with open(str(self.path), "rb") as f: + return f.read() + else: + raise ValueError(f"Unable to get bytes for blob {self}") + + @contextlib.contextmanager + def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]: + """Read data as a byte stream.""" + if isinstance(self.data, bytes): + yield BytesIO(self.data) + elif self.data is None and self.path: + with open(str(self.path), "rb") as f: + yield f + else: + raise NotImplementedError(f"Unable to convert blob {self}") + + @classmethod + def from_path( + cls, + path: PathLike, + *, + encoding: str = "utf-8", + mime_type: Optional[str] = None, + guess_type: bool = True, + ) -> Blob: + """Load the blob from a path like object. + + Args: + path: path like object to file to be read + encoding: Encoding to use if decoding the bytes into a string + mime_type: if provided, will be set as the mime-type of the data + guess_type: If True, the mimetype will be guessed from the file extension, + if a mime-type was not provided + + Returns: + Blob instance + """ + if mime_type is None and guess_type: + _mimetype = mimetypes.guess_type(path)[0] if guess_type else None + else: + _mimetype = mime_type + # We do not load the data immediately, instead we treat the blob as a + # reference to the underlying data. + return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path) + + @classmethod + def from_data( + cls, + data: Union[str, bytes], + *, + encoding: str = "utf-8", + mime_type: Optional[str] = None, + path: Optional[str] = None, + ) -> Blob: + """Initialize the blob from in-memory data. + + Args: + data: the in-memory data associated with the blob + encoding: Encoding to use if decoding the bytes into a string + mime_type: if provided, will be set as the mime-type of the data + path: if provided, will be set as the source from which the data came + + Returns: + Blob instance + """ + return cls(data=data, mimetype=mime_type, encoding=encoding, path=path) + + def __repr__(self) -> str: + """Define the blob representation.""" + str_repr = f"Blob {id(self)}" + if self.source: + str_repr += f" {self.source}" + return str_repr + + +class BlobLoader(ABC): + """Abstract interface for blob loaders implementation. + + Implementer should be able to load raw content from a datasource system according + to some criteria and return the raw content lazily as a stream of blobs. + """ + + @abstractmethod + def yield_blobs( + self, + ) -> Iterable[Blob]: + """A lazy loader for raw data represented by LangChain's Blob object. + + Returns: + A generator over blobs + """ diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py new file mode 100644 index 000000000..c391d7ae6 --- /dev/null +++ b/api/core/rag/extractor/csv_extractor.py @@ -0,0 +1,71 @@ +"""Abstract interface for document loader implementations.""" +import csv +from typing import Optional + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document + + +class CSVExtractor(BaseExtractor): + """Load CSV files. + + + Args: + file_path: Path to the file to load. + """ + + def __init__( + self, + file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False, + source_column: Optional[str] = None, + csv_args: Optional[dict] = None, + ): + """Initialize with file path.""" + self._file_path = file_path + self._encoding = encoding + self._autodetect_encoding = autodetect_encoding + self.source_column = source_column + self.csv_args = csv_args or {} + + def extract(self) -> list[Document]: + """Load data into document objects.""" + try: + with open(self._file_path, newline="", encoding=self._encoding) as csvfile: + docs = self._read_from_file(csvfile) + except UnicodeDecodeError as e: + if self._autodetect_encoding: + detected_encodings = detect_filze_encodings(self._file_path) + for encoding in detected_encodings: + try: + with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile: + docs = self._read_from_file(csvfile) + break + except UnicodeDecodeError: + continue + else: + raise RuntimeError(f"Error loading {self._file_path}") from e + + return docs + + def _read_from_file(self, csvfile) -> list[Document]: + docs = [] + csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore + for i, row in enumerate(csv_reader): + content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items()) + try: + source = ( + row[self.source_column] + if self.source_column is not None + else '' + ) + except KeyError: + raise ValueError( + f"Source column '{self.source_column}' not found in CSV file." + ) + metadata = {"source": source, "row": i} + doc = Document(page_content=content, metadata=metadata) + docs.append(doc) + + return docs diff --git a/api/core/rag/extractor/entity/datasource_type.py b/api/core/rag/extractor/entity/datasource_type.py new file mode 100644 index 000000000..2c79e7b97 --- /dev/null +++ b/api/core/rag/extractor/entity/datasource_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class DatasourceType(Enum): + FILE = "upload_file" + NOTION = "notion_import" diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py new file mode 100644 index 000000000..bc5310f7b --- /dev/null +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -0,0 +1,36 @@ +from pydantic import BaseModel + +from models.dataset import Document +from models.model import UploadFile + + +class NotionInfo(BaseModel): + """ + Notion import info. + """ + notion_workspace_id: str + notion_obj_id: str + notion_page_type: str + document: Document = None + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **data) -> None: + super().__init__(**data) + + +class ExtractSetting(BaseModel): + """ + Model class for provider response. + """ + datasource_type: str + upload_file: UploadFile = None + notion_info: NotionInfo = None + document_model: str = None + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **data) -> None: + super().__init__(**data) diff --git a/api/core/data_loader/loader/excel.py b/api/core/rag/extractor/excel_extractor.py similarity index 66% rename from api/core/data_loader/loader/excel.py rename to api/core/rag/extractor/excel_extractor.py index cddb29854..532391048 100644 --- a/api/core/data_loader/loader/excel.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,14 +1,14 @@ -import logging +"""Abstract interface for document loader implementations.""" +from typing import Optional -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document from openpyxl.reader.excel import load_workbook -logger = logging.getLogger(__name__) +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document -class ExcelLoader(BaseLoader): - """Load xlxs files. +class ExcelExtractor(BaseExtractor): + """Load Excel files. Args: @@ -16,13 +16,18 @@ class ExcelLoader(BaseLoader): """ def __init__( - self, - file_path: str + self, + file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False ): """Initialize with file path.""" self._file_path = file_path + self._encoding = encoding + self._autodetect_encoding = autodetect_encoding - def load(self) -> list[Document]: + def extract(self) -> list[Document]: + """Load from file path.""" data = [] keys = [] wb = load_workbook(filename=self._file_path, read_only=True) diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py new file mode 100644 index 000000000..7c7dc5bda --- /dev/null +++ b/api/core/rag/extractor/extract_processor.py @@ -0,0 +1,139 @@ +import tempfile +from pathlib import Path +from typing import Union + +import requests +from flask import current_app + +from core.rag.extractor.csv_extractor import CSVExtractor +from core.rag.extractor.entity.datasource_type import DatasourceType +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.excel_extractor import ExcelExtractor +from core.rag.extractor.html_extractor import HtmlExtractor +from core.rag.extractor.markdown_extractor import MarkdownExtractor +from core.rag.extractor.notion_extractor import NotionExtractor +from core.rag.extractor.pdf_extractor import PdfExtractor +from core.rag.extractor.text_extractor import TextExtractor +from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor +from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor +from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor +from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor +from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor +from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor +from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor +from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor +from core.rag.extractor.word_extractor import WordExtractor +from core.rag.models.document import Document +from extensions.ext_storage import storage +from models.model import UploadFile + +SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain'] +USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + + +class ExtractProcessor: + @classmethod + def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \ + -> Union[list[Document], str]: + extract_setting = ExtractSetting( + datasource_type="upload_file", + upload_file=upload_file, + document_model='text_model' + ) + if return_text: + delimiter = '\n' + return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)]) + else: + return cls.extract(extract_setting, is_automatic) + + @classmethod + def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: + response = requests.get(url, headers={ + "User-Agent": USER_AGENT + }) + + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(url).suffix + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + with open(file_path, 'wb') as file: + file.write(response.content) + extract_setting = ExtractSetting( + datasource_type="upload_file", + document_model='text_model' + ) + if return_text: + delimiter = '\n' + return delimiter.join([document.page_content for document in cls.extract( + extract_setting=extract_setting, file_path=file_path)]) + else: + return cls.extract(extract_setting=extract_setting, file_path=file_path) + + @classmethod + def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, + file_path: str = None) -> list[Document]: + if extract_setting.datasource_type == DatasourceType.FILE.value: + with tempfile.TemporaryDirectory() as temp_dir: + if not file_path: + upload_file: UploadFile = extract_setting.upload_file + suffix = Path(upload_file.key).suffix + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + storage.download(upload_file.key, file_path) + input_file = Path(file_path) + file_extension = input_file.suffix.lower() + etl_type = current_app.config['ETL_TYPE'] + unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] + if etl_type == 'Unstructured': + if file_extension == '.xlsx': + extractor = ExcelExtractor(file_path) + elif file_extension == '.pdf': + extractor = PdfExtractor(file_path) + elif file_extension in ['.md', '.markdown']: + extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \ + else MarkdownExtractor(file_path, autodetect_encoding=True) + elif file_extension in ['.htm', '.html']: + extractor = HtmlExtractor(file_path) + elif file_extension in ['.docx']: + extractor = UnstructuredWordExtractor(file_path, unstructured_api_url) + elif file_extension == '.csv': + extractor = CSVExtractor(file_path, autodetect_encoding=True) + elif file_extension == '.msg': + extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url) + elif file_extension == '.eml': + extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url) + elif file_extension == '.ppt': + extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url) + elif file_extension == '.pptx': + extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url) + elif file_extension == '.xml': + extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url) + else: + # txt + extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \ + else TextExtractor(file_path, autodetect_encoding=True) + else: + if file_extension == '.xlsx': + extractor = ExcelExtractor(file_path) + elif file_extension == '.pdf': + extractor = PdfExtractor(file_path) + elif file_extension in ['.md', '.markdown']: + extractor = MarkdownExtractor(file_path, autodetect_encoding=True) + elif file_extension in ['.htm', '.html']: + extractor = HtmlExtractor(file_path) + elif file_extension in ['.docx']: + extractor = WordExtractor(file_path) + elif file_extension == '.csv': + extractor = CSVExtractor(file_path, autodetect_encoding=True) + else: + # txt + extractor = TextExtractor(file_path, autodetect_encoding=True) + return extractor.extract() + elif extract_setting.datasource_type == DatasourceType.NOTION.value: + extractor = NotionExtractor( + notion_workspace_id=extract_setting.notion_info.notion_workspace_id, + notion_obj_id=extract_setting.notion_info.notion_obj_id, + notion_page_type=extract_setting.notion_info.notion_page_type, + document_model=extract_setting.notion_info.document + ) + return extractor.extract() + else: + raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}") diff --git a/api/core/rag/extractor/extractor_base.py b/api/core/rag/extractor/extractor_base.py new file mode 100644 index 000000000..c490e5933 --- /dev/null +++ b/api/core/rag/extractor/extractor_base.py @@ -0,0 +1,12 @@ +"""Abstract interface for document loader implementations.""" +from abc import ABC, abstractmethod + + +class BaseExtractor(ABC): + """Interface for extract files. + """ + + @abstractmethod + def extract(self): + raise NotImplementedError + diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py new file mode 100644 index 000000000..0c17a47b3 --- /dev/null +++ b/api/core/rag/extractor/helpers.py @@ -0,0 +1,46 @@ +"""Document loader helpers.""" + +import concurrent.futures +from typing import NamedTuple, Optional, cast + + +class FileEncoding(NamedTuple): + """A file encoding as the NamedTuple.""" + + encoding: Optional[str] + """The encoding of the file.""" + confidence: float + """The confidence of the encoding.""" + language: Optional[str] + """The language of the file.""" + + +def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]: + """Try to detect the file encoding. + + Returns a list of `FileEncoding` tuples with the detected encodings ordered + by confidence. + + Args: + file_path: The path to the file to detect the encoding for. + timeout: The timeout in seconds for the encoding detection. + """ + import chardet + + def read_and_detect(file_path: str) -> list[dict]: + with open(file_path, "rb") as f: + rawdata = f.read() + return cast(list[dict], chardet.detect_all(rawdata)) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(read_and_detect, file_path) + try: + encodings = future.result(timeout=timeout) + except concurrent.futures.TimeoutError: + raise TimeoutError( + f"Timeout reached while detecting encoding for {file_path}" + ) + + if all(encoding["encoding"] is None for encoding in encodings): + raise RuntimeError(f"Could not detect encoding for {file_path}") + return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None] diff --git a/api/core/data_loader/loader/csv_loader.py b/api/core/rag/extractor/html_extractor.py similarity index 61% rename from api/core/data_loader/loader/csv_loader.py rename to api/core/rag/extractor/html_extractor.py index ce252c157..557ea42b1 100644 --- a/api/core/data_loader/loader/csv_loader.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,51 +1,55 @@ -import csv -import logging +"""Abstract interface for document loader implementations.""" from typing import Optional -from langchain.document_loaders import CSVLoader as LCCSVLoader -from langchain.document_loaders.helpers import detect_file_encodings -from langchain.schema import Document - -logger = logging.getLogger(__name__) +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.extractor.helpers import detect_file_encodings +from core.rag.models.document import Document -class CSVLoader(LCCSVLoader): +class HtmlExtractor(BaseExtractor): + """Load html files. + + + Args: + file_path: Path to the file to load. + """ + def __init__( self, file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False, source_column: Optional[str] = None, csv_args: Optional[dict] = None, - encoding: Optional[str] = None, - autodetect_encoding: bool = True, ): - self.file_path = file_path + """Initialize with file path.""" + self._file_path = file_path + self._encoding = encoding + self._autodetect_encoding = autodetect_encoding self.source_column = source_column - self.encoding = encoding self.csv_args = csv_args or {} - self.autodetect_encoding = autodetect_encoding - def load(self) -> list[Document]: + def extract(self) -> list[Document]: """Load data into document objects.""" try: - with open(self.file_path, newline="", encoding=self.encoding) as csvfile: + with open(self._file_path, newline="", encoding=self._encoding) as csvfile: docs = self._read_from_file(csvfile) except UnicodeDecodeError as e: - if self.autodetect_encoding: - detected_encodings = detect_file_encodings(self.file_path) + if self._autodetect_encoding: + detected_encodings = detect_file_encodings(self._file_path) for encoding in detected_encodings: - logger.debug("Trying encoding: ", encoding.encoding) try: - with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile: + with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile: docs = self._read_from_file(csvfile) break except UnicodeDecodeError: continue else: - raise RuntimeError(f"Error loading {self.file_path}") from e + raise RuntimeError(f"Error loading {self._file_path}") from e return docs - def _read_from_file(self, csvfile): + def _read_from_file(self, csvfile) -> list[Document]: docs = [] csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore for i, row in enumerate(csv_reader): diff --git a/api/core/data_loader/loader/markdown.py b/api/core/rag/extractor/markdown_extractor.py similarity index 79% rename from api/core/data_loader/loader/markdown.py rename to api/core/rag/extractor/markdown_extractor.py index ecbc6d548..91c687bac 100644 --- a/api/core/data_loader/loader/markdown.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -1,39 +1,27 @@ -import logging +"""Abstract interface for document loader implementations.""" import re from typing import Optional, cast -from langchain.document_loaders.base import BaseLoader -from langchain.document_loaders.helpers import detect_file_encodings -from langchain.schema import Document - -logger = logging.getLogger(__name__) +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.extractor.helpers import detect_file_encodings +from core.rag.models.document import Document -class MarkdownLoader(BaseLoader): - """Load md files. +class MarkdownExtractor(BaseExtractor): + """Load Markdown files. Args: file_path: Path to the file to load. - - remove_hyperlinks: Whether to remove hyperlinks from the text. - - remove_images: Whether to remove images from the text. - - encoding: File encoding to use. If `None`, the file will be loaded - with the default system encoding. - - autodetect_encoding: Whether to try to autodetect the file encoding - if the specified encoding fails. """ def __init__( - self, - file_path: str, - remove_hyperlinks: bool = True, - remove_images: bool = True, - encoding: Optional[str] = None, - autodetect_encoding: bool = True, + self, + file_path: str, + remove_hyperlinks: bool = True, + remove_images: bool = True, + encoding: Optional[str] = None, + autodetect_encoding: bool = True, ): """Initialize with file path.""" self._file_path = file_path @@ -42,7 +30,8 @@ class MarkdownLoader(BaseLoader): self._encoding = encoding self._autodetect_encoding = autodetect_encoding - def load(self) -> list[Document]: + def extract(self) -> list[Document]: + """Load from file path.""" tups = self.parse_tups(self._file_path) documents = [] for header, value in tups: @@ -113,7 +102,6 @@ class MarkdownLoader(BaseLoader): if self._autodetect_encoding: detected_encodings = detect_file_encodings(filepath) for encoding in detected_encodings: - logger.debug("Trying encoding: ", encoding.encoding) try: with open(filepath, encoding=encoding.encoding) as f: content = f.read() diff --git a/api/core/data_loader/loader/notion.py b/api/core/rag/extractor/notion_extractor.py similarity index 89% rename from api/core/data_loader/loader/notion.py rename to api/core/rag/extractor/notion_extractor.py index f8d883768..5fc60ba80 100644 --- a/api/core/data_loader/loader/notion.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -4,9 +4,10 @@ from typing import Any, Optional import requests from flask import current_app -from langchain.document_loaders.base import BaseLoader +from flask_login import current_user from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor from extensions.ext_database import db from models.dataset import Document as DocumentModel from models.source import DataSourceBinding @@ -22,52 +23,37 @@ RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] -class NotionLoader(BaseLoader): +class NotionExtractor(BaseExtractor): + def __init__( self, - notion_access_token: str, notion_workspace_id: str, notion_obj_id: str, notion_page_type: str, - document_model: Optional[DocumentModel] = None + document_model: Optional[DocumentModel] = None, + notion_access_token: Optional[str] = None ): + self._notion_access_token = None self._document_model = document_model self._notion_workspace_id = notion_workspace_id self._notion_obj_id = notion_obj_id self._notion_page_type = notion_page_type - self._notion_access_token = notion_access_token + if notion_access_token: + self._notion_access_token = notion_access_token + else: + self._notion_access_token = self._get_access_token(current_user.current_tenant_id, + self._notion_workspace_id) + if not self._notion_access_token: + integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') + if integration_token is None: + raise ValueError( + "Must specify `integration_token` or set environment " + "variable `NOTION_INTEGRATION_TOKEN`." + ) - if not self._notion_access_token: - integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') - if integration_token is None: - raise ValueError( - "Must specify `integration_token` or set environment " - "variable `NOTION_INTEGRATION_TOKEN`." - ) + self._notion_access_token = integration_token - self._notion_access_token = integration_token - - @classmethod - def from_document(cls, document_model: DocumentModel): - data_source_info = document_model.data_source_info_dict - if not data_source_info or 'notion_page_id' not in data_source_info \ - or 'notion_workspace_id' not in data_source_info: - raise ValueError("no notion page found") - - notion_workspace_id = data_source_info['notion_workspace_id'] - notion_obj_id = data_source_info['notion_page_id'] - notion_page_type = data_source_info['type'] - notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id) - - return cls( - notion_access_token=notion_access_token, - notion_workspace_id=notion_workspace_id, - notion_obj_id=notion_obj_id, - notion_page_type=notion_page_type, - document_model=document_model - ) - - def load(self) -> list[Document]: + def extract(self) -> list[Document]: self.update_last_edited_time( self._document_model ) diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py new file mode 100644 index 000000000..cbb265539 --- /dev/null +++ b/api/core/rag/extractor/pdf_extractor.py @@ -0,0 +1,72 @@ +"""Abstract interface for document loader implementations.""" +from collections.abc import Iterator +from typing import Optional + +from core.rag.extractor.blod.blod import Blob +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document +from extensions.ext_storage import storage + + +class PdfExtractor(BaseExtractor): + """Load pdf files. + + + Args: + file_path: Path to the file to load. + """ + + def __init__( + self, + file_path: str, + file_cache_key: Optional[str] = None + ): + """Initialize with file path.""" + self._file_path = file_path + self._file_cache_key = file_cache_key + + def extract(self) -> list[Document]: + plaintext_file_key = '' + plaintext_file_exists = False + if self._file_cache_key: + try: + text = storage.load(self._file_cache_key).decode('utf-8') + plaintext_file_exists = True + return [Document(page_content=text)] + except FileNotFoundError: + pass + documents = list(self.load()) + text_list = [] + for document in documents: + text_list.append(document.page_content) + text = "\n\n".join(text_list) + + # save plaintext file for caching + if not plaintext_file_exists and plaintext_file_key: + storage.save(plaintext_file_key, text.encode('utf-8')) + + return documents + + def load( + self, + ) -> Iterator[Document]: + """Lazy load given path as pages.""" + blob = Blob.from_path(self._file_path) + yield from self.parse(blob) + + def parse(self, blob: Blob) -> Iterator[Document]: + """Lazily parse the blob.""" + import pypdfium2 + + with blob.as_bytes_io() as file_path: + pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) + try: + for page_number, page in enumerate(pdf_reader): + text_page = page.get_textpage() + content = text_page.get_text_range() + text_page.close() + page.close() + metadata = {"source": blob.source, "page": page_number} + yield Document(page_content=content, metadata=metadata) + finally: + pdf_reader.close() diff --git a/api/core/rag/extractor/text_extractor.py b/api/core/rag/extractor/text_extractor.py new file mode 100644 index 000000000..ac5d0920c --- /dev/null +++ b/api/core/rag/extractor/text_extractor.py @@ -0,0 +1,50 @@ +"""Abstract interface for document loader implementations.""" +from typing import Optional + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.extractor.helpers import detect_file_encodings +from core.rag.models.document import Document + + +class TextExtractor(BaseExtractor): + """Load text files. + + + Args: + file_path: Path to the file to load. + """ + + def __init__( + self, + file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False + ): + """Initialize with file path.""" + self._file_path = file_path + self._encoding = encoding + self._autodetect_encoding = autodetect_encoding + + def extract(self) -> list[Document]: + """Load from file path.""" + text = "" + try: + with open(self._file_path, encoding=self._encoding) as f: + text = f.read() + except UnicodeDecodeError as e: + if self._autodetect_encoding: + detected_encodings = detect_file_encodings(self._file_path) + for encoding in detected_encodings: + try: + with open(self._file_path, encoding=encoding.encoding) as f: + text = f.read() + break + except UnicodeDecodeError: + continue + else: + raise RuntimeError(f"Error loading {self._file_path}") from e + except Exception as e: + raise RuntimeError(f"Error loading {self._file_path}") from e + + metadata = {"source": self._file_path} + return [Document(page_content=text, metadata=metadata)] diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py new file mode 100644 index 000000000..b37981a30 --- /dev/null +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -0,0 +1,61 @@ +import logging +import os + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document + +logger = logging.getLogger(__name__) + + +class UnstructuredWordExtractor(BaseExtractor): + """Loader that uses unstructured to load word documents. + """ + + def __init__( + self, + file_path: str, + api_url: str, + ): + """Initialize with file path.""" + self._file_path = file_path + self._api_url = api_url + + def extract(self) -> list[Document]: + from unstructured.__version__ import __version__ as __unstructured_version__ + from unstructured.file_utils.filetype import FileType, detect_filetype + + unstructured_version = tuple( + [int(x) for x in __unstructured_version__.split(".")] + ) + # check the file extension + try: + import magic # noqa: F401 + + is_doc = detect_filetype(self._file_path) == FileType.DOC + except ImportError: + _, extension = os.path.splitext(str(self._file_path)) + is_doc = extension == ".doc" + + if is_doc and unstructured_version < (0, 4, 11): + raise ValueError( + f"You are on unstructured version {__unstructured_version__}. " + "Partitioning .doc files is only supported in unstructured>=0.4.11. " + "Please upgrade the unstructured package and try again." + ) + + if is_doc: + from unstructured.partition.doc import partition_doc + + elements = partition_doc(filename=self._file_path) + else: + from unstructured.partition.docx import partition_docx + + elements = partition_docx(filename=self._file_path) + + from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0) + documents = [] + for chunk in chunks: + text = chunk.text.strip() + documents.append(Document(page_content=text)) + return documents diff --git a/api/core/data_loader/loader/unstructured/unstructured_eml.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py similarity index 87% rename from api/core/data_loader/loader/unstructured/unstructured_eml.py rename to api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index 2fa3aac13..1d92bbbee 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_eml.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -2,13 +2,14 @@ import base64 import logging from bs4 import BeautifulSoup -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredEmailLoader(BaseLoader): +class UnstructuredEmailExtractor(BaseExtractor): """Load msg files. Args: file_path: Path to the file to load. @@ -23,7 +24,7 @@ class UnstructuredEmailLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.email import partition_email elements = partition_email(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_markdown.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py similarity index 85% rename from api/core/data_loader/loader/unstructured/unstructured_markdown.py rename to api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 036a2afd2..3ac04ddc1 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_markdown.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -1,12 +1,12 @@ import logging -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredMarkdownLoader(BaseLoader): +class UnstructuredMarkdownExtractor(BaseExtractor): """Load md files. @@ -33,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.md import partition_md elements = partition_md(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_msg.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py similarity index 80% rename from api/core/data_loader/loader/unstructured/unstructured_msg.py rename to api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index 495be328e..d4b72e37e 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_msg.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -1,12 +1,12 @@ import logging -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredMsgLoader(BaseLoader): +class UnstructuredMsgExtractor(BaseExtractor): """Load msg files. @@ -23,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.msg import partition_msg elements = partition_msg(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_ppt.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py similarity index 78% rename from api/core/data_loader/loader/unstructured/unstructured_ppt.py rename to api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index cfac91cc7..cd3aba986 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_ppt.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -1,11 +1,12 @@ import logging -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredPPTLoader(BaseLoader): + +class UnstructuredPPTExtractor(BaseExtractor): """Load msg files. @@ -14,15 +15,15 @@ class UnstructuredPPTLoader(BaseLoader): """ def __init__( - self, - file_path: str, - api_url: str + self, + file_path: str, + api_url: str ): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.ppt import partition_ppt elements = partition_ppt(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_pptx.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py similarity index 78% rename from api/core/data_loader/loader/unstructured/unstructured_pptx.py rename to api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index 41e3bfcb5..f9667d252 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_pptx.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -1,10 +1,12 @@ import logging -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredPPTXLoader(BaseLoader): + + +class UnstructuredPPTXExtractor(BaseExtractor): """Load msg files. @@ -13,15 +15,15 @@ class UnstructuredPPTXLoader(BaseLoader): """ def __init__( - self, - file_path: str, - api_url: str + self, + file_path: str, + api_url: str ): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.pptx import partition_pptx elements = partition_pptx(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_text.py b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py similarity index 80% rename from api/core/data_loader/loader/unstructured/unstructured_text.py rename to api/core/rag/extractor/unstructured/unstructured_text_extractor.py index 09d14fdb1..5af21b2b1 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_text.py +++ b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py @@ -1,12 +1,12 @@ import logging -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredTextLoader(BaseLoader): +class UnstructuredTextExtractor(BaseExtractor): """Load msg files. @@ -23,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.text import partition_text elements = partition_text(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_xml.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py similarity index 81% rename from api/core/data_loader/loader/unstructured/unstructured_xml.py rename to api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index cca6e1b0b..b08ff63a1 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_xml.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -1,12 +1,12 @@ import logging -from langchain.document_loaders.base import BaseLoader -from langchain.schema import Document +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document logger = logging.getLogger(__name__) -class UnstructuredXmlLoader(BaseLoader): +class UnstructuredXmlExtractor(BaseExtractor): """Load msg files. @@ -23,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> list[Document]: + def extract(self) -> list[Document]: from unstructured.partition.xml import partition_xml elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py new file mode 100644 index 000000000..8e2cd14be --- /dev/null +++ b/api/core/rag/extractor/word_extractor.py @@ -0,0 +1,62 @@ +"""Abstract interface for document loader implementations.""" +import os +import tempfile +from urllib.parse import urlparse + +import requests + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document + + +class WordExtractor(BaseExtractor): + """Load pdf files. + + + Args: + file_path: Path to the file to load. + """ + + def __init__(self, file_path: str): + """Initialize with file path.""" + self.file_path = file_path + if "~" in self.file_path: + self.file_path = os.path.expanduser(self.file_path) + + # If the file is a web path, download it to a temporary file, and use that + if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path): + r = requests.get(self.file_path) + + if r.status_code != 200: + raise ValueError( + "Check the url of your file; returned status code %s" + % r.status_code + ) + + self.web_path = self.file_path + self.temp_file = tempfile.NamedTemporaryFile() + self.temp_file.write(r.content) + self.file_path = self.temp_file.name + elif not os.path.isfile(self.file_path): + raise ValueError("File path %s is not a valid file or url" % self.file_path) + + def __del__(self) -> None: + if hasattr(self, "temp_file"): + self.temp_file.close() + + def extract(self) -> list[Document]: + """Load given path as single page.""" + import docx2txt + + return [ + Document( + page_content=docx2txt.process(self.file_path), + metadata={"source": self.file_path}, + ) + ] + + @staticmethod + def _is_valid_url(url: str) -> bool: + """Check if the url is valid.""" + parsed = urlparse(url) + return bool(parsed.netloc) and bool(parsed.scheme) diff --git a/api/core/rag/index_processor/__init__.py b/api/core/rag/index_processor/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/index_processor/constant/__init__.py b/api/core/rag/index_processor/constant/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/index_processor/constant/index_type.py b/api/core/rag/index_processor/constant/index_type.py new file mode 100644 index 000000000..e42cc44c6 --- /dev/null +++ b/api/core/rag/index_processor/constant/index_type.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class IndexType(Enum): + PARAGRAPH_INDEX = "text_model" + QA_INDEX = "qa_model" + PARENT_CHILD_INDEX = "parent_child_index" + SUMMARY_INDEX = "summary_index" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py new file mode 100644 index 000000000..2ad9718ec --- /dev/null +++ b/api/core/rag/index_processor/index_processor_base.py @@ -0,0 +1,70 @@ +"""Abstract interface for document loader implementations.""" +from abc import ABC, abstractmethod +from typing import Optional + +from langchain.text_splitter import TextSplitter + +from core.model_manager import ModelInstance +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.models.document import Document +from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter +from models.dataset import Dataset, DatasetProcessRule + + +class BaseIndexProcessor(ABC): + """Interface for extract files. + """ + + @abstractmethod + def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: + raise NotImplementedError + + @abstractmethod + def transform(self, documents: list[Document], **kwargs) -> list[Document]: + raise NotImplementedError + + @abstractmethod + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + raise NotImplementedError + + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + raise NotImplementedError + + @abstractmethod + def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, + score_threshold: float, reranking_model: dict) -> list[Document]: + raise NotImplementedError + + def _get_splitter(self, processing_rule: dict, + embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + """ + Get the NodeParser object according to the processing rule. + """ + if processing_rule['mode'] == "custom": + # The user-defined segmentation rule + rules = processing_rule['rules'] + segmentation = rules["segmentation"] + if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000: + raise ValueError("Custom segment length should be between 50 and 1000.") + + separator = segmentation["separator"] + if separator: + separator = separator.replace('\\n', '\n') + + character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( + chunk_size=segmentation["max_tokens"], + chunk_overlap=0, + fixed_separator=separator, + separators=["\n\n", "。", ".", " ", ""], + embedding_model_instance=embedding_model_instance + ) + else: + # Automatic segmentation + character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( + chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], + chunk_overlap=0, + separators=["\n\n", "。", ".", " ", ""], + embedding_model_instance=embedding_model_instance + ) + + return character_splitter diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py new file mode 100644 index 000000000..df43a6491 --- /dev/null +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -0,0 +1,28 @@ +"""Abstract interface for document loader implementations.""" + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor + + +class IndexProcessorFactory: + """IndexProcessorInit. + """ + + def __init__(self, index_type: str): + self._index_type = index_type + + def init_index_processor(self) -> BaseIndexProcessor: + """Init index processor.""" + + if not self._index_type: + raise ValueError("Index type must be specified.") + + if self._index_type == IndexType.PARAGRAPH_INDEX.value: + return ParagraphIndexProcessor() + elif self._index_type == IndexType.QA_INDEX.value: + + return QAIndexProcessor() + else: + raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/__init__.py b/api/core/rag/index_processor/processor/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py new file mode 100644 index 000000000..3f0467ee2 --- /dev/null +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -0,0 +1,92 @@ +"""Paragraph index processor.""" +import uuid +from typing import Optional + +from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.extract_processor import ExtractProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import Document +from libs import helper +from models.dataset import Dataset + + +class ParagraphIndexProcessor(BaseIndexProcessor): + + def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: + + text_docs = ExtractProcessor.extract(extract_setting=extract_setting, + is_automatic=kwargs.get('process_rule_mode') == "automatic") + + return text_docs + + def transform(self, documents: list[Document], **kwargs) -> list[Document]: + # Split the text documents into nodes. + splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), + embedding_model_instance=kwargs.get('embedding_model_instance')) + all_documents = [] + for document in documents: + # document clean + document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document.page_content = document_text + # parse document to nodes + document_nodes = splitter.split_documents([document]) + split_documents = [] + for document_node in document_nodes: + + if document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata['doc_id'] = doc_id + document_node.metadata['doc_hash'] = hash + # delete Spliter character + page_content = document_node.page_content + if page_content.startswith(".") or page_content.startswith("。"): + page_content = page_content[1:] + else: + page_content = page_content + document_node.page_content = page_content + split_documents.append(document_node) + all_documents.extend(split_documents) + return all_documents + + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + if dataset.indexing_technique == 'high_quality': + vector = Vector(dataset) + vector.create(documents) + if with_keywords: + keyword = Keyword(dataset) + keyword.create(documents) + + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + if dataset.indexing_technique == 'high_quality': + vector = Vector(dataset) + if node_ids: + vector.delete_by_ids(node_ids) + else: + vector.delete() + if with_keywords: + keyword = Keyword(dataset) + if node_ids: + keyword.delete_by_ids(node_ids) + else: + keyword.delete() + + def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, + score_threshold: float, reranking_model: dict) -> list[Document]: + # Set search parameters. + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) + # Organize results. + docs = [] + for result in results: + metadata = result.metadata + metadata['score'] = result.score + if result.score > score_threshold: + doc = Document(page_content=result.page_content, metadata=metadata) + docs.append(doc) + return docs diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py new file mode 100644 index 000000000..f61c728b4 --- /dev/null +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -0,0 +1,161 @@ +"""Paragraph index processor.""" +import logging +import re +import threading +import uuid +from typing import Optional + +import pandas as pd +from flask import Flask, current_app +from flask_login import current_user +from werkzeug.datastructures import FileStorage + +from core.generator.llm_generator import LLMGenerator +from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.extract_processor import ExtractProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import Document +from libs import helper +from models.dataset import Dataset + + +class QAIndexProcessor(BaseIndexProcessor): + def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: + + text_docs = ExtractProcessor.extract(extract_setting=extract_setting, + is_automatic=kwargs.get('process_rule_mode') == "automatic") + return text_docs + + def transform(self, documents: list[Document], **kwargs) -> list[Document]: + splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), + embedding_model_instance=None) + + # Split the text documents into nodes. + all_documents = [] + all_qa_documents = [] + for document in documents: + # document clean + document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document.page_content = document_text + + # parse document to nodes + document_nodes = splitter.split_documents([document]) + split_documents = [] + for document_node in document_nodes: + + if document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata['doc_id'] = doc_id + document_node.metadata['doc_hash'] = hash + # delete Spliter character + page_content = document_node.page_content + if page_content.startswith(".") or page_content.startswith("。"): + page_content = page_content[1:] + else: + page_content = page_content + document_node.page_content = page_content + split_documents.append(document_node) + all_documents.extend(split_documents) + for i in range(0, len(all_documents), 10): + threads = [] + sub_documents = all_documents[i:i + 10] + for doc in sub_documents: + document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={ + 'flask_app': current_app._get_current_object(), + 'tenant_id': current_user.current_tenant.id, + 'document_node': doc, + 'all_qa_documents': all_qa_documents, + 'document_language': kwargs.get('document_language', 'English')}) + threads.append(document_format_thread) + document_format_thread.start() + for thread in threads: + thread.join() + return all_qa_documents + + def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: + + # check file type + if not file.filename.endswith('.csv'): + raise ValueError("Invalid file type. Only CSV files are allowed") + + try: + # Skip the first row + df = pd.read_csv(file) + text_docs = [] + for index, row in df.iterrows(): + data = Document(page_content=row[0], metadata={'answer': row[1]}) + text_docs.append(data) + if len(text_docs) == 0: + raise ValueError("The CSV file is empty.") + + except Exception as e: + raise ValueError(str(e)) + return text_docs + + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + if dataset.indexing_technique == 'high_quality': + vector = Vector(dataset) + vector.create(documents) + + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + vector = Vector(dataset) + if node_ids: + vector.delete_by_ids(node_ids) + else: + vector.delete() + + def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, + score_threshold: float, reranking_model: dict): + # Set search parameters. + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) + # Organize results. + docs = [] + for result in results: + metadata = result.metadata + metadata['score'] = result.score + if result.score > score_threshold: + doc = Document(page_content=result.page_content, metadata=metadata) + docs.append(doc) + return docs + + def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): + format_documents = [] + if document_node.page_content is None or not document_node.page_content.strip(): + return + with flask_app.app_context(): + try: + # qa model document + response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language) + document_qa_list = self._format_split_text(response) + qa_documents = [] + for result in document_qa_list: + qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(result['question']) + qa_document.metadata['answer'] = result['answer'] + qa_document.metadata['doc_id'] = doc_id + qa_document.metadata['doc_hash'] = hash + qa_documents.append(qa_document) + format_documents.extend(qa_documents) + except Exception as e: + logging.exception(e) + + all_qa_documents.extend(format_documents) + + def _format_split_text(self, text): + regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" + matches = re.findall(regex, text, re.UNICODE) + + return [ + { + "question": q, + "answer": re.sub(r"\n\s*", "\n", a.strip()) + } + for q, a in matches if q and a + ] diff --git a/api/core/rag/models/__init__.py b/api/core/rag/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py new file mode 100644 index 000000000..024a18bad --- /dev/null +++ b/api/core/rag/models/document.py @@ -0,0 +1,16 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class Document(BaseModel): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + + """Arbitrary metadata about the page content (e.g., source, relationships to other + documents, etc.). + """ + metadata: Optional[dict] = Field(default_factory=dict) + + diff --git a/api/core/tool/web_reader_tool.py b/api/core/tool/web_reader_tool.py index 1fac273d4..c96fac7c2 100644 --- a/api/core/tool/web_reader_tool.py +++ b/api/core/tool/web_reader_tool.py @@ -21,9 +21,9 @@ from pydantic import BaseModel, Field from regex import regex from core.chain.llm_chain import LLMChain -from core.data_loader import file_extractor -from core.data_loader.file_extractor import FileExtractor from core.entities.application_entities import ModelConfigEntity +from core.rag.extractor import extract_processor +from core.rag.extractor.extract_processor import ExtractProcessor FULL_TEMPLATE = """ TITLE: {title} @@ -146,7 +146,7 @@ def get_url(url: str) -> str: headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" } - supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] + supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) @@ -158,8 +158,8 @@ def get_url(url: str) -> str: if main_content_type not in supported_content_types: return "Unsupported content-type [{}] of URL.".format(main_content_type) - if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES: - return FileExtractor.load_from_url(url, return_text=True) + if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: + return ExtractProcessor.load_from_url(url, return_text=True) response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) a = extract_using_readabilipy(response.text) diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 490f8514d..57b6e090c 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -6,15 +6,12 @@ from langchain.tools import BaseTool from pydantic import BaseModel, Field from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.embedding.cached_embedding import CacheEmbedding -from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError -from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.retrieval_service import RetrievalService from core.rerank.rerank import RerankRunner from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -from services.retrieval_service import RetrievalService default_retrieval_model = { 'search_method': 'semantic_search', @@ -174,76 +171,24 @@ class DatasetMultiRetrieverTool(BaseTool): if dataset.indexing_technique == "economy": # use keyword table query - kw_table_index = KeywordTableIndex( - dataset=dataset, - config=KeywordTableConfig( - max_keywords_per_chunk=5 - ) - ) - - documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k + ) if documents: all_documents.extend(documents) else: - - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) - except LLMBadRequestError: - return [] - except ProviderTokenNotInitError: - return [] - - embeddings = CacheEmbedding(embedding_model) - - documents = [] - threads = [] if self.top_k > 0: - # retrieval_model source with semantic - if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[ - 'search_method'] == 'hybrid_search': - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': str(dataset.id), - 'query': query, - 'top_k': self.top_k, - 'score_threshold': self.score_threshold, - 'reranking_model': None, - 'all_documents': documents, - 'search_method': 'hybrid_search', - 'embeddings': embeddings - }) - threads.append(embedding_thread) - embedding_thread.start() - - # retrieval_model source with full text - if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[ - 'search_method'] == 'hybrid_search': - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, - kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': str(dataset.id), - 'query': query, - 'search_method': 'hybrid_search', - 'embeddings': embeddings, - 'score_threshold': retrieval_model[ - 'score_threshold'] if retrieval_model[ - 'score_threshold_enabled'] else None, - 'top_k': self.top_k, - 'reranking_model': retrieval_model[ - 'reranking_model'] if retrieval_model[ - 'reranking_enable'] else None, - 'all_documents': documents - }) - threads.append(full_text_index_thread) - full_text_index_thread.start() - - for thread in threads: - thread.join() + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) all_documents.extend(documents) \ No newline at end of file diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index c07cd3a5f..d3ec0fba6 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,20 +1,12 @@ -import threading from typing import Optional -from flask import current_app from langchain.tools import BaseTool from pydantic import BaseModel, Field from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.embedding.cached_embedding import CacheEmbedding -from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError -from core.rerank.rerank import RerankRunner +from core.rag.datasource.retrieval_service import RetrievalService from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -from services.retrieval_service import RetrievalService default_retrieval_model = { 'search_method': 'semantic_search', @@ -77,94 +69,24 @@ class DatasetRetrieverTool(BaseTool): retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query - kw_table_index = KeywordTableIndex( - dataset=dataset, - config=KeywordTableConfig( - max_keywords_per_chunk=5 - ) - ) - - documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k + ) return str("\n".join([document.page_content for document in documents])) else: - # get embedding model instance - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) - except InvokeAuthorizationError: - return '' - - embeddings = CacheEmbedding(embedding_model) - - documents = [] - threads = [] if self.top_k > 0: - # retrieval source with semantic - if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': str(dataset.id), - 'query': query, - 'top_k': self.top_k, - 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ - 'score_threshold_enabled'] else None, - 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ - 'reranking_enable'] else None, - 'all_documents': documents, - 'search_method': retrieval_model['search_method'], - 'embeddings': embeddings - }) - threads.append(embedding_thread) - embedding_thread.start() - - # retrieval_model source with full text - if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': str(dataset.id), - 'query': query, - 'search_method': retrieval_model['search_method'], - 'embeddings': embeddings, - 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ - 'score_threshold_enabled'] else None, - 'top_k': self.top_k, - 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ - 'reranking_enable'] else None, - 'all_documents': documents - }) - threads.append(full_text_index_thread) - full_text_index_thread.start() - - for thread in threads: - thread.join() - - # hybrid search: rerank after all documents have been searched - if retrieval_model['search_method'] == 'hybrid_search': - # get rerank model instance - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=retrieval_model['reranking_model']['reranking_provider_name'], - model_type=ModelType.RERANK, - model=retrieval_model['reranking_model']['reranking_model_name'] - ) - except InvokeAuthorizationError: - return '' - - rerank_runner = RerankRunner(rerank_model_instance) - documents = rerank_runner.run( - query=query, - documents=documents, - score_threshold=retrieval_model['score_threshold'] if retrieval_model[ - 'score_threshold_enabled'] else None, - top_n=self.top_k - ) + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) else: documents = [] @@ -234,4 +156,4 @@ class DatasetRetrieverTool(BaseTool): return str("\n".join(document_context_list)) async def _arun(self, tool_input: str) -> str: - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 72ff3cdbd..625cbffd2 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -21,9 +21,9 @@ from pydantic import BaseModel, Field from regex import regex from core.chain.llm_chain import LLMChain -from core.data_loader import file_extractor -from core.data_loader.file_extractor import FileExtractor from core.entities.application_entities import ModelConfigEntity +from core.rag.extractor import extract_processor +from core.rag.extractor.extract_processor import ExtractProcessor FULL_TEMPLATE = """ TITLE: {title} @@ -149,7 +149,7 @@ def get_url(url: str, user_agent: str = None) -> str: if user_agent: headers["User-Agent"] = user_agent - supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] + supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) @@ -161,8 +161,8 @@ def get_url(url: str, user_agent: str = None) -> str: if main_content_type not in supported_content_types: return "Unsupported content-type [{}] of URL.".format(main_content_type) - if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES: - return FileExtractor.load_from_url(url, return_text=True) + if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: + return ExtractProcessor.load_from_url(url, return_text=True) response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) a = extract_using_readabilipy(response.text) diff --git a/api/core/vector_store/milvus_vector_store.py b/api/core/vector_store/milvus_vector_store.py deleted file mode 100644 index 67b958ded..000000000 --- a/api/core/vector_store/milvus_vector_store.py +++ /dev/null @@ -1,56 +0,0 @@ -from core.vector_store.vector.milvus import Milvus - - -class MilvusVectorStore(Milvus): - def del_texts(self, where_filter: dict): - if not where_filter: - raise ValueError('where_filter must not be empty') - - self.col.delete(where_filter.get('filter')) - - def del_text(self, uuid: str) -> None: - expr = f"id == {uuid}" - self.col.delete(expr) - - def text_exists(self, uuid: str) -> bool: - result = self.col.query( - expr=f'metadata["doc_id"] == "{uuid}"', - output_fields=["id"] - ) - - return len(result) > 0 - - def get_ids_by_document_id(self, document_id: str): - result = self.col.query( - expr=f'metadata["document_id"] == "{document_id}"', - output_fields=["id"] - ) - if result: - return [item["id"] for item in result] - else: - return None - - def get_ids_by_metadata_field(self, key: str, value: str): - result = self.col.query( - expr=f'metadata["{key}"] == "{value}"', - output_fields=["id"] - ) - if result: - return [item["id"] for item in result] - else: - return None - - def get_ids_by_doc_ids(self, doc_ids: list): - result = self.col.query( - expr=f'metadata["doc_id"] in {doc_ids}', - output_fields=["id"] - ) - if result: - return [item["id"] for item in result] - else: - return None - - def delete(self): - from pymilvus import utility - utility.drop_collection(self.collection_name, None, self.alias) - diff --git a/api/core/vector_store/qdrant_vector_store.py b/api/core/vector_store/qdrant_vector_store.py deleted file mode 100644 index 53ad7b2aa..000000000 --- a/api/core/vector_store/qdrant_vector_store.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Any, cast - -from langchain.schema import Document -from qdrant_client.http.models import Filter, FilterSelector, PointIdsList -from qdrant_client.local.qdrant_local import QdrantLocal - -from core.vector_store.vector.qdrant import Qdrant - - -class QdrantVectorStore(Qdrant): - def del_texts(self, filter: Filter): - if not filter: - raise ValueError('filter must not be empty') - - self._reload_if_needed() - - self.client.delete( - collection_name=self.collection_name, - points_selector=FilterSelector( - filter=filter - ), - ) - - def del_text(self, uuid: str) -> None: - self._reload_if_needed() - - self.client.delete( - collection_name=self.collection_name, - points_selector=PointIdsList( - points=[uuid], - ), - ) - - def text_exists(self, uuid: str) -> bool: - self._reload_if_needed() - - response = self.client.retrieve( - collection_name=self.collection_name, - ids=[uuid] - ) - - return len(response) > 0 - - def delete(self): - self._reload_if_needed() - - self.client.delete_collection(collection_name=self.collection_name) - - def delete_group(self): - self._reload_if_needed() - - self.client.delete_collection(collection_name=self.collection_name) - - @classmethod - def _document_from_scored_point( - cls, - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, - ) -> Document: - if scored_point.payload.get('doc_id'): - return Document( - page_content=scored_point.payload.get(content_payload_key), - metadata={'doc_id': scored_point.id} - ) - - return Document( - page_content=scored_point.payload.get(content_payload_key), - metadata=scored_point.payload.get(metadata_payload_key) or {}, - ) - - def _reload_if_needed(self): - if isinstance(self.client, QdrantLocal): - self.client = cast(QdrantLocal, self.client) - self.client._load() - diff --git a/api/core/vector_store/vector/milvus.py b/api/core/vector_store/vector/milvus.py deleted file mode 100644 index 9d8695dc5..000000000 --- a/api/core/vector_store/vector/milvus.py +++ /dev/null @@ -1,852 +0,0 @@ -"""Wrapper around the Milvus vector database.""" -from __future__ import annotations - -import logging -from collections.abc import Iterable, Sequence -from typing import Any, Optional, Union -from uuid import uuid4 - -import numpy as np -from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings -from langchain.vectorstores.base import VectorStore -from langchain.vectorstores.utils import maximal_marginal_relevance - -logger = logging.getLogger(__name__) - -DEFAULT_MILVUS_CONNECTION = { - "host": "localhost", - "port": "19530", - "user": "", - "password": "", - "secure": False, -} - - -class Milvus(VectorStore): - """Initialize wrapper around the milvus vector database. - - In order to use this you need to have `pymilvus` installed and a - running Milvus - - See the following documentation for how to run a Milvus instance: - https://milvus.io/docs/install_standalone-docker.md - - If looking for a hosted Milvus, take a look at this documentation: - https://zilliz.com/cloud and make use of the Zilliz vectorstore found in - this project, - - IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA. - - Args: - embedding_function (Embeddings): Function used to embed the text. - collection_name (str): Which Milvus collection to use. Defaults to - "LangChainCollection". - connection_args (Optional[dict[str, any]]): The connection args used for - this class comes in the form of a dict. - consistency_level (str): The consistency level to use for a collection. - Defaults to "Session". - index_params (Optional[dict]): Which index params to use. Defaults to - HNSW/AUTOINDEX depending on service. - search_params (Optional[dict]): Which search params to use. Defaults to - default of index. - drop_old (Optional[bool]): Whether to drop the current collection. Defaults - to False. - - The connection args used for this class comes in the form of a dict, - here are a few of the options: - address (str): The actual address of Milvus - instance. Example address: "localhost:19530" - uri (str): The uri of Milvus instance. Example uri: - "http://randomwebsite:19530", - "tcp:foobarsite:19530", - "https://ok.s3.south.com:19530". - host (str): The host of Milvus instance. Default at "localhost", - PyMilvus will fill in the default host if only port is provided. - port (str/int): The port of Milvus instance. Default at 19530, PyMilvus - will fill in the default port if only host is provided. - user (str): Use which user to connect to Milvus instance. If user and - password are provided, we will add related header in every RPC call. - password (str): Required when user is provided. The password - corresponding to the user. - secure (bool): Default is false. If set to true, tls will be enabled. - client_key_path (str): If use tls two-way authentication, need to - write the client.key path. - client_pem_path (str): If use tls two-way authentication, need to - write the client.pem path. - ca_pem_path (str): If use tls two-way authentication, need to write - the ca.pem path. - server_pem_path (str): If use tls one-way authentication, need to - write the server.pem path. - server_name (str): If use tls, need to write the common name. - - Example: - .. code-block:: python - - from langchain import Milvus - from langchain.embeddings import OpenAIEmbeddings - - embedding = OpenAIEmbeddings() - # Connect to a milvus instance on localhost - milvus_store = Milvus( - embedding_function = Embeddings, - collection_name = "LangChainCollection", - drop_old = True, - ) - - Raises: - ValueError: If the pymilvus python package is not installed. - """ - - def __init__( - self, - embedding_function: Embeddings, - collection_name: str = "LangChainCollection", - connection_args: Optional[dict[str, Any]] = None, - consistency_level: str = "Session", - index_params: Optional[dict] = None, - search_params: Optional[dict] = None, - drop_old: Optional[bool] = False, - ): - """Initialize the Milvus vector store.""" - try: - from pymilvus import Collection, utility - except ImportError: - raise ValueError( - "Could not import pymilvus python package. " - "Please install it with `pip install pymilvus`." - ) - - # Default search params when one is not provided. - self.default_search_params = { - "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, - "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, - "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, - "HNSW": {"metric_type": "L2", "params": {"ef": 10}}, - "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}}, - "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}}, - "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, - "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}}, - "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}}, - "AUTOINDEX": {"metric_type": "L2", "params": {}}, - } - - self.embedding_func = embedding_function - self.collection_name = collection_name - self.index_params = index_params - self.search_params = search_params - self.consistency_level = consistency_level - - # In order for a collection to be compatible, pk needs to be auto'id and int - self._primary_field = "id" - # In order for compatibility, the text field will need to be called "text" - self._text_field = "page_content" - # In order for compatibility, the vector field needs to be called "vector" - self._vector_field = "vectors" - # In order for compatibility, the metadata field will need to be called "metadata" - self._metadata_field = "metadata" - self.fields: list[str] = [] - # Create the connection to the server - if connection_args is None: - connection_args = DEFAULT_MILVUS_CONNECTION - self.alias = self._create_connection_alias(connection_args) - self.col: Optional[Collection] = None - - # Grab the existing collection if it exists - if utility.has_collection(self.collection_name, using=self.alias): - self.col = Collection( - self.collection_name, - using=self.alias, - ) - # If need to drop old, drop it - if drop_old and isinstance(self.col, Collection): - self.col.drop() - self.col = None - - # Initialize the vector store - self._init() - - @property - def embeddings(self) -> Embeddings: - return self.embedding_func - - def _create_connection_alias(self, connection_args: dict) -> str: - """Create the connection to the Milvus server.""" - from pymilvus import MilvusException, connections - - # Grab the connection arguments that are used for checking existing connection - host: str = connection_args.get("host", None) - port: Union[str, int] = connection_args.get("port", None) - address: str = connection_args.get("address", None) - uri: str = connection_args.get("uri", None) - user = connection_args.get("user", None) - - # Order of use is host/port, uri, address - if host is not None and port is not None: - given_address = str(host) + ":" + str(port) - elif uri is not None: - given_address = uri.split("https://")[1] - elif address is not None: - given_address = address - else: - given_address = None - logger.debug("Missing standard address type for reuse atttempt") - - # User defaults to empty string when getting connection info - if user is not None: - tmp_user = user - else: - tmp_user = "" - - # If a valid address was given, then check if a connection exists - if given_address is not None: - for con in connections.list_connections(): - addr = connections.get_connection_addr(con[0]) - if ( - con[1] - and ("address" in addr) - and (addr["address"] == given_address) - and ("user" in addr) - and (addr["user"] == tmp_user) - ): - logger.debug("Using previous connection: %s", con[0]) - return con[0] - - # Generate a new connection if one doesn't exist - alias = uuid4().hex - try: - connections.connect(alias=alias, **connection_args) - logger.debug("Created new connection using: %s", alias) - return alias - except MilvusException as e: - logger.error("Failed to create new connection using: %s", alias) - raise e - - def _init( - self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None - ) -> None: - if embeddings is not None: - self._create_collection(embeddings, metadatas) - self._extract_fields() - self._create_index() - self._create_search_params() - self._load() - - def _create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None - ) -> None: - from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException - from pymilvus.orm.types import infer_dtype_bydata - - # Determine embedding dim - dim = len(embeddings[0]) - fields = [] - # Determine metadata schema - # if metadatas: - # # Create FieldSchema for each entry in metadata. - # for key, value in metadatas[0].items(): - # # Infer the corresponding datatype of the metadata - # dtype = infer_dtype_bydata(value) - # # Datatype isn't compatible - # if dtype == DataType.UNKNOWN or dtype == DataType.NONE: - # logger.error( - # "Failure to create collection, unrecognized dtype for key: %s", - # key, - # ) - # raise ValueError(f"Unrecognized datatype for {key}.") - # # Dataype is a string/varchar equivalent - # elif dtype == DataType.VARCHAR: - # fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535)) - # else: - # fields.append(FieldSchema(key, dtype)) - if metadatas: - fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535)) - - # Create the text field - fields.append( - FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535) - ) - # Create the primary key field - fields.append( - FieldSchema( - self._primary_field, DataType.INT64, is_primary=True, auto_id=True - ) - ) - # Create the vector field, supports binary or float vectors - fields.append( - FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim) - ) - - # Create the schema for the collection - schema = CollectionSchema(fields) - - # Create the collection - try: - self.col = Collection( - name=self.collection_name, - schema=schema, - consistency_level=self.consistency_level, - using=self.alias, - ) - except MilvusException as e: - logger.error( - "Failed to create collection: %s error: %s", self.collection_name, e - ) - raise e - - def _extract_fields(self) -> None: - """Grab the existing fields from the Collection""" - from pymilvus import Collection - - if isinstance(self.col, Collection): - schema = self.col.schema - for x in schema.fields: - self.fields.append(x.name) - # Since primary field is auto-id, no need to track it - self.fields.remove(self._primary_field) - - def _get_index(self) -> Optional[dict[str, Any]]: - """Return the vector index information if it exists""" - from pymilvus import Collection - - if isinstance(self.col, Collection): - for x in self.col.indexes: - if x.field_name == self._vector_field: - return x.to_dict() - return None - - def _create_index(self) -> None: - """Create a index on the collection""" - from pymilvus import Collection, MilvusException - - if isinstance(self.col, Collection) and self._get_index() is None: - try: - # If no index params, use a default HNSW based one - if self.index_params is None: - self.index_params = { - "metric_type": "IP", - "index_type": "HNSW", - "params": {"M": 8, "efConstruction": 64}, - } - - try: - self.col.create_index( - self._vector_field, - index_params=self.index_params, - using=self.alias, - ) - - # If default did not work, most likely on Zilliz Cloud - except MilvusException: - # Use AUTOINDEX based index - self.index_params = { - "metric_type": "L2", - "index_type": "AUTOINDEX", - "params": {}, - } - self.col.create_index( - self._vector_field, - index_params=self.index_params, - using=self.alias, - ) - logger.debug( - "Successfully created an index on collection: %s", - self.collection_name, - ) - - except MilvusException as e: - logger.error( - "Failed to create an index on collection: %s", self.collection_name - ) - raise e - - def _create_search_params(self) -> None: - """Generate search params based on the current index type""" - from pymilvus import Collection - - if isinstance(self.col, Collection) and self.search_params is None: - index = self._get_index() - if index is not None: - index_type: str = index["index_param"]["index_type"] - metric_type: str = index["index_param"]["metric_type"] - self.search_params = self.default_search_params[index_type] - self.search_params["metric_type"] = metric_type - - def _load(self) -> None: - """Load the collection if available.""" - from pymilvus import Collection - - if isinstance(self.col, Collection) and self._get_index() is not None: - self.col.load() - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - timeout: Optional[int] = None, - batch_size: int = 1000, - **kwargs: Any, - ) -> list[str]: - """Insert text data into Milvus. - - Inserting data when the collection has not be made yet will result - in creating a new Collection. The data of the first entity decides - the schema of the new collection, the dim is extracted from the first - embedding and the columns are decided by the first metadata dict. - Metada keys will need to be present for all inserted values. At - the moment there is no None equivalent in Milvus. - - Args: - texts (Iterable[str]): The texts to embed, it is assumed - that they all fit in memory. - metadatas (Optional[List[dict]]): Metadata dicts attached to each of - the texts. Defaults to None. - timeout (Optional[int]): Timeout for each batch insert. Defaults - to None. - batch_size (int, optional): Batch size to use for insertion. - Defaults to 1000. - - Raises: - MilvusException: Failure to add texts - - Returns: - List[str]: The resulting keys for each inserted element. - """ - from pymilvus import Collection, MilvusException - - texts = list(texts) - - try: - embeddings = self.embedding_func.embed_documents(texts) - except NotImplementedError: - embeddings = [self.embedding_func.embed_query(x) for x in texts] - - if len(embeddings) == 0: - logger.debug("Nothing to insert, skipping.") - return [] - - # If the collection hasn't been initialized yet, perform all steps to do so - if not isinstance(self.col, Collection): - self._init(embeddings, metadatas) - - # Dict to hold all insert columns - insert_dict: dict[str, list] = { - self._text_field: texts, - self._vector_field: embeddings, - } - - # Collect the metadata into the insert dict. - # if metadatas is not None: - # for d in metadatas: - # for key, value in d.items(): - # if key in self.fields: - # insert_dict.setdefault(key, []).append(value) - if metadatas is not None: - for d in metadatas: - insert_dict.setdefault(self._metadata_field, []).append(d) - - # Total insert count - vectors: list = insert_dict[self._vector_field] - total_count = len(vectors) - - pks: list[str] = [] - - assert isinstance(self.col, Collection) - for i in range(0, total_count, batch_size): - # Grab end index - end = min(i + batch_size, total_count) - # Convert dict to list of lists batch for insertion - insert_list = [insert_dict[x][i:end] for x in self.fields] - # Insert into the collection. - try: - res: Collection - res = self.col.insert(insert_list, timeout=timeout, **kwargs) - pks.extend(res.primary_keys) - except MilvusException as e: - logger.error( - "Failed to insert batch starting at entity: %s/%s", i, total_count - ) - raise e - return pks - - def similarity_search( - self, - query: str, - k: int = 4, - param: Optional[dict] = None, - expr: Optional[str] = None, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> list[Document]: - """Perform a similarity search against the query string. - - Args: - query (str): The text to search. - k (int, optional): How many results to return. Defaults to 4. - param (dict, optional): The search params for the index type. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - kwargs: Collection.search() keyword arguments. - - Returns: - List[Document]: Document results for search. - """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - res = self.similarity_search_with_score( - query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs - ) - return [doc for doc, _ in res] - - def similarity_search_by_vector( - self, - embedding: list[float], - k: int = 4, - param: Optional[dict] = None, - expr: Optional[str] = None, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> list[Document]: - """Perform a similarity search against the query string. - - Args: - embedding (List[float]): The embedding vector to search. - k (int, optional): How many results to return. Defaults to 4. - param (dict, optional): The search params for the index type. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - kwargs: Collection.search() keyword arguments. - - Returns: - List[Document]: Document results for search. - """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - res = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs - ) - return [doc for doc, _ in res] - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - param: Optional[dict] = None, - expr: Optional[str] = None, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Perform a search on a query string and return results with score. - - For more information about the search parameters, take a look at the pymilvus - documentation found here: - https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md - - Args: - query (str): The text being searched. - k (int, optional): The amount of results to return. Defaults to 4. - param (dict): The search params for the specified index. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - kwargs: Collection.search() keyword arguments. - - Returns: - List[float], List[Tuple[Document, any, any]]: - """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - - # Embed the query text. - embedding = self.embedding_func.embed_query(query) - - res = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs - ) - return res - - def _similarity_search_with_relevance_scores( - self, - query: str, - k: int = 4, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and relevance scores in the range [0, 1]. - - 0 is dissimilar, 1 is most similar. - - Args: - query: input text - k: Number of Documents to return. Defaults to 4. - **kwargs: kwargs to be passed to similarity search. Should include: - score_threshold: Optional, a floating point value between 0 to 1 to - filter the resulting set of retrieved docs - - Returns: - List of Tuples of (doc, similarity_score) - """ - return self.similarity_search_with_score(query, k, **kwargs) - - def similarity_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - param: Optional[dict] = None, - expr: Optional[str] = None, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Perform a search on a query string and return results with score. - - For more information about the search parameters, take a look at the pymilvus - documentation found here: - https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md - - Args: - embedding (List[float]): The embedding vector being searched. - k (int, optional): The amount of results to return. Defaults to 4. - param (dict): The search params for the specified index. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - kwargs: Collection.search() keyword arguments. - - Returns: - List[Tuple[Document, float]]: Result doc and score. - """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - - if param is None: - param = self.search_params - - # Determine result metadata fields. - output_fields = self.fields[:] - output_fields.remove(self._vector_field) - - # Perform the search. - res = self.col.search( - data=[embedding], - anns_field=self._vector_field, - param=param, - limit=k, - expr=expr, - output_fields=output_fields, - timeout=timeout, - **kwargs, - ) - # Organize results. - ret = [] - for result in res[0]: - meta = {x: result.entity.get(x) for x in output_fields} - doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata')) - pair = (doc, result.score) - ret.append(pair) - - return ret - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - param: Optional[dict] = None, - expr: Optional[str] = None, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> list[Document]: - """Perform a search and return results that are reordered by MMR. - - Args: - query (str): The text being searched. - k (int, optional): How many results to give. Defaults to 4. - fetch_k (int, optional): Total results to select k from. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5 - param (dict, optional): The search params for the specified index. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - kwargs: Collection.search() keyword arguments. - - - Returns: - List[Document]: Document results for search. - """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - - embedding = self.embedding_func.embed_query(query) - - return self.max_marginal_relevance_search_by_vector( - embedding=embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - param=param, - expr=expr, - timeout=timeout, - **kwargs, - ) - - def max_marginal_relevance_search_by_vector( - self, - embedding: list[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - param: Optional[dict] = None, - expr: Optional[str] = None, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> list[Document]: - """Perform a search and return results that are reordered by MMR. - - Args: - embedding (str): The embedding vector being searched. - k (int, optional): How many results to give. Defaults to 4. - fetch_k (int, optional): Total results to select k from. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5 - param (dict, optional): The search params for the specified index. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - kwargs: Collection.search() keyword arguments. - - Returns: - List[Document]: Document results for search. - """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - - if param is None: - param = self.search_params - - # Determine result metadata fields. - output_fields = self.fields[:] - output_fields.remove(self._vector_field) - - # Perform the search. - res = self.col.search( - data=[embedding], - anns_field=self._vector_field, - param=param, - limit=fetch_k, - expr=expr, - output_fields=output_fields, - timeout=timeout, - **kwargs, - ) - # Organize results. - ids = [] - documents = [] - scores = [] - for result in res[0]: - meta = {x: result.entity.get(x) for x in output_fields} - doc = Document(page_content=meta.pop(self._text_field), metadata=meta) - documents.append(doc) - scores.append(result.score) - ids.append(result.id) - - vectors = self.col.query( - expr=f"{self._primary_field} in {ids}", - output_fields=[self._primary_field, self._vector_field], - timeout=timeout, - ) - # Reorganize the results from query to match search order. - vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors} - - ordered_result_embeddings = [vectors[x] for x in ids] - - # Get the new order of results. - new_ordering = maximal_marginal_relevance( - np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult - ) - - # Reorder the values and return. - ret = [] - for x in new_ordering: - # Function can return -1 index - if x == -1: - break - else: - ret.append(documents[x]) - return ret - - @classmethod - def from_texts( - cls, - texts: list[str], - embedding: Embeddings, - metadatas: Optional[list[dict]] = None, - collection_name: str = "LangChainCollection", - connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION, - consistency_level: str = "Session", - index_params: Optional[dict] = None, - search_params: Optional[dict] = None, - drop_old: bool = False, - batch_size: int = 100, - ids: Optional[Sequence[str]] = None, - **kwargs: Any, - ) -> Milvus: - """Create a Milvus collection, indexes it with HNSW, and insert data. - - Args: - texts (List[str]): Text data. - embedding (Embeddings): Embedding function. - metadatas (Optional[List[dict]]): Metadata for each text if it exists. - Defaults to None. - collection_name (str, optional): Collection name to use. Defaults to - "LangChainCollection". - connection_args (dict[str, Any], optional): Connection args to use. Defaults - to DEFAULT_MILVUS_CONNECTION. - consistency_level (str, optional): Which consistency level to use. Defaults - to "Session". - index_params (Optional[dict], optional): Which index_params to use. Defaults - to None. - search_params (Optional[dict], optional): Which search params to use. - Defaults to None. - drop_old (Optional[bool], optional): Whether to drop the collection with - that name if it exists. Defaults to False. - batch_size: - How many vectors upload per-request. - Default: 100 - ids: Optional[Sequence[str]] = None, - - Returns: - Milvus: Milvus Vector Store - """ - vector_db = cls( - embedding_function=embedding, - collection_name=collection_name, - connection_args=connection_args, - consistency_level=consistency_level, - index_params=index_params, - search_params=search_params, - drop_old=drop_old, - **kwargs, - ) - vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size) - return vector_db diff --git a/api/core/vector_store/vector/qdrant.py b/api/core/vector_store/vector/qdrant.py deleted file mode 100644 index 47e5fe27c..000000000 --- a/api/core/vector_store/vector/qdrant.py +++ /dev/null @@ -1,1759 +0,0 @@ -"""Wrapper around Qdrant vector database.""" -from __future__ import annotations - -import asyncio -import functools -import uuid -import warnings -from collections.abc import Callable, Generator, Iterable, Sequence -from itertools import islice -from operator import itemgetter -from typing import TYPE_CHECKING, Any, Optional, Union - -import numpy as np -from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings -from langchain.vectorstores import VectorStore -from langchain.vectorstores.utils import maximal_marginal_relevance -from qdrant_client.http.models import PayloadSchemaType, TextIndexParams, TextIndexType, TokenizerType - -if TYPE_CHECKING: - from qdrant_client import grpc # noqa - from qdrant_client.conversions import common_types - from qdrant_client.http import models as rest - - DictFilter = dict[str, Union[str, int, bool, dict, list]] - MetadataFilter = Union[DictFilter, common_types.Filter] - - -class QdrantException(Exception): - """Base class for all the Qdrant related exceptions""" - - -def sync_call_fallback(method: Callable) -> Callable: - """ - Decorator to call the synchronous method of the class if the async method is not - implemented. This decorator might be only used for the methods that are defined - as async in the class. - """ - - @functools.wraps(method) - async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - try: - return await method(self, *args, **kwargs) - except NotImplementedError: - # If the async method is not implemented, call the synchronous method - # by removing the first letter from the method name. For example, - # if the async method is called ``aaad_texts``, the synchronous method - # will be called ``aad_texts``. - sync_method = functools.partial( - getattr(self, method.__name__[1:]), *args, **kwargs - ) - return await asyncio.get_event_loop().run_in_executor(None, sync_method) - - return wrapper - - -class Qdrant(VectorStore): - """Wrapper around Qdrant vector database. - - To use you should have the ``qdrant-client`` package installed. - - Example: - .. code-block:: python - - from qdrant_client import QdrantClient - from langchain import Qdrant - - client = QdrantClient() - collection_name = "MyCollection" - qdrant = Qdrant(client, collection_name, embedding_function) - """ - - CONTENT_KEY = "page_content" - METADATA_KEY = "metadata" - GROUP_KEY = "group_id" - VECTOR_NAME = None - - def __init__( - self, - client: Any, - collection_name: str, - embeddings: Optional[Embeddings] = None, - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - group_payload_key: str = GROUP_KEY, - group_id: str = None, - distance_strategy: str = "COSINE", - vector_name: Optional[str] = VECTOR_NAME, - embedding_function: Optional[Callable] = None, # deprecated - is_new_collection: bool = False - ): - """Initialize with necessary components.""" - try: - import qdrant_client - except ImportError: - raise ValueError( - "Could not import qdrant-client python package. " - "Please install it with `pip install qdrant-client`." - ) - - if not isinstance(client, qdrant_client.QdrantClient): - raise ValueError( - f"client should be an instance of qdrant_client.QdrantClient, " - f"got {type(client)}" - ) - - if embeddings is None and embedding_function is None: - raise ValueError( - "`embeddings` value can't be None. Pass `Embeddings` instance." - ) - - if embeddings is not None and embedding_function is not None: - raise ValueError( - "Both `embeddings` and `embedding_function` are passed. " - "Use `embeddings` only." - ) - - self._embeddings = embeddings - self._embeddings_function = embedding_function - self.client: qdrant_client.QdrantClient = client - self.collection_name = collection_name - self.content_payload_key = content_payload_key or self.CONTENT_KEY - self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY - self.group_payload_key = group_payload_key or self.GROUP_KEY - self.vector_name = vector_name or self.VECTOR_NAME - self.group_id = group_id - self.is_new_collection= is_new_collection - - if embedding_function is not None: - warnings.warn( - "Using `embedding_function` is deprecated. " - "Pass `Embeddings` instance to `embeddings` instead." - ) - - if not isinstance(embeddings, Embeddings): - warnings.warn( - "`embeddings` should be an instance of `Embeddings`." - "Using `embeddings` as `embedding_function` which is deprecated" - ) - self._embeddings_function = embeddings - self._embeddings = None - - self.distance_strategy = distance_strategy.upper() - - @property - def embeddings(self) -> Optional[Embeddings]: - return self._embeddings - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - **kwargs: Any, - ) -> list[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - ids: - Optional list of ids to associate with the texts. Ids have to be - uuid-like strings. - batch_size: - How many vectors upload per-request. - Default: 64 - group_id: - collection group - - Returns: - List of ids from adding the texts into the vectorstore. - """ - added_ids = [] - for batch_ids, points in self._generate_rest_batches( - texts, metadatas, ids, batch_size - ): - self.client.upsert( - collection_name=self.collection_name, points=points - ) - added_ids.extend(batch_ids) - # if is new collection, create payload index on group_id - if self.is_new_collection: - # create payload index - self.client.create_payload_index(self.collection_name, self.group_payload_key, - field_schema=PayloadSchemaType.KEYWORD, - field_type=PayloadSchemaType.KEYWORD) - # creat full text index - text_index_params = TextIndexParams( - type=TextIndexType.TEXT, - tokenizer=TokenizerType.MULTILINGUAL, - min_token_len=2, - max_token_len=20, - lowercase=True - ) - self.client.create_payload_index(self.collection_name, self.content_payload_key, - field_schema=text_index_params) - return added_ids - - @sync_call_fallback - async def aadd_texts( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - **kwargs: Any, - ) -> list[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - ids: - Optional list of ids to associate with the texts. Ids have to be - uuid-like strings. - batch_size: - How many vectors upload per-request. - Default: 64 - - Returns: - List of ids from adding the texts into the vectorstore. - """ - from qdrant_client import grpc # noqa - from qdrant_client.conversions.conversion import RestToGrpc - - added_ids = [] - for batch_ids, points in self._generate_rest_batches( - texts, metadatas, ids, batch_size - ): - await self.client.async_grpc_points.Upsert( - grpc.UpsertPoints( - collection_name=self.collection_name, - points=[RestToGrpc.convert_point_struct(point) for point in points], - ) - ) - added_ids.extend(batch_ids) - - return added_ids - - def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of Documents most similar to the query. - """ - results = self.similarity_search_with_score( - query, - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return list(map(itemgetter(0), results)) - - @sync_call_fallback - async def asimilarity_search( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilter] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs most similar to query. - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - Returns: - List of Documents most similar to the query. - """ - results = await self.asimilarity_search_with_score(query, k, filter, **kwargs) - return list(map(itemgetter(0), results)) - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of documents most similar to the query text and distance for each. - """ - return self.similarity_search_with_score_by_vector( - self._embed_query(query), - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - - @sync_call_fallback - async def asimilarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of documents most similar to the query text and distance for each. - """ - return await self.asimilarity_search_with_score_by_vector( - self._embed_query(query), - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - - def similarity_search_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of Documents most similar to the query. - """ - results = self.similarity_search_with_score_by_vector( - embedding, - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return list(map(itemgetter(0), results)) - - @sync_call_fallback - async def asimilarity_search_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of Documents most similar to the query. - """ - results = await self.asimilarity_search_with_score_by_vector( - embedding, - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return list(map(itemgetter(0), results)) - - def similarity_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of documents most similar to the query text and distance for each. - """ - if filter is not None and isinstance(filter, dict): - warnings.warn( - "Using dict as a `filter` is deprecated. Please use qdrant-client " - "filters directly: " - "https://qdrant.tech/documentation/concepts/filtering/", - DeprecationWarning, - ) - qdrant_filter = self._qdrant_filter_from_dict(filter) - else: - qdrant_filter = filter - - query_vector = embedding - if self.vector_name is not None: - query_vector = (self.vector_name, embedding) # type: ignore[assignment] - - results = self.client.search( - collection_name=self.collection_name, - query_vector=query_vector, - query_filter=qdrant_filter, - search_params=search_params, - limit=k, - offset=offset, - with_payload=True, - with_vectors=True, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return [ - ( - self._document_from_scored_point( - result, self.content_payload_key, self.metadata_payload_key - ), - result.score, - ) - for result in results - ] - - def similarity_search_by_bm25( - self, - filter: Optional[MetadataFilter] = None, - k: int = 4 - ) -> list[Document]: - """Return docs most similar by bm25. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - Returns: - List of documents most similar to the query text and distance for each. - """ - response = self.client.scroll( - collection_name=self.collection_name, - scroll_filter=filter, - limit=k, - with_payload=True, - with_vectors=True - - ) - results = response[0] - documents = [] - for result in results: - if result: - documents.append(self._document_from_scored_point( - result, self.content_payload_key, self.metadata_payload_key - )) - - return documents - - @sync_call_fallback - async def asimilarity_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - - Returns: - List of documents most similar to the query text and distance for each. - """ - from qdrant_client import grpc # noqa - from qdrant_client.conversions.conversion import RestToGrpc - from qdrant_client.http import models as rest - - if filter is not None and isinstance(filter, dict): - warnings.warn( - "Using dict as a `filter` is deprecated. Please use qdrant-client " - "filters directly: " - "https://qdrant.tech/documentation/concepts/filtering/", - DeprecationWarning, - ) - qdrant_filter = self._qdrant_filter_from_dict(filter) - else: - qdrant_filter = filter - - if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter): - qdrant_filter = RestToGrpc.convert_filter(qdrant_filter) - - response = await self.client.async_grpc_points.Search( - grpc.SearchPoints( - collection_name=self.collection_name, - vector_name=self.vector_name, - vector=embedding, - filter=qdrant_filter, - params=search_params, - limit=k, - offset=offset, - with_payload=grpc.WithPayloadSelector(enable=True), - with_vectors=grpc.WithVectorsSelector(enable=False), - score_threshold=score_threshold, - read_consistency=consistency, - **kwargs, - ) - ) - - return [ - ( - self._document_from_scored_point_grpc( - result, self.content_payload_key, self.metadata_payload_key - ), - result.score, - ) - for result in response.result - ] - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance. - """ - query_embedding = self._embed_query(query) - return self.max_marginal_relevance_search_by_vector( - query_embedding, k, fetch_k, lambda_mult, **kwargs - ) - - @sync_call_fallback - async def amax_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance. - """ - query_embedding = self._embed_query(query) - return await self.amax_marginal_relevance_search_by_vector( - query_embedding, k, fetch_k, lambda_mult, **kwargs - ) - - def max_marginal_relevance_search_by_vector( - self, - embedding: list[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance. - """ - results = self.max_marginal_relevance_search_with_score_by_vector( - embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs - ) - return list(map(itemgetter(0), results)) - - @sync_call_fallback - async def amax_marginal_relevance_search_by_vector( - self, - embedding: list[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance and distance for - each. - """ - results = await self.amax_marginal_relevance_search_with_score_by_vector( - embedding, k, fetch_k, lambda_mult, **kwargs - ) - return list(map(itemgetter(0), results)) - - def max_marginal_relevance_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance and distance for - each. - """ - query_vector = embedding - if self.vector_name is not None: - query_vector = (self.vector_name, query_vector) # type: ignore[assignment] - - results = self.client.search( - collection_name=self.collection_name, - query_vector=query_vector, - with_payload=True, - with_vectors=True, - limit=fetch_k, - ) - embeddings = [ - result.vector.get(self.vector_name) # type: ignore[index, union-attr] - if self.vector_name is not None - else result.vector - for result in results - ] - mmr_selected = maximal_marginal_relevance( - np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult - ) - return [ - ( - self._document_from_scored_point( - results[i], self.content_payload_key, self.metadata_payload_key - ), - results[i].score, - ) - for i in mmr_selected - ] - - @sync_call_fallback - async def amax_marginal_relevance_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance and distance for - each. - """ - from qdrant_client import grpc # noqa - from qdrant_client.conversions.conversion import GrpcToRest - - response = await self.client.async_grpc_points.Search( - grpc.SearchPoints( - collection_name=self.collection_name, - vector_name=self.vector_name, - vector=embedding, - with_payload=grpc.WithPayloadSelector(enable=True), - with_vectors=grpc.WithVectorsSelector(enable=True), - limit=fetch_k, - ) - ) - results = [ - GrpcToRest.convert_vectors(result.vectors) for result in response.result - ] - embeddings: list[list[float]] = [ - result.get(self.vector_name) # type: ignore - if isinstance(result, dict) - else result - for result in results - ] - mmr_selected: list[int] = maximal_marginal_relevance( - np.array(embedding), - embeddings, - k=k, - lambda_mult=lambda_mult, - ) - return [ - ( - self._document_from_scored_point_grpc( - response.result[i], - self.content_payload_key, - self.metadata_payload_key, - ), - response.result[i].score, - ) - for i in mmr_selected - ] - - def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]: - """Delete by vector ID or other criteria. - - Args: - ids: List of ids to delete. - **kwargs: Other keyword arguments that subclasses might use. - - Returns: - Optional[bool]: True if deletion is successful, - False otherwise, None if not implemented. - """ - from qdrant_client.http import models as rest - - result = self.client.delete( - collection_name=self.collection_name, - points_selector=ids, - ) - return result.status == rest.UpdateStatus.COMPLETED - - @classmethod - def from_texts( - cls: type[Qdrant], - texts: list[str], - embedding: Embeddings, - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, - distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - group_payload_key: str = GROUP_KEY, - group_id: str = None, - vector_name: Optional[str] = VECTOR_NAME, - batch_size: int = 64, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, - hnsw_config: Optional[common_types.HnswConfigDiff] = None, - optimizers_config: Optional[common_types.OptimizersConfigDiff] = None, - wal_config: Optional[common_types.WalConfigDiff] = None, - quantization_config: Optional[common_types.QuantizationConfig] = None, - init_from: Optional[common_types.InitFrom] = None, - force_recreate: bool = False, - **kwargs: Any, - ) -> Qdrant: - """Construct Qdrant wrapper from a list of texts. - - Args: - texts: A list of texts to be indexed in Qdrant. - embedding: A subclass of `Embeddings`, responsible for text vectorization. - metadatas: - An optional list of metadata. If provided it has to be of the same - length as a list of texts. - ids: - Optional list of ids to associate with the texts. Ids have to be - uuid-like strings. - location: - If `:memory:` - use in-memory Qdrant instance. - If `str` - use it as a `url` parameter. - If `None` - fallback to relying on `host` and `port` parameters. - url: either host or str of "Optional[scheme], host, Optional[port], - Optional[prefix]". Default: `None` - port: Port of the REST API interface. Default: 6333 - grpc_port: Port of the gRPC interface. Default: 6334 - prefer_grpc: - If true - use gPRC interface whenever possible in custom methods. - Default: False - https: If true - use HTTPS(SSL) protocol. Default: None - api_key: API key for authentication in Qdrant Cloud. Default: None - prefix: - If not None - add prefix to the REST URL path. - Example: service/v1 will result in - http://localhost:6333/service/v1/{qdrant-endpoint} for REST API. - Default: None - timeout: - Timeout for REST and gRPC API requests. - Default: 5.0 seconds for REST and unlimited for gRPC - host: - Host name of Qdrant service. If url and host are None, set to - 'localhost'. Default: None - path: - Path in which the vectors will be stored while using local mode. - Default: None - collection_name: - Name of the Qdrant collection to be used. If not provided, - it will be created randomly. Default: None - distance_func: - Distance function. One of: "Cosine" / "Euclid" / "Dot". - Default: "Cosine" - content_payload_key: - A payload key used to store the content of the document. - Default: "page_content" - metadata_payload_key: - A payload key used to store the metadata of the document. - Default: "metadata" - group_payload_key: - A payload key used to store the content of the document. - Default: "group_id" - group_id: - collection group id - vector_name: - Name of the vector to be used internally in Qdrant. - Default: None - batch_size: - How many vectors upload per-request. - Default: 64 - shard_number: Number of shards in collection. Default is 1, minimum is 1. - replication_factor: - Replication factor for collection. Default is 1, minimum is 1. - Defines how many copies of each shard will be created. - Have effect only in distributed mode. - write_consistency_factor: - Write consistency factor for collection. Default is 1, minimum is 1. - Defines how many replicas should apply the operation for us to consider - it successful. Increasing this number will make the collection more - resilient to inconsistencies, but will also make it fail if not enough - replicas are available. - Does not have any performance impact. - Have effect only in distributed mode. - on_disk_payload: - If true - point`s payload will not be stored in memory. - It will be read from the disk every time it is requested. - This setting saves RAM by (slightly) increasing the response time. - Note: those payload values that are involved in filtering and are - indexed - remain in RAM. - hnsw_config: Params for HNSW index - optimizers_config: Params for optimizer - wal_config: Params for Write-Ahead-Log - quantization_config: - Params for quantization, if None - quantization will be disabled - init_from: - Use data stored in another collection to initialize this collection - force_recreate: - Force recreating the collection - **kwargs: - Additional arguments passed directly into REST client initialization - - This is a user-friendly interface that: - 1. Creates embeddings, one for each text - 2. Initializes the Qdrant database as an in-memory docstore by default - (and overridable to a remote docstore) - 3. Adds the text embeddings to the Qdrant database - - This is intended to be a quick way to get started. - - Example: - .. code-block:: python - - from langchain import Qdrant - from langchain.embeddings import OpenAIEmbeddings - embeddings = OpenAIEmbeddings() - qdrant = Qdrant.from_texts(texts, embeddings, "localhost") - """ - qdrant = cls._construct_instance( - texts, - embedding, - metadatas, - ids, - location, - url, - port, - grpc_port, - prefer_grpc, - https, - api_key, - prefix, - timeout, - host, - path, - collection_name, - distance_func, - content_payload_key, - metadata_payload_key, - group_payload_key, - group_id, - vector_name, - shard_number, - replication_factor, - write_consistency_factor, - on_disk_payload, - hnsw_config, - optimizers_config, - wal_config, - quantization_config, - init_from, - force_recreate, - **kwargs, - ) - qdrant.add_texts(texts, metadatas, ids, batch_size) - return qdrant - - @classmethod - @sync_call_fallback - async def afrom_texts( - cls: type[Qdrant], - texts: list[str], - embedding: Embeddings, - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, - distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - vector_name: Optional[str] = VECTOR_NAME, - batch_size: int = 64, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, - hnsw_config: Optional[common_types.HnswConfigDiff] = None, - optimizers_config: Optional[common_types.OptimizersConfigDiff] = None, - wal_config: Optional[common_types.WalConfigDiff] = None, - quantization_config: Optional[common_types.QuantizationConfig] = None, - init_from: Optional[common_types.InitFrom] = None, - force_recreate: bool = False, - **kwargs: Any, - ) -> Qdrant: - """Construct Qdrant wrapper from a list of texts. - - Args: - texts: A list of texts to be indexed in Qdrant. - embedding: A subclass of `Embeddings`, responsible for text vectorization. - metadatas: - An optional list of metadata. If provided it has to be of the same - length as a list of texts. - ids: - Optional list of ids to associate with the texts. Ids have to be - uuid-like strings. - location: - If `:memory:` - use in-memory Qdrant instance. - If `str` - use it as a `url` parameter. - If `None` - fallback to relying on `host` and `port` parameters. - url: either host or str of "Optional[scheme], host, Optional[port], - Optional[prefix]". Default: `None` - port: Port of the REST API interface. Default: 6333 - grpc_port: Port of the gRPC interface. Default: 6334 - prefer_grpc: - If true - use gPRC interface whenever possible in custom methods. - Default: False - https: If true - use HTTPS(SSL) protocol. Default: None - api_key: API key for authentication in Qdrant Cloud. Default: None - prefix: - If not None - add prefix to the REST URL path. - Example: service/v1 will result in - http://localhost:6333/service/v1/{qdrant-endpoint} for REST API. - Default: None - timeout: - Timeout for REST and gRPC API requests. - Default: 5.0 seconds for REST and unlimited for gRPC - host: - Host name of Qdrant service. If url and host are None, set to - 'localhost'. Default: None - path: - Path in which the vectors will be stored while using local mode. - Default: None - collection_name: - Name of the Qdrant collection to be used. If not provided, - it will be created randomly. Default: None - distance_func: - Distance function. One of: "Cosine" / "Euclid" / "Dot". - Default: "Cosine" - content_payload_key: - A payload key used to store the content of the document. - Default: "page_content" - metadata_payload_key: - A payload key used to store the metadata of the document. - Default: "metadata" - vector_name: - Name of the vector to be used internally in Qdrant. - Default: None - batch_size: - How many vectors upload per-request. - Default: 64 - shard_number: Number of shards in collection. Default is 1, minimum is 1. - replication_factor: - Replication factor for collection. Default is 1, minimum is 1. - Defines how many copies of each shard will be created. - Have effect only in distributed mode. - write_consistency_factor: - Write consistency factor for collection. Default is 1, minimum is 1. - Defines how many replicas should apply the operation for us to consider - it successful. Increasing this number will make the collection more - resilient to inconsistencies, but will also make it fail if not enough - replicas are available. - Does not have any performance impact. - Have effect only in distributed mode. - on_disk_payload: - If true - point`s payload will not be stored in memory. - It will be read from the disk every time it is requested. - This setting saves RAM by (slightly) increasing the response time. - Note: those payload values that are involved in filtering and are - indexed - remain in RAM. - hnsw_config: Params for HNSW index - optimizers_config: Params for optimizer - wal_config: Params for Write-Ahead-Log - quantization_config: - Params for quantization, if None - quantization will be disabled - init_from: - Use data stored in another collection to initialize this collection - force_recreate: - Force recreating the collection - **kwargs: - Additional arguments passed directly into REST client initialization - - This is a user-friendly interface that: - 1. Creates embeddings, one for each text - 2. Initializes the Qdrant database as an in-memory docstore by default - (and overridable to a remote docstore) - 3. Adds the text embeddings to the Qdrant database - - This is intended to be a quick way to get started. - - Example: - .. code-block:: python - - from langchain import Qdrant - from langchain.embeddings import OpenAIEmbeddings - embeddings = OpenAIEmbeddings() - qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost") - """ - qdrant = cls._construct_instance( - texts, - embedding, - metadatas, - ids, - location, - url, - port, - grpc_port, - prefer_grpc, - https, - api_key, - prefix, - timeout, - host, - path, - collection_name, - distance_func, - content_payload_key, - metadata_payload_key, - vector_name, - shard_number, - replication_factor, - write_consistency_factor, - on_disk_payload, - hnsw_config, - optimizers_config, - wal_config, - quantization_config, - init_from, - force_recreate, - **kwargs, - ) - await qdrant.aadd_texts(texts, metadatas, ids, batch_size) - return qdrant - - @classmethod - def _construct_instance( - cls: type[Qdrant], - texts: list[str], - embedding: Embeddings, - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, - distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - group_payload_key: str = GROUP_KEY, - group_id: str = None, - vector_name: Optional[str] = VECTOR_NAME, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, - hnsw_config: Optional[common_types.HnswConfigDiff] = None, - optimizers_config: Optional[common_types.OptimizersConfigDiff] = None, - wal_config: Optional[common_types.WalConfigDiff] = None, - quantization_config: Optional[common_types.QuantizationConfig] = None, - init_from: Optional[common_types.InitFrom] = None, - force_recreate: bool = False, - **kwargs: Any, - ) -> Qdrant: - try: - import qdrant_client - except ImportError: - raise ValueError( - "Could not import qdrant-client python package. " - "Please install it with `pip install qdrant-client`." - ) - from qdrant_client.http import models as rest - - # Just do a single quick embedding to get vector size - partial_embeddings = embedding.embed_documents(texts[:1]) - vector_size = len(partial_embeddings[0]) - collection_name = collection_name or uuid.uuid4().hex - distance_func = distance_func.upper() - is_new_collection = False - client = qdrant_client.QdrantClient( - location=location, - url=url, - port=port, - grpc_port=grpc_port, - prefer_grpc=prefer_grpc, - https=https, - api_key=api_key, - prefix=prefix, - timeout=timeout, - host=host, - path=path, - **kwargs, - ) - all_collection_name = [] - collections_response = client.get_collections() - collection_list = collections_response.collections - for collection in collection_list: - all_collection_name.append(collection.name) - if collection_name not in all_collection_name: - vectors_config = rest.VectorParams( - size=vector_size, - distance=rest.Distance[distance_func], - ) - - # If vector name was provided, we're going to use the named vectors feature - # with just a single vector. - if vector_name is not None: - vectors_config = { # type: ignore[assignment] - vector_name: vectors_config, - } - - client.recreate_collection( - collection_name=collection_name, - vectors_config=vectors_config, - shard_number=shard_number, - replication_factor=replication_factor, - write_consistency_factor=write_consistency_factor, - on_disk_payload=on_disk_payload, - hnsw_config=hnsw_config, - optimizers_config=optimizers_config, - wal_config=wal_config, - quantization_config=quantization_config, - init_from=init_from, - timeout=int(timeout), # type: ignore[arg-type] - ) - is_new_collection = True - if force_recreate: - raise ValueError - - # Get the vector configuration of the existing collection and vector, if it - # was specified. If the old configuration does not match the current one, - # an exception is being thrown. - collection_info = client.get_collection(collection_name=collection_name) - current_vector_config = collection_info.config.params.vectors - if isinstance(current_vector_config, dict) and vector_name is not None: - if vector_name not in current_vector_config: - raise QdrantException( - f"Existing Qdrant collection {collection_name} does not " - f"contain vector named {vector_name}. Did you mean one of the " - f"existing vectors: {', '.join(current_vector_config.keys())}? " - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - current_vector_config = current_vector_config.get( - vector_name - ) # type: ignore[assignment] - elif isinstance(current_vector_config, dict) and vector_name is None: - raise QdrantException( - f"Existing Qdrant collection {collection_name} uses named vectors. " - f"If you want to reuse it, please set `vector_name` to any of the " - f"existing named vectors: " - f"{', '.join(current_vector_config.keys())}." # noqa - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - elif ( - not isinstance(current_vector_config, dict) and vector_name is not None - ): - raise QdrantException( - f"Existing Qdrant collection {collection_name} doesn't use named " - f"vectors. If you want to reuse it, please set `vector_name` to " - f"`None`. If you want to recreate the collection, set " - f"`force_recreate` parameter to `True`." - ) - - # Check if the vector configuration has the same dimensionality. - if current_vector_config.size != vector_size: # type: ignore[union-attr] - raise QdrantException( - f"Existing Qdrant collection is configured for vectors with " - f"{current_vector_config.size} " # type: ignore[union-attr] - f"dimensions. Selected embeddings are {vector_size}-dimensional. " - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - - current_distance_func = ( - current_vector_config.distance.name.upper() # type: ignore[union-attr] - ) - if current_distance_func != distance_func: - raise QdrantException( - f"Existing Qdrant collection is configured for " - f"{current_vector_config.distance} " # type: ignore[union-attr] - f"similarity. Please set `distance_func` parameter to " - f"`{distance_func}` if you want to reuse it. If you want to " - f"recreate the collection, set `force_recreate` parameter to " - f"`True`." - ) - qdrant = cls( - client=client, - collection_name=collection_name, - embeddings=embedding, - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - distance_strategy=distance_func, - vector_name=vector_name, - group_id=group_id, - group_payload_key=group_payload_key, - is_new_collection=is_new_collection - ) - return qdrant - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """ - The 'correct' relevance function - may differ depending on a few things, including: - - the distance / similarity metric used by the VectorStore - - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) - - embedding dimensionality - - etc. - """ - - if self.distance_strategy == "COSINE": - return self._cosine_relevance_score_fn - elif self.distance_strategy == "DOT": - return self._max_inner_product_relevance_score_fn - elif self.distance_strategy == "EUCLID": - return self._euclidean_relevance_score_fn - else: - raise ValueError( - "Unknown distance strategy, must be cosine, " - "max_inner_product, or euclidean" - ) - - def _similarity_search_with_relevance_scores( - self, - query: str, - k: int = 4, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and relevance scores in the range [0, 1]. - - 0 is dissimilar, 1 is most similar. - - Args: - query: input text - k: Number of Documents to return. Defaults to 4. - **kwargs: kwargs to be passed to similarity search. Should include: - score_threshold: Optional, a floating point value between 0 to 1 to - filter the resulting set of retrieved docs - - Returns: - List of Tuples of (doc, similarity_score) - """ - return self.similarity_search_with_score(query, k, **kwargs) - - @classmethod - def _build_payloads( - cls, - texts: Iterable[str], - metadatas: Optional[list[dict]], - content_payload_key: str, - metadata_payload_key: str, - group_id: str, - group_payload_key: str - ) -> list[dict]: - payloads = [] - for i, text in enumerate(texts): - if text is None: - raise ValueError( - "At least one of the texts is None. Please remove it before " - "calling .from_texts or .add_texts on Qdrant instance." - ) - metadata = metadatas[i] if metadatas is not None else None - payloads.append( - { - content_payload_key: text, - metadata_payload_key: metadata, - group_payload_key: group_id - } - ) - - return payloads - - @classmethod - def _document_from_scored_point( - cls, - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, - ) -> Document: - return Document( - page_content=scored_point.payload.get(content_payload_key), - metadata=scored_point.payload.get(metadata_payload_key) or {}, - ) - - @classmethod - def _document_from_scored_point_grpc( - cls, - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, - ) -> Document: - from qdrant_client.conversions.conversion import grpc_to_payload - - payload = grpc_to_payload(scored_point.payload) - return Document( - page_content=payload[content_payload_key], - metadata=payload.get(metadata_payload_key) or {}, - ) - - def _build_condition(self, key: str, value: Any) -> list[rest.FieldCondition]: - from qdrant_client.http import models as rest - - out = [] - - if isinstance(value, dict): - for _key, value in value.items(): - out.extend(self._build_condition(f"{key}.{_key}", value)) - elif isinstance(value, list): - for _value in value: - if isinstance(_value, dict): - out.extend(self._build_condition(f"{key}[]", _value)) - else: - out.extend(self._build_condition(f"{key}", _value)) - else: - out.append( - rest.FieldCondition( - key=key, - match=rest.MatchValue(value=value), - ) - ) - - return out - - def _qdrant_filter_from_dict( - self, filter: Optional[DictFilter] - ) -> Optional[rest.Filter]: - from qdrant_client.http import models as rest - - if not filter: - return None - - return rest.Filter( - must=[ - condition - for key, value in filter.items() - for condition in self._build_condition(key, value) - ] - ) - - def _embed_query(self, query: str) -> list[float]: - """Embed query text. - - Used to provide backward compatibility with `embedding_function` argument. - - Args: - query: Query text. - - Returns: - List of floats representing the query embedding. - """ - if self.embeddings is not None: - embedding = self.embeddings.embed_query(query) - else: - if self._embeddings_function is not None: - embedding = self._embeddings_function(query) - else: - raise ValueError("Neither of embeddings or embedding_function is set") - return embedding.tolist() if hasattr(embedding, "tolist") else embedding - - def _embed_texts(self, texts: Iterable[str]) -> list[list[float]]: - """Embed search texts. - - Used to provide backward compatibility with `embedding_function` argument. - - Args: - texts: Iterable of texts to embed. - - Returns: - List of floats representing the texts embedding. - """ - if self.embeddings is not None: - embeddings = self.embeddings.embed_documents(list(texts)) - if hasattr(embeddings, "tolist"): - embeddings = embeddings.tolist() - elif self._embeddings_function is not None: - embeddings = [] - for text in texts: - embedding = self._embeddings_function(text) - if hasattr(embeddings, "tolist"): - embedding = embedding.tolist() - embeddings.append(embedding) - else: - raise ValueError("Neither of embeddings or embedding_function is set") - - return embeddings - - def _generate_rest_batches( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - group_id: Optional[str] = None, - ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: - from qdrant_client.http import models as rest - - texts_iterator = iter(texts) - metadatas_iterator = iter(metadatas or []) - ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) - while batch_texts := list(islice(texts_iterator, batch_size)): - # Take the corresponding metadata and id for each text in a batch - batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None - batch_ids = list(islice(ids_iterator, batch_size)) - - # Generate the embeddings for all the texts in a batch - batch_embeddings = self._embed_texts(batch_texts) - - points = [ - rest.PointStruct( - id=point_id, - vector=vector - if self.vector_name is None - else {self.vector_name: vector}, - payload=payload, - ) - for point_id, vector, payload in zip( - batch_ids, - batch_embeddings, - self._build_payloads( - batch_texts, - batch_metadatas, - self.content_payload_key, - self.metadata_payload_key, - self.group_id, - self.group_payload_key - ), - ) - ] - - yield batch_ids, points diff --git a/api/core/vector_store/vector/weaviate.py b/api/core/vector_store/vector/weaviate.py deleted file mode 100644 index 8ac77152e..000000000 --- a/api/core/vector_store/vector/weaviate.py +++ /dev/null @@ -1,506 +0,0 @@ -"""Wrapper around weaviate vector database.""" -from __future__ import annotations - -import datetime -from collections.abc import Callable, Iterable -from typing import Any, Optional -from uuid import uuid4 - -import numpy as np -from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings -from langchain.utils import get_from_dict_or_env -from langchain.vectorstores.base import VectorStore -from langchain.vectorstores.utils import maximal_marginal_relevance - - -def _default_schema(index_name: str) -> dict: - return { - "class": index_name, - "properties": [ - { - "name": "text", - "dataType": ["text"], - } - ], - } - - -def _create_weaviate_client(**kwargs: Any) -> Any: - client = kwargs.get("client") - if client is not None: - return client - - weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL") - - try: - # the weaviate api key param should not be mandatory - weaviate_api_key = get_from_dict_or_env( - kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None - ) - except ValueError: - weaviate_api_key = None - - try: - import weaviate - except ImportError: - raise ValueError( - "Could not import weaviate python package. " - "Please install it with `pip install weaviate-client`" - ) - - auth = ( - weaviate.auth.AuthApiKey(api_key=weaviate_api_key) - if weaviate_api_key is not None - else None - ) - client = weaviate.Client(weaviate_url, auth_client_secret=auth) - - return client - - -def _default_score_normalizer(val: float) -> float: - return 1 - val - - -def _json_serializable(value: Any) -> Any: - if isinstance(value, datetime.datetime): - return value.isoformat() - return value - - -class Weaviate(VectorStore): - """Wrapper around Weaviate vector database. - - To use, you should have the ``weaviate-client`` python package installed. - - Example: - .. code-block:: python - - import weaviate - from langchain.vectorstores import Weaviate - client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...) - weaviate = Weaviate(client, index_name, text_key) - - """ - - def __init__( - self, - client: Any, - index_name: str, - text_key: str, - embedding: Optional[Embeddings] = None, - attributes: Optional[list[str]] = None, - relevance_score_fn: Optional[ - Callable[[float], float] - ] = _default_score_normalizer, - by_text: bool = True, - ): - """Initialize with Weaviate client.""" - try: - import weaviate - except ImportError: - raise ValueError( - "Could not import weaviate python package. " - "Please install it with `pip install weaviate-client`." - ) - if not isinstance(client, weaviate.Client): - raise ValueError( - f"client should be an instance of weaviate.Client, got {type(client)}" - ) - self._client = client - self._index_name = index_name - self._embedding = embedding - self._text_key = text_key - self._query_attrs = [self._text_key] - self.relevance_score_fn = relevance_score_fn - self._by_text = by_text - if attributes is not None: - self._query_attrs.extend(attributes) - - @property - def embeddings(self) -> Optional[Embeddings]: - return self._embedding - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - return ( - self.relevance_score_fn - if self.relevance_score_fn - else _default_score_normalizer - ) - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - **kwargs: Any, - ) -> list[str]: - """Upload texts with metadata (properties) to Weaviate.""" - from weaviate.util import get_valid_uuid - - ids = [] - embeddings: Optional[list[list[float]]] = None - if self._embedding: - if not isinstance(texts, list): - texts = list(texts) - embeddings = self._embedding.embed_documents(texts) - - with self._client.batch as batch: - for i, text in enumerate(texts): - data_properties = {self._text_key: text} - if metadatas is not None: - for key, val in metadatas[i].items(): - data_properties[key] = _json_serializable(val) - - # Allow for ids (consistent w/ other methods) - # # Or uuids (backwards compatble w/ existing arg) - # If the UUID of one of the objects already exists - # then the existing object will be replaced by the new object. - _id = get_valid_uuid(uuid4()) - if "uuids" in kwargs: - _id = kwargs["uuids"][i] - elif "ids" in kwargs: - _id = kwargs["ids"][i] - - batch.add_data_object( - data_object=data_properties, - class_name=self._index_name, - uuid=_id, - vector=embeddings[i] if embeddings else None, - ) - ids.append(_id) - return ids - - def similarity_search( - self, query: str, k: int = 4, **kwargs: Any - ) -> list[Document]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - - Returns: - List of Documents most similar to the query. - """ - if self._by_text: - return self.similarity_search_by_text(query, k, **kwargs) - else: - if self._embedding is None: - raise ValueError( - "_embedding cannot be None for similarity_search when " - "_by_text=False" - ) - embedding = self._embedding.embed_query(query) - return self.similarity_search_by_vector(embedding, k, **kwargs) - - def similarity_search_by_text( - self, query: str, k: int = 4, **kwargs: Any - ) -> list[Document]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - - Returns: - List of Documents most similar to the query. - """ - content: dict[str, Any] = {"concepts": [query]} - if kwargs.get("search_distance"): - content["certainty"] = kwargs.get("search_distance") - query_obj = self._client.query.get(self._index_name, self._query_attrs) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) - if kwargs.get("additional"): - query_obj = query_obj.with_additional(kwargs.get("additional")) - result = query_obj.with_near_text(content).with_limit(k).do() - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - docs = [] - for res in result["data"]["Get"][self._index_name]: - text = res.pop(self._text_key) - docs.append(Document(page_content=text, metadata=res)) - return docs - - def similarity_search_by_bm25( - self, query: str, k: int = 4, **kwargs: Any - ) -> list[Document]: - """Return docs using BM25F. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - - Returns: - List of Documents most similar to the query. - """ - content: dict[str, Any] = {"concepts": [query]} - if kwargs.get("search_distance"): - content["certainty"] = kwargs.get("search_distance") - query_obj = self._client.query.get(self._index_name, self._query_attrs) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) - if kwargs.get("additional"): - query_obj = query_obj.with_additional(kwargs.get("additional")) - properties = ['text'] - result = query_obj.with_bm25(query=query, properties=properties).with_limit(k).do() - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - docs = [] - for res in result["data"]["Get"][self._index_name]: - text = res.pop(self._text_key) - docs.append(Document(page_content=text, metadata=res)) - return docs - - def similarity_search_by_vector( - self, embedding: list[float], k: int = 4, **kwargs: Any - ) -> list[Document]: - """Look up similar documents by embedding vector in Weaviate.""" - vector = {"vector": embedding} - query_obj = self._client.query.get(self._index_name, self._query_attrs) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) - if kwargs.get("additional"): - query_obj = query_obj.with_additional(kwargs.get("additional")) - result = query_obj.with_near_vector(vector).with_limit(k).do() - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - docs = [] - for res in result["data"]["Get"][self._index_name]: - text = res.pop(self._text_key) - docs.append(Document(page_content=text, metadata=res)) - return docs - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - if self._embedding is not None: - embedding = self._embedding.embed_query(query) - else: - raise ValueError( - "max_marginal_relevance_search requires a suitable Embeddings object" - ) - - return self.max_marginal_relevance_search_by_vector( - embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs - ) - - def max_marginal_relevance_search_by_vector( - self, - embedding: list[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - vector = {"vector": embedding} - query_obj = self._client.query.get(self._index_name, self._query_attrs) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) - results = ( - query_obj.with_additional("vector") - .with_near_vector(vector) - .with_limit(fetch_k) - .do() - ) - - payload = results["data"]["Get"][self._index_name] - embeddings = [result["_additional"]["vector"] for result in payload] - mmr_selected = maximal_marginal_relevance( - np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult - ) - - docs = [] - for idx in mmr_selected: - text = payload[idx].pop(self._text_key) - payload[idx].pop("_additional") - meta = payload[idx] - docs.append(Document(page_content=text, metadata=meta)) - return docs - - def similarity_search_with_score( - self, query: str, k: int = 4, **kwargs: Any - ) -> list[tuple[Document, float]]: - """ - Return list of documents most similar to the query - text and cosine distance in float for each. - Lower score represents more similarity. - """ - if self._embedding is None: - raise ValueError( - "_embedding cannot be None for similarity_search_with_score" - ) - content: dict[str, Any] = {"concepts": [query]} - if kwargs.get("search_distance"): - content["certainty"] = kwargs.get("search_distance") - query_obj = self._client.query.get(self._index_name, self._query_attrs) - - embedded_query = self._embedding.embed_query(query) - if not self._by_text: - vector = {"vector": embedded_query} - result = ( - query_obj.with_near_vector(vector) - .with_limit(k) - .with_additional(["vector", "distance"]) - .do() - ) - else: - result = ( - query_obj.with_near_text(content) - .with_limit(k) - .with_additional(["vector", "distance"]) - .do() - ) - - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - - docs_and_scores = [] - for res in result["data"]["Get"][self._index_name]: - text = res.pop(self._text_key) - score = res["_additional"]["distance"] - docs_and_scores.append((Document(page_content=text, metadata=res), score)) - return docs_and_scores - - @classmethod - def from_texts( - cls: type[Weaviate], - texts: list[str], - embedding: Embeddings, - metadatas: Optional[list[dict]] = None, - **kwargs: Any, - ) -> Weaviate: - """Construct Weaviate wrapper from raw documents. - - This is a user-friendly interface that: - 1. Embeds documents. - 2. Creates a new index for the embeddings in the Weaviate instance. - 3. Adds the documents to the newly created Weaviate index. - - This is intended to be a quick way to get started. - - Example: - .. code-block:: python - - from langchain.vectorstores.weaviate import Weaviate - from langchain.embeddings import OpenAIEmbeddings - embeddings = OpenAIEmbeddings() - weaviate = Weaviate.from_texts( - texts, - embeddings, - weaviate_url="http://localhost:8080" - ) - """ - - client = _create_weaviate_client(**kwargs) - - from weaviate.util import get_valid_uuid - - index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}") - embeddings = embedding.embed_documents(texts) if embedding else None - text_key = "text" - schema = _default_schema(index_name) - attributes = list(metadatas[0].keys()) if metadatas else None - - # check whether the index already exists - if not client.schema.contains(schema): - client.schema.create_class(schema) - - with client.batch as batch: - for i, text in enumerate(texts): - data_properties = { - text_key: text, - } - if metadatas is not None: - for key in metadatas[i].keys(): - data_properties[key] = metadatas[i][key] - - # If the UUID of one of the objects already exists - # then the existing objectwill be replaced by the new object. - if "uuids" in kwargs: - _id = kwargs["uuids"][i] - else: - _id = get_valid_uuid(uuid4()) - - # if an embedding strategy is not provided, we let - # weaviate create the embedding. Note that this will only - # work if weaviate has been installed with a vectorizer module - # like text2vec-contextionary for example - params = { - "uuid": _id, - "data_object": data_properties, - "class_name": index_name, - } - if embeddings is not None: - params["vector"] = embeddings[i] - - batch.add_data_object(**params) - - batch.flush() - - relevance_score_fn = kwargs.get("relevance_score_fn") - by_text: bool = kwargs.get("by_text", False) - - return cls( - client, - index_name, - text_key, - embedding=embedding, - attributes=attributes, - relevance_score_fn=relevance_score_fn, - by_text=by_text, - ) - - def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> None: - """Delete by vector IDs. - - Args: - ids: List of ids to delete. - """ - - if ids is None: - raise ValueError("No ids provided to delete.") - - # TODO: Check if this can be done in bulk - for id in ids: - self._client.data_object.delete(uuid=id) diff --git a/api/core/vector_store/weaviate_vector_store.py b/api/core/vector_store/weaviate_vector_store.py deleted file mode 100644 index b5b3d84a9..000000000 --- a/api/core/vector_store/weaviate_vector_store.py +++ /dev/null @@ -1,38 +0,0 @@ -from core.vector_store.vector.weaviate import Weaviate - - -class WeaviateVectorStore(Weaviate): - def del_texts(self, where_filter: dict): - if not where_filter: - raise ValueError('where_filter must not be empty') - - self._client.batch.delete_objects( - class_name=self._index_name, - where=where_filter, - output='minimal' - ) - - def del_text(self, uuid: str) -> None: - self._client.data_object.delete( - uuid, - class_name=self._index_name - ) - - def text_exists(self, uuid: str) -> bool: - result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({ - "path": ["doc_id"], - "operator": "Equal", - "valueText": uuid, - }).with_limit(1).do() - - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - - entries = result["data"]["Get"][self._index_name] - if len(entries) == 0: - return False - - return True - - def delete(self): - self._client.schema.delete_class(self._index_name) diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index 93181ea16..42f1c7061 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -6,4 +6,4 @@ from tasks.clean_dataset_task import clean_dataset_task def handle(sender, **kwargs): dataset = sender clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, - dataset.index_struct, dataset.collection_binding_id) + dataset.index_struct, dataset.collection_binding_id, dataset.doc_form) diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py index d6553b385..d0bec667a 100644 --- a/api/events/event_handlers/clean_when_document_deleted.py +++ b/api/events/event_handlers/clean_when_document_deleted.py @@ -6,4 +6,5 @@ from tasks.clean_document_task import clean_document_task def handle(sender, **kwargs): document_id = sender dataset_id = kwargs.get('dataset_id') - clean_document_task.delay(document_id, dataset_id) + doc_form = kwargs.get('doc_form') + clean_document_task.delay(document_id, dataset_id, doc_form) diff --git a/api/models/dataset.py b/api/models/dataset.py index d31e49f6c..473a796be 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -94,6 +94,14 @@ class Dataset(db.Model): return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ .filter(Document.dataset_id == self.id).scalar() + @property + def doc_form(self): + document = db.session.query(Document).filter( + Document.dataset_id == self.id).first() + if document: + return document.doc_form + return None + @property def retrieval_model_dict(self): default_retrieval_model = { diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 5db863fe8..cdcb3121b 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -6,7 +6,7 @@ from flask import current_app from werkzeug.exceptions import NotFound import app -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, DatasetQuery, Document @@ -41,18 +41,9 @@ def clean_unused_datasets_task(): if not documents or len(documents) == 0: try: # remove index - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') - # delete from vector index - if vector_index: - if dataset.collection_binding_id: - vector_index.delete_by_group_id(dataset.id) - else: - if dataset.collection_binding_id: - vector_index.delete_by_group_id(dataset.id) - else: - vector_index.delete() - kw_index.delete() + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + index_processor.clean(dataset, None) + # update document update_params = { Document.enabled: False diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 7fef7e5ec..b151ebada 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -11,10 +11,11 @@ from flask_login import current_user from sqlalchemy import func from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError -from core.index.index import IndexBuilder from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.models.document import Document as RAGDocument from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db @@ -402,7 +403,7 @@ class DocumentService: @staticmethod def delete_document(document): # trigger document_was_deleted signal - document_was_deleted.send(document.id, dataset_id=document.dataset_id) + document_was_deleted.send(document.id, dataset_id=document.dataset_id, doc_form=document.doc_form) db.session.delete(document) db.session.commit() @@ -1060,7 +1061,7 @@ class SegmentService: # save vector index try: - VectorService.create_segment_vector(args['keywords'], segment_document, dataset) + VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False @@ -1087,6 +1088,7 @@ class SegmentService: ).scalar() pre_segment_data_list = [] segment_data_list = [] + keywords_list = [] for segment_item in segments: content = segment_item['content'] doc_id = str(uuid.uuid4()) @@ -1119,15 +1121,13 @@ class SegmentService: segment_document.answer = segment_item['answer'] db.session.add(segment_document) segment_data_list.append(segment_document) - pre_segment_data = { - 'segment': segment_document, - 'keywords': segment_item['keywords'] - } - pre_segment_data_list.append(pre_segment_data) + + pre_segment_data_list.append(segment_document) + keywords_list.append(segment_item['keywords']) try: # save vector index - VectorService.multi_create_segment_vector(pre_segment_data_list, dataset) + VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) except Exception as e: logging.exception("create segment index failed") for segment_document in segment_data_list: @@ -1157,11 +1157,18 @@ class SegmentService: db.session.commit() # update segment index task if args['keywords']: - kw_index = IndexBuilder.get_index(dataset, 'economy') - # delete from keyword index - kw_index.delete_by_ids([segment.index_node_id]) - # save keyword index - kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords) + keyword = Keyword(dataset) + keyword.delete_by_ids([segment.index_node_id]) + document = RAGDocument( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + keyword.add_texts([document], keywords_list=[args['keywords']]) else: segment_hash = helper.generate_text_hash(content) tokens = 0 diff --git a/api/services/file_service.py b/api/services/file_service.py index a1c95b091..53dd09023 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -9,8 +9,8 @@ from flask_login import current_user from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound -from core.data_loader.file_extractor import FileExtractor from core.file.upload_file_parser import UploadFileParser +from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Account @@ -32,7 +32,8 @@ class FileService: def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: extension = file.filename.split('.')[-1] etl_type = current_app.config['ETL_TYPE'] - allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS + allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \ + else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() elif only_image and extension.lower() not in IMAGE_EXTENSIONS: @@ -136,7 +137,7 @@ class FileService: if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() - text = FileExtractor.load(upload_file, return_text=True) + text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) text = text[0:PREVIEW_WORDS_LIMIT] if text else '' return text @@ -164,7 +165,7 @@ class FileService: return generator, upload_file.mime_type @staticmethod - def get_public_image_preview(file_id: str) -> str: + def get_public_image_preview(file_id: str) -> tuple[Generator, str]: upload_file = db.session.query(UploadFile) \ .filter(UploadFile.id == file_id) \ .first() diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index e52527c62..568974b74 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,21 +1,18 @@ import logging -import threading import time import numpy as np -from flask import current_app -from langchain.embeddings.base import Embeddings -from langchain.schema import Document from sklearn.manifold import TSNE from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rerank.rerank import RerankRunner +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document from extensions.ext_database import db from models.account import Account from models.dataset import Dataset, DatasetQuery, DocumentSegment -from services.retrieval_service import RetrievalService default_retrieval_model = { 'search_method': 'semantic_search', @@ -28,6 +25,7 @@ default_retrieval_model = { 'score_threshold_enabled': False } + class HitTestingService: @classmethod def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict: @@ -57,61 +55,15 @@ class HitTestingService: embeddings = CacheEmbedding(embedding_model) - all_documents = [] - threads = [] - - # retrieval_model source with semantic - if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': str(dataset.id), - 'query': query, - 'top_k': retrieval_model['top_k'], - 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, - 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, - 'all_documents': all_documents, - 'search_method': retrieval_model['search_method'], - 'embeddings': embeddings - }) - threads.append(embedding_thread) - embedding_thread.start() - - # retrieval source with full text - if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': str(dataset.id), - 'query': query, - 'search_method': retrieval_model['search_method'], - 'embeddings': embeddings, - 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, - 'top_k': retrieval_model['top_k'], - 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, - 'all_documents': all_documents - }) - threads.append(full_text_index_thread) - full_text_index_thread.start() - - for thread in threads: - thread.join() - - if retrieval_model['search_method'] == 'hybrid_search': - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=retrieval_model['reranking_model']['reranking_provider_name'], - model_type=ModelType.RERANK, - model=retrieval_model['reranking_model']['reranking_model_name'] - ) - - rerank_runner = RerankRunner(rerank_model_instance) - all_documents = rerank_runner.run( - query=query, - documents=all_documents, - score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, - top_n=retrieval_model['top_k'], - user=f"account-{account.id}" - ) + all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=retrieval_model['top_k'], + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) end = time.perf_counter() logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") @@ -203,4 +155,3 @@ class HitTestingService: if not query or len(query) > 250: raise ValueError('Query is required and cannot exceed 250 characters') - diff --git a/api/services/retrieval_service.py b/api/services/retrieval_service.py deleted file mode 100644 index bc8f4ad5b..000000000 --- a/api/services/retrieval_service.py +++ /dev/null @@ -1,119 +0,0 @@ -from typing import Optional - -from flask import Flask, current_app -from langchain.embeddings.base import Embeddings - -from core.index.vector_index.vector_index import VectorIndex -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError -from core.rerank.rerank import RerankRunner -from extensions.ext_database import db -from models.dataset import Dataset - -default_retrieval_model = { - 'search_method': 'semantic_search', - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False -} - - -class RetrievalService: - - @classmethod - def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, search_method: str, embeddings: Embeddings): - with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() - - vector_index = VectorIndex( - dataset=dataset, - config=current_app.config, - embeddings=embeddings - ) - - documents = vector_index.search( - query, - search_type='similarity_score_threshold', - search_kwargs={ - 'k': top_k, - 'score_threshold': score_threshold, - 'filter': { - 'group_id': [dataset.id] - } - } - ) - - if documents: - if reranking_model and search_method == 'semantic_search': - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=reranking_model['reranking_provider_name'], - model_type=ModelType.RERANK, - model=reranking_model['reranking_model_name'] - ) - except InvokeAuthorizationError: - return - - rerank_runner = RerankRunner(rerank_model_instance) - all_documents.extend(rerank_runner.run( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) - else: - all_documents.extend(documents) - - @classmethod - def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, search_method: str, embeddings: Embeddings): - with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() - - vector_index = VectorIndex( - dataset=dataset, - config=current_app.config, - embeddings=embeddings - ) - - documents = vector_index.search_by_full_text_index( - query, - search_type='similarity_score_threshold', - top_k=top_k - ) - if documents: - if reranking_model and search_method == 'full_text_search': - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=reranking_model['reranking_provider_name'], - model_type=ModelType.RERANK, - model=reranking_model['reranking_model_name'] - ) - except InvokeAuthorizationError: - return - - rerank_runner = RerankRunner(rerank_model_instance) - all_documents.extend(rerank_runner.run( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) - else: - all_documents.extend(documents) diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 1a0447b38..d336162ba 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,44 +1,18 @@ - from typing import Optional -from langchain.schema import Document - -from core.index.index import IndexBuilder +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document from models.dataset import Dataset, DocumentSegment class VectorService: @classmethod - def create_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - ) - - # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts([document], duplicate_check=True) - - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - if keywords and len(keywords) > 0: - index.create_segment_keywords(segment.index_node_id, keywords) - else: - index.add_texts([document]) - - @classmethod - def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset): + def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], + segments: list[DocumentSegment], dataset: Dataset): documents = [] - for pre_segment_data in pre_segment_data_list: - segment = pre_segment_data['segment'] + for segment in segments: document = Document( page_content=segment.content, metadata={ @@ -49,30 +23,26 @@ class VectorService: } ) documents.append(document) - - # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts(documents, duplicate_check=True) + if dataset.indexing_technique == 'high_quality': + # save vector index + vector = Vector( + dataset=dataset + ) + vector.add_texts(documents, duplicate_check=True) # save keyword index - keyword_index = IndexBuilder.get_index(dataset, 'economy') - if keyword_index: - keyword_index.multi_create_segment_keywords(pre_segment_data_list) + keyword = Keyword(dataset) + + if keywords_list and len(keywords_list) > 0: + keyword.add_texts(documents, keyword_list=keywords_list) + else: + keyword.add_texts(documents) @classmethod def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): # update segment index task - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') - # delete from vector index - if vector_index: - vector_index.delete_by_ids([segment.index_node_id]) - # delete from keyword index - kw_index.delete_by_ids([segment.index_node_id]) - - # add new index + # format new index document = Document( page_content=segment.content, metadata={ @@ -82,13 +52,20 @@ class VectorService: "dataset_id": segment.dataset_id, } ) + if dataset.indexing_technique == 'high_quality': + # update vector index + vector = Vector( + dataset=dataset + ) + vector.delete_by_ids([segment.index_node_id]) + vector.add_texts([document], duplicate_check=True) - # save vector index - if vector_index: - vector_index.add_texts([document], duplicate_check=True) + # update keyword index + keyword = Keyword(dataset) + keyword.delete_by_ids([segment.index_node_id]) # save keyword index if keywords and len(keywords) > 0: - kw_index.create_segment_keywords(segment.index_node_id, keywords) + keyword.add_texts([document], keywords_list=[keywords]) else: - kw_index.add_texts([document]) + keyword.add_texts([document]) diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index ae235a2a6..a26ecf552 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -4,10 +4,10 @@ import time import click from celery import shared_task -from langchain.schema import Document from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Document as DatasetDocument @@ -60,15 +60,9 @@ def add_document_to_index_task(dataset_document_id: str): if not dataset: raise Exception('Document has no dataset') - # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts(documents) - - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - index.add_texts(documents) + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.load(dataset, documents) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 61529f9bd..155567da7 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -5,7 +5,7 @@ import click from celery import shared_task from langchain.schema import Document -from core.index.index import IndexBuilder +from core.rag.datasource.vdb.vector_factory import Vector from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -48,9 +48,9 @@ def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: s "doc_id": annotation_id } ) - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts([document]) + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector.create([document], duplicate_check=True) + end_at = time.perf_counter() logging.info( click.style( diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 5b6c45b4f..1159cfc35 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -6,7 +6,7 @@ from celery import shared_task from langchain.schema import Document from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -79,9 +79,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: collection_binding_id=dataset_collection_binding.id ) - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts(documents) + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector.create(documents, duplicate_check=True) db.session.commit() redis_client.setex(indexing_cache_key, 600, 'completed') diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 852f89951..81155a35e 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -4,7 +4,7 @@ import time import click from celery import shared_task -from core.index.index import IndexBuilder +from core.rag.datasource.vdb.vector_factory import Vector from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -30,12 +30,11 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str collection_binding_id=dataset_collection_binding.id ) - vector_index = IndexBuilder.get_default_high_quality_index(dataset) - if vector_index: - try: - vector_index.delete_by_metadata_field('annotation_id', annotation_id) - except Exception: - logging.exception("Delete annotation index failed when annotation deleted.") + try: + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector.delete_by_metadata_field('annotation_id', annotation_id) + except Exception: + logging.exception("Delete annotation index failed when annotation deleted.") end_at = time.perf_counter() logging.info( click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index c5f028c72..7b88d7ac5 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -5,7 +5,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -48,12 +48,11 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): collection_binding_id=app_annotation_setting.collection_binding_id ) - vector_index = IndexBuilder.get_default_high_quality_index(dataset) - if vector_index: - try: - vector_index.delete_by_metadata_field('app_id', app_id) - except Exception: - logging.exception("Delete doc index failed when dataset deleted.") + try: + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector.delete_by_metadata_field('app_id', app_id) + except Exception: + logging.exception("Delete annotation index failed when annotation deleted.") redis_client.setex(disable_app_annotation_job_key, 600, 'completed') # delete annotation setting diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index a125dd571..c1f63466f 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -7,7 +7,7 @@ from celery import shared_task from langchain.schema import Document from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -81,15 +81,15 @@ def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_ } ) documents.append(document) - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - try: - index.delete_by_metadata_field('app_id', app_id) - except Exception as e: - logging.info( - click.style('Delete annotation index error: {}'.format(str(e)), - fg='red')) - index.add_texts(documents) + + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + try: + vector.delete_by_metadata_field('app_id', app_id) + except Exception as e: + logging.info( + click.style('Delete annotation index error: {}'.format(str(e)), + fg='red')) + vector.add_texts(documents) db.session.commit() redis_client.setex(enable_app_annotation_job_key, 600, 'completed') end_at = time.perf_counter() diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index e632c3a24..7a193d2c5 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -5,7 +5,7 @@ import click from celery import shared_task from langchain.schema import Document -from core.index.index import IndexBuilder +from core.rag.datasource.vdb.vector_factory import Vector from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -49,10 +49,9 @@ def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id "doc_id": annotation_id } ) - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.delete_by_metadata_field('annotation_id', annotation_id) - index.add_texts([document]) + vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector.delete_by_metadata_field('annotation_id', annotation_id) + vector.add_texts([document]) end_at = time.perf_counter() logging.info( click.style( diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 74ebcea15..16e4affc9 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -4,7 +4,7 @@ import time import click from celery import shared_task -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import ( AppDatasetJoin, @@ -18,7 +18,7 @@ from models.dataset import ( @shared_task(queue='dataset') def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, - index_struct: str, collection_binding_id: str): + index_struct: str, collection_binding_id: str, doc_form: str): """ Clean dataset when dataset deleted. :param dataset_id: dataset id @@ -26,6 +26,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, :param indexing_technique: indexing technique :param index_struct: index struct dict :param collection_binding_id: collection binding id + :param doc_form: dataset form Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ @@ -38,26 +39,14 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, tenant_id=tenant_id, indexing_technique=indexing_technique, index_struct=index_struct, - collection_binding_id=collection_binding_id + collection_binding_id=collection_binding_id, + doc_form=doc_form ) documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() - kw_index = IndexBuilder.get_index(dataset, 'economy') - - # delete from vector index - if dataset.indexing_technique == 'high_quality': - vector_index = IndexBuilder.get_default_high_quality_index(dataset) - try: - vector_index.delete_by_group_id(dataset.id) - except Exception: - logging.exception("Delete doc index failed when dataset deleted.") - - # delete from keyword index - try: - kw_index.delete() - except Exception: - logging.exception("Delete nodes index failed when dataset deleted.") + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, None) for document in documents: db.session.delete(document) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 76eb1a572..71ebad1da 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -4,17 +4,18 @@ import time import click from celery import shared_task -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment @shared_task(queue='dataset') -def clean_document_task(document_id: str, dataset_id: str): +def clean_document_task(document_id: str, dataset_id: str, doc_form: str): """ Clean document when document deleted. :param document_id: document id :param dataset_id: dataset id + :param doc_form: doc_form Usage: clean_document_task.delay(document_id, dataset_id) """ @@ -27,21 +28,12 @@ def clean_document_task(document_id: str, dataset_id: str): if not dataset: raise Exception('Document has no dataset') - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] - - # delete from vector index - if vector_index: - vector_index.delete_by_document_id(document_id) - - # delete from keyword index - if index_node_ids: - kw_index.delete_by_ids(index_node_ids) + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, index_node_ids) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index b85938d7d..9b697b635 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -4,7 +4,7 @@ import time import click from celery import shared_task -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -26,9 +26,8 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): if not dataset: raise Exception('Document has no dataset') - - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() for document_id in document_ids: document = db.session.query(Document).filter( Document.id == document_id @@ -38,13 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - if vector_index: - vector_index.delete_by_document_id(document_id) - - # delete from keyword index - if index_node_ids: - kw_index.delete_by_ids(index_node_ids) + index_processor.clean(dataset, index_node_ids) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 6c492f069..f33a1e91b 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -5,10 +5,10 @@ from typing import Optional import click from celery import shared_task -from langchain.schema import Document from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -68,18 +68,9 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) return - # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts([document], duplicate_check=True) - - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - if keywords and len(keywords) > 0: - index.create_segment_keywords(segment.index_node_id, keywords) - else: - index.add_texts([document]) + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.load(dataset, [document]) # update segment to completed update_params = { diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 008f122a8..3827d62fb 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -3,9 +3,9 @@ import time import click from celery import shared_task -from langchain.schema import Document -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -29,10 +29,10 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if not dataset: raise Exception('Dataset not found') - + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": - index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) - index.delete_by_group_id(dataset.id) + index_processor.clean(dataset, None, with_keywords=False) elif action == "add": dataset_documents = db.session.query(DatasetDocument).filter( DatasetDocument.dataset_id == dataset_id, @@ -42,8 +42,6 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ).all() if dataset_documents: - # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False) documents = [] for dataset_document in dataset_documents: # delete from vector index @@ -65,7 +63,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): documents.append(document) # save vector index - index.create(documents) + index_processor.load(dataset, documents, with_keywords=False) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 9c9b00a2f..d79286cf3 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -4,7 +4,7 @@ import time import click from celery import shared_task -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, Document @@ -39,15 +39,9 @@ def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_ logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment_id), fg='cyan')) return - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') - - # delete from vector index - if vector_index: - vector_index.delete_by_ids([index_node_id]) - - # delete from keyword index - kw_index.delete_by_ids([index_node_id]) + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.clean(dataset, [index_node_id]) end_at = time.perf_counter() logging.info(click.style('Segment deleted from index: {} latency: {}'.format(segment_id, end_at - start_at), fg='green')) diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 97f4fd067..4788bf4e4 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -5,7 +5,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -48,15 +48,9 @@ def disable_segment_from_index_task(segment_id: str): logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) return - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') - - # delete from vector index - if vector_index: - vector_index.delete_by_ids([segment.index_node_id]) - - # delete from keyword index - kw_index.delete_by_ids([segment.index_node_id]) + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.clean(dataset, [segment.index_node_id]) end_at = time.perf_counter() logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 57f080e3f..84e202970 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -6,9 +6,9 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.data_loader.loader.notion import NotionLoader -from core.index.index import IndexBuilder from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.rag.extractor.notion_extractor import NotionExtractor +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment from models.source import DataSourceBinding @@ -54,11 +54,11 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): if not data_source_binding: raise ValueError('Data source binding not found.') - loader = NotionLoader( - notion_access_token=data_source_binding.access_token, + loader = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, - notion_page_type=page_type + notion_page_type=page_type, + notion_access_token=data_source_binding.access_token ) last_edited_time = loader.get_notion_last_edited_time() @@ -74,20 +74,14 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: raise Exception('Dataset not found') - - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') + index_type = document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - if vector_index: - vector_index.delete_by_document_id(document_id) - - # delete from keyword index - if index_node_ids: - kw_index.delete_by_ids(index_node_ids) + index_processor.clean(dataset, index_node_ids) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 12014799b..e59c549a6 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -6,8 +6,8 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -42,19 +42,14 @@ def document_indexing_update_task(dataset_id: str, document_id: str): if not dataset: raise Exception('Dataset not found') - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') + index_type = document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - if vector_index: - vector_index.delete_by_ids(index_node_ids) - - # delete from keyword index - if index_node_ids: - kw_index.delete_by_ids(index_node_ids) + index_processor.clean(dataset, index_node_ids) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 8dffd0152..a6254a822 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -4,10 +4,10 @@ import time import click from celery import shared_task -from langchain.schema import Document from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -60,15 +60,9 @@ def enable_segment_to_index_task(segment_id: str): logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) return + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts([document], duplicate_check=True) - - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - index.add_texts([document]) + index_processor.load(dataset, [document]) end_at = time.perf_counter() logging.info(click.style('Segment enabled to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index a18842a59..514712991 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -5,7 +5,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.index.index import IndexBuilder +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Document, DocumentSegment @@ -37,18 +37,12 @@ def remove_document_from_index_task(document_id: str): if not dataset: raise Exception('Document has no dataset') - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - # delete from vector index - if vector_index: - vector_index.delete_by_document_id(document.id) - - # delete from keyword index segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: - kw_index.delete_by_ids(index_node_ids) + index_processor.clean(dataset, index_node_ids) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/update_segment_index_task.py b/api/tasks/update_segment_index_task.py deleted file mode 100644 index 802bac885..000000000 --- a/api/tasks/update_segment_index_task.py +++ /dev/null @@ -1,114 +0,0 @@ -import datetime -import logging -import time -from typing import Optional - -import click -from celery import shared_task -from langchain.schema import Document -from werkzeug.exceptions import NotFound - -from core.index.index import IndexBuilder -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from models.dataset import DocumentSegment - - -@shared_task(queue='dataset') -def update_segment_index_task(segment_id: str, keywords: Optional[list[str]] = None): - """ - Async update segment index - :param segment_id: - :param keywords: - Usage: update_segment_index_task.delay(segment_id) - """ - logging.info(click.style('Start update segment index: {}'.format(segment_id), fg='green')) - start_at = time.perf_counter() - - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() - if not segment: - raise NotFound('Segment not found') - - if segment.status != 'updating': - return - - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) - - try: - dataset = segment.dataset - - if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) - return - - dataset_document = segment.document - - if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) - return - - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) - return - - # update segment status to indexing - update_params = { - DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.utcnow() - } - DocumentSegment.query.filter_by(id=segment.id).update(update_params) - db.session.commit() - - vector_index = IndexBuilder.get_index(dataset, 'high_quality') - kw_index = IndexBuilder.get_index(dataset, 'economy') - - # delete from vector index - if vector_index: - vector_index.delete_by_ids([segment.index_node_id]) - - # delete from keyword index - kw_index.delete_by_ids([segment.index_node_id]) - - # add new index - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - ) - - # save vector index - index = IndexBuilder.get_index(dataset, 'high_quality') - if index: - index.add_texts([document], duplicate_check=True) - - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - if keywords and len(keywords) > 0: - index.create_segment_keywords(segment.index_node_id, keywords) - else: - index.add_texts([document]) - - # update segment to completed - update_params = { - DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.utcnow() - } - DocumentSegment.query.filter_by(id=segment.id).update(update_params) - db.session.commit() - - end_at = time.perf_counter() - logging.info(click.style('Segment update index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) - except Exception as e: - logging.exception("update segment index failed") - segment.enabled = False - segment.disabled_at = datetime.datetime.utcnow() - segment.status = 'error' - segment.error = str(e) - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) diff --git a/api/tasks/update_segment_keyword_index_task.py b/api/tasks/update_segment_keyword_index_task.py deleted file mode 100644 index ee88beba9..000000000 --- a/api/tasks/update_segment_keyword_index_task.py +++ /dev/null @@ -1,68 +0,0 @@ -import datetime -import logging -import time - -import click -from celery import shared_task -from werkzeug.exceptions import NotFound - -from core.index.index import IndexBuilder -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from models.dataset import DocumentSegment - - -@shared_task(queue='dataset') -def update_segment_keyword_index_task(segment_id: str): - """ - Async update segment index - :param segment_id: - Usage: update_segment_keyword_index_task.delay(segment_id) - """ - logging.info(click.style('Start update segment keyword index: {}'.format(segment_id), fg='green')) - start_at = time.perf_counter() - - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() - if not segment: - raise NotFound('Segment not found') - - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) - - try: - dataset = segment.dataset - - if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) - return - - dataset_document = segment.document - - if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) - return - - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) - return - - kw_index = IndexBuilder.get_index(dataset, 'economy') - - # delete from keyword index - kw_index.delete_by_ids([segment.index_node_id]) - - # save keyword index - index = IndexBuilder.get_index(dataset, 'economy') - if index: - index.update_segment_keywords_index(segment.index_node_id, segment.keywords) - - end_at = time.perf_counter() - logging.info(click.style('Segment update index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) - except Exception as e: - logging.exception("update segment index failed") - segment.enabled = False - segment.disabled_at = datetime.datetime.utcnow() - segment.status = 'error' - segment.error = str(e) - db.session.commit() - finally: - redis_client.delete(indexing_cache_key)