mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 02:46:52 +08:00
Feat/add retriever rerank (#1560)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
@@ -173,6 +173,9 @@ class DatasetService:
|
||||
filtered_data['updated_by'] = user.id
|
||||
filtered_data['updated_at'] = datetime.datetime.now()
|
||||
|
||||
# update Retrieval model
|
||||
filtered_data['retrieval_model'] = data['retrieval_model']
|
||||
|
||||
dataset.query.filter_by(id=dataset_id).update(filtered_data)
|
||||
|
||||
db.session.commit()
|
||||
@@ -473,7 +476,19 @@ class DocumentService:
|
||||
embedding_model.name
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
|
||||
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
|
||||
|
||||
documents = []
|
||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||
@@ -733,6 +748,7 @@ class DocumentService:
|
||||
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
||||
embedding_model = None
|
||||
dataset_collection_binding_id = None
|
||||
retrieval_model = None
|
||||
if document_data['indexing_technique'] == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
@@ -742,6 +758,20 @@ class DocumentService:
|
||||
embedding_model.name
|
||||
)
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
if 'retrieval_model' in document_data and document_data['retrieval_model']:
|
||||
retrieval_model = document_data['retrieval_model']
|
||||
else:
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
retrieval_model = default_retrieval_model
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
@@ -751,7 +781,8 @@ class DocumentService:
|
||||
created_by=account.id,
|
||||
embedding_model=embedding_model.name if embedding_model else None,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
|
||||
collection_binding_id=dataset_collection_binding_id
|
||||
collection_binding_id=dataset_collection_binding_id,
|
||||
retrieval_model=retrieval_model
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
@@ -768,7 +799,7 @@ class DocumentService:
|
||||
return dataset, documents, batch
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if 'original_document_id' not in args or not args['original_document_id']:
|
||||
DocumentService.data_source_args_validate(args)
|
||||
DocumentService.process_rule_args_validate(args)
|
||||
|
||||
Reference in New Issue
Block a user