mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 10:56:52 +08:00
Feat/add retriever rerank (#1560)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
@@ -470,7 +470,16 @@ class AppModelConfigService:
|
||||
|
||||
# dataset_configs
|
||||
if 'dataset_configs' not in config or not config["dataset_configs"]:
|
||||
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
|
||||
config["dataset_configs"] = {'retrieval_model': 'single'}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if config["dataset_configs"]['retrieval_model'] == 'multiple':
|
||||
if not config["dataset_configs"]['reranking_model']:
|
||||
raise ValueError("reranking_model has not been set")
|
||||
if not isinstance(config["dataset_configs"]['reranking_model'], dict):
|
||||
raise ValueError("reranking_model must be of object type")
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
@@ -173,6 +173,9 @@ class DatasetService:
|
||||
filtered_data['updated_by'] = user.id
|
||||
filtered_data['updated_at'] = datetime.datetime.now()
|
||||
|
||||
# update Retrieval model
|
||||
filtered_data['retrieval_model'] = data['retrieval_model']
|
||||
|
||||
dataset.query.filter_by(id=dataset_id).update(filtered_data)
|
||||
|
||||
db.session.commit()
|
||||
@@ -473,7 +476,19 @@ class DocumentService:
|
||||
embedding_model.name
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
|
||||
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
|
||||
|
||||
documents = []
|
||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||
@@ -733,6 +748,7 @@ class DocumentService:
|
||||
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
||||
embedding_model = None
|
||||
dataset_collection_binding_id = None
|
||||
retrieval_model = None
|
||||
if document_data['indexing_technique'] == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
@@ -742,6 +758,20 @@ class DocumentService:
|
||||
embedding_model.name
|
||||
)
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
if 'retrieval_model' in document_data and document_data['retrieval_model']:
|
||||
retrieval_model = document_data['retrieval_model']
|
||||
else:
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
retrieval_model = default_retrieval_model
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
@@ -751,7 +781,8 @@ class DocumentService:
|
||||
created_by=account.id,
|
||||
embedding_model=embedding_model.name if embedding_model else None,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
|
||||
collection_binding_id=dataset_collection_binding_id
|
||||
collection_binding_id=dataset_collection_binding_id,
|
||||
retrieval_model=retrieval_model
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
@@ -768,7 +799,7 @@ class DocumentService:
|
||||
return dataset, documents, batch
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if 'original_document_id' not in args or not args['original_document_id']:
|
||||
DocumentService.data_source_args_validate(args)
|
||||
DocumentService.process_rule_args_validate(args)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
@@ -9,16 +11,26 @@ from langchain.schema import Document
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DocumentSegment, DatasetQuery
|
||||
from services.retrieval_service import RetrievalService
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
|
||||
class HitTestingService:
|
||||
@classmethod
|
||||
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
|
||||
def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
|
||||
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return {
|
||||
"query": {
|
||||
@@ -28,31 +40,68 @@ class HitTestingService:
|
||||
"records": []
|
||||
}
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
if not retrieval_model:
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get embedding model
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
all_documents = []
|
||||
threads = []
|
||||
|
||||
# retrieval_model source with semantic
|
||||
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'query': query,
|
||||
'top_k': retrieval_model['top_k'],
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
|
||||
'all_documents': all_documents,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings
|
||||
})
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval source with full text
|
||||
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'query': query,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
'top_k': retrieval_model['top_k'],
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
|
||||
'all_documents': all_documents
|
||||
})
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if retrieval_model['search_method'] == 'hybrid_search':
|
||||
hybrid_rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_name=retrieval_model['reranking_model']['reranking_model_name']
|
||||
)
|
||||
all_documents = hybrid_rerank.rerank(query, all_documents,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
retrieval_model['top_k'])
|
||||
|
||||
start = time.perf_counter()
|
||||
documents = vector_index.search(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': 10,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
end = time.perf_counter()
|
||||
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
|
||||
|
||||
@@ -67,7 +116,7 @@ class HitTestingService:
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
return cls.compact_retrieve_response(dataset, embeddings, query, documents)
|
||||
return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
|
||||
|
||||
@classmethod
|
||||
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
|
||||
@@ -99,7 +148,7 @@ class HitTestingService:
|
||||
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata['score'],
|
||||
"score": document.metadata.get('score', None),
|
||||
"tsne_position": tsne_position_data[i]
|
||||
}
|
||||
|
||||
@@ -136,3 +185,11 @@ class HitTestingService:
|
||||
tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])})
|
||||
|
||||
return tsne_position_data
|
||||
|
||||
@classmethod
|
||||
def hit_testing_args_check(cls, args):
|
||||
query = args['query']
|
||||
|
||||
if not query or len(query) > 250:
|
||||
raise ValueError('Query is required and cannot exceed 250 characters')
|
||||
|
||||
|
||||
88
api/services/retrieval_service.py
Normal file
88
api/services/retrieval_service.py
Normal file
@@ -0,0 +1,88 @@
|
||||
|
||||
from typing import Optional
|
||||
from flask import current_app, Flask
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from models.dataset import Dataset
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
|
||||
|
||||
class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, search_method: str, embeddings: Embeddings):
|
||||
with flask_app.app_context():
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = vector_index.search(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': top_k,
|
||||
'score_threshold': score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if documents:
|
||||
if reranking_model and search_method == 'semantic_search':
|
||||
rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=reranking_model['reranking_provider_name'],
|
||||
model_name=reranking_model['reranking_model_name']
|
||||
)
|
||||
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
|
||||
@classmethod
|
||||
def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, search_method: str, embeddings: Embeddings):
|
||||
with flask_app.app_context():
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = vector_index.search_by_full_text_index(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
if reranking_model and search_method == 'full_text_search':
|
||||
rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=reranking_model['reranking_provider_name'],
|
||||
model_name=reranking_model['reranking_model_name']
|
||||
)
|
||||
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user