mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-16 14:26:52 +08:00
Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import threading
|
||||
from typing import Type, Optional, List
|
||||
|
||||
@@ -7,12 +6,12 @@ from langchain.tools import BaseTool
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment, Document
|
||||
from services.retrieval_service import RetrievalService
|
||||
@@ -43,7 +42,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||
dataset_id: str
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
conversation_message_task: ConversationMessageTask
|
||||
hit_callbacks: List[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
|
||||
@@ -70,9 +69,12 @@ class DatasetRetrieverTool(BaseTool):
|
||||
|
||||
if not dataset:
|
||||
return ''
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
kw_table_index = KeywordTableIndex(
|
||||
@@ -85,16 +87,16 @@ class DatasetRetrieverTool(BaseTool):
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
|
||||
# get embedding model instance
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
return ''
|
||||
except ProviderTokenNotInitError:
|
||||
except InvokeAuthorizationError:
|
||||
return ''
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
@@ -140,21 +142,34 @@ class DatasetRetrieverTool(BaseTool):
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# hybrid search: rerank after all documents have been searched
|
||||
if retrieval_model['search_method'] == 'hybrid_search':
|
||||
hybrid_rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_name=retrieval_model['reranking_model']['reranking_model_name']
|
||||
# get rerank model instance
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=retrieval_model['reranking_model']['reranking_model_name']
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return ''
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
documents = rerank_runner.run(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=retrieval_model['score_threshold'] if retrieval_model[
|
||||
'score_threshold_enabled'] else None,
|
||||
top_n=self.top_k
|
||||
)
|
||||
documents = hybrid_rerank.rerank(query, documents,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
self.top_k)
|
||||
else:
|
||||
documents = []
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
|
||||
hit_callback.on_tool_end(documents)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
@@ -212,7 +227,9 @@ class DatasetRetrieverTool(BaseTool):
|
||||
source['content'] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user