mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 18:23:07 +08:00
Feature/mutil embedding model (#908)
Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
@@ -10,13 +10,15 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Document
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
@@ -33,6 +35,9 @@ dataset_detail_fields = {
|
||||
'created_at': TimestampField,
|
||||
'updated_by': fields.String,
|
||||
'updated_at': TimestampField,
|
||||
'embedding_model': fields.String,
|
||||
'embedding_model_provider': fields.String,
|
||||
'embedding_available': fields.Boolean
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
@@ -74,8 +79,22 @@ class DatasetListApi(Resource):
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
current_user.current_tenant_id, current_user)
|
||||
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
|
||||
# if len(valid_model_list) == 0:
|
||||
# raise ProviderNotInitializeError(
|
||||
# f"No Embedding Model available. Please configure a valid provider "
|
||||
# f"in the Settings -> Model Provider.")
|
||||
model_names = [item['model_name'] for item in valid_model_list]
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item['embedding_model'] in model_names:
|
||||
item['embedding_available'] = True
|
||||
else:
|
||||
item['embedding_available'] = False
|
||||
response = {
|
||||
'data': marshal(datasets, dataset_detail_fields),
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
@@ -99,7 +118,6 @@ class DatasetListApi(Resource):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
@@ -233,6 +251,8 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
@@ -250,11 +270,14 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
args['process_rule'], args['doc_form'])
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
@@ -262,11 +285,14 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'])
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response, 200
|
||||
|
||||
@@ -274,6 +274,7 @@ class DatasetDocumentListApi(Resource):
|
||||
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
|
||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
@@ -282,14 +283,19 @@ class DatasetDocumentListApi(Resource):
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
@@ -328,6 +334,7 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@@ -406,11 +413,13 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
|
||||
data_process_rule_dict)
|
||||
data_process_rule_dict, None, dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
return response
|
||||
|
||||
@@ -473,22 +482,27 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
data_process_rule_dict)
|
||||
data_process_rule_dict, None, dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
elif dataset.data_source_type:
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif dataset.data_source_type == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
info_list,
|
||||
data_process_rule_dict)
|
||||
data_process_rule_dict,
|
||||
None, dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response
|
||||
@@ -575,7 +589,8 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
|
||||
if document.is_paused:
|
||||
document.indexing_status = 'paused'
|
||||
return marshal(document, self.document_status_fields)
|
||||
|
||||
|
||||
@@ -832,6 +847,22 @@ class DocumentStatusApi(DocumentResource):
|
||||
|
||||
remove_document_from_index_task.delay(document_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
elif action == "un_archive":
|
||||
if not document.archived:
|
||||
raise InvalidActionError('Document is not archived.')
|
||||
|
||||
document.archived = False
|
||||
document.archived_at = None
|
||||
document.archived_by = None
|
||||
document.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
add_document_to_index_task.delay(document_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
else:
|
||||
raise InvalidActionError()
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import InvalidActionError
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DocumentSegment
|
||||
@@ -17,7 +22,9 @@ from models.dataset import DocumentSegment
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
|
||||
from tasks.remove_segment_from_index_task import remove_segment_from_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
import pandas as pd
|
||||
|
||||
segment_fields = {
|
||||
'id': fields.String,
|
||||
@@ -152,6 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
@@ -197,7 +218,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
# Set cache to prevent indexing the same segment multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
remove_segment_from_index_task.delay(segment.id)
|
||||
disable_segment_from_index_task.delay(segment.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
else:
|
||||
@@ -222,6 +243,19 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
@@ -233,7 +267,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.create_segment(args, document)
|
||||
segment = SegmentService.create_segment(args, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
@@ -245,6 +279,61 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@@ -270,17 +359,88 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document)
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file)
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
if document.doc_form == 'qa_model':
|
||||
data = {'content': row[0], 'answer': row[1]}
|
||||
else:
|
||||
data = {'content': row[0]}
|
||||
result.append(data)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, 'waiting')
|
||||
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
|
||||
current_user.current_tenant_id, current_user.id)
|
||||
except Exception as e:
|
||||
return {'error': str(e)}, 500
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id):
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': cache_result.decode()
|
||||
}, 200
|
||||
|
||||
|
||||
@@ -292,3 +452,6 @@ api.add_resource(DatasetDocumentSegmentAddApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
|
||||
api.add_resource(DatasetDocumentSegmentUpdateApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
|
||||
api.add_resource(DatasetDocumentSegmentBatchImportApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
|
||||
'/datasets/batch_import_status/<uuid:job_id>')
|
||||
|
||||
@@ -11,7 +11,8 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
||||
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@@ -102,6 +103,10 @@ class HitTestingApi(Resource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
|
||||
@@ -10,10 +10,10 @@ from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
class DatesetDocumentStore:
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
user_id: str,
|
||||
document_id: Optional[str] = None,
|
||||
self,
|
||||
dataset: Dataset,
|
||||
user_id: str,
|
||||
document_id: Optional[str] = None,
|
||||
):
|
||||
self._dataset = dataset
|
||||
self._user_id = user_id
|
||||
@@ -59,7 +59,7 @@ class DatesetDocumentStore:
|
||||
return output
|
||||
|
||||
def add_documents(
|
||||
self, docs: Sequence[Document], allow_update: bool = True
|
||||
self, docs: Sequence[Document], allow_update: bool = True
|
||||
) -> None:
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == self._document_id
|
||||
@@ -69,7 +69,9 @@ class DatesetDocumentStore:
|
||||
max_position = 0
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=self._dataset.tenant_id
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
model_provider_name=self._dataset.embedding_model_provider,
|
||||
model_name=self._dataset.embedding_model
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
@@ -123,7 +125,7 @@ class DatesetDocumentStore:
|
||||
return result is not None
|
||||
|
||||
def get_document(
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
) -> Optional[Document]:
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
|
||||
@@ -179,8 +179,8 @@ class LLMGenerator:
|
||||
return rule_config
|
||||
|
||||
@classmethod
|
||||
def generate_qa_document(cls, tenant_id: str, query):
|
||||
prompt = GENERATOR_QA_PROMPT
|
||||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -15,7 +15,9 @@ class IndexBuilder:
|
||||
return None
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
@@ -67,14 +67,6 @@ class IndexingRunner:
|
||||
dataset_document=dataset_document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
# new_documents = []
|
||||
# for document in documents:
|
||||
# response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
|
||||
# document_qa_list = self.format_split_text(response)
|
||||
# for result in document_qa_list:
|
||||
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
|
||||
# new_documents.append(document)
|
||||
# build index
|
||||
self._build_index(
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
@@ -225,14 +217,25 @@ class IndexingRunner:
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None) -> dict:
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
@@ -263,14 +266,13 @@ class IndexingRunner:
|
||||
|
||||
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
|
||||
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
@@ -289,13 +291,26 @@ class IndexingRunner:
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
|
||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# load data from notion
|
||||
tokens = 0
|
||||
@@ -344,14 +359,13 @@ class IndexingRunner:
|
||||
|
||||
tokens += embedding_model.get_num_tokens(document.page_content)
|
||||
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
@@ -458,7 +472,8 @@ class IndexingRunner:
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_form=dataset_document.doc_form
|
||||
document_form=dataset_document.doc_form,
|
||||
document_language=dataset_document.doc_language
|
||||
)
|
||||
|
||||
# save node to document segment
|
||||
@@ -494,7 +509,8 @@ class IndexingRunner:
|
||||
return documents
|
||||
|
||||
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
|
||||
processing_rule: DatasetProcessRule, tenant_id: str,
|
||||
document_form: str, document_language: str) -> List[Document]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
@@ -523,8 +539,9 @@ class IndexingRunner:
|
||||
sub_documents = all_documents[i:i + 10]
|
||||
for doc in sub_documents:
|
||||
document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
|
||||
'flask_app': current_app._get_current_object(), 'tenant_id': tenant_id, 'document_node': doc,
|
||||
'all_qa_documents': all_qa_documents})
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
|
||||
'document_language': document_language})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
@@ -532,14 +549,14 @@ class IndexingRunner:
|
||||
return all_qa_documents
|
||||
return all_documents
|
||||
|
||||
def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents):
|
||||
def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
return
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
|
||||
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
@@ -641,7 +658,9 @@ class IndexingRunner:
|
||||
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# chunk nodes by chunk size
|
||||
@@ -722,6 +741,32 @@ class IndexingRunner:
|
||||
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
|
||||
"""
|
||||
Batch add segments index processing
|
||||
"""
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts(documents, duplicate_check=True)
|
||||
|
||||
# save keyword index
|
||||
index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if index:
|
||||
index.add_texts(documents)
|
||||
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
pass
|
||||
|
||||
@@ -44,13 +44,13 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
)
|
||||
|
||||
GENERATOR_QA_PROMPT = (
|
||||
"Please respond according to the language of the user's input text. If the text is in language [A], you must also reply in language [A].\n"
|
||||
'The user will send a long text. Please think step by step.'
|
||||
'Step 1: Understand and summarize the main content of this text.\n'
|
||||
'Step 2: What key information or concepts are mentioned in this text?\n'
|
||||
'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
|
||||
'Step 4: Generate 20 questions and answers based on these key information and concepts.'
|
||||
'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
|
||||
"Answer in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
||||
"Answer must be the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
||||
)
|
||||
|
||||
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
|
||||
|
||||
@@ -9,6 +9,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
@@ -70,10 +71,17 @@ class DatasetRetrieverTool(BaseTool):
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.k})
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
)
|
||||
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
return ''
|
||||
except ProviderTokenNotInitError:
|
||||
return ''
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""add_qa_document_language
|
||||
|
||||
Revision ID: 2c8af9671032
|
||||
Revises: 8d2d099ceb74
|
||||
Create Date: 2023-08-01 18:57:27.294973
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '2c8af9671032'
|
||||
down_revision = '5022897aaceb'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('doc_language', sa.String(length=255), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.drop_column('doc_language')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,34 @@
|
||||
"""add_dataset_model_name
|
||||
|
||||
Revision ID: e8883b0148c9
|
||||
Revises: 2c8af9671032
|
||||
Create Date: 2023-08-15 20:54:58.936787
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e8883b0148c9'
|
||||
down_revision = '2c8af9671032'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('embedding_model', sa.String(length=255), server_default=sa.text("'text-embedding-ada-002'::character varying"), nullable=False))
|
||||
batch_op.add_column(sa.Column('embedding_model_provider', sa.String(length=255), server_default=sa.text("'openai'::character varying"), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_column('embedding_model_provider')
|
||||
batch_op.drop_column('embedding_model')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -36,6 +36,10 @@ class Dataset(db.Model):
|
||||
updated_by = db.Column(UUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
embedding_model = db.Column(db.String(
|
||||
255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying"))
|
||||
embedding_model_provider = db.Column(db.String(
|
||||
255), nullable=False, server_default=db.text("'openai'::character varying"))
|
||||
|
||||
@property
|
||||
def dataset_keyword_table(self):
|
||||
@@ -209,6 +213,7 @@ class Document(db.Model):
|
||||
doc_metadata = db.Column(db.JSON, nullable=True)
|
||||
doc_form = db.Column(db.String(
|
||||
255), nullable=False, server_default=db.text("'text_model'::character varying"))
|
||||
doc_language = db.Column(db.String(255), nullable=True)
|
||||
|
||||
DATA_SOURCES = ['upload_file', 'notion_import']
|
||||
|
||||
|
||||
@@ -47,4 +47,5 @@ websocket-client~=1.6.1
|
||||
dashscope~=1.5.0
|
||||
huggingface_hub~=0.16.4
|
||||
transformers~=4.31.0
|
||||
stripe~=5.5.0
|
||||
stripe~=5.5.0
|
||||
pandas==1.5.3
|
||||
@@ -9,6 +9,7 @@ from typing import Optional, List
|
||||
from flask import current_app
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask_login import current_user
|
||||
@@ -25,14 +26,16 @@ from services.errors.account import NoPermissionError
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.errors.document import DocumentIndexingError
|
||||
from services.errors.file import FileNotExistsError
|
||||
from services.vector_service import VectorService
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.create_segment_to_index_task import create_segment_to_index_task
|
||||
from tasks.update_segment_index_task import update_segment_index_task
|
||||
from tasks.update_segment_keyword_index_task\
|
||||
import update_segment_keyword_index_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
|
||||
|
||||
class DatasetService:
|
||||
@@ -88,12 +91,16 @@ class DatasetService:
|
||||
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(
|
||||
f'Dataset with name {name} already exists.')
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||
dataset.created_by = account.id
|
||||
dataset.updated_by = account.id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
dataset.embedding_model = embedding_model.name
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
return dataset
|
||||
@@ -372,7 +379,7 @@ class DocumentService:
|
||||
indexing_cache_key = 'document_{}_is_paused'.format(document.id)
|
||||
redis_client.delete(indexing_cache_key)
|
||||
# trigger async task
|
||||
document_indexing_task.delay(document.dataset_id, document.id)
|
||||
recover_document_indexing_task.delay(document.dataset_id, document.id)
|
||||
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
@@ -450,6 +457,7 @@ class DocumentService:
|
||||
document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, file_name, batch)
|
||||
db.session.add(document)
|
||||
@@ -495,20 +503,11 @@ class DocumentService:
|
||||
document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, page['page_name'], batch)
|
||||
# if page['type'] == 'database':
|
||||
# document.splitting_completed_at = datetime.datetime.utcnow()
|
||||
# document.cleaning_completed_at = datetime.datetime.utcnow()
|
||||
# document.parsing_completed_at = datetime.datetime.utcnow()
|
||||
# document.completed_at = datetime.datetime.utcnow()
|
||||
# document.indexing_status = 'completed'
|
||||
# document.word_count = 0
|
||||
# document.tokens = 0
|
||||
# document.indexing_latency = 0
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
# if page['type'] != 'database':
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
@@ -520,15 +519,15 @@ class DocumentService:
|
||||
db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
#document_index_created.send(dataset.id, document_ids=document_ids)
|
||||
document_indexing_task.delay(dataset.id, document_ids)
|
||||
|
||||
return documents, batch
|
||||
|
||||
@staticmethod
|
||||
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
||||
data_source_info: dict, created_from: str, position: int, account: Account, name: str,
|
||||
batch: str):
|
||||
document_language: str, data_source_info: dict, created_from: str, position: int,
|
||||
account: Account,
|
||||
name: str, batch: str):
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
@@ -540,7 +539,8 @@ class DocumentService:
|
||||
name=name,
|
||||
created_from=created_from,
|
||||
created_by=account.id,
|
||||
doc_form=document_form
|
||||
doc_form=document_form,
|
||||
doc_language=document_language
|
||||
)
|
||||
return document
|
||||
|
||||
@@ -654,13 +654,18 @@ class DocumentService:
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if documents_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name='',
|
||||
data_source_type=document_data["data_source"]["type"],
|
||||
indexing_technique=document_data["indexing_technique"],
|
||||
created_by=account.id
|
||||
created_by=account.id,
|
||||
embedding_model=embedding_model.name,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
@@ -870,13 +875,15 @@ class SegmentService:
|
||||
raise ValueError("Answer is required")
|
||||
|
||||
@classmethod
|
||||
def create_segment(cls, args: dict, document: Document):
|
||||
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
||||
content = args['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=document.tenant_id
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
@@ -894,6 +901,9 @@ class SegmentService:
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
status='completed',
|
||||
indexing_at=datetime.datetime.utcnow(),
|
||||
completed_at=datetime.datetime.utcnow(),
|
||||
created_by=current_user.id
|
||||
)
|
||||
if document.doc_form == 'qa_model':
|
||||
@@ -901,49 +911,88 @@ class SegmentService:
|
||||
|
||||
db.session.add(segment_document)
|
||||
db.session.commit()
|
||||
indexing_cache_key = 'segment_{}_indexing'.format(segment_document.id)
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
create_segment_to_index_task.delay(segment_document.id, args['keywords'])
|
||||
return segment_document
|
||||
|
||||
# save vector index
|
||||
try:
|
||||
VectorService.create_segment_vector(args['keywords'], segment_document, dataset)
|
||||
except Exception as e:
|
||||
logging.exception("create segment index failed")
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = datetime.datetime.utcnow()
|
||||
segment_document.status = 'error'
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
|
||||
return segment
|
||||
|
||||
@classmethod
|
||||
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document):
|
||||
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
|
||||
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Segment is indexing, please try again later")
|
||||
content = args['content']
|
||||
if segment.content == content:
|
||||
if document.doc_form == 'qa_model':
|
||||
segment.answer = args['answer']
|
||||
if args['keywords']:
|
||||
segment.keywords = args['keywords']
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
# update segment index task
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
update_segment_keyword_index_task.delay(segment.id)
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
try:
|
||||
content = args['content']
|
||||
if segment.content == content:
|
||||
if document.doc_form == 'qa_model':
|
||||
segment.answer = args['answer']
|
||||
if args['keywords']:
|
||||
segment.keywords = args['keywords']
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
# update segment index task
|
||||
if args['keywords']:
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from keyword index
|
||||
kw_index.delete_by_ids([segment.index_node_id])
|
||||
# save keyword index
|
||||
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=document.tenant_id
|
||||
)
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
segment.content = content
|
||||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
segment.tokens = tokens
|
||||
segment.status = 'updating'
|
||||
segment.updated_by = current_user.id
|
||||
segment.updated_at = datetime.datetime.utcnow()
|
||||
if document.doc_form == 'qa_model':
|
||||
segment.answer = args['answer']
|
||||
db.session.add(segment)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
segment.content = content
|
||||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
segment.tokens = tokens
|
||||
segment.status = 'completed'
|
||||
segment.indexing_at = datetime.datetime.utcnow()
|
||||
segment.completed_at = datetime.datetime.utcnow()
|
||||
segment.updated_by = current_user.id
|
||||
segment.updated_at = datetime.datetime.utcnow()
|
||||
if document.doc_form == 'qa_model':
|
||||
segment.answer = args['answer']
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
# update segment vector index
|
||||
VectorService.create_segment_vector(args['keywords'], segment, dataset)
|
||||
except Exception as e:
|
||||
logging.exception("update segment index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = datetime.datetime.utcnow()
|
||||
segment.status = 'error'
|
||||
segment.error = str(e)
|
||||
db.session.commit()
|
||||
# update segment index task
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
update_segment_index_task.delay(segment.id, args['keywords'])
|
||||
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
|
||||
return segment
|
||||
|
||||
@classmethod
|
||||
def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
|
||||
indexing_cache_key = 'segment_{}_delete_indexing'.format(segment.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Segment is deleting.")
|
||||
# send delete segment index task
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
# enabled segment need to delete index
|
||||
if segment.enabled:
|
||||
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
@@ -29,7 +29,9 @@ class HitTestingService:
|
||||
}
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
69
api/services/vector_service.py
Normal file
69
api/services/vector_service.py
Normal file
@@ -0,0 +1,69 @@
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class VectorService:
|
||||
|
||||
@classmethod
|
||||
def create_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset):
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts([document], duplicate_check=True)
|
||||
|
||||
# save keyword index
|
||||
index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if index:
|
||||
if keywords and len(keywords) > 0:
|
||||
index.create_segment_keywords(segment.index_node_id, keywords)
|
||||
else:
|
||||
index.add_texts([document])
|
||||
|
||||
@classmethod
|
||||
def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset):
|
||||
# update segment index task
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
vector_index.delete_by_ids([segment.index_node_id])
|
||||
|
||||
# delete from keyword index
|
||||
kw_index.delete_by_ids([segment.index_node_id])
|
||||
|
||||
# add new index
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
# save vector index
|
||||
if vector_index:
|
||||
vector_index.add_texts([document], duplicate_check=True)
|
||||
|
||||
# save keyword index
|
||||
if keywords and len(keywords) > 0:
|
||||
kw_index.create_segment_keywords(segment.index_node_id, keywords)
|
||||
else:
|
||||
kw_index.add_texts([document])
|
||||
95
api/tasks/batch_create_segment_to_index_task.py
Normal file
95
api/tasks/batch_create_segment_to_index_task.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, List
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import func
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
from models.dataset import DocumentSegment, Dataset, Document
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: str, document_id: str,
|
||||
tenant_id: str, user_id: str):
|
||||
"""
|
||||
Async batch create segment to index
|
||||
:param job_id:
|
||||
:param content:
|
||||
:param dataset_id:
|
||||
:param document_id:
|
||||
:param tenant_id:
|
||||
:param user_id:
|
||||
|
||||
Usage: batch_create_segment_to_index_task.delay(segment_id)
|
||||
"""
|
||||
logging.info(click.style('Start batch create segment jobId: {}'.format(job_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not exist.')
|
||||
|
||||
dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
|
||||
if not dataset_document:
|
||||
raise ValueError('Document not exist.')
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
|
||||
raise ValueError('Document is not available.')
|
||||
document_segments = []
|
||||
for segment in content:
|
||||
content = segment['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == dataset_document.id
|
||||
).scalar()
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=datetime.datetime.utcnow(),
|
||||
status='completed',
|
||||
completed_at=datetime.datetime.utcnow()
|
||||
)
|
||||
if dataset_document.doc_form == 'qa_model':
|
||||
segment_document.answer = segment['answer']
|
||||
db.session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
# add index to db
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.batch_add_segments(document_segments, dataset)
|
||||
db.session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, 'completed')
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style('Segment batch created job: {} latency: {}'.format(job_id, end_at - start_at), fg='green'))
|
||||
except Exception as e:
|
||||
logging.exception("Segments batch created index failed:{}".format(str(e)))
|
||||
redis_client.setex(indexing_cache_key, 600, 'error')
|
||||
58
api/tasks/delete_segment_from_index_task.py
Normal file
58
api/tasks/delete_segment_from_index_task.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DocumentSegment, Dataset, Document
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str):
|
||||
"""
|
||||
Async Remove segment from index
|
||||
:param segment_id:
|
||||
:param index_node_id:
|
||||
:param dataset_id:
|
||||
:param document_id:
|
||||
|
||||
Usage: delete_segment_from_index_task.delay(segment_id)
|
||||
"""
|
||||
logging.info(click.style('Start delete segment from index: {}'.format(segment_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
indexing_cache_key = 'segment_{}_delete_indexing'.format(segment_id)
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logging.info(click.style('Segment {} has no dataset, pass.'.format(segment_id), fg='cyan'))
|
||||
return
|
||||
|
||||
dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
|
||||
if not dataset_document:
|
||||
logging.info(click.style('Segment {} has no document, pass.'.format(segment_id), fg='cyan'))
|
||||
return
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
|
||||
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment_id), fg='cyan'))
|
||||
return
|
||||
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
vector_index.delete_by_ids([index_node_id])
|
||||
|
||||
# delete from keyword index
|
||||
kw_index.delete_by_ids([index_node_id])
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style('Segment deleted from index: {} latency: {}'.format(segment_id, end_at - start_at), fg='green'))
|
||||
except Exception:
|
||||
logging.exception("delete segment from index failed")
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
@@ -12,14 +12,14 @@ from models.dataset import DocumentSegment
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def remove_segment_from_index_task(segment_id: str):
|
||||
def disable_segment_from_index_task(segment_id: str):
|
||||
"""
|
||||
Async Remove segment from index
|
||||
Async disable segment from index
|
||||
:param segment_id:
|
||||
|
||||
Usage: remove_segment_from_index.delay(segment_id)
|
||||
Usage: disable_segment_from_index_task.delay(segment_id)
|
||||
"""
|
||||
logging.info(click.style('Start remove segment from index: {}'.format(segment_id), fg='green'))
|
||||
logging.info(click.style('Start disable segment from index: {}'.format(segment_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
|
||||
@@ -52,17 +52,6 @@ def update_segment_keyword_index_task(segment_id: str):
|
||||
# delete from keyword index
|
||||
kw_index.delete_by_ids([segment.index_node_id])
|
||||
|
||||
# add new index
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
# save keyword index
|
||||
index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if index:
|
||||
|
||||
Reference in New Issue
Block a user