mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
Fix/ignore economy dataset (#1043)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
@@ -10,6 +10,7 @@ from flask import current_app
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask_login import current_user
|
||||
@@ -91,16 +92,18 @@ class DatasetService:
|
||||
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(
|
||||
f'Dataset with name {name} already exists.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
embedding_model = None
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||
dataset.created_by = account.id
|
||||
dataset.updated_by = account.id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
dataset.embedding_model = embedding_model.name
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
|
||||
dataset.embedding_model = embedding_model.name if embedding_model else None
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
return dataset
|
||||
@@ -115,6 +118,23 @@ class DatasetService:
|
||||
else:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_model_setting(dataset):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(f"The dataset in unavailable, due to: "
|
||||
f"{ex.description}")
|
||||
|
||||
@staticmethod
|
||||
def update_dataset(dataset_id, data, user):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@@ -124,6 +144,19 @@ class DatasetService:
|
||||
if data['indexing_technique'] == 'economy':
|
||||
deal_dataset_vector_index_task.delay(dataset_id, 'remove')
|
||||
elif data['indexing_technique'] == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
deal_dataset_vector_index_task.delay(dataset_id, 'add')
|
||||
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
|
||||
|
||||
@@ -397,23 +430,23 @@ class DocumentService:
|
||||
|
||||
# check document limit
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
count = 0
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||
count = len(upload_file_list)
|
||||
elif document_data["data_source"]["type"] == "notion_import":
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info['pages'])
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
total_count = documents_count + count
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if total_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
if 'original_document_id' not in document_data or not document_data['original_document_id']:
|
||||
count = 0
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||
count = len(upload_file_list)
|
||||
elif document_data["data_source"]["type"] == "notion_import":
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info['pages'])
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
total_count = documents_count + count
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if total_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
dataset.data_source_type = document_data["data_source"]["type"]
|
||||
db.session.commit()
|
||||
|
||||
if not dataset.indexing_technique:
|
||||
if 'indexing_technique' not in document_data \
|
||||
@@ -421,6 +454,13 @@ class DocumentService:
|
||||
raise ValueError("Indexing technique is required")
|
||||
|
||||
dataset.indexing_technique = document_data["indexing_technique"]
|
||||
if document_data["indexing_technique"] == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
)
|
||||
dataset.embedding_model = embedding_model.name
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
|
||||
|
||||
documents = []
|
||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||
@@ -466,11 +506,11 @@ class DocumentService:
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, file_name, batch)
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, file_name, batch)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
@@ -512,11 +552,11 @@ class DocumentService:
|
||||
"type": page['type']
|
||||
}
|
||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, page['page_name'], batch)
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, page['page_name'], batch)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
@@ -536,9 +576,9 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
||||
document_language: str, data_source_info: dict, created_from: str, position: int,
|
||||
account: Account,
|
||||
name: str, batch: str):
|
||||
document_language: str, data_source_info: dict, created_from: str, position: int,
|
||||
account: Account,
|
||||
name: str, batch: str):
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
@@ -567,6 +607,7 @@ class DocumentService:
|
||||
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
||||
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = 'web'):
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
|
||||
if document.display_status != 'available':
|
||||
raise ValueError("Document is not available")
|
||||
@@ -674,9 +715,11 @@ class DocumentService:
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if total_count > tenant_document_count:
|
||||
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
embedding_model = None
|
||||
if document_data['indexing_technique'] == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
@@ -684,8 +727,8 @@ class DocumentService:
|
||||
data_source_type=document_data["data_source"]["type"],
|
||||
indexing_technique=document_data["indexing_technique"],
|
||||
created_by=account.id,
|
||||
embedding_model=embedding_model.name,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name
|
||||
embedding_model=embedding_model.name if embedding_model else None,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
@@ -903,15 +946,15 @@ class SegmentService:
|
||||
content = args['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == document.id
|
||||
).scalar()
|
||||
@@ -973,15 +1016,16 @@ class SegmentService:
|
||||
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
segment.content = content
|
||||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
@@ -1013,7 +1057,7 @@ class SegmentService:
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Segment is deleting.")
|
||||
|
||||
|
||||
# enabled segment need to delete index
|
||||
if segment.enabled:
|
||||
# send delete segment index task
|
||||
|
||||
Reference in New Issue
Block a user