mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-15 22:06:52 +08:00
feat: support multi datasets router chain mode (#231)
This commit is contained in:
@@ -1,18 +1,18 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from langchain.callbacks import SharedCallbackManager
|
||||
from langchain.callbacks import SharedCallbackManager, CallbackManager
|
||||
from langchain.chains import SequentialChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
from core.agent.agent_builder import AgentBuilder
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.chain_builder import ChainBuilder
|
||||
from core.constant import llm_constant
|
||||
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.tool.dataset_tool_builder import DatasetToolBuilder
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class MainChainBuilder:
|
||||
@@ -31,8 +31,7 @@ class MainChainBuilder:
|
||||
tenant_id=tenant_id,
|
||||
agent_mode=agent_mode,
|
||||
memory=memory,
|
||||
dataset_tool_callback_handler=DatasetToolCallbackHandler(conversation_message_task),
|
||||
agent_loop_gather_callback_handler=chain_callback_handler.agent_loop_gather_callback_handler
|
||||
conversation_message_task=conversation_message_task
|
||||
)
|
||||
chains += tool_chains
|
||||
|
||||
@@ -59,15 +58,15 @@ class MainChainBuilder:
|
||||
|
||||
@classmethod
|
||||
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
||||
dataset_tool_callback_handler: DatasetToolCallbackHandler,
|
||||
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
# agent mode
|
||||
chains = []
|
||||
if agent_mode and agent_mode.get('enabled'):
|
||||
tools = agent_mode.get('tools', [])
|
||||
|
||||
pre_fixed_chains = []
|
||||
agent_tools = []
|
||||
# agent_tools = []
|
||||
datasets = []
|
||||
for tool in tools:
|
||||
tool_type = list(tool.keys())[0]
|
||||
tool_config = list(tool.values())[0]
|
||||
@@ -76,34 +75,27 @@ class MainChainBuilder:
|
||||
if chain:
|
||||
pre_fixed_chains.append(chain)
|
||||
elif tool_type == "dataset":
|
||||
dataset_tool = DatasetToolBuilder.build_dataset_tool(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=tool_config.get("id"),
|
||||
response_mode='no_synthesizer', # "compact"
|
||||
callback_handler=dataset_tool_callback_handler
|
||||
)
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == tool_config.get("id")
|
||||
).first()
|
||||
|
||||
if dataset_tool:
|
||||
agent_tools.append(dataset_tool)
|
||||
if dataset:
|
||||
datasets.append(dataset)
|
||||
|
||||
# add pre-fixed chains
|
||||
chains += pre_fixed_chains
|
||||
|
||||
if len(agent_tools) == 1:
|
||||
if len(datasets) > 0:
|
||||
# tool to chain
|
||||
tool_chain = ChainBuilder.to_tool_chain(tool=agent_tools[0], output_key='tool_output')
|
||||
chains.append(tool_chain)
|
||||
elif len(agent_tools) > 1:
|
||||
# build agent config
|
||||
agent_chain = AgentBuilder.to_agent_chain(
|
||||
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
|
||||
tenant_id=tenant_id,
|
||||
tools=agent_tools,
|
||||
memory=memory,
|
||||
dataset_tool_callback_handler=dataset_tool_callback_handler,
|
||||
agent_loop_gather_callback_handler=agent_loop_gather_callback_handler
|
||||
datasets=datasets,
|
||||
conversation_message_task=conversation_message_task,
|
||||
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
|
||||
)
|
||||
|
||||
chains.append(agent_chain)
|
||||
chains.append(multi_dataset_router_chain)
|
||||
|
||||
final_output_key = cls.get_chains_output_key(chains)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user