Feat: Q&A format segmentation support (#668)

Co-authored-by: jyong <718720800@qq.com>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
KVOJJJin
2023-07-28 20:47:15 +08:00
committed by GitHub
parent aae2fb8a30
commit cf93d8d6e2
52 changed files with 2038 additions and 274 deletions

View File

@@ -220,6 +220,7 @@ class DatasetIndexingEstimateApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
@@ -234,12 +235,12 @@ class DatasetIndexingEstimateApi(Resource):
raise NotFound("File not found.")
indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'])
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form'])
elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
args['process_rule'])
args['process_rule'], args['doc_form'])
else:
raise ValueError('Data source type not support')
return response, 200

View File

@@ -60,6 +60,7 @@ document_fields = {
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'doc_form': fields.String,
}
document_with_segments_fields = {
@@ -86,6 +87,7 @@ document_with_segments_fields = {
'total_segments': fields.Integer
}
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
dataset = DatasetService.get_dataset(dataset_id)
@@ -269,6 +271,7 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('process_rule', type=dict, required=False, location='json')
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']:
@@ -313,6 +316,7 @@ class DatasetInitApi(Resource):
nullable=False, location='json')
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
args = parser.parse_args()
# validate args
@@ -488,6 +492,8 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
DocumentSegment.status != 're_segment').count()
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = 'paused'
documents_status.append(marshal(document, self.document_status_fields))
data = {
'data': documents_status
@@ -583,7 +589,8 @@ class DocumentDetailApi(DocumentResource):
'segment_count': document.segment_count,
'average_segment_length': document.average_segment_length,
'hit_count': document.hit_count,
'display_status': document.display_status
'display_status': document.display_status,
'doc_form': document.doc_form
}
else:
process_rules = DatasetService.get_process_rules(dataset_id)
@@ -614,7 +621,8 @@ class DocumentDetailApi(DocumentResource):
'segment_count': document.segment_count,
'average_segment_length': document.average_segment_length,
'hit_count': document.hit_count,
'display_status': document.display_status
'display_status': document.display_status,
'doc_form': document.doc_form
}
return response, 200

View File

@@ -15,8 +15,8 @@ from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService
from tasks.add_segment_to_index_task import add_segment_to_index_task
from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.remove_segment_from_index_task import remove_segment_from_index_task
segment_fields = {
@@ -24,6 +24,7 @@ segment_fields = {
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
@@ -125,6 +126,7 @@ class DatasetDocumentSegmentListApi(Resource):
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form,
'has_more': has_more,
'limit': limit,
'total': total
@@ -180,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_segment_to_index_task.delay(segment.id)
enable_segment_to_index_task.delay(segment.id)
return {'result': 'success'}, 200
elif action == "disable":
@@ -202,7 +204,92 @@ class DatasetDocumentSegmentApi(Resource):
raise InvalidActionError()
class DatasetDocumentSegmentAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')

View File

@@ -28,6 +28,7 @@ segment_fields = {
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),