Feat/delete single dataset retrival (#6570)

This commit is contained in:
Jyong
2024-07-24 12:50:11 +08:00
committed by GitHub
parent 0fb741f269
commit e4bb943fe5
22 changed files with 651 additions and 115 deletions

View File

@@ -62,7 +62,12 @@ class DatasetConfigManager:
return None
# dataset configs
dataset_configs = config.get('dataset_configs', {'retrieval_model': 'single'})
if 'dataset_configs' in config and config.get('dataset_configs'):
dataset_configs = config.get('dataset_configs')
else:
dataset_configs = {
'retrieval_model': 'multiple'
}
query_variable = config.get('dataset_query_variable')
if dataset_configs['retrieval_model'] == 'single':
@@ -83,9 +88,10 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
top_k=dataset_configs.get('top_k'),
top_k=dataset_configs.get('top_k', 4),
score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model')
reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights')
)
)

View File

@@ -159,7 +159,11 @@ class DatasetRetrieveConfigEntity(BaseModel):
retrieve_strategy: RetrieveStrategy
top_k: Optional[int] = None
score_threshold: Optional[float] = None
rerank_mode: Optional[str] = 'reranking_model'
reranking_model: Optional[dict] = None
weights: Optional[dict] = None
class DatasetEntity(BaseModel):

View File

@@ -5,15 +5,20 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.models.document import Document
from core.rag.rerank.rerank import RerankRunner
from core.rag.rerank.constants.rerank_mode import RerankMode
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.rerank.weight_rerank import WeightRerankRunner
class DataPostProcessor:
"""Interface for data post-processing document.
"""
def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False):
self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id)
def __init__(self, tenant_id: str, reranking_mode: str,
reranking_model: Optional[dict] = None, weights: Optional[dict] = None,
reorder_enabled: bool = False):
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
self.reorder_runner = self._get_reorder_runner(reorder_enabled)
def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
@@ -26,19 +31,37 @@ class DataPostProcessor:
return documents
def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]:
if reranking_model:
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None,
weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]:
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
return WeightRerankRunner(
tenant_id,
Weights(
weight_type=weights['weight_type'],
vector_setting=VectorSetting(
vector_weight=weights['vector_setting']['vector_weight'],
embedding_provider_name=weights['vector_setting']['embedding_provider_name'],
embedding_model_name=weights['vector_setting']['embedding_model_name'],
),
keyword_setting=KeywordSetting(
keyword_weight=weights['keyword_setting']['keyword_weight'],
)
)
except InvokeAuthorizationError:
return None
return RerankRunner(rerank_model_instance)
)
elif reranking_mode == RerankMode.RERANKING_MODEL.value:
if reranking_model:
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return None
return RerankModelRunner(rerank_model_instance)
return None
return None
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:

View File

@@ -1,4 +1,5 @@
import re
from typing import Optional
import jieba
from jieba.analyse import default_tfidf
@@ -11,7 +12,7 @@ class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text,

View File

@@ -6,6 +6,7 @@ from flask import Flask, current_app
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.rerank.constants.rerank_mode import RerankMode
from core.rag.retrieval.retrival_methods import RetrievalMethod
from extensions.ext_database import db
from models.dataset import Dataset
@@ -26,13 +27,19 @@ class RetrievalService:
@classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
top_k: int, score_threshold: Optional[float] = .0,
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None,
weights: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = []
keyword_search_documents = []
embedding_search_documents = []
full_text_search_documents = []
hybrid_search_documents = []
threads = []
exceptions = []
# retrieval_model source with keyword
@@ -87,7 +94,8 @@ class RetrievalService:
raise Exception(exception_message)
if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
@@ -143,7 +151,9 @@ class RetrievalService:
if documents:
if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
RerankMode.RERANKING_MODEL.value,
reranking_model, None, False)
all_documents.extend(data_post_processor.invoke(
query=query,
documents=documents,
@@ -175,7 +185,9 @@ class RetrievalService:
)
if documents:
if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
RerankMode.RERANKING_MODEL.value,
reranking_model, None, False)
all_documents.extend(data_post_processor.invoke(
query=query,
documents=documents,

View File

@@ -396,9 +396,11 @@ class QdrantVector(BaseVector):
documents = []
for result in results:
if result:
documents.append(self._document_from_scored_point(
document = self._document_from_scored_point(
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
))
)
document.metadata['vector'] = result.vector
documents.append(document)
return documents

View File

View File

@@ -0,0 +1,8 @@
from enum import Enum
class RerankMode(Enum):
RERANKING_MODEL = 'reranking_model'
WEIGHTED_SCORE = 'weighted_score'

View File

@@ -0,0 +1,23 @@
from pydantic import BaseModel
class VectorSetting(BaseModel):
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
keyword_weight: float
class Weights(BaseModel):
"""Model for weighted rerank."""
weight_type: str
vector_setting: VectorSetting
keyword_setting: KeywordSetting

View File

@@ -4,7 +4,7 @@ from core.model_manager import ModelInstance
from core.rag.models.document import Document
class RerankRunner:
class RerankModelRunner:
def __init__(self, rerank_model_instance: ModelInstance) -> None:
self.rerank_model_instance = rerank_model_instance

View File

@@ -0,0 +1,178 @@
import math
from collections import Counter
from typing import Optional
import numpy as np
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.models.document import Document
from core.rag.rerank.entity.weight import VectorSetting, Weights
class WeightRerankRunner:
def __init__(self, tenant_id: str, weights: Weights) -> None:
self.tenant_id = tenant_id
self.weights = weights
def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
"""
Run rerank model
:param query: search query
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:return:
"""
docs = []
doc_id = []
unique_documents = []
for document in documents:
if document.metadata['doc_id'] not in doc_id:
doc_id.append(document.metadata['doc_id'])
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
rerank_documents = []
query_scores = self._calculate_keyword_score(query, documents)
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
# format document
score = self.weights.vector_setting.vector_weight * query_vector_score + \
self.weights.keyword_setting.keyword_weight * query_score
if score_threshold and score < score_threshold:
continue
document.metadata['score'] = score
rerank_documents.append(document)
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
"""
Calculate BM25 scores
:param query: search query
:param documents: documents for reranking
:return:
"""
keyword_table_handler = JiebaKeywordTableHandler()
query_keywords = keyword_table_handler.extract_keywords(query, None)
documents_keywords = []
for document in documents:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
document.metadata['keywords'] = document_keywords
documents_keywords.append(document_keywords)
# Counter query keywords(TF)
query_keyword_counts = Counter(query_keywords)
# total documents
total_documents = len(documents)
# calculate all documents' keywords IDF
all_keywords = set()
for document_keywords in documents_keywords:
all_keywords.update(document_keywords)
keyword_idf = {}
for keyword in all_keywords:
# calculate include query keywords' documents
doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
# IDF
keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
query_tfidf = {}
for keyword, count in query_keyword_counts.items():
tf = count
idf = keyword_idf.get(keyword, 0)
query_tfidf[keyword] = tf * idf
# calculate all documents' TF-IDF
documents_tfidf = []
for document_keywords in documents_keywords:
document_keyword_counts = Counter(document_keywords)
document_tfidf = {}
for keyword, count in document_keyword_counts.items():
tf = count
idf = keyword_idf.get(keyword, 0)
document_tfidf[keyword] = tf * idf
documents_tfidf.append(document_tfidf)
def cosine_similarity(vec1, vec2):
intersection = set(vec1.keys()) & set(vec2.keys())
numerator = sum(vec1[x] * vec2[x] for x in intersection)
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
denominator = math.sqrt(sum1) * math.sqrt(sum2)
if not denominator:
return 0.0
else:
return float(numerator) / denominator
similarities = []
for document_tfidf in documents_tfidf:
similarity = cosine_similarity(query_tfidf, document_tfidf)
similarities.append(similarity)
# for idx, similarity in enumerate(similarities):
# print(f"Document {idx + 1} similarity: {similarity}")
return similarities
def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document],
vector_setting: VectorSetting) -> list[float]:
"""
Calculate Cosine scores
:param query: search query
:param documents: documents for reranking
:return:
"""
query_vector_scores = []
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=vector_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=vector_setting.embedding_model_name
)
cache_embedding = CacheEmbedding(embedding_model)
query_vector = cache_embedding.embed_query(query)
for document in documents:
# calculate cosine similarity
if 'score' in document.metadata:
query_vector_scores.append(document.metadata['score'])
else:
content_vector = document.metadata['vector']
# transform to NumPy
vec1 = np.array(query_vector)
vec2 = np.array(document.metadata['vector'])
# calculate dot product
dot_product = np.dot(vec1, vec2)
# calculate norm
norm_vec1 = np.linalg.norm(vec1)
norm_vec2 = np.linalg.norm(vec2)
# calculate cosine similarity
cosine_sim = dot_product / (norm_vec1 * norm_vec2)
query_vector_scores.append(cosine_sim)
return query_vector_scores

View File

@@ -1,4 +1,6 @@
import math
import threading
from collections import Counter
from typing import Optional, cast
from flask import Flask, current_app
@@ -14,9 +16,10 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
from core.ops.utils import measure_time
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from core.rag.rerank.rerank import RerankRunner
from core.rag.retrieval.retrival_methods import RetrievalMethod
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
@@ -132,8 +135,9 @@ class DatasetRetrieval:
app_id, tenant_id, user_id, user_from,
available_datasets, query, retrieve_config.top_k,
retrieve_config.score_threshold,
retrieve_config.reranking_model.get('reranking_provider_name'),
retrieve_config.reranking_model.get('reranking_model_name'),
retrieve_config.rerank_mode,
retrieve_config.reranking_model,
retrieve_config.weights,
message_id,
)
@@ -272,7 +276,8 @@ class DatasetRetrieval:
retrival_method=retrival_method, dataset_id=dataset.id,
query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model
reranking_model=reranking_model,
weights=retrieval_model_config.get('weights', None),
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
@@ -292,14 +297,18 @@ class DatasetRetrieval:
query: str,
top_k: int,
score_threshold: float,
reranking_provider_name: str,
reranking_model_name: str,
reranking_mode: str,
reranking_model: Optional[dict] = None,
weights: Optional[dict] = None,
reranking_enable: bool = True,
message_id: Optional[str] = None,
):
threads = []
all_documents = []
dataset_ids = [dataset.id for dataset in available_datasets]
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset.id,
@@ -311,23 +320,24 @@ class DatasetRetrieval:
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_provider_name,
model_type=ModelType.RERANK,
model=reranking_model_name
)
rerank_runner = RerankRunner(rerank_model_instance)
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode,
reranking_model, weights, False)
with measure_time() as timer:
all_documents = rerank_runner.run(
query, all_documents,
score_threshold,
top_k
)
with measure_time() as timer:
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
)
else:
if index_type == "economy":
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
elif index_type == "high_quality":
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
self._on_query(query, dataset_ids, app_id, user_from, user_id)
if all_documents:
@@ -420,7 +430,8 @@ class DatasetRetrieval:
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None
if retrieval_model['reranking_enable'] else None,
weights=retrieval_model.get('weights', None),
)
all_documents.extend(documents)
@@ -513,3 +524,94 @@ class DatasetRetrieval:
tools.append(tool)
return tools
def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]:
"""
Calculate keywords scores
:param query: search query
:param documents: documents for reranking
:return:
"""
keyword_table_handler = JiebaKeywordTableHandler()
query_keywords = keyword_table_handler.extract_keywords(query, None)
documents_keywords = []
for document in documents:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
document.metadata['keywords'] = document_keywords
documents_keywords.append(document_keywords)
# Counter query keywords(TF)
query_keyword_counts = Counter(query_keywords)
# total documents
total_documents = len(documents)
# calculate all documents' keywords IDF
all_keywords = set()
for document_keywords in documents_keywords:
all_keywords.update(document_keywords)
keyword_idf = {}
for keyword in all_keywords:
# calculate include query keywords' documents
doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
# IDF
keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
query_tfidf = {}
for keyword, count in query_keyword_counts.items():
tf = count
idf = keyword_idf.get(keyword, 0)
query_tfidf[keyword] = tf * idf
# calculate all documents' TF-IDF
documents_tfidf = []
for document_keywords in documents_keywords:
document_keyword_counts = Counter(document_keywords)
document_tfidf = {}
for keyword, count in document_keyword_counts.items():
tf = count
idf = keyword_idf.get(keyword, 0)
document_tfidf[keyword] = tf * idf
documents_tfidf.append(document_tfidf)
def cosine_similarity(vec1, vec2):
intersection = set(vec1.keys()) & set(vec2.keys())
numerator = sum(vec1[x] * vec2[x] for x in intersection)
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
denominator = math.sqrt(sum1) * math.sqrt(sum2)
if not denominator:
return 0.0
else:
return float(numerator) / denominator
similarities = []
for document_tfidf in documents_tfidf:
similarity = cosine_similarity(query_tfidf, document_tfidf)
similarities.append(similarity)
for document, score in zip(documents, similarities):
# format document
document.metadata['score'] = score
documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True)
return documents[:top_k] if top_k else documents
def calculate_vector_score(self, all_documents: list[Document],
top_k: int, score_threshold: float) -> list[Document]:
filter_documents = []
for document in all_documents:
if document.metadata['score'] >= score_threshold:
filter_documents.append(document)
if not filter_documents:
return []
filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True)
return filter_documents[:top_k] if top_k else filter_documents

View File

View File

@@ -7,7 +7,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.rerank.rerank import RerankRunner
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrival_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
@@ -72,7 +72,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
model=self.reranking_model_name
)
rerank_runner = RerankRunner(rerank_model_instance)
rerank_runner = RerankModelRunner(rerank_model_instance)
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
for hit_callback in self.hit_callbacks:
@@ -180,7 +180,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None
if retrieval_model['reranking_enable'] else None,
weights=retrieval_model.get('weights', None),
)
all_documents.extend(documents)

View File

@@ -78,7 +78,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None
if retrieval_model['reranking_enable'] else None,
weights=retrieval_model.get('weights', None),
)
else:
documents = []

View File

@@ -13,13 +13,41 @@ class RerankingModelConfig(BaseModel):
model: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
weight_type: str
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class MultipleRetrievalConfig(BaseModel):
"""
Multiple Retrieval Config.
"""
top_k: int
score_threshold: Optional[float] = None
reranking_mode: str = 'reranking_model'
reranking_enable: bool = True
reranking_model: RerankingModelConfig
weights: WeightedScoreConfig
class ModelConfig(BaseModel):

View File

@@ -138,13 +138,38 @@ class KnowledgeRetrievalNode(BaseNode):
planning_strategy=planning_strategy
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model':
reranking_model = {
'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model['provider'],
'reranking_model_name': node_data.multiple_retrieval_config.reranking_model['name']
}
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score':
reranking_model = None
weights = {
'weight_type': node_data.multiple_retrieval_config.weights.weight_type,
'vector_setting': {
"vector_weight": node_data.multiple_retrieval_config.weights.vector_setting.vector_weight,
"embedding_provider_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_provider_name,
"embedding_model_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_model_name,
},
'keyword_setting': {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
}
}
else:
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id,
self.user_from.value,
available_datasets, query,
node_data.multiple_retrieval_config.top_k,
node_data.multiple_retrieval_config.score_threshold,
node_data.multiple_retrieval_config.reranking_model.provider,
node_data.multiple_retrieval_config.reranking_model.model)
node_data.multiple_retrieval_config.reranking_mode,
reranking_model,
weights,
node_data.multiple_retrieval_config.reranking_enable,
)
context_list = []
if all_documents: