Feat/dataset service api (#1245)

Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
Jyong
2023-09-27 16:06:32 +08:00
committed by GitHub
parent 54ff03c35d
commit 46154c6705
43 changed files with 1636 additions and 906 deletions

View File

@@ -96,7 +96,7 @@ class DatasetService:
embedding_model = None
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
tenant_id=tenant_id
)
dataset = Dataset(name=name, indexing_technique=indexing_technique)
# dataset = Dataset(name=name, provider=provider, config=config)
@@ -477,6 +477,7 @@ class DocumentService:
)
dataset.collection_binding_id = dataset_collection_binding.id
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
if 'original_document_id' in document_data and document_data["original_document_id"]:
@@ -626,6 +627,9 @@ class DocumentService:
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
if document.display_status != 'available':
raise ValueError("Document is not available")
# update document name
if 'name' in document_data and document_data['name']:
document.name = document_data['name']
# save process rule
if 'process_rule' in document_data and document_data['process_rule']:
process_rule = document_data["process_rule"]
@@ -767,7 +771,7 @@ class DocumentService:
return dataset, documents, batch
@classmethod
def document_create_args_validate(cls, args: dict):
def document_create_args_validate(cls, args: dict):
if 'original_document_id' not in args or not args['original_document_id']:
DocumentService.data_source_args_validate(args)
DocumentService.process_rule_args_validate(args)
@@ -1014,6 +1018,66 @@ class SegmentService:
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
return segment
@classmethod
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
embedding_model = None
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
).scalar()
pre_segment_data_list = []
segment_data_list = []
for segment_item in segments:
content = segment_item['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == 'high_quality' and embedding_model:
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
status='completed',
indexing_at=datetime.datetime.utcnow(),
completed_at=datetime.datetime.utcnow(),
created_by=current_user.id
)
if document.doc_form == 'qa_model':
segment_document.answer = segment_item['answer']
db.session.add(segment_document)
segment_data_list.append(segment_document)
pre_segment_data = {
'segment': segment_document,
'keywords': segment_item['keywords']
}
pre_segment_data_list.append(pre_segment_data)
try:
# save vector index
VectorService.multi_create_segment_vector(pre_segment_data_list, dataset)
except Exception as e:
logging.exception("create segment index failed")
for segment_document in segment_data_list:
segment_document.enabled = False
segment_document.disabled_at = datetime.datetime.utcnow()
segment_document.status = 'error'
segment_document.error = str(e)
db.session.commit()
return segment_data_list
@classmethod
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)