mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-11 11:56:53 +08:00
Feat/dataset service api (#1245)
Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
@@ -96,7 +96,7 @@ class DatasetService:
|
||||
embedding_model = None
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||
@@ -477,6 +477,7 @@ class DocumentService:
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
|
||||
|
||||
documents = []
|
||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||
if 'original_document_id' in document_data and document_data["original_document_id"]:
|
||||
@@ -626,6 +627,9 @@ class DocumentService:
|
||||
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
|
||||
if document.display_status != 'available':
|
||||
raise ValueError("Document is not available")
|
||||
# update document name
|
||||
if 'name' in document_data and document_data['name']:
|
||||
document.name = document_data['name']
|
||||
# save process rule
|
||||
if 'process_rule' in document_data and document_data['process_rule']:
|
||||
process_rule = document_data["process_rule"]
|
||||
@@ -767,7 +771,7 @@ class DocumentService:
|
||||
return dataset, documents, batch
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if 'original_document_id' not in args or not args['original_document_id']:
|
||||
DocumentService.data_source_args_validate(args)
|
||||
DocumentService.process_rule_args_validate(args)
|
||||
@@ -1014,6 +1018,66 @@ class SegmentService:
|
||||
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
|
||||
return segment
|
||||
|
||||
@classmethod
|
||||
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == document.id
|
||||
).scalar()
|
||||
pre_segment_data_list = []
|
||||
segment_data_list = []
|
||||
for segment_item in segments:
|
||||
content = segment_item['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality' and embedding_model:
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
document_id=document.id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
status='completed',
|
||||
indexing_at=datetime.datetime.utcnow(),
|
||||
completed_at=datetime.datetime.utcnow(),
|
||||
created_by=current_user.id
|
||||
)
|
||||
if document.doc_form == 'qa_model':
|
||||
segment_document.answer = segment_item['answer']
|
||||
db.session.add(segment_document)
|
||||
segment_data_list.append(segment_document)
|
||||
pre_segment_data = {
|
||||
'segment': segment_document,
|
||||
'keywords': segment_item['keywords']
|
||||
}
|
||||
pre_segment_data_list.append(pre_segment_data)
|
||||
|
||||
try:
|
||||
# save vector index
|
||||
VectorService.multi_create_segment_vector(pre_segment_data_list, dataset)
|
||||
except Exception as e:
|
||||
logging.exception("create segment index failed")
|
||||
for segment_document in segment_data_list:
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = datetime.datetime.utcnow()
|
||||
segment_document.status = 'error'
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
return segment_data_list
|
||||
|
||||
@classmethod
|
||||
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
|
||||
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
||||
|
||||
Reference in New Issue
Block a user