mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 18:23:07 +08:00
Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -7,11 +7,12 @@ from langchain.tools import BaseTool
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment, Document
|
||||
from services.retrieval_service import RetrievalService
|
||||
@@ -43,9 +44,9 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
score_threshold: Optional[float] = None
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
conversation_message_task: ConversationMessageTask
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
hit_callbacks: List[DatasetIndexToolCallbackHandler] = []
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset_ids: List[str], tenant_id: str, **kwargs):
|
||||
@@ -64,22 +65,28 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'all_documents': all_documents
|
||||
'all_documents': all_documents,
|
||||
'hit_callbacks': self.hit_callbacks
|
||||
})
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
rerank = ModelFactory.get_reranking_model(
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_provider_name=self.reranking_provider_name,
|
||||
model_name=self.reranking_model_name
|
||||
provider=self.reranking_provider_name,
|
||||
model_type=ModelType.RERANK,
|
||||
model=self.reranking_model_name
|
||||
)
|
||||
all_documents = rerank.rerank(query, all_documents, self.score_threshold, self.top_k)
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
|
||||
hit_callback.on_tool_end(all_documents)
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(all_documents)
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if 'score' in item.metadata and item.metadata['score']:
|
||||
@@ -139,14 +146,17 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
source['content'] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: List):
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: List,
|
||||
hit_callbacks: List[DatasetIndexToolCallbackHandler]):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
@@ -155,6 +165,10 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
for hit_callback in hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
@@ -173,10 +187,12 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
else:
|
||||
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
return []
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import threading
|
||||
from typing import Type, Optional, List
|
||||
|
||||
@@ -7,12 +6,12 @@ from langchain.tools import BaseTool
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment, Document
|
||||
from services.retrieval_service import RetrievalService
|
||||
@@ -43,7 +42,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||
dataset_id: str
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
conversation_message_task: ConversationMessageTask
|
||||
hit_callbacks: List[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
|
||||
@@ -70,9 +69,12 @@ class DatasetRetrieverTool(BaseTool):
|
||||
|
||||
if not dataset:
|
||||
return ''
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
kw_table_index = KeywordTableIndex(
|
||||
@@ -85,16 +87,16 @@ class DatasetRetrieverTool(BaseTool):
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
|
||||
# get embedding model instance
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
return ''
|
||||
except ProviderTokenNotInitError:
|
||||
except InvokeAuthorizationError:
|
||||
return ''
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
@@ -140,21 +142,34 @@ class DatasetRetrieverTool(BaseTool):
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# hybrid search: rerank after all documents have been searched
|
||||
if retrieval_model['search_method'] == 'hybrid_search':
|
||||
hybrid_rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_name=retrieval_model['reranking_model']['reranking_model_name']
|
||||
# get rerank model instance
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=retrieval_model['reranking_model']['reranking_model_name']
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return ''
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
documents = rerank_runner.run(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=retrieval_model['score_threshold'] if retrieval_model[
|
||||
'score_threshold_enabled'] else None,
|
||||
top_n=self.top_k
|
||||
)
|
||||
documents = hybrid_rerank.rerank(query, documents,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
self.top_k)
|
||||
else:
|
||||
documents = []
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
|
||||
hit_callback.on_tool_end(documents)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
@@ -212,7 +227,9 @@ class DatasetRetrieverTool(BaseTool):
|
||||
source['content'] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import subprocess
|
||||
import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from typing import Type
|
||||
from typing import Type, Any
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup, NavigableString, Comment, CData
|
||||
@@ -23,7 +23,7 @@ from regex import regex
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.data_loader import file_extractor
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
@@ -67,7 +67,8 @@ class WebReaderTool(BaseTool):
|
||||
summary_chunk_overlap: int = 0
|
||||
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
|
||||
continue_reading: bool = True
|
||||
model_instance: BaseLLM = None
|
||||
model_config: ModelConfigEntity
|
||||
model_parameters: dict[str, Any]
|
||||
|
||||
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
|
||||
try:
|
||||
@@ -80,7 +81,7 @@ class WebReaderTool(BaseTool):
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
|
||||
if summary and self.model_instance:
|
||||
if summary:
|
||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=self.summary_chunk_tokens,
|
||||
chunk_overlap=self.summary_chunk_overlap,
|
||||
@@ -117,12 +118,14 @@ class WebReaderTool(BaseTool):
|
||||
|
||||
def get_summary_chain(self) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(
|
||||
model_instance=self.model_instance,
|
||||
prompt=refine_prompts.PROMPT
|
||||
model_config=self.model_config,
|
||||
prompt=refine_prompts.PROMPT,
|
||||
parameters=self.model_parameters
|
||||
)
|
||||
refine_chain = LLMChain(
|
||||
model_instance=self.model_instance,
|
||||
prompt=refine_prompts.REFINE_PROMPT
|
||||
model_config=self.model_config,
|
||||
prompt=refine_prompts.REFINE_PROMPT,
|
||||
parameters=self.model_parameters
|
||||
)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
|
||||
Reference in New Issue
Block a user