mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 19:06:51 +08:00
Feat/delete single dataset retrival (#6570)
This commit is contained in:
@@ -62,7 +62,12 @@ class DatasetConfigManager:
|
||||
return None
|
||||
|
||||
# dataset configs
|
||||
dataset_configs = config.get('dataset_configs', {'retrieval_model': 'single'})
|
||||
if 'dataset_configs' in config and config.get('dataset_configs'):
|
||||
dataset_configs = config.get('dataset_configs')
|
||||
else:
|
||||
dataset_configs = {
|
||||
'retrieval_model': 'multiple'
|
||||
}
|
||||
query_variable = config.get('dataset_query_variable')
|
||||
|
||||
if dataset_configs['retrieval_model'] == 'single':
|
||||
@@ -83,9 +88,10 @@ class DatasetConfigManager:
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
top_k=dataset_configs.get('top_k'),
|
||||
top_k=dataset_configs.get('top_k', 4),
|
||||
score_threshold=dataset_configs.get('score_threshold'),
|
||||
reranking_model=dataset_configs.get('reranking_model')
|
||||
reranking_model=dataset_configs.get('reranking_model'),
|
||||
weights=dataset_configs.get('weights')
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -159,7 +159,11 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
retrieve_strategy: RetrieveStrategy
|
||||
top_k: Optional[int] = None
|
||||
score_threshold: Optional[float] = None
|
||||
rerank_mode: Optional[str] = 'reranking_model'
|
||||
reranking_model: Optional[dict] = None
|
||||
weights: Optional[dict] = None
|
||||
|
||||
|
||||
|
||||
|
||||
class DatasetEntity(BaseModel):
|
||||
|
||||
@@ -5,15 +5,20 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank import RerankRunner
|
||||
from core.rag.rerank.constants.rerank_mode import RerankMode
|
||||
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.rerank.weight_rerank import WeightRerankRunner
|
||||
|
||||
|
||||
class DataPostProcessor:
|
||||
"""Interface for data post-processing document.
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False):
|
||||
self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id)
|
||||
def __init__(self, tenant_id: str, reranking_mode: str,
|
||||
reranking_model: Optional[dict] = None, weights: Optional[dict] = None,
|
||||
reorder_enabled: bool = False):
|
||||
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
|
||||
self.reorder_runner = self._get_reorder_runner(reorder_enabled)
|
||||
|
||||
def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
|
||||
@@ -26,19 +31,37 @@ class DataPostProcessor:
|
||||
|
||||
return documents
|
||||
|
||||
def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=reranking_model['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model['reranking_model_name']
|
||||
def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None,
|
||||
weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]:
|
||||
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
|
||||
return WeightRerankRunner(
|
||||
tenant_id,
|
||||
Weights(
|
||||
weight_type=weights['weight_type'],
|
||||
vector_setting=VectorSetting(
|
||||
vector_weight=weights['vector_setting']['vector_weight'],
|
||||
embedding_provider_name=weights['vector_setting']['embedding_provider_name'],
|
||||
embedding_model_name=weights['vector_setting']['embedding_model_name'],
|
||||
),
|
||||
keyword_setting=KeywordSetting(
|
||||
keyword_weight=weights['keyword_setting']['keyword_weight'],
|
||||
)
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return None
|
||||
return RerankRunner(rerank_model_instance)
|
||||
)
|
||||
elif reranking_mode == RerankMode.RERANKING_MODEL.value:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=reranking_model['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model['reranking_model_name']
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return None
|
||||
return RerankModelRunner(rerank_model_instance)
|
||||
return None
|
||||
return None
|
||||
|
||||
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import jieba
|
||||
from jieba.analyse import default_tfidf
|
||||
@@ -11,7 +12,7 @@ class JiebaKeywordTableHandler:
|
||||
def __init__(self):
|
||||
default_tfidf.stop_words = STOPWORDS
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text,
|
||||
|
||||
@@ -6,6 +6,7 @@ from flask import Flask, current_app
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.rerank.constants.rerank_mode import RerankMode
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
@@ -26,13 +27,19 @@ class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
|
||||
top_k: int, score_threshold: Optional[float] = .0,
|
||||
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None,
|
||||
weights: Optional[dict] = None):
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return []
|
||||
all_documents = []
|
||||
keyword_search_documents = []
|
||||
embedding_search_documents = []
|
||||
full_text_search_documents = []
|
||||
hybrid_search_documents = []
|
||||
threads = []
|
||||
exceptions = []
|
||||
# retrieval_model source with keyword
|
||||
@@ -87,7 +94,8 @@ class RetrievalService:
|
||||
raise Exception(exception_message)
|
||||
|
||||
if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
|
||||
reranking_model, weights, False)
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
@@ -143,7 +151,9 @@ class RetrievalService:
|
||||
|
||||
if documents:
|
||||
if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
|
||||
RerankMode.RERANKING_MODEL.value,
|
||||
reranking_model, None, False)
|
||||
all_documents.extend(data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=documents,
|
||||
@@ -175,7 +185,9 @@ class RetrievalService:
|
||||
)
|
||||
if documents:
|
||||
if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
|
||||
RerankMode.RERANKING_MODEL.value,
|
||||
reranking_model, None, False)
|
||||
all_documents.extend(data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=documents,
|
||||
|
||||
@@ -396,9 +396,11 @@ class QdrantVector(BaseVector):
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
documents.append(self._document_from_scored_point(
|
||||
document = self._document_from_scored_point(
|
||||
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
|
||||
))
|
||||
)
|
||||
document.metadata['vector'] = result.vector
|
||||
documents.append(document)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
0
api/core/rag/docstore/__init__.py
Normal file
0
api/core/rag/docstore/__init__.py
Normal file
8
api/core/rag/rerank/constants/rerank_mode.py
Normal file
8
api/core/rag/rerank/constants/rerank_mode.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RerankMode(Enum):
|
||||
|
||||
RERANKING_MODEL = 'reranking_model'
|
||||
WEIGHTED_SCORE = 'weighted_score'
|
||||
|
||||
23
api/core/rag/rerank/entity/weight.py
Normal file
23
api/core/rag/rerank/entity/weight.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
vector_weight: float
|
||||
|
||||
embedding_provider_name: str
|
||||
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class Weights(BaseModel):
|
||||
"""Model for weighted rerank."""
|
||||
|
||||
weight_type: str
|
||||
|
||||
vector_setting: VectorSetting
|
||||
|
||||
keyword_setting: KeywordSetting
|
||||
@@ -4,7 +4,7 @@ from core.model_manager import ModelInstance
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class RerankRunner:
|
||||
class RerankModelRunner:
|
||||
def __init__(self, rerank_model_instance: ModelInstance) -> None:
|
||||
self.rerank_model_instance = rerank_model_instance
|
||||
|
||||
178
api/core/rag/rerank/weight_rerank.py
Normal file
178
api/core/rag/rerank/weight_rerank.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import math
|
||||
from collections import Counter
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.entity.weight import VectorSetting, Weights
|
||||
|
||||
|
||||
class WeightRerankRunner:
|
||||
|
||||
def __init__(self, tenant_id: str, weights: Weights) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.weights = weights
|
||||
|
||||
def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
|
||||
"""
|
||||
Run rerank model
|
||||
:param query: search query
|
||||
:param documents: documents for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id if needed
|
||||
|
||||
:return:
|
||||
"""
|
||||
docs = []
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if document.metadata['doc_id'] not in doc_id:
|
||||
doc_id.append(document.metadata['doc_id'])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
|
||||
rerank_documents = []
|
||||
query_scores = self._calculate_keyword_score(query, documents)
|
||||
|
||||
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
|
||||
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
|
||||
# format document
|
||||
score = self.weights.vector_setting.vector_weight * query_vector_score + \
|
||||
self.weights.keyword_setting.keyword_weight * query_score
|
||||
if score_threshold and score < score_threshold:
|
||||
continue
|
||||
document.metadata['score'] = score
|
||||
rerank_documents.append(document)
|
||||
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True)
|
||||
return rerank_documents[:top_n] if top_n else rerank_documents
|
||||
|
||||
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
|
||||
"""
|
||||
Calculate BM25 scores
|
||||
:param query: search query
|
||||
:param documents: documents for reranking
|
||||
|
||||
:return:
|
||||
"""
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
query_keywords = keyword_table_handler.extract_keywords(query, None)
|
||||
documents_keywords = []
|
||||
for document in documents:
|
||||
# get the document keywords
|
||||
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
|
||||
document.metadata['keywords'] = document_keywords
|
||||
documents_keywords.append(document_keywords)
|
||||
|
||||
# Counter query keywords(TF)
|
||||
query_keyword_counts = Counter(query_keywords)
|
||||
|
||||
# total documents
|
||||
total_documents = len(documents)
|
||||
|
||||
# calculate all documents' keywords IDF
|
||||
all_keywords = set()
|
||||
for document_keywords in documents_keywords:
|
||||
all_keywords.update(document_keywords)
|
||||
|
||||
keyword_idf = {}
|
||||
for keyword in all_keywords:
|
||||
# calculate include query keywords' documents
|
||||
doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
|
||||
# IDF
|
||||
keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
|
||||
|
||||
query_tfidf = {}
|
||||
|
||||
for keyword, count in query_keyword_counts.items():
|
||||
tf = count
|
||||
idf = keyword_idf.get(keyword, 0)
|
||||
query_tfidf[keyword] = tf * idf
|
||||
|
||||
# calculate all documents' TF-IDF
|
||||
documents_tfidf = []
|
||||
for document_keywords in documents_keywords:
|
||||
document_keyword_counts = Counter(document_keywords)
|
||||
document_tfidf = {}
|
||||
for keyword, count in document_keyword_counts.items():
|
||||
tf = count
|
||||
idf = keyword_idf.get(keyword, 0)
|
||||
document_tfidf[keyword] = tf * idf
|
||||
documents_tfidf.append(document_tfidf)
|
||||
|
||||
def cosine_similarity(vec1, vec2):
|
||||
intersection = set(vec1.keys()) & set(vec2.keys())
|
||||
numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
||||
|
||||
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
|
||||
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
|
||||
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
||||
|
||||
if not denominator:
|
||||
return 0.0
|
||||
else:
|
||||
return float(numerator) / denominator
|
||||
|
||||
similarities = []
|
||||
for document_tfidf in documents_tfidf:
|
||||
similarity = cosine_similarity(query_tfidf, document_tfidf)
|
||||
similarities.append(similarity)
|
||||
|
||||
# for idx, similarity in enumerate(similarities):
|
||||
# print(f"Document {idx + 1} similarity: {similarity}")
|
||||
|
||||
return similarities
|
||||
|
||||
def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document],
|
||||
vector_setting: VectorSetting) -> list[float]:
|
||||
"""
|
||||
Calculate Cosine scores
|
||||
:param query: search query
|
||||
:param documents: documents for reranking
|
||||
|
||||
:return:
|
||||
"""
|
||||
query_vector_scores = []
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=vector_setting.embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=vector_setting.embedding_model_name
|
||||
|
||||
)
|
||||
cache_embedding = CacheEmbedding(embedding_model)
|
||||
query_vector = cache_embedding.embed_query(query)
|
||||
for document in documents:
|
||||
# calculate cosine similarity
|
||||
if 'score' in document.metadata:
|
||||
query_vector_scores.append(document.metadata['score'])
|
||||
else:
|
||||
content_vector = document.metadata['vector']
|
||||
# transform to NumPy
|
||||
vec1 = np.array(query_vector)
|
||||
vec2 = np.array(document.metadata['vector'])
|
||||
|
||||
# calculate dot product
|
||||
dot_product = np.dot(vec1, vec2)
|
||||
|
||||
# calculate norm
|
||||
norm_vec1 = np.linalg.norm(vec1)
|
||||
norm_vec2 = np.linalg.norm(vec2)
|
||||
|
||||
# calculate cosine similarity
|
||||
cosine_sim = dot_product / (norm_vec1 * norm_vec2)
|
||||
query_vector_scores.append(cosine_sim)
|
||||
|
||||
return query_vector_scores
|
||||
@@ -1,4 +1,6 @@
|
||||
import math
|
||||
import threading
|
||||
from collections import Counter
|
||||
from typing import Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -14,9 +16,10 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
||||
from core.ops.utils import measure_time
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank import RerankRunner
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
@@ -132,8 +135,9 @@ class DatasetRetrieval:
|
||||
app_id, tenant_id, user_id, user_from,
|
||||
available_datasets, query, retrieve_config.top_k,
|
||||
retrieve_config.score_threshold,
|
||||
retrieve_config.reranking_model.get('reranking_provider_name'),
|
||||
retrieve_config.reranking_model.get('reranking_model_name'),
|
||||
retrieve_config.rerank_mode,
|
||||
retrieve_config.reranking_model,
|
||||
retrieve_config.weights,
|
||||
message_id,
|
||||
)
|
||||
|
||||
@@ -272,7 +276,8 @@ class DatasetRetrieval:
|
||||
retrival_method=retrival_method, dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model
|
||||
reranking_model=reranking_model,
|
||||
weights=retrieval_model_config.get('weights', None),
|
||||
)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
@@ -292,14 +297,18 @@ class DatasetRetrieval:
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_provider_name: str,
|
||||
reranking_model_name: str,
|
||||
reranking_mode: str,
|
||||
reranking_model: Optional[dict] = None,
|
||||
weights: Optional[dict] = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
threads = []
|
||||
all_documents = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
index_type = None
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset.id,
|
||||
@@ -311,23 +320,24 @@ class DatasetRetrieval:
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=reranking_provider_name,
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model_name
|
||||
)
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
if reranking_enable:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode,
|
||||
reranking_model, weights, False)
|
||||
|
||||
with measure_time() as timer:
|
||||
all_documents = rerank_runner.run(
|
||||
query, all_documents,
|
||||
score_threshold,
|
||||
top_k
|
||||
)
|
||||
with measure_time() as timer:
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k
|
||||
)
|
||||
else:
|
||||
if index_type == "economy":
|
||||
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
|
||||
elif index_type == "high_quality":
|
||||
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
|
||||
self._on_query(query, dataset_ids, app_id, user_from, user_id)
|
||||
|
||||
if all_documents:
|
||||
@@ -420,7 +430,8 @@ class DatasetRetrieval:
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
if retrieval_model['reranking_enable'] else None,
|
||||
weights=retrieval_model.get('weights', None),
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
@@ -513,3 +524,94 @@ class DatasetRetrieval:
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]:
|
||||
"""
|
||||
Calculate keywords scores
|
||||
:param query: search query
|
||||
:param documents: documents for reranking
|
||||
|
||||
:return:
|
||||
"""
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
query_keywords = keyword_table_handler.extract_keywords(query, None)
|
||||
documents_keywords = []
|
||||
for document in documents:
|
||||
# get the document keywords
|
||||
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
|
||||
document.metadata['keywords'] = document_keywords
|
||||
documents_keywords.append(document_keywords)
|
||||
|
||||
# Counter query keywords(TF)
|
||||
query_keyword_counts = Counter(query_keywords)
|
||||
|
||||
# total documents
|
||||
total_documents = len(documents)
|
||||
|
||||
# calculate all documents' keywords IDF
|
||||
all_keywords = set()
|
||||
for document_keywords in documents_keywords:
|
||||
all_keywords.update(document_keywords)
|
||||
|
||||
keyword_idf = {}
|
||||
for keyword in all_keywords:
|
||||
# calculate include query keywords' documents
|
||||
doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
|
||||
# IDF
|
||||
keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
|
||||
|
||||
query_tfidf = {}
|
||||
|
||||
for keyword, count in query_keyword_counts.items():
|
||||
tf = count
|
||||
idf = keyword_idf.get(keyword, 0)
|
||||
query_tfidf[keyword] = tf * idf
|
||||
|
||||
# calculate all documents' TF-IDF
|
||||
documents_tfidf = []
|
||||
for document_keywords in documents_keywords:
|
||||
document_keyword_counts = Counter(document_keywords)
|
||||
document_tfidf = {}
|
||||
for keyword, count in document_keyword_counts.items():
|
||||
tf = count
|
||||
idf = keyword_idf.get(keyword, 0)
|
||||
document_tfidf[keyword] = tf * idf
|
||||
documents_tfidf.append(document_tfidf)
|
||||
|
||||
def cosine_similarity(vec1, vec2):
|
||||
intersection = set(vec1.keys()) & set(vec2.keys())
|
||||
numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
||||
|
||||
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
|
||||
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
|
||||
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
||||
|
||||
if not denominator:
|
||||
return 0.0
|
||||
else:
|
||||
return float(numerator) / denominator
|
||||
|
||||
similarities = []
|
||||
for document_tfidf in documents_tfidf:
|
||||
similarity = cosine_similarity(query_tfidf, document_tfidf)
|
||||
similarities.append(similarity)
|
||||
|
||||
for document, score in zip(documents, similarities):
|
||||
# format document
|
||||
document.metadata['score'] = score
|
||||
documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True)
|
||||
return documents[:top_k] if top_k else documents
|
||||
|
||||
def calculate_vector_score(self, all_documents: list[Document],
|
||||
top_k: int, score_threshold: float) -> list[Document]:
|
||||
filter_documents = []
|
||||
for document in all_documents:
|
||||
if document.metadata['score'] >= score_threshold:
|
||||
filter_documents.append(document)
|
||||
if not filter_documents:
|
||||
return []
|
||||
filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True)
|
||||
return filter_documents[:top_k] if top_k else filter_documents
|
||||
|
||||
|
||||
|
||||
|
||||
0
api/core/rag/splitter/__init__.py
Normal file
0
api/core/rag/splitter/__init__.py
Normal file
@@ -7,7 +7,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.rerank.rerank import RerankRunner
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
@@ -72,7 +72,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
model=self.reranking_model_name
|
||||
)
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
rerank_runner = RerankModelRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
@@ -180,7 +180,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
if retrieval_model['reranking_enable'] else None,
|
||||
weights=retrieval_model.get('weights', None),
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
@@ -78,7 +78,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
if retrieval_model['reranking_enable'] else None,
|
||||
weights=retrieval_model.get('weights', None),
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
|
||||
@@ -13,13 +13,41 @@ class RerankingModelConfig(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
"""
|
||||
Vector Setting.
|
||||
"""
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
"""
|
||||
Keyword Setting.
|
||||
"""
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
"""
|
||||
weight_type: str
|
||||
vector_setting: VectorSetting
|
||||
keyword_setting: KeywordSetting
|
||||
|
||||
|
||||
class MultipleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Multiple Retrieval Config.
|
||||
"""
|
||||
top_k: int
|
||||
score_threshold: Optional[float] = None
|
||||
reranking_mode: str = 'reranking_model'
|
||||
reranking_enable: bool = True
|
||||
reranking_model: RerankingModelConfig
|
||||
weights: WeightedScoreConfig
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
|
||||
@@ -138,13 +138,38 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
planning_strategy=planning_strategy
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model':
|
||||
reranking_model = {
|
||||
'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model['provider'],
|
||||
'reranking_model_name': node_data.multiple_retrieval_config.reranking_model['name']
|
||||
}
|
||||
weights = None
|
||||
elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score':
|
||||
reranking_model = None
|
||||
weights = {
|
||||
'weight_type': node_data.multiple_retrieval_config.weights.weight_type,
|
||||
'vector_setting': {
|
||||
"vector_weight": node_data.multiple_retrieval_config.weights.vector_setting.vector_weight,
|
||||
"embedding_provider_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_provider_name,
|
||||
"embedding_model_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_model_name,
|
||||
},
|
||||
'keyword_setting': {
|
||||
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
|
||||
}
|
||||
}
|
||||
else:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id,
|
||||
self.user_from.value,
|
||||
available_datasets, query,
|
||||
node_data.multiple_retrieval_config.top_k,
|
||||
node_data.multiple_retrieval_config.score_threshold,
|
||||
node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
node_data.multiple_retrieval_config.reranking_model.model)
|
||||
node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model,
|
||||
weights,
|
||||
node_data.multiple_retrieval_config.reranking_enable,
|
||||
)
|
||||
|
||||
context_list = []
|
||||
if all_documents:
|
||||
|
||||
Reference in New Issue
Block a user