Fix/ignore economy dataset (#1043)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2023-08-29 03:37:45 +08:00
committed by GitHub
parent f9bec1edf8
commit a55ba6e614
13 changed files with 320 additions and 205 deletions

View File

@@ -10,6 +10,7 @@ from flask import current_app
from sqlalchemy import func
from core.index.index import IndexBuilder
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client
from flask_login import current_user
@@ -91,16 +92,18 @@ class DatasetService:
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(
f'Dataset with name {name} already exists.')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
embedding_model = None
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
dataset = Dataset(name=name, indexing_technique=indexing_technique)
# dataset = Dataset(name=name, provider=provider, config=config)
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
dataset.embedding_model = embedding_model.name if embedding_model else None
db.session.add(dataset)
db.session.commit()
return dataset
@@ -115,6 +118,23 @@ class DatasetService:
else:
return dataset
@staticmethod
def check_dataset_model_setting(dataset):
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ValueError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(f"The dataset in unavailable, due to: "
f"{ex.description}")
@staticmethod
def update_dataset(dataset_id, data, user):
dataset = DatasetService.get_dataset(dataset_id)
@@ -124,6 +144,19 @@ class DatasetService:
if data['indexing_technique'] == 'economy':
deal_dataset_vector_index_task.delay(dataset_id, 'remove')
elif data['indexing_technique'] == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ValueError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
deal_dataset_vector_index_task.delay(dataset_id, 'add')
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
@@ -397,23 +430,23 @@ class DocumentService:
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
count = 0
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
count = len(upload_file_list)
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
count = count + len(notion_info['pages'])
documents_count = DocumentService.get_tenant_documents_count()
total_count = documents_count + count
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if total_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
if 'original_document_id' not in document_data or not document_data['original_document_id']:
count = 0
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
count = len(upload_file_list)
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
count = count + len(notion_info['pages'])
documents_count = DocumentService.get_tenant_documents_count()
total_count = documents_count + count
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if total_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
# if dataset is empty, update dataset data_source_type
if not dataset.data_source_type:
dataset.data_source_type = document_data["data_source"]["type"]
db.session.commit()
if not dataset.indexing_technique:
if 'indexing_technique' not in document_data \
@@ -421,6 +454,13 @@ class DocumentService:
raise ValueError("Indexing technique is required")
dataset.indexing_technique = document_data["indexing_technique"]
if document_data["indexing_technique"] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
@@ -466,11 +506,11 @@ class DocumentService:
"upload_file_id": file_id,
}
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, file_name, batch)
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, file_name, batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
@@ -512,11 +552,11 @@ class DocumentService:
"type": page['type']
}
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, page['page_name'], batch)
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, page['page_name'], batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
@@ -536,9 +576,9 @@ class DocumentService:
@staticmethod
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
document_language: str, data_source_info: dict, created_from: str, position: int,
account: Account,
name: str, batch: str):
document_language: str, data_source_info: dict, created_from: str, position: int,
account: Account,
name: str, batch: str):
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
@@ -567,6 +607,7 @@ class DocumentService:
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'):
DatasetService.check_dataset_model_setting(dataset)
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
if document.display_status != 'available':
raise ValueError("Document is not available")
@@ -674,9 +715,11 @@ class DocumentService:
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
embedding_model = None
if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
# save dataset
dataset = Dataset(
tenant_id=tenant_id,
@@ -684,8 +727,8 @@ class DocumentService:
data_source_type=document_data["data_source"]["type"],
indexing_technique=document_data["indexing_technique"],
created_by=account.id,
embedding_model=embedding_model.name,
embedding_model_provider=embedding_model.model_provider.provider_name
embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
)
db.session.add(dataset)
@@ -903,15 +946,15 @@ class SegmentService:
content = args['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
tokens = 0
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
).scalar()
@@ -973,15 +1016,16 @@ class SegmentService:
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
@@ -1013,7 +1057,7 @@ class SegmentService:
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise ValueError("Segment is deleting.")
# enabled segment need to delete index
if segment.enabled:
# send delete segment index task