mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 19:06:51 +08:00
Feature/mutil embedding model (#908)
Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from typing import Optional, List
|
||||
from flask import current_app
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask_login import current_user
|
||||
@@ -25,14 +26,16 @@ from services.errors.account import NoPermissionError
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.errors.document import DocumentIndexingError
|
||||
from services.errors.file import FileNotExistsError
|
||||
from services.vector_service import VectorService
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.create_segment_to_index_task import create_segment_to_index_task
|
||||
from tasks.update_segment_index_task import update_segment_index_task
|
||||
from tasks.update_segment_keyword_index_task\
|
||||
import update_segment_keyword_index_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
|
||||
|
||||
class DatasetService:
|
||||
@@ -88,12 +91,16 @@ class DatasetService:
|
||||
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(
|
||||
f'Dataset with name {name} already exists.')
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||
dataset.created_by = account.id
|
||||
dataset.updated_by = account.id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
dataset.embedding_model = embedding_model.name
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
return dataset
|
||||
@@ -372,7 +379,7 @@ class DocumentService:
|
||||
indexing_cache_key = 'document_{}_is_paused'.format(document.id)
|
||||
redis_client.delete(indexing_cache_key)
|
||||
# trigger async task
|
||||
document_indexing_task.delay(document.dataset_id, document.id)
|
||||
recover_document_indexing_task.delay(document.dataset_id, document.id)
|
||||
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
@@ -450,6 +457,7 @@ class DocumentService:
|
||||
document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, file_name, batch)
|
||||
db.session.add(document)
|
||||
@@ -495,20 +503,11 @@ class DocumentService:
|
||||
document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, page['page_name'], batch)
|
||||
# if page['type'] == 'database':
|
||||
# document.splitting_completed_at = datetime.datetime.utcnow()
|
||||
# document.cleaning_completed_at = datetime.datetime.utcnow()
|
||||
# document.parsing_completed_at = datetime.datetime.utcnow()
|
||||
# document.completed_at = datetime.datetime.utcnow()
|
||||
# document.indexing_status = 'completed'
|
||||
# document.word_count = 0
|
||||
# document.tokens = 0
|
||||
# document.indexing_latency = 0
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
# if page['type'] != 'database':
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
@@ -520,15 +519,15 @@ class DocumentService:
|
||||
db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
#document_index_created.send(dataset.id, document_ids=document_ids)
|
||||
document_indexing_task.delay(dataset.id, document_ids)
|
||||
|
||||
return documents, batch
|
||||
|
||||
@staticmethod
|
||||
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
||||
data_source_info: dict, created_from: str, position: int, account: Account, name: str,
|
||||
batch: str):
|
||||
document_language: str, data_source_info: dict, created_from: str, position: int,
|
||||
account: Account,
|
||||
name: str, batch: str):
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
@@ -540,7 +539,8 @@ class DocumentService:
|
||||
name=name,
|
||||
created_from=created_from,
|
||||
created_by=account.id,
|
||||
doc_form=document_form
|
||||
doc_form=document_form,
|
||||
doc_language=document_language
|
||||
)
|
||||
return document
|
||||
|
||||
@@ -654,13 +654,18 @@ class DocumentService:
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if documents_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name='',
|
||||
data_source_type=document_data["data_source"]["type"],
|
||||
indexing_technique=document_data["indexing_technique"],
|
||||
created_by=account.id
|
||||
created_by=account.id,
|
||||
embedding_model=embedding_model.name,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
@@ -870,13 +875,15 @@ class SegmentService:
|
||||
raise ValueError("Answer is required")
|
||||
|
||||
@classmethod
|
||||
def create_segment(cls, args: dict, document: Document):
|
||||
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
||||
content = args['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=document.tenant_id
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
@@ -894,6 +901,9 @@ class SegmentService:
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
status='completed',
|
||||
indexing_at=datetime.datetime.utcnow(),
|
||||
completed_at=datetime.datetime.utcnow(),
|
||||
created_by=current_user.id
|
||||
)
|
||||
if document.doc_form == 'qa_model':
|
||||
@@ -901,49 +911,88 @@ class SegmentService:
|
||||
|
||||
db.session.add(segment_document)
|
||||
db.session.commit()
|
||||
indexing_cache_key = 'segment_{}_indexing'.format(segment_document.id)
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
create_segment_to_index_task.delay(segment_document.id, args['keywords'])
|
||||
return segment_document
|
||||
|
||||
# save vector index
|
||||
try:
|
||||
VectorService.create_segment_vector(args['keywords'], segment_document, dataset)
|
||||
except Exception as e:
|
||||
logging.exception("create segment index failed")
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = datetime.datetime.utcnow()
|
||||
segment_document.status = 'error'
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
|
||||
return segment
|
||||
|
||||
@classmethod
|
||||
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document):
|
||||
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
|
||||
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Segment is indexing, please try again later")
|
||||
content = args['content']
|
||||
if segment.content == content:
|
||||
if document.doc_form == 'qa_model':
|
||||
segment.answer = args['answer']
|
||||
if args['keywords']:
|
||||
segment.keywords = args['keywords']
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
# update segment index task
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
update_segment_keyword_index_task.delay(segment.id)
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
try:
|
||||
content = args['content']
|
||||
if segment.content == content:
|
||||
if document.doc_form == 'qa_model':
|
||||
segment.answer = args['answer']
|
||||
if args['keywords']:
|
||||
segment.keywords = args['keywords']
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
# update segment index task
|
||||
if args['keywords']:
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from keyword index
|
||||
kw_index.delete_by_ids([segment.index_node_id])
|
||||
# save keyword index
|
||||
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=document.tenant_id
|
||||
)
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
segment.content = content
|
||||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
segment.tokens = tokens
|
||||
segment.status = 'updating'
|
||||
segment.updated_by = current_user.id
|
||||
segment.updated_at = datetime.datetime.utcnow()
|
||||
if document.doc_form == 'qa_model':
|
||||
segment.answer = args['answer']
|
||||
db.session.add(segment)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
segment.content = content
|
||||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
segment.tokens = tokens
|
||||
segment.status = 'completed'
|
||||
segment.indexing_at = datetime.datetime.utcnow()
|
||||
segment.completed_at = datetime.datetime.utcnow()
|
||||
segment.updated_by = current_user.id
|
||||
segment.updated_at = datetime.datetime.utcnow()
|
||||
if document.doc_form == 'qa_model':
|
||||
segment.answer = args['answer']
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
# update segment vector index
|
||||
VectorService.create_segment_vector(args['keywords'], segment, dataset)
|
||||
except Exception as e:
|
||||
logging.exception("update segment index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = datetime.datetime.utcnow()
|
||||
segment.status = 'error'
|
||||
segment.error = str(e)
|
||||
db.session.commit()
|
||||
# update segment index task
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
update_segment_index_task.delay(segment.id, args['keywords'])
|
||||
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
|
||||
return segment
|
||||
|
||||
@classmethod
|
||||
def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
|
||||
indexing_cache_key = 'segment_{}_delete_indexing'.format(segment.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Segment is deleting.")
|
||||
# send delete segment index task
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
# enabled segment need to delete index
|
||||
if segment.enabled:
|
||||
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
@@ -29,7 +29,9 @@ class HitTestingService:
|
||||
}
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
69
api/services/vector_service.py
Normal file
69
api/services/vector_service.py
Normal file
@@ -0,0 +1,69 @@
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class VectorService:
|
||||
|
||||
@classmethod
|
||||
def create_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset):
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts([document], duplicate_check=True)
|
||||
|
||||
# save keyword index
|
||||
index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if index:
|
||||
if keywords and len(keywords) > 0:
|
||||
index.create_segment_keywords(segment.index_node_id, keywords)
|
||||
else:
|
||||
index.add_texts([document])
|
||||
|
||||
@classmethod
|
||||
def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset):
|
||||
# update segment index task
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
vector_index.delete_by_ids([segment.index_node_id])
|
||||
|
||||
# delete from keyword index
|
||||
kw_index.delete_by_ids([segment.index_node_id])
|
||||
|
||||
# add new index
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
# save vector index
|
||||
if vector_index:
|
||||
vector_index.add_texts([document], duplicate_check=True)
|
||||
|
||||
# save keyword index
|
||||
if keywords and len(keywords) > 0:
|
||||
kw_index.create_segment_keywords(segment.index_node_id, keywords)
|
||||
else:
|
||||
kw_index.add_texts([document])
|
||||
Reference in New Issue
Block a user