mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-11 20:06:54 +08:00
Initial commit
This commit is contained in:
281
api/controllers/console/datasets/datasets.py
Normal file
281
api/controllers/console/datasets/datasets.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'description': fields.String,
|
||||
'provider': fields.String,
|
||||
'permission': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'indexing_technique': fields.String,
|
||||
'app_count': fields.Integer,
|
||||
'document_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'updated_by': fields.String,
|
||||
'updated_at': TimestampField,
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
"id": fields.String,
|
||||
"content": fields.String,
|
||||
"source": fields.String,
|
||||
"source_app_id": fields.String,
|
||||
"created_by_role": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField
|
||||
}
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError('Name must be between 1 to 40 characters.')
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 200:
|
||||
raise ValueError('Description cannot exceed 200 characters.')
|
||||
return description
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
ids = request.args.getlist('ids')
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
current_user.current_tenant_id, current_user)
|
||||
|
||||
response = {
|
||||
'data': marshal(datasets, dataset_detail_fields),
|
||||
'has_more': len(datasets) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='type is required. Name must be between 1 to 40 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=('high_quality', 'economy'),
|
||||
help='Invalid indexing technique.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name=args['name'],
|
||||
indexing_technique=args['indexing_technique'],
|
||||
account=current_user
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
class DatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(
|
||||
dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False,
|
||||
help='type is required. Name must be between 1 to 40 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('description',
|
||||
location='json', store_missing=False,
|
||||
type=_validate_description_length)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=('high_quality', 'economy'),
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
dataset = DatasetService.update_dataset(
|
||||
dataset_id_str, args, current_user)
|
||||
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
return {'result': 'success'}, 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
|
||||
class DatasetQueryApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
|
||||
dataset_queries, total = DatasetService.get_dataset_queries(
|
||||
dataset_id=dataset.id,
|
||||
page=page,
|
||||
per_page=limit
|
||||
)
|
||||
|
||||
response = {
|
||||
'data': marshal(dataset_queries, dataset_query_detail_fields),
|
||||
'has_more': len(dataset_queries) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
class DatasetIndexingEstimateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
segment_rule = request.get_json()
|
||||
file_detail = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id == segment_rule["file_id"]
|
||||
).first()
|
||||
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(file_detail, segment_rule['process_rule'])
|
||||
return response, 200
|
||||
|
||||
|
||||
class DatasetRelatedAppListApi(Resource):
|
||||
app_detail_kernel_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
}
|
||||
|
||||
related_app_list = {
|
||||
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
|
||||
'total': fields.Integer,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(related_app_list)
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
app_dataset_joins = DatasetService.get_related_apps(dataset.id)
|
||||
|
||||
related_apps = []
|
||||
for app_dataset_join in app_dataset_joins:
|
||||
app_model = app_dataset_join.app
|
||||
if app_model:
|
||||
related_apps.append(app_model)
|
||||
|
||||
return {
|
||||
'data': related_apps,
|
||||
'total': len(related_apps)
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/file-indexing-estimate')
|
||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
||||
682
api/controllers/console/datasets/datasets_document.py
Normal file
682
api/controllers/console/datasets/datasets_document.py
Normal file
@@ -0,0 +1,682 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import desc, asc
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DocumentAlreadyFinishedError, InvalidActionError, DocumentIndexingError, \
|
||||
InvalidMetadataError, ArchivedDocumentImmutableError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetProcessRule, Dataset
|
||||
from models.dataset import Document, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService, DatasetService
|
||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
|
||||
dataset_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'description': fields.String,
|
||||
'permission': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'indexing_technique': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
}
|
||||
|
||||
document_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'data_source_type': fields.String,
|
||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||
'dataset_process_rule_id': fields.String,
|
||||
'name': fields.String,
|
||||
'created_from': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'tokens': fields.Integer,
|
||||
'indexing_status': fields.String,
|
||||
'error': fields.String,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'archived': fields.Boolean,
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
}
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
|
||||
if document.tenant_id != current_user.current_tenant_id:
|
||||
raise Forbidden('No permission.')
|
||||
|
||||
return document
|
||||
|
||||
|
||||
class GetProcessRuleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get('document_id')
|
||||
if document_id:
|
||||
# get the latest process rule
|
||||
document = Document.query.get_or_404(document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(document.dataset_id)
|
||||
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# get the latest process rule
|
||||
dataset_process_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.dataset_id == document.dataset_id). \
|
||||
order_by(DatasetProcessRule.created_at.desc()). \
|
||||
limit(1). \
|
||||
one_or_none()
|
||||
mode = dataset_process_rule.mode
|
||||
rules = dataset_process_rule.rules_dict
|
||||
else:
|
||||
mode = DocumentService.DEFAULT_RULES['mode']
|
||||
rules = DocumentService.DEFAULT_RULES['rules']
|
||||
|
||||
return {
|
||||
'mode': mode,
|
||||
'rules': rules
|
||||
}
|
||||
|
||||
|
||||
class DatasetDocumentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
search = request.args.get('search', default=None, type=str)
|
||||
sort = request.args.get('sort', default='-created_at', type=str)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
query = Document.query.filter_by(
|
||||
dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
|
||||
|
||||
if search:
|
||||
search = f'%{search}%'
|
||||
query = query.filter(Document.name.like(search))
|
||||
|
||||
if sort.startswith('-'):
|
||||
sort_logic = desc
|
||||
sort = sort[1:]
|
||||
else:
|
||||
sort_logic = asc
|
||||
|
||||
if sort == 'hit_count':
|
||||
sub_query = db.select(DocumentSegment.document_id,
|
||||
db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) \
|
||||
.group_by(DocumentSegment.document_id) \
|
||||
.subquery()
|
||||
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \
|
||||
.order_by(sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)))
|
||||
elif sort == 'created_at':
|
||||
query = query.order_by(sort_logic(Document.created_at))
|
||||
else:
|
||||
query = query.order_by(desc(Document.created_at))
|
||||
|
||||
paginated_documents = query.paginate(
|
||||
page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
response = {
|
||||
'data': marshal(documents, document_fields),
|
||||
'has_more': len(documents) == limit,
|
||||
'limit': limit,
|
||||
'total': paginated_documents.total,
|
||||
'page': page
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(document_fields)
|
||||
def post(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
raise ValueError('indexing_technique is required.')
|
||||
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
document = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
|
||||
return document
|
||||
|
||||
|
||||
class DatasetInitApi(Resource):
|
||||
dataset_and_document_fields = {
|
||||
'dataset': fields.Nested(dataset_fields),
|
||||
'document': fields.Nested(document_fields)
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_fields)
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True,
|
||||
nullable=False, location='json')
|
||||
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
dataset, document = DocumentService.save_document_without_dataset_id(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
document_data=args,
|
||||
account=current_user
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
|
||||
response = {
|
||||
'dataset': dataset,
|
||||
'document': document
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class DocumentIndexingEstimateApi(DocumentResource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
if document.indexing_status in ['completed', 'error']:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
|
||||
response = {
|
||||
"tokens": 0,
|
||||
"total_price": 0,
|
||||
"currency": "USD",
|
||||
"total_segments": 0,
|
||||
"preview": []
|
||||
}
|
||||
|
||||
if document.data_source_type == 'upload_file':
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and 'upload_file_id' in data_source_info:
|
||||
file_id = data_source_info['upload_file_id']
|
||||
|
||||
file = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == document.tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise NotFound('File not found.')
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(file, data_process_rule_dict)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class DocumentIndexingStatusApi(DocumentResource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
completed_segments = DocumentSegment.query \
|
||||
.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id)) \
|
||||
.count()
|
||||
total_segments = DocumentSegment.query \
|
||||
.filter_by(document_id=str(document_id)) \
|
||||
.count()
|
||||
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
|
||||
return marshal(document, self.document_status_fields)
|
||||
|
||||
|
||||
class DocumentDetailApi(DocumentResource):
|
||||
METADATA_CHOICES = {'all', 'only', 'without'}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
metadata = request.args.get('metadata', 'all')
|
||||
if metadata not in self.METADATA_CHOICES:
|
||||
raise InvalidMetadataError(f'Invalid metadata value: {metadata}')
|
||||
|
||||
if metadata == 'only':
|
||||
response = {
|
||||
'id': document.id,
|
||||
'doc_type': document.doc_type,
|
||||
'doc_metadata': document.doc_metadata
|
||||
}
|
||||
elif metadata == 'without':
|
||||
process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
'id': document.id,
|
||||
'position': document.position,
|
||||
'data_source_type': document.data_source_type,
|
||||
'data_source_info': data_source_info,
|
||||
'dataset_process_rule_id': document.dataset_process_rule_id,
|
||||
'dataset_process_rule': process_rules,
|
||||
'name': document.name,
|
||||
'created_from': document.created_from,
|
||||
'created_by': document.created_by,
|
||||
'created_at': document.created_at.timestamp(),
|
||||
'tokens': document.tokens,
|
||||
'indexing_status': document.indexing_status,
|
||||
'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None,
|
||||
'indexing_latency': document.indexing_latency,
|
||||
'error': document.error,
|
||||
'enabled': document.enabled,
|
||||
'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None,
|
||||
'disabled_by': document.disabled_by,
|
||||
'archived': document.archived,
|
||||
'segment_count': document.segment_count,
|
||||
'average_segment_length': document.average_segment_length,
|
||||
'hit_count': document.hit_count,
|
||||
'display_status': document.display_status
|
||||
}
|
||||
else:
|
||||
process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
data_source_info = document.data_source_detail_dict_()
|
||||
response = {
|
||||
'id': document.id,
|
||||
'position': document.position,
|
||||
'data_source_type': document.data_source_type,
|
||||
'data_source_info': data_source_info,
|
||||
'dataset_process_rule_id': document.dataset_process_rule_id,
|
||||
'dataset_process_rule': process_rules,
|
||||
'name': document.name,
|
||||
'created_from': document.created_from,
|
||||
'created_by': document.created_by,
|
||||
'created_at': document.created_at.timestamp(),
|
||||
'tokens': document.tokens,
|
||||
'indexing_status': document.indexing_status,
|
||||
'completed_at': int(document.completed_at.timestamp())if document.completed_at else None,
|
||||
'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None,
|
||||
'indexing_latency': document.indexing_latency,
|
||||
'error': document.error,
|
||||
'enabled': document.enabled,
|
||||
'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None,
|
||||
'disabled_by': document.disabled_by,
|
||||
'archived': document.archived,
|
||||
'doc_type': document.doc_type,
|
||||
'doc_metadata': document.doc_metadata,
|
||||
'segment_count': document.segment_count,
|
||||
'average_segment_length': document.average_segment_length,
|
||||
'hit_count': document.hit_count,
|
||||
'display_status': document.display_status
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
class DocumentProcessingApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
if action == "pause":
|
||||
if document.indexing_status != "indexing":
|
||||
raise InvalidActionError('Document not in indexing state.')
|
||||
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = datetime.utcnow()
|
||||
document.is_paused = True
|
||||
db.session.commit()
|
||||
|
||||
elif action == "resume":
|
||||
if document.indexing_status not in ["paused", "error"]:
|
||||
raise InvalidActionError('Document not in paused or error state.')
|
||||
|
||||
document.paused_by = None
|
||||
document.paused_at = None
|
||||
document.is_paused = False
|
||||
db.session.commit()
|
||||
else:
|
||||
raise InvalidActionError()
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class DocumentDeleteApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError('Cannot delete document during indexing.')
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DocumentMetadataApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
req_data = request.get_json()
|
||||
|
||||
doc_type = req_data.get('doc_type')
|
||||
doc_metadata = req_data.get('doc_metadata')
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
if doc_type is None or doc_metadata is None:
|
||||
raise ValueError('Both doc_type and doc_metadata must be provided.')
|
||||
|
||||
if doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA:
|
||||
raise ValueError('Invalid doc_type.')
|
||||
|
||||
if not isinstance(doc_metadata, dict):
|
||||
raise ValueError('doc_metadata must be a dictionary.')
|
||||
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||
|
||||
document.doc_metadata = {}
|
||||
|
||||
for key, value_type in metadata_schema.items():
|
||||
value = doc_metadata.get(key)
|
||||
if value is not None and isinstance(value, value_type):
|
||||
document.doc_metadata[key] = value
|
||||
|
||||
document.doc_type = doc_type
|
||||
document.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success', 'message': 'Document metadata updated.'}, 200
|
||||
|
||||
|
||||
class DocumentStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
indexing_cache_key = 'document_{}_indexing'.format(document.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise InvalidActionError("Document is being indexed, please try again later")
|
||||
|
||||
if action == "enable":
|
||||
if document.enabled:
|
||||
raise InvalidActionError('Document already enabled.')
|
||||
|
||||
document.enabled = True
|
||||
document.disabled_at = None
|
||||
document.disabled_by = None
|
||||
document.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
add_document_to_index_task.delay(document_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
elif action == "disable":
|
||||
if not document.enabled:
|
||||
raise InvalidActionError('Document already disabled.')
|
||||
|
||||
document.enabled = False
|
||||
document.disabled_at = datetime.utcnow()
|
||||
document.disabled_by = current_user.id
|
||||
document.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
remove_document_from_index_task.delay(document_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
elif action == "archive":
|
||||
if document.archived:
|
||||
raise InvalidActionError('Document already archived.')
|
||||
|
||||
document.archived = True
|
||||
document.archived_at = datetime.utcnow()
|
||||
document.archived_by = current_user.id
|
||||
document.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
if document.enabled:
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
remove_document_from_index_task.delay(document_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
else:
|
||||
raise InvalidActionError()
|
||||
|
||||
|
||||
class DocumentPauseApi(DocumentResource):
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""pause document."""
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
try:
|
||||
# pause document
|
||||
DocumentService.pause_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError('Cannot pause completed document.')
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DocumentRecoverApi(DocumentResource):
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""recover document."""
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
try:
|
||||
# pause document
|
||||
DocumentService.recover_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError('Document is not in paused status.')
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
|
||||
api.add_resource(DatasetDocumentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents')
|
||||
api.add_resource(DatasetInitApi,
|
||||
'/datasets/init')
|
||||
api.add_resource(DocumentIndexingEstimateApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate')
|
||||
api.add_resource(DocumentIndexingStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status')
|
||||
api.add_resource(DocumentDetailApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
|
||||
api.add_resource(DocumentProcessingApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>')
|
||||
api.add_resource(DocumentDeleteApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
|
||||
api.add_resource(DocumentMetadataApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata')
|
||||
api.add_resource(DocumentStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
|
||||
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
|
||||
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
|
||||
203
api/controllers/console/datasets/datasets_segments.py
Normal file
203
api/controllers/console/datasets/datasets_segments.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import InvalidActionError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.add_segment_to_index_task import add_segment_to_index_task
|
||||
from tasks.remove_segment_from_index_task import remove_segment_from_index_task
|
||||
|
||||
segment_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'document_id': fields.String,
|
||||
'content': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'tokens': fields.Integer,
|
||||
'keywords': fields.List(fields.String),
|
||||
'index_node_id': fields.String,
|
||||
'index_node_hash': fields.String,
|
||||
'hit_count': fields.Integer,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'status': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'indexing_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField
|
||||
}
|
||||
|
||||
segment_list_response = {
|
||||
'data': fields.List(fields.Nested(segment_fields)),
|
||||
'has_more': fields.Boolean,
|
||||
'limit': fields.Integer
|
||||
}
|
||||
|
||||
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=str, default=None, location='args')
|
||||
parser.add_argument('limit', type=int, default=20, location='args')
|
||||
parser.add_argument('status', type=str,
|
||||
action='append', default=[], location='args')
|
||||
parser.add_argument('hit_count_gte', type=int,
|
||||
default=None, location='args')
|
||||
parser.add_argument('enabled', type=str, default='all', location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
last_id = args['last_id']
|
||||
limit = min(args['limit'], 100)
|
||||
status_list = args['status']
|
||||
hit_count_gte = args['hit_count_gte']
|
||||
|
||||
query = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
)
|
||||
|
||||
if last_id is not None:
|
||||
last_segment = DocumentSegment.query.get(str(last_id))
|
||||
if last_segment:
|
||||
query = query.filter(
|
||||
DocumentSegment.position > last_segment.position)
|
||||
else:
|
||||
return {'data': [], 'has_more': False, 'limit': limit}, 200
|
||||
|
||||
if status_list:
|
||||
query = query.filter(DocumentSegment.status.in_(status_list))
|
||||
|
||||
if hit_count_gte is not None:
|
||||
query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
|
||||
|
||||
if args['enabled'].lower() != 'all':
|
||||
if args['enabled'].lower() == 'true':
|
||||
query = query.filter(DocumentSegment.enabled == True)
|
||||
elif args['enabled'].lower() == 'false':
|
||||
query = query.filter(DocumentSegment.enabled == False)
|
||||
|
||||
total = query.count()
|
||||
segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()
|
||||
|
||||
has_more = False
|
||||
if len(segments) > limit:
|
||||
has_more = True
|
||||
segments = segments[:-1]
|
||||
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'has_more': has_more,
|
||||
'limit': limit,
|
||||
'total': total
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetDocumentSegmentApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id, segment_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
|
||||
document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
|
||||
cache_result = redis_client.get(document_indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise InvalidActionError("Document is being indexed, please try again later")
|
||||
|
||||
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise InvalidActionError("Segment is being indexed, please try again later")
|
||||
|
||||
if action == "enable":
|
||||
if segment.enabled:
|
||||
raise InvalidActionError("Segment is already enabled.")
|
||||
|
||||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same segment multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
add_segment_to_index_task.delay(segment.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
elif action == "disable":
|
||||
if not segment.enabled:
|
||||
raise InvalidActionError("Segment is already disabled.")
|
||||
|
||||
segment.enabled = False
|
||||
segment.disabled_at = datetime.utcnow()
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same segment multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
remove_segment_from_index_task.delay(segment.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
else:
|
||||
raise InvalidActionError()
|
||||
|
||||
|
||||
api.add_resource(DatasetDocumentSegmentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
|
||||
api.add_resource(DatasetDocumentSegmentApi,
|
||||
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
|
||||
73
api/controllers/console/datasets/error.py
Normal file
73
api/controllers/console/datasets/error.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "No file uploaded."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = 'file_too_large'
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class HighQualityDatasetOnlyError(BaseHTTPException):
|
||||
error_code = 'high_quality_dataset_only'
|
||||
description = "High quality dataset only."
|
||||
code = 400
|
||||
|
||||
|
||||
class DatasetNotInitializedError(BaseHTTPException):
|
||||
error_code = 'dataset_not_initialized'
|
||||
description = "Dataset not initialized."
|
||||
code = 400
|
||||
|
||||
|
||||
class ArchivedDocumentImmutableError(BaseHTTPException):
|
||||
error_code = 'archived_document_immutable'
|
||||
description = "Cannot process an archived document."
|
||||
code = 403
|
||||
|
||||
|
||||
class DatasetNameDuplicateError(BaseHTTPException):
|
||||
error_code = 'dataset_name_duplicate'
|
||||
description = "Dataset name already exists."
|
||||
code = 409
|
||||
|
||||
|
||||
class InvalidActionError(BaseHTTPException):
|
||||
error_code = 'invalid_action'
|
||||
description = "Invalid action."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentAlreadyFinishedError(BaseHTTPException):
|
||||
error_code = 'document_already_finished'
|
||||
description = "Document already finished."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentIndexingError(BaseHTTPException):
|
||||
error_code = 'document_indexing'
|
||||
description = "Document indexing."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidMetadataError(BaseHTTPException):
|
||||
error_code = 'invalid_metadata'
|
||||
description = "Invalid metadata."
|
||||
code = 400
|
||||
147
api/controllers/console/datasets/file.py
Normal file
147
api/controllers/console/datasets/file.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import datetime
|
||||
import hashlib
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, marshal_with, fields
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
|
||||
UnsupportedFileTypeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.index.readers.html_parser import HTMLParser
|
||||
from core.index.readers.pdf_parser import PDFParser
|
||||
from extensions.ext_storage import storage
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class FileApi(Resource):
|
||||
file_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'size': fields.Integer,
|
||||
'extension': fields.String,
|
||||
'mime_type': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(file_fields)
|
||||
def post(self):
|
||||
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file_content = file.read()
|
||||
file_size = len(file_content)
|
||||
|
||||
if file_size > FILE_SIZE_LIMIT:
|
||||
message = "({file_size} > {FILE_SIZE_LIMIT})"
|
||||
raise FileTooLargeError(message)
|
||||
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, file_content)
|
||||
|
||||
# save file to db
|
||||
config = current_app.config
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
storage_type=config['STORAGE_TYPE'],
|
||||
key=file_key,
|
||||
name=file.filename,
|
||||
size=file_size,
|
||||
extension=extension,
|
||||
mime_type=file.mimetype,
|
||||
created_by=current_user.id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_content).hexdigest()
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
class FilePreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
key = file_id + request.path
|
||||
cached_response = cache.get(key)
|
||||
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
|
||||
return cached_response['response']
|
||||
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
# extract text from file
|
||||
extension = upload_file.extension
|
||||
if extension not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, filepath)
|
||||
|
||||
if extension == 'pdf':
|
||||
parser = PDFParser({'upload_file': upload_file})
|
||||
text = parser.parse_file(Path(filepath))
|
||||
elif extension in ['html', 'htm']:
|
||||
# Use BeautifulSoup to extract text
|
||||
parser = HTMLParser()
|
||||
text = parser.parse_file(Path(filepath))
|
||||
else:
|
||||
# ['txt', 'markdown', 'md']
|
||||
with open(filepath, "rb") as fp:
|
||||
data = fp.read()
|
||||
text = data.decode(encoding='utf-8').strip() if data else ''
|
||||
|
||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||
return {'content': text}
|
||||
|
||||
|
||||
api.add_resource(FileApi, '/files/upload')
|
||||
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
|
||||
100
api/controllers/console/datasets/hit_testing.py
Normal file
100
api/controllers/console/datasets/hit_testing.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, marshal, fields
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
document_fields = {
|
||||
'id': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'name': fields.String,
|
||||
'doc_type': fields.String,
|
||||
}
|
||||
|
||||
segment_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'document_id': fields.String,
|
||||
'content': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'tokens': fields.Integer,
|
||||
'keywords': fields.List(fields.String),
|
||||
'index_node_id': fields.String,
|
||||
'index_node_hash': fields.String,
|
||||
'hit_count': fields.Integer,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'status': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'indexing_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'document': fields.Nested(document_fields),
|
||||
}
|
||||
|
||||
hit_testing_record_fields = {
|
||||
'segment': fields.Nested(segment_fields),
|
||||
'score': fields.Float,
|
||||
'tsne_position': fields.Raw
|
||||
}
|
||||
|
||||
|
||||
class HitTestingApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# only high quality dataset can be used for hit testing
|
||||
if dataset.indexing_technique != 'high_quality':
|
||||
raise HighQualityDatasetOnlyError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
query = args['query']
|
||||
|
||||
if not query or len(query) > 250:
|
||||
raise ValueError('Query is required and cannot exceed 250 characters')
|
||||
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
account=current_user,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except Exception as e:
|
||||
logging.exception("Hit testing failed.")
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')
|
||||
Reference in New Issue
Block a user