Feat/delete single dataset retrival (#6570)

This commit is contained in:
Jyong
2024-07-24 12:50:11 +08:00
committed by GitHub
parent 0fb741f269
commit e4bb943fe5
22 changed files with 651 additions and 115 deletions

View File

@@ -1,4 +1,5 @@
import re
from typing import Optional
import jieba
from jieba.analyse import default_tfidf
@@ -11,7 +12,7 @@ class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text,

View File

@@ -6,6 +6,7 @@ from flask import Flask, current_app
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.rerank.constants.rerank_mode import RerankMode
from core.rag.retrieval.retrival_methods import RetrievalMethod
from extensions.ext_database import db
from models.dataset import Dataset
@@ -26,13 +27,19 @@ class RetrievalService:
@classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
top_k: int, score_threshold: Optional[float] = .0,
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None,
weights: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = []
keyword_search_documents = []
embedding_search_documents = []
full_text_search_documents = []
hybrid_search_documents = []
threads = []
exceptions = []
# retrieval_model source with keyword
@@ -87,7 +94,8 @@ class RetrievalService:
raise Exception(exception_message)
if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
@@ -143,7 +151,9 @@ class RetrievalService:
if documents:
if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
RerankMode.RERANKING_MODEL.value,
reranking_model, None, False)
all_documents.extend(data_post_processor.invoke(
query=query,
documents=documents,
@@ -175,7 +185,9 @@ class RetrievalService:
)
if documents:
if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
RerankMode.RERANKING_MODEL.value,
reranking_model, None, False)
all_documents.extend(data_post_processor.invoke(
query=query,
documents=documents,

View File

@@ -396,9 +396,11 @@ class QdrantVector(BaseVector):
documents = []
for result in results:
if result:
documents.append(self._document_from_scored_point(
document = self._document_from_scored_point(
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
))
)
document.metadata['vector'] = result.vector
documents.append(document)
return documents