Feat/delete single dataset retrival (#6570)

This commit is contained in:
Jyong
2024-07-24 12:50:11 +08:00
committed by GitHub
parent 0fb741f269
commit e4bb943fe5
22 changed files with 651 additions and 115 deletions

View File

@@ -1,4 +1,6 @@
import math
import threading
from collections import Counter
from typing import Optional, cast
from flask import Flask, current_app
@@ -14,9 +16,10 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
from core.ops.utils import measure_time
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from core.rag.rerank.rerank import RerankRunner
from core.rag.retrieval.retrival_methods import RetrievalMethod
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
@@ -132,8 +135,9 @@ class DatasetRetrieval:
app_id, tenant_id, user_id, user_from,
available_datasets, query, retrieve_config.top_k,
retrieve_config.score_threshold,
retrieve_config.reranking_model.get('reranking_provider_name'),
retrieve_config.reranking_model.get('reranking_model_name'),
retrieve_config.rerank_mode,
retrieve_config.reranking_model,
retrieve_config.weights,
message_id,
)
@@ -272,7 +276,8 @@ class DatasetRetrieval:
retrival_method=retrival_method, dataset_id=dataset.id,
query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model
reranking_model=reranking_model,
weights=retrieval_model_config.get('weights', None),
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
@@ -292,14 +297,18 @@ class DatasetRetrieval:
query: str,
top_k: int,
score_threshold: float,
reranking_provider_name: str,
reranking_model_name: str,
reranking_mode: str,
reranking_model: Optional[dict] = None,
weights: Optional[dict] = None,
reranking_enable: bool = True,
message_id: Optional[str] = None,
):
threads = []
all_documents = []
dataset_ids = [dataset.id for dataset in available_datasets]
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset.id,
@@ -311,23 +320,24 @@ class DatasetRetrieval:
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_provider_name,
model_type=ModelType.RERANK,
model=reranking_model_name
)
rerank_runner = RerankRunner(rerank_model_instance)
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode,
reranking_model, weights, False)
with measure_time() as timer:
all_documents = rerank_runner.run(
query, all_documents,
score_threshold,
top_k
)
with measure_time() as timer:
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
)
else:
if index_type == "economy":
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
elif index_type == "high_quality":
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
self._on_query(query, dataset_ids, app_id, user_from, user_id)
if all_documents:
@@ -420,7 +430,8 @@ class DatasetRetrieval:
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None
if retrieval_model['reranking_enable'] else None,
weights=retrieval_model.get('weights', None),
)
all_documents.extend(documents)
@@ -513,3 +524,94 @@ class DatasetRetrieval:
tools.append(tool)
return tools
def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]:
"""
Calculate keywords scores
:param query: search query
:param documents: documents for reranking
:return:
"""
keyword_table_handler = JiebaKeywordTableHandler()
query_keywords = keyword_table_handler.extract_keywords(query, None)
documents_keywords = []
for document in documents:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
document.metadata['keywords'] = document_keywords
documents_keywords.append(document_keywords)
# Counter query keywords(TF)
query_keyword_counts = Counter(query_keywords)
# total documents
total_documents = len(documents)
# calculate all documents' keywords IDF
all_keywords = set()
for document_keywords in documents_keywords:
all_keywords.update(document_keywords)
keyword_idf = {}
for keyword in all_keywords:
# calculate include query keywords' documents
doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
# IDF
keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
query_tfidf = {}
for keyword, count in query_keyword_counts.items():
tf = count
idf = keyword_idf.get(keyword, 0)
query_tfidf[keyword] = tf * idf
# calculate all documents' TF-IDF
documents_tfidf = []
for document_keywords in documents_keywords:
document_keyword_counts = Counter(document_keywords)
document_tfidf = {}
for keyword, count in document_keyword_counts.items():
tf = count
idf = keyword_idf.get(keyword, 0)
document_tfidf[keyword] = tf * idf
documents_tfidf.append(document_tfidf)
def cosine_similarity(vec1, vec2):
intersection = set(vec1.keys()) & set(vec2.keys())
numerator = sum(vec1[x] * vec2[x] for x in intersection)
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
denominator = math.sqrt(sum1) * math.sqrt(sum2)
if not denominator:
return 0.0
else:
return float(numerator) / denominator
similarities = []
for document_tfidf in documents_tfidf:
similarity = cosine_similarity(query_tfidf, document_tfidf)
similarities.append(similarity)
for document, score in zip(documents, similarities):
# format document
document.metadata['score'] = score
documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True)
return documents[:top_k] if top_k else documents
def calculate_vector_score(self, all_documents: list[Document],
top_k: int, score_threshold: float) -> list[Document]:
filter_documents = []
for document in all_documents:
if document.metadata['score'] >= score_threshold:
filter_documents.append(document)
if not filter_documents:
return []
filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True)
return filter_documents[:top_k] if top_k else filter_documents