Model Runtime (#1858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
takatost
2024-01-02 23:42:00 +08:00
committed by GitHub
parent e91dd28a76
commit d069c668f8
807 changed files with 171310 additions and 23806 deletions

View File

@@ -1,4 +1,3 @@
import json
import threading
from typing import Type, Optional, List
@@ -7,12 +6,12 @@ from langchain.tools import BaseTool
from pydantic import Field, BaseModel
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.conversation_message_task import ConversationMessageTask
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rerank.rerank import RerankRunner
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, Document
from services.retrieval_service import RetrievalService
@@ -43,7 +42,7 @@ class DatasetRetrieverTool(BaseTool):
dataset_id: str
top_k: int = 2
score_threshold: Optional[float] = None
conversation_message_task: ConversationMessageTask
hit_callbacks: List[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
@@ -70,9 +69,12 @@ class DatasetRetrieverTool(BaseTool):
if not dataset:
return ''
for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
@@ -85,16 +87,16 @@ class DatasetRetrieverTool(BaseTool):
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
return str("\n".join([document.page_content for document in documents]))
else:
# get embedding model instance
try:
embedding_model = ModelFactory.get_embedding_model(
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
return ''
except ProviderTokenNotInitError:
except InvokeAuthorizationError:
return ''
embeddings = CacheEmbedding(embedding_model)
@@ -140,21 +142,34 @@ class DatasetRetrieverTool(BaseTool):
for thread in threads:
thread.join()
# hybrid search: rerank after all documents have been searched
if retrieval_model['search_method'] == 'hybrid_search':
hybrid_rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
model_name=retrieval_model['reranking_model']['reranking_model_name']
# get rerank model instance
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=retrieval_model['reranking_model']['reranking_provider_name'],
model_type=ModelType.RERANK,
model=retrieval_model['reranking_model']['reranking_model_name']
)
except InvokeAuthorizationError:
return ''
rerank_runner = RerankRunner(rerank_model_instance)
documents = rerank_runner.run(
query=query,
documents=documents,
score_threshold=retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enabled'] else None,
top_n=self.top_k
)
documents = hybrid_rerank.rerank(query, documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
self.top_k)
else:
documents = []
hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
hit_callback.on_tool_end(documents)
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
@@ -212,7 +227,9 @@ class DatasetRetrieverTool(BaseTool):
source['content'] = segment.content
context_list.append(source)
resource_number += 1
hit_callback.return_retriever_resource_info(context_list)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))