mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 10:56:52 +08:00
Initial commit
This commit is contained in:
83
api/core/tool/dataset_tool_builder.py
Normal file
83
api/core/tool/dataset_tool_builder.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.callbacks import CallbackManager
|
||||
from llama_index.langchain_helpers.agents import IndexToolConfig
|
||||
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.index.keyword_table_index import KeywordTableIndex
|
||||
from core.index.vector_index import VectorIndex
|
||||
from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
|
||||
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class DatasetToolBuilder:
|
||||
@classmethod
|
||||
def build_dataset_tool(cls, tenant_id: str, dataset_id: str,
|
||||
response_mode: str = "no_synthesizer",
|
||||
callback_handler: Optional[DatasetToolCallbackHandler] = None):
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
index = KeywordTableIndex(dataset=dataset).query_index
|
||||
|
||||
if not index:
|
||||
return None
|
||||
|
||||
query_kwargs = {
|
||||
"mode": "default",
|
||||
"response_mode": response_mode,
|
||||
"query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE,
|
||||
"max_keywords_per_query": 5,
|
||||
# If num_chunks_per_query is too large,
|
||||
# it will slow down the synthesis process due to multiple iterations of refinement.
|
||||
"num_chunks_per_query": 2
|
||||
}
|
||||
else:
|
||||
index = VectorIndex(dataset=dataset).query_index
|
||||
|
||||
if not index:
|
||||
return None
|
||||
|
||||
query_kwargs = {
|
||||
"mode": "default",
|
||||
"response_mode": response_mode,
|
||||
# If top_k is too large,
|
||||
# it will slow down the synthesis process due to multiple iterations of refinement.
|
||||
"similarity_top_k": 2
|
||||
}
|
||||
|
||||
# fulfill description when it is empty
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
index_tool_config = IndexToolConfig(
|
||||
index=index,
|
||||
name=f"dataset-{dataset_id}",
|
||||
description=description,
|
||||
index_query_kwargs=query_kwargs,
|
||||
tool_kwargs={
|
||||
"callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()])
|
||||
},
|
||||
# tool_kwargs={"return_direct": True},
|
||||
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
|
||||
)
|
||||
|
||||
index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset_id)
|
||||
|
||||
return EnhanceLlamaIndexTool.from_tool_config(
|
||||
tool_config=index_tool_config,
|
||||
callback_handler=index_callback_handler
|
||||
)
|
||||
43
api/core/tool/llama_index_tool.py
Normal file
43
api/core/tool/llama_index_tool.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Dict
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from llama_index.indices.base import BaseGPTIndex
|
||||
from llama_index.langchain_helpers.agents import IndexToolConfig
|
||||
from pydantic import Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
|
||||
|
||||
|
||||
class EnhanceLlamaIndexTool(BaseTool):
|
||||
"""Tool for querying a LlamaIndex."""
|
||||
|
||||
# NOTE: name/description still needs to be set
|
||||
index: BaseGPTIndex
|
||||
query_kwargs: Dict = Field(default_factory=dict)
|
||||
return_sources: bool = False
|
||||
callback_handler: IndexToolCallbackHandler
|
||||
|
||||
@classmethod
|
||||
def from_tool_config(cls, tool_config: IndexToolConfig,
|
||||
callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
|
||||
"""Create a tool from a tool config."""
|
||||
return_sources = tool_config.tool_kwargs.pop("return_sources", False)
|
||||
return cls(
|
||||
index=tool_config.index,
|
||||
callback_handler=callback_handler,
|
||||
name=tool_config.name,
|
||||
description=tool_config.description,
|
||||
return_sources=return_sources,
|
||||
query_kwargs=tool_config.index_query_kwargs,
|
||||
**tool_config.tool_kwargs,
|
||||
)
|
||||
|
||||
def _run(self, tool_input: str) -> str:
|
||||
response = self.index.query(tool_input, **self.query_kwargs)
|
||||
self.callback_handler.on_tool_end(response)
|
||||
return str(response)
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
response = await self.index.aquery(tool_input, **self.query_kwargs)
|
||||
self.callback_handler.on_tool_end(response)
|
||||
return str(response)
|
||||
Reference in New Issue
Block a user