Knowledge optimization (#3755)

Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
Jyong
2024-04-24 15:02:29 +08:00
committed by GitHub
parent 3cd8e6f5c6
commit f257f2c396
75 changed files with 2756 additions and 266 deletions

View File

@@ -21,11 +21,12 @@ from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode, AppModelConfig
from models.tools import ApiToolProvider
from services.tag_service import TagService
from services.workflow_service import WorkflowService
class AppService:
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination:
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
"""
Get app list with pagination
:param tenant_id: tenant id
@@ -49,6 +50,14 @@ class AppService:
if 'name' in args and args['name']:
name = args['name'][:30]
filters.append(App.name.ilike(f'%{name}%'))
if 'tag_ids' in args and args['tag_ids']:
target_ids = TagService.get_target_ids_by_tag_ids('app',
tenant_id,
args['tag_ids'])
if target_ids:
filters.append(App.id.in_(target_ids))
else:
return None
app_models = db.paginate(
db.select(App).where(*filters).order_by(App.created_at.desc()),

View File

@@ -38,28 +38,39 @@ from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.feature_service import FeatureModel, FeatureService
from services.tag_service import TagService
from services.vector_service import VectorService
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.retry_document_indexing_task import retry_document_indexing_task
class DatasetService:
@staticmethod
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None):
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None):
if user:
permission_filter = db.or_(Dataset.created_by == user.id,
Dataset.permission == 'all_team_members')
else:
permission_filter = Dataset.permission == 'all_team_members'
datasets = Dataset.query.filter(
query = Dataset.query.filter(
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
.order_by(Dataset.created_at.desc()) \
.paginate(
.order_by(Dataset.created_at.desc())
if search:
query = query.filter(db.and_(Dataset.name.ilike(f'%{search}%')))
if tag_ids:
target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids)
if target_ids:
query = query.filter(db.and_(Dataset.id.in_(target_ids)))
else:
return [], 0
datasets = query.paginate(
page=page,
per_page=per_page,
max_per_page=100,
@@ -165,9 +176,36 @@ class DatasetService:
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
provider=data['embedding_model_provider'],
model_type=ModelType.TEXT_EMBEDDING,
model=data['embedding_model']
)
filtered_data['embedding_model'] = embedding_model.model
filtered_data['embedding_model_provider'] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider,
embedding_model.model
)
filtered_data['collection_binding_id'] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
if data['embedding_model_provider'] != dataset.embedding_model_provider or \
data['embedding_model'] != dataset.embedding_model:
action = 'update'
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data['embedding_model_provider'],
model_type=ModelType.TEXT_EMBEDDING,
model=data['embedding_model']
)
filtered_data['embedding_model'] = embedding_model.model
filtered_data['embedding_model_provider'] = embedding_model.provider
@@ -376,6 +414,15 @@ class DocumentService:
return documents
@staticmethod
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = db.session.query(Document).filter(
Document.dataset_id == dataset_id,
Document.indexing_status == 'error' or Document.indexing_status == 'paused'
).all()
return documents
@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
documents = db.session.query(Document).filter(
@@ -440,6 +487,20 @@ class DocumentService:
# trigger async task
recover_document_indexing_task.delay(document.dataset_id, document.id)
@staticmethod
def retry_document(dataset_id: str, documents: list[Document]):
for document in documents:
# retry document indexing
document.indexing_status = 'waiting'
db.session.add(document)
db.session.commit()
# add retry flag
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
redis_client.setex(retry_indexing_cache_key, 600, 1)
# trigger async task
document_ids = [document.id for document in documents]
retry_document_indexing_task.delay(dataset_id, document_ids)
@staticmethod
def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
@@ -537,6 +598,7 @@ class DocumentService:
db.session.commit()
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
for file_id in upload_file_list:
@@ -553,6 +615,28 @@ class DocumentService:
data_source_info = {
"upload_file_id": file_id,
}
# check duplicate
if document_data.get('duplicate', False):
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type='upload_file',
enabled=True,
name=file_name
).first()
if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = datetime.datetime.utcnow()
document.created_from = created_from
document.doc_form = document_data['doc_form']
document.doc_language = document_data['doc_language']
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = 'waiting'
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
@@ -618,7 +702,10 @@ class DocumentService:
db.session.commit()
# trigger async task
document_indexing_task.delay(dataset.id, document_ids)
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
return documents, batch
@@ -626,7 +713,8 @@ class DocumentService:
def check_documents_upload_quota(count: int, features: FeatureModel):
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
if count > can_upload_size:
raise ValueError(f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
raise ValueError(
f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
@staticmethod
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
@@ -752,7 +840,6 @@ class DocumentService:
db.session.commit()
# trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id)
return document
@staticmethod

161
api/services/tag_service.py Normal file
View File

@@ -0,0 +1,161 @@
import uuid
from flask_login import current_user
from sqlalchemy import func
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
query = db.session.query(
Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count')
).outerjoin(
TagBinding, Tag.id == TagBinding.tag_id
).filter(
Tag.type == tag_type,
Tag.tenant_id == current_tenant_id
)
if keyword:
query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%')))
query = query.group_by(
Tag.id
)
results = query.order_by(Tag.created_at.desc()).all()
return results
@staticmethod
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
tags = db.session.query(Tag).filter(
Tag.id.in_(tag_ids),
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type
).all()
if not tags:
return []
tag_ids = [tag.id for tag in tags]
tag_bindings = db.session.query(
TagBinding.target_id
).filter(
TagBinding.tag_id.in_(tag_ids),
TagBinding.tenant_id == current_tenant_id
).all()
if not tag_bindings:
return []
results = [tag_binding.target_id for tag_binding in tag_bindings]
return results
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
tags = db.session.query(Tag).join(
TagBinding,
Tag.id == TagBinding.tag_id
).filter(
TagBinding.target_id == target_id,
TagBinding.tenant_id == current_tenant_id,
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type
).all()
return tags if tags else []
@staticmethod
def save_tags(args: dict) -> Tag:
tag = Tag(
id=str(uuid.uuid4()),
name=args['name'],
type=args['type'],
created_by=current_user.id,
tenant_id=current_user.current_tenant_id
)
db.session.add(tag)
db.session.commit()
return tag
@staticmethod
def update_tags(args: dict, tag_id: str) -> Tag:
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
tag.name = args['name']
db.session.commit()
return tag
@staticmethod
def get_tag_binding_count(tag_id: str) -> int:
count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
return count
@staticmethod
def delete_tag(tag_id: str):
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
db.session.delete(tag)
# delete tag binding
tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)
db.session.commit()
@staticmethod
def save_tag_binding(args):
# check if target exists
TagService.check_target_exists(args['type'], args['target_id'])
# save tag binding
for tag_id in args['tag_ids']:
tag_binding = db.session.query(TagBinding).filter(
TagBinding.tag_id == tag_id,
TagBinding.target_id == args['target_id']
).first()
if tag_binding:
continue
new_tag_binding = TagBinding(
tag_id=tag_id,
target_id=args['target_id'],
tenant_id=current_user.current_tenant_id,
created_by=current_user.id
)
db.session.add(new_tag_binding)
db.session.commit()
@staticmethod
def delete_tag_binding(args):
# check if target exists
TagService.check_target_exists(args['type'], args['target_id'])
# delete tag binding
tag_bindings = db.session.query(TagBinding).filter(
TagBinding.target_id == args['target_id'],
TagBinding.tag_id == (args['tag_id'])
).first()
if tag_bindings:
db.session.delete(tag_bindings)
db.session.commit()
@staticmethod
def check_target_exists(type: str, target_id: str):
if type == 'knowledge':
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == current_user.current_tenant_id,
Dataset.id == target_id
).first()
if not dataset:
raise NotFound("Dataset not found")
elif type == 'app':
app = db.session.query(App).filter(
App.tenant_id == current_user.current_tenant_id,
App.id == target_id
).first()
if not app:
raise NotFound("App not found")
else:
raise NotFound("Invalid binding type")