Feature/mutil embedding model (#908)

Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
Jyong
2023-08-18 17:37:31 +08:00
committed by GitHub
parent 4420281d96
commit db7156dafd
54 changed files with 1704 additions and 278 deletions

View File

@@ -9,6 +9,7 @@ from typing import Optional, List
from flask import current_app
from sqlalchemy import func
from core.index.index import IndexBuilder
from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client
from flask_login import current_user
@@ -25,14 +26,16 @@ from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.vector_service import VectorService
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.create_segment_to_index_task import create_segment_to_index_task
from tasks.update_segment_index_task import update_segment_index_task
from tasks.update_segment_keyword_index_task\
import update_segment_keyword_index_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
class DatasetService:
@@ -88,12 +91,16 @@ class DatasetService:
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(
f'Dataset with name {name} already exists.')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
dataset = Dataset(name=name, indexing_technique=indexing_technique)
# dataset = Dataset(name=name, provider=provider, config=config)
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset.embedding_model = embedding_model.name
db.session.add(dataset)
db.session.commit()
return dataset
@@ -372,7 +379,7 @@ class DocumentService:
indexing_cache_key = 'document_{}_is_paused'.format(document.id)
redis_client.delete(indexing_cache_key)
# trigger async task
document_indexing_task.delay(document.dataset_id, document.id)
recover_document_indexing_task.delay(document.dataset_id, document.id)
@staticmethod
def get_documents_position(dataset_id):
@@ -450,6 +457,7 @@ class DocumentService:
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, file_name, batch)
db.session.add(document)
@@ -495,20 +503,11 @@ class DocumentService:
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, page['page_name'], batch)
# if page['type'] == 'database':
# document.splitting_completed_at = datetime.datetime.utcnow()
# document.cleaning_completed_at = datetime.datetime.utcnow()
# document.parsing_completed_at = datetime.datetime.utcnow()
# document.completed_at = datetime.datetime.utcnow()
# document.indexing_status = 'completed'
# document.word_count = 0
# document.tokens = 0
# document.indexing_latency = 0
db.session.add(document)
db.session.flush()
# if page['type'] != 'database':
document_ids.append(document.id)
documents.append(document)
position += 1
@@ -520,15 +519,15 @@ class DocumentService:
db.session.commit()
# trigger async task
#document_index_created.send(dataset.id, document_ids=document_ids)
document_indexing_task.delay(dataset.id, document_ids)
return documents, batch
@staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
data_source_info: dict, created_from: str, position: int, account: Account, name: str,
batch: str):
document_language: str, data_source_info: dict, created_from: str, position: int,
account: Account,
name: str, batch: str):
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
@@ -540,7 +539,8 @@ class DocumentService:
name=name,
created_from=created_from,
created_by=account.id,
doc_form=document_form
doc_form=document_form,
doc_language=document_language
)
return document
@@ -654,13 +654,18 @@ class DocumentService:
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if documents_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
# save dataset
dataset = Dataset(
tenant_id=tenant_id,
name='',
data_source_type=document_data["data_source"]["type"],
indexing_technique=document_data["indexing_technique"],
created_by=account.id
created_by=account.id,
embedding_model=embedding_model.name,
embedding_model_provider=embedding_model.model_provider.provider_name
)
db.session.add(dataset)
@@ -870,13 +875,15 @@ class SegmentService:
raise ValueError("Answer is required")
@classmethod
def create_segment(cls, args: dict, document: Document):
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
content = args['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
embedding_model = ModelFactory.get_embedding_model(
tenant_id=document.tenant_id
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# calc embedding use tokens
@@ -894,6 +901,9 @@ class SegmentService:
content=content,
word_count=len(content),
tokens=tokens,
status='completed',
indexing_at=datetime.datetime.utcnow(),
completed_at=datetime.datetime.utcnow(),
created_by=current_user.id
)
if document.doc_form == 'qa_model':
@@ -901,49 +911,88 @@ class SegmentService:
db.session.add(segment_document)
db.session.commit()
indexing_cache_key = 'segment_{}_indexing'.format(segment_document.id)
redis_client.setex(indexing_cache_key, 600, 1)
create_segment_to_index_task.delay(segment_document.id, args['keywords'])
return segment_document
# save vector index
try:
VectorService.create_segment_vector(args['keywords'], segment_document, dataset)
except Exception as e:
logging.exception("create segment index failed")
segment_document.enabled = False
segment_document.disabled_at = datetime.datetime.utcnow()
segment_document.status = 'error'
segment_document.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
return segment
@classmethod
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document):
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise ValueError("Segment is indexing, please try again later")
content = args['content']
if segment.content == content:
if document.doc_form == 'qa_model':
segment.answer = args['answer']
if args['keywords']:
segment.keywords = args['keywords']
db.session.add(segment)
db.session.commit()
# update segment index task
redis_client.setex(indexing_cache_key, 600, 1)
update_segment_keyword_index_task.delay(segment.id)
else:
segment_hash = helper.generate_text_hash(content)
try:
content = args['content']
if segment.content == content:
if document.doc_form == 'qa_model':
segment.answer = args['answer']
if args['keywords']:
segment.keywords = args['keywords']
db.session.add(segment)
db.session.commit()
# update segment index task
if args['keywords']:
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from keyword index
kw_index.delete_by_ids([segment.index_node_id])
# save keyword index
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
else:
segment_hash = helper.generate_text_hash(content)
embedding_model = ModelFactory.get_embedding_model(
tenant_id=document.tenant_id
)
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
segment.tokens = tokens
segment.status = 'updating'
segment.updated_by = current_user.id
segment.updated_at = datetime.datetime.utcnow()
if document.doc_form == 'qa_model':
segment.answer = args['answer']
db.session.add(segment)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)
segment.tokens = tokens
segment.status = 'completed'
segment.indexing_at = datetime.datetime.utcnow()
segment.completed_at = datetime.datetime.utcnow()
segment.updated_by = current_user.id
segment.updated_at = datetime.datetime.utcnow()
if document.doc_form == 'qa_model':
segment.answer = args['answer']
db.session.add(segment)
db.session.commit()
# update segment vector index
VectorService.create_segment_vector(args['keywords'], segment, dataset)
except Exception as e:
logging.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.utcnow()
segment.status = 'error'
segment.error = str(e)
db.session.commit()
# update segment index task
redis_client.setex(indexing_cache_key, 600, 1)
update_segment_index_task.delay(segment.id, args['keywords'])
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
return segment
@classmethod
def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
indexing_cache_key = 'segment_{}_delete_indexing'.format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise ValueError("Segment is deleting.")
# send delete segment index task
redis_client.setex(indexing_cache_key, 600, 1)
# enabled segment need to delete index
if segment.enabled:
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
db.session.delete(segment)
db.session.commit()

View File

@@ -29,7 +29,9 @@ class HitTestingService:
}
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)

View File

@@ -0,0 +1,69 @@
from typing import Optional, List
from langchain.schema import Document
from core.index.index import IndexBuilder
from models.dataset import Dataset, DocumentSegment
class VectorService:
@classmethod
def create_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset):
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts([document], duplicate_check=True)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
if keywords and len(keywords) > 0:
index.create_segment_keywords(segment.index_node_id, keywords)
else:
index.add_texts([document])
@classmethod
def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset):
# update segment index task
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete_by_ids([segment.index_node_id])
# delete from keyword index
kw_index.delete_by_ids([segment.index_node_id])
# add new index
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
# save vector index
if vector_index:
vector_index.add_texts([document], duplicate_check=True)
# save keyword index
if keywords and len(keywords) > 0:
kw_index.create_segment_keywords(segment.index_node_id, keywords)
else:
kw_index.add_texts([document])