Initial commit

This commit is contained in:
John Wang
2023-05-15 08:51:32 +08:00
commit db896255d6
744 changed files with 56028 additions and 0 deletions

View File

@@ -0,0 +1,281 @@
# -*- coding:utf-8 -*-
from flask import request
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
import services
from controllers.console import api
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from libs.helper import TimestampField
from extensions.ext_database import db
from models.model import UploadFile
from services.dataset_service import DatasetService
dataset_detail_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'provider': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'app_count': fields.Integer,
'document_count': fields.Integer,
'word_count': fields.Integer,
'created_by': fields.String,
'created_at': TimestampField,
'updated_by': fields.String,
'updated_at': TimestampField,
}
dataset_query_detail_fields = {
"id": fields.String,
"content": fields.String,
"source": fields.String,
"source_app_id": fields.String,
"created_by_role": fields.String,
"created_by": fields.String,
"created_at": TimestampField
}
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError('Name must be between 1 to 40 characters.')
return name
def _validate_description_length(description):
if len(description) > 200:
raise ValueError('Description cannot exceed 200 characters.')
return description
class DatasetListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
ids = request.args.getlist('ids')
provider = request.args.get('provider', default="vendor")
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(page, limit, provider,
current_user.current_tenant_id, current_user)
response = {
'data': marshal(datasets, dataset_detail_fields),
'has_more': len(datasets) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='type is required. Name must be between 1 to 40 characters.',
type=_validate_name)
parser.add_argument('indexing_technique', type=str, location='json',
choices=('high_quality', 'economy'),
help='Invalid indexing technique.')
args = parser.parse_args()
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
name=args['name'],
indexing_technique=args['indexing_technique'],
account=current_user
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return marshal(dataset, dataset_detail_fields), 201
class DatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(
dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
return marshal(dataset, dataset_detail_fields), 200
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id):
dataset_id_str = str(dataset_id)
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False,
help='type is required. Name must be between 1 to 40 characters.',
type=_validate_name)
parser.add_argument('description',
location='json', store_missing=False,
type=_validate_description_length)
parser.add_argument('indexing_technique', type=str, location='json',
choices=('high_quality', 'economy'),
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.')
args = parser.parse_args()
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
dataset = DatasetService.update_dataset(
dataset_id_str, args, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
return marshal(dataset, dataset_detail_fields), 200
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id):
dataset_id_str = str(dataset_id)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
if DatasetService.delete_dataset(dataset_id_str, current_user):
return {'result': 'success'}, 204
else:
raise NotFound("Dataset not found.")
class DatasetQueryApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
dataset_queries, total = DatasetService.get_dataset_queries(
dataset_id=dataset.id,
page=page,
per_page=limit
)
response = {
'data': marshal(dataset_queries, dataset_query_detail_fields),
'has_more': len(dataset_queries) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
class DatasetIndexingEstimateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
segment_rule = request.get_json()
file_detail = db.session.query(UploadFile).filter(
UploadFile.tenant_id == current_user.current_tenant_id,
UploadFile.id == segment_rule["file_id"]
).first()
if file_detail is None:
raise NotFound("File not found.")
indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(file_detail, segment_rule['process_rule'])
return response, 200
class DatasetRelatedAppListApi(Resource):
app_detail_kernel_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
}
related_app_list = {
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
'total': fields.Integer,
}
@setup_required
@login_required
@account_initialization_required
@marshal_with(related_app_list)
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
app_dataset_joins = DatasetService.get_related_apps(dataset.id)
related_apps = []
for app_dataset_join in app_dataset_joins:
app_model = app_dataset_join.app
if app_model:
related_apps.append(app_model)
return {
'data': related_apps,
'total': len(related_apps)
}, 200
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/file-indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')

View File

@@ -0,0 +1,682 @@
# -*- coding:utf-8 -*-
import random
from datetime import datetime
from flask import request
from flask_login import login_required, current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import desc, asc
from werkzeug.exceptions import NotFound, Forbidden
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DocumentAlreadyFinishedError, InvalidActionError, DocumentIndexingError, \
InvalidMetadataError, ArchivedDocumentImmutableError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.llm.error import ProviderTokenNotInitError
from extensions.ext_redis import redis_client
from libs.helper import TimestampField
from extensions.ext_database import db
from models.dataset import DatasetProcessRule, Dataset
from models.dataset import Document, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DocumentService, DatasetService
from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task
dataset_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
document_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
}
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
if document.tenant_id != current_user.current_tenant_id:
raise Forbidden('No permission.')
return document
class GetProcessRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
req_data = request.args
document_id = req_data.get('document_id')
if document_id:
# get the latest process rule
document = Document.query.get_or_404(document_id)
dataset = DatasetService.get_dataset(document.dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# get the latest process rule
dataset_process_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.dataset_id == document.dataset_id). \
order_by(DatasetProcessRule.created_at.desc()). \
limit(1). \
one_or_none()
mode = dataset_process_rule.mode
rules = dataset_process_rule.rules_dict
else:
mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules']
return {
'mode': mode,
'rules': rules
}
class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id = str(dataset_id)
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
search = request.args.get('search', default=None, type=str)
sort = request.args.get('sort', default='-created_at', type=str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = Document.query.filter_by(
dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
if search:
search = f'%{search}%'
query = query.filter(Document.name.like(search))
if sort.startswith('-'):
sort_logic = desc
sort = sort[1:]
else:
sort_logic = asc
if sort == 'hit_count':
sub_query = db.select(DocumentSegment.document_id,
db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) \
.group_by(DocumentSegment.document_id) \
.subquery()
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \
.order_by(sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)))
elif sort == 'created_at':
query = query.order_by(sort_logic(Document.created_at))
else:
query = query.order_by(desc(Document.created_at))
paginated_documents = query.paginate(
page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
response = {
'data': marshal(documents, document_fields),
'has_more': len(documents) == limit,
'limit': limit,
'total': paginated_documents.total,
'page': page
}
return response
@setup_required
@login_required
@account_initialization_required
@marshal_with(document_fields)
def post(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
parser = reqparse.RequestParser()
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
location='json')
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']:
raise ValueError('indexing_technique is required.')
# validate args
DocumentService.document_create_args_validate(args)
try:
document = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
return document
class DatasetInitApi(Resource):
dataset_and_document_fields = {
'dataset': fields.Nested(dataset_fields),
'document': fields.Nested(document_fields)
}
@setup_required
@login_required
@account_initialization_required
@marshal_with(dataset_and_document_fields)
def post(self):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True,
nullable=False, location='json')
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args()
# validate args
DocumentService.document_create_args_validate(args)
try:
dataset, document = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id,
document_data=args,
account=current_user
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
response = {
'dataset': dataset,
'document': document
}
return response
class DocumentIndexingEstimateApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
if document.indexing_status in ['completed', 'error']:
raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict()
response = {
"tokens": 0,
"total_price": 0,
"currency": "USD",
"total_segments": 0,
"preview": []
}
if document.data_source_type == 'upload_file':
data_source_info = document.data_source_info_dict
if data_source_info and 'upload_file_id' in data_source_info:
file_id = data_source_info['upload_file_id']
file = db.session.query(UploadFile).filter(
UploadFile.tenant_id == document.tenant_id,
UploadFile.id == file_id
).first()
# raise error if file not found
if not file:
raise NotFound('File not found.')
indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(file, data_process_rule_dict)
return response
class DocumentIndexingStatusApi(DocumentResource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
completed_segments = DocumentSegment.query \
.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id)) \
.count()
total_segments = DocumentSegment.query \
.filter_by(document_id=str(document_id)) \
.count()
document.completed_segments = completed_segments
document.total_segments = total_segments
return marshal(document, self.document_status_fields)
class DocumentDetailApi(DocumentResource):
METADATA_CHOICES = {'all', 'only', 'without'}
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
metadata = request.args.get('metadata', 'all')
if metadata not in self.METADATA_CHOICES:
raise InvalidMetadataError(f'Invalid metadata value: {metadata}')
if metadata == 'only':
response = {
'id': document.id,
'doc_type': document.doc_type,
'doc_metadata': document.doc_metadata
}
elif metadata == 'without':
process_rules = DatasetService.get_process_rules(dataset_id)
data_source_info = document.data_source_detail_dict
response = {
'id': document.id,
'position': document.position,
'data_source_type': document.data_source_type,
'data_source_info': data_source_info,
'dataset_process_rule_id': document.dataset_process_rule_id,
'dataset_process_rule': process_rules,
'name': document.name,
'created_from': document.created_from,
'created_by': document.created_by,
'created_at': document.created_at.timestamp(),
'tokens': document.tokens,
'indexing_status': document.indexing_status,
'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None,
'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None,
'indexing_latency': document.indexing_latency,
'error': document.error,
'enabled': document.enabled,
'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None,
'disabled_by': document.disabled_by,
'archived': document.archived,
'segment_count': document.segment_count,
'average_segment_length': document.average_segment_length,
'hit_count': document.hit_count,
'display_status': document.display_status
}
else:
process_rules = DatasetService.get_process_rules(dataset_id)
data_source_info = document.data_source_detail_dict_()
response = {
'id': document.id,
'position': document.position,
'data_source_type': document.data_source_type,
'data_source_info': data_source_info,
'dataset_process_rule_id': document.dataset_process_rule_id,
'dataset_process_rule': process_rules,
'name': document.name,
'created_from': document.created_from,
'created_by': document.created_by,
'created_at': document.created_at.timestamp(),
'tokens': document.tokens,
'indexing_status': document.indexing_status,
'completed_at': int(document.completed_at.timestamp())if document.completed_at else None,
'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None,
'indexing_latency': document.indexing_latency,
'error': document.error,
'enabled': document.enabled,
'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None,
'disabled_by': document.disabled_by,
'archived': document.archived,
'doc_type': document.doc_type,
'doc_metadata': document.doc_metadata,
'segment_count': document.segment_count,
'average_segment_length': document.average_segment_length,
'hit_count': document.hit_count,
'display_status': document.display_status
}
return response, 200
class DocumentProcessingApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
if action == "pause":
if document.indexing_status != "indexing":
raise InvalidActionError('Document not in indexing state.')
document.paused_by = current_user.id
document.paused_at = datetime.utcnow()
document.is_paused = True
db.session.commit()
elif action == "resume":
if document.indexing_status not in ["paused", "error"]:
raise InvalidActionError('Document not in paused or error state.')
document.paused_by = None
document.paused_at = None
document.is_paused = False
db.session.commit()
else:
raise InvalidActionError()
return {'result': 'success'}, 200
class DocumentDeleteApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
try:
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError('Cannot delete document during indexing.')
return {'result': 'success'}, 204
class DocumentMetadataApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def put(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
req_data = request.get_json()
doc_type = req_data.get('doc_type')
doc_metadata = req_data.get('doc_metadata')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
if doc_type is None or doc_metadata is None:
raise ValueError('Both doc_type and doc_metadata must be provided.')
if doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA:
raise ValueError('Invalid doc_type.')
if not isinstance(doc_metadata, dict):
raise ValueError('doc_metadata must be a dictionary.')
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
document.doc_metadata = {}
for key, value_type in metadata_schema.items():
value = doc_metadata.get(key)
if value is not None and isinstance(value, value_type):
document.doc_metadata[key] = value
document.doc_type = doc_type
document.updated_at = datetime.utcnow()
db.session.commit()
return {'result': 'success', 'message': 'Document metadata updated.'}, 200
class DocumentStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
indexing_cache_key = 'document_{}_indexing'.format(document.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later")
if action == "enable":
if document.enabled:
raise InvalidActionError('Document already enabled.')
document.enabled = True
document.disabled_at = None
document.disabled_by = None
document.updated_at = datetime.utcnow()
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id)
return {'result': 'success'}, 200
elif action == "disable":
if not document.enabled:
raise InvalidActionError('Document already disabled.')
document.enabled = False
document.disabled_at = datetime.utcnow()
document.disabled_by = current_user.id
document.updated_at = datetime.utcnow()
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
remove_document_from_index_task.delay(document_id)
return {'result': 'success'}, 200
elif action == "archive":
if document.archived:
raise InvalidActionError('Document already archived.')
document.archived = True
document.archived_at = datetime.utcnow()
document.archived_by = current_user.id
document.updated_at = datetime.utcnow()
db.session.commit()
if document.enabled:
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
remove_document_from_index_task.delay(document_id)
return {'result': 'success'}, 200
else:
raise InvalidActionError()
class DocumentPauseApi(DocumentResource):
def patch(self, dataset_id, document_id):
"""pause document."""
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
raise NotFound("Document Not Exists.")
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
try:
# pause document
DocumentService.pause_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError('Cannot pause completed document.')
return {'result': 'success'}, 204
class DocumentRecoverApi(DocumentResource):
def patch(self, dataset_id, document_id):
"""recover document."""
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
raise NotFound("Document Not Exists.")
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
try:
# pause document
DocumentService.recover_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError('Document is not in paused status.')
return {'result': 'success'}, 204
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi,
'/datasets/<uuid:dataset_id>/documents')
api.add_resource(DatasetInitApi,
'/datasets/init')
api.add_resource(DocumentIndexingEstimateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate')
api.add_resource(DocumentIndexingStatusApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status')
api.add_resource(DocumentDetailApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
api.add_resource(DocumentProcessingApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>')
api.add_resource(DocumentDeleteApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
api.add_resource(DocumentMetadataApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata')
api.add_resource(DocumentStatusApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')

View File

@@ -0,0 +1,203 @@
# -*- coding:utf-8 -*-
from datetime import datetime
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, fields, marshal
from werkzeug.exceptions import NotFound, Forbidden
import services
from controllers.console import api
from controllers.console.datasets.error import InvalidActionError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService
from tasks.add_segment_to_index_task import add_segment_to_index_task
from tasks.remove_segment_from_index_task import remove_segment_from_index_task
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField
}
segment_list_response = {
'data': fields.List(fields.Nested(segment_fields)),
'has_more': fields.Boolean,
'limit': fields.Integer
}
class DatasetDocumentSegmentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=str, default=None, location='args')
parser.add_argument('limit', type=int, default=20, location='args')
parser.add_argument('status', type=str,
action='append', default=[], location='args')
parser.add_argument('hit_count_gte', type=int,
default=None, location='args')
parser.add_argument('enabled', type=str, default='all', location='args')
args = parser.parse_args()
last_id = args['last_id']
limit = min(args['limit'], 100)
status_list = args['status']
hit_count_gte = args['hit_count_gte']
query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
)
if last_id is not None:
last_segment = DocumentSegment.query.get(str(last_id))
if last_segment:
query = query.filter(
DocumentSegment.position > last_segment.position)
else:
return {'data': [], 'has_more': False, 'limit': limit}, 200
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
if hit_count_gte is not None:
query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
if args['enabled'].lower() != 'all':
if args['enabled'].lower() == 'true':
query = query.filter(DocumentSegment.enabled == True)
elif args['enabled'].lower() == 'false':
query = query.filter(DocumentSegment.enabled == False)
total = query.count()
segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()
has_more = False
if len(segments) > limit:
has_more = True
segments = segments[:-1]
return {
'data': marshal(segments, segment_fields),
'has_more': has_more,
'limit': limit,
'total': total
}, 200
class DatasetDocumentSegmentApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, segment_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
cache_result = redis_client.get(document_indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later")
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Segment is being indexed, please try again later")
if action == "enable":
if segment.enabled:
raise InvalidActionError("Segment is already enabled.")
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_segment_to_index_task.delay(segment.id)
return {'result': 'success'}, 200
elif action == "disable":
if not segment.enabled:
raise InvalidActionError("Segment is already disabled.")
segment.enabled = False
segment.disabled_at = datetime.utcnow()
segment.disabled_by = current_user.id
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
remove_segment_from_index_task.delay(segment.id)
return {'result': 'success'}, 200
else:
raise InvalidActionError()
api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')

View File

@@ -0,0 +1,73 @@
from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
description = "No file uploaded."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = 'file_too_large'
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = 'unsupported_file_type'
description = "File type not allowed."
code = 415
class HighQualityDatasetOnlyError(BaseHTTPException):
error_code = 'high_quality_dataset_only'
description = "High quality dataset only."
code = 400
class DatasetNotInitializedError(BaseHTTPException):
error_code = 'dataset_not_initialized'
description = "Dataset not initialized."
code = 400
class ArchivedDocumentImmutableError(BaseHTTPException):
error_code = 'archived_document_immutable'
description = "Cannot process an archived document."
code = 403
class DatasetNameDuplicateError(BaseHTTPException):
error_code = 'dataset_name_duplicate'
description = "Dataset name already exists."
code = 409
class InvalidActionError(BaseHTTPException):
error_code = 'invalid_action'
description = "Invalid action."
code = 400
class DocumentAlreadyFinishedError(BaseHTTPException):
error_code = 'document_already_finished'
description = "Document already finished."
code = 400
class DocumentIndexingError(BaseHTTPException):
error_code = 'document_indexing'
description = "Document indexing."
code = 400
class InvalidMetadataError(BaseHTTPException):
error_code = 'invalid_metadata'
description = "Invalid metadata."
code = 400

View File

@@ -0,0 +1,147 @@
import datetime
import hashlib
import tempfile
import time
import uuid
from pathlib import Path
from cachetools import TTLCache
from flask import request, current_app
from flask_login import login_required, current_user
from flask_restful import Resource, marshal_with, fields
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
UnsupportedFileTypeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.index.readers.html_parser import HTMLParser
from core.index.readers.pdf_parser import PDFParser
from extensions.ext_storage import storage
from libs.helper import TimestampField
from extensions.ext_database import db
from models.model import UploadFile
cache = TTLCache(maxsize=None, ttl=30)
FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm']
PREVIEW_WORDS_LIMIT = 3000
class FileApi(Resource):
file_fields = {
'id': fields.String,
'name': fields.String,
'size': fields.Integer,
'extension': fields.String,
'mime_type': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
@setup_required
@login_required
@account_initialization_required
@marshal_with(file_fields)
def post(self):
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file_content = file.read()
file_size = len(file_content)
if file_size > FILE_SIZE_LIMIT:
message = "({file_size} > {FILE_SIZE_LIMIT})"
raise FileTooLargeError(message)
extension = file.filename.split('.')[-1]
if extension not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
# save file to storage
storage.save(file_key, file_content)
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=file.filename,
size=file_size,
extension=extension,
mime_type=file.mimetype,
created_by=current_user.id,
created_at=datetime.datetime.utcnow(),
used=False,
hash=hashlib.sha3_256(file_content).hexdigest()
)
db.session.add(upload_file)
db.session.commit()
return upload_file, 201
class FilePreviewApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, file_id):
file_id = str(file_id)
key = file_id + request.path
cached_response = cache.get(key)
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
return cached_response['response']
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
if not upload_file:
raise NotFound("File not found")
# extract text from file
extension = upload_file.extension
if extension not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, filepath)
if extension == 'pdf':
parser = PDFParser({'upload_file': upload_file})
text = parser.parse_file(Path(filepath))
elif extension in ['html', 'htm']:
# Use BeautifulSoup to extract text
parser = HTMLParser()
text = parser.parse_file(Path(filepath))
else:
# ['txt', 'markdown', 'md']
with open(filepath, "rb") as fp:
data = fp.read()
text = data.decode(encoding='utf-8').strip() if data else ''
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return {'content': text}
api.add_resource(FileApi, '/files/upload')
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')

View File

@@ -0,0 +1,100 @@
import logging
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, marshal, fields
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
import services
from controllers.console import api
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.helper import TimestampField
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
document_fields = {
'id': fields.String,
'data_source_type': fields.String,
'name': fields.String,
'doc_type': fields.String,
}
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'document': fields.Nested(document_fields),
}
hit_testing_record_fields = {
'segment': fields.Nested(segment_fields),
'score': fields.Float,
'tsne_position': fields.Raw
}
class HitTestingApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# only high quality dataset can be used for hit testing
if dataset.indexing_technique != 'high_quality':
raise HighQualityDatasetOnlyError()
parser = reqparse.RequestParser()
parser.add_argument('query', type=str, location='json')
args = parser.parse_args()
query = args['query']
if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters')
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=query,
account=current_user,
limit=10,
)
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))
api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')