FEAT: NEW WORKFLOW ENGINE (#3160)

Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
takatost
2024-04-08 18:51:46 +08:00
committed by GitHub
parent 2fb9850af5
commit 7753ba2d37
1161 changed files with 103836 additions and 10327 deletions

View File

@@ -9,7 +9,7 @@ import pandas as pd
from flask import Flask, current_app
from werkzeug.datastructures import FileStorage
from core.generator.llm_generator import LLMGenerator
from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector

View File

View File

View File

@@ -0,0 +1,59 @@
import time
from collections.abc import Mapping
from typing import Any, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
class FakeLLM(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
streaming: bool = False
"""Whether to stream the results or not."""
response: str
@property
def _llm_type(self) -> str:
return "fake-chat-model"
def _call(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
return self.response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"response": self.response}
def get_num_tokens(self, text: str) -> int:
return 0
def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
if self.streaming:
for token in output_str:
if run_manager:
run_manager.on_llm_new_token(token)
time.sleep(0.01)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
llm_output = {"token_usage": {
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
}}
return ChatResult(generations=[generation], llm_output=llm_output)

View File

@@ -0,0 +1,46 @@
from typing import Any, Optional
from langchain import LLMChain as LCLLMChain
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance
from core.rag.retrieval.agent.fake_llm import FakeLLM
class LLMChain(LCLLMChain):
model_config: ModelConfigWithCredentialsEntity
"""The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="")
parameters: dict[str, Any] = {}
def generate(
self,
input_list: list[dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
messages = prompts[0].to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
stream=False,
stop=stop,
model_parameters=self.parameters
)
generations = [
[Generation(text=result.message.content)]
]
return LLMResult(generations=generations)

View File

@@ -0,0 +1,179 @@
from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage
from langchain.tools import BaseTool
from pydantic import root_validator
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.rag.retrieval.agent.fake_llm import FakeLLM
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_config: ModelConfigWithCredentialsEntity
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
# for item in rst_json:
# output += f'{item["content"]}\n'
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
try:
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
else:
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
raise e
def real_plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
ai_message = AIMessage(
content=result.message.content or "",
additional_kwargs={
'function_call': {
'id': result.message.tool_calls[0].id,
**result.message.tool_calls[0].function.dict()
} if result.message.tool_calls else None
}
)
agent_decision = _parse_ai_message(ai_message)
return agent_decision
async def aplan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
raise NotImplementedError()
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigWithCredentialsEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_config=model_config,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)

View File

@@ -0,0 +1,29 @@
import json
import re
from typing import Union
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser
from langchain.agents.structured_chat.output_parser import logger
from langchain.schema import AgentAction, AgentFinish, OutputParserException
class StructuredChatOutputParser(LCStructuredChatOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
try:
action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
if action_match is not None:
response = json.loads(action_match.group(2).strip(), strict=False)
if isinstance(response, list):
# gpt turbo frequently ignores the directive to emit a single action
logger.warning("Got multiple action responses: %s", response)
response = response[0]
if response["action"] == "Final Answer":
return AgentFinish({"output": response["action_input"]}, text)
else:
return AgentAction(
response["action"], response.get("action_input", {}), text
)
else:
return AgentFinish({"output": text}, text)
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}")

View File

@@ -0,0 +1,259 @@
import re
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.rag.retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
dataset_tools: Sequence[BaseTool]
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.dataset_tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools))
rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
raise e
try:
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
elif isinstance(tool_inputs, str):
agent_decision.tool_input = kwargs['input']
else:
agent_decision.return_values['output'] = ''
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
unique_tool_names = set(tool.name for tool in tools)
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def create_completion_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {input}
Thought: {agent_scratchpad}
"""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
agent_scratchpad += action.log
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_config.mode == "chat":
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
else:
return agent_scratchpad
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigWithCredentialsEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_config.mode == "chat":
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
else:
prompt = cls.create_completion_prompt(
tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=input_variables
)
llm_chain = LLMChain(
model_config=model_config,
prompt=prompt,
callback_manager=callback_manager,
parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
dataset_tools=tools,
**kwargs,
)

View File

@@ -0,0 +1,117 @@
import logging
from typing import Optional, Union
from langchain.agents import AgentExecutor as LCAgentExecutor
from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent
from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.message_entities import prompt_messages_to_lc_messages
from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError
from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_config: ModelConfigWithCredentialsEntity
tools: list[BaseTool]
summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None
memory: Optional[TokenBufferMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
max_execution_time: Optional[float] = None
early_stopping_method: str = "generate"
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
class AgentExecuteResult(BaseModel):
strategy: PlanningStrategy
output: Optional[str]
configuration: AgentConfiguration
class AgentExecutor:
def __init__(self, configuration: AgentConfiguration):
self.configuration = configuration
self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
verbose=True
)
else:
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
return agent
def should_use_agent(self, query: str) -> bool:
return self.agent.should_use_agent(query)
def run(self, query: str) -> AgentExecuteResult:
moderation_result = moderation.check_moderation(
self.configuration.model_config,
query
)
if moderation_result:
return AgentExecuteResult(
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
strategy=self.configuration.strategy,
configuration=self.configuration
)
agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.configuration.tools,
max_iterations=self.configuration.max_iterations,
max_execution_time=self.configuration.max_execution_time,
early_stopping_method=self.configuration.early_stopping_method,
callbacks=self.configuration.callbacks
)
try:
output = agent_executor.run(input=query)
except InvokeError as ex:
raise ex
except Exception as ex:
logging.exception("agent_executor run failed")
output = None
return AgentExecuteResult(
output=output,
strategy=self.configuration.strategy,
configuration=self.configuration
)

View File

@@ -0,0 +1,182 @@
from typing import Optional, cast
from langchain.tools import BaseTool
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset
class DatasetRetrieval:
def retrieve(self, tenant_id: str,
model_config: ModelConfigWithCredentialsEntity,
config: DatasetEntity,
query: str,
invoke_from: InvokeFrom,
show_retrieve_source: bool,
hit_callback: DatasetIndexToolCallbackHandler,
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
"""
Retrieve dataset.
:param tenant_id: tenant id
:param model_config: model config
:param config: dataset config
:param query: query
:param invoke_from: invoke from
:param show_retrieve_source: show retrieve source
:param hit_callback: hit callback
:param memory: memory
:return:
"""
dataset_ids = config.dataset_ids
retrieve_config = config.retrieve_config
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model,
credentials=model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
dataset_retriever_tools = self.to_dataset_retriever_tool(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
return_resource=show_retrieve_source,
invoke_from=invoke_from,
hit_callback=hit_callback
)
if len(dataset_retriever_tools) == 0:
return None
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
model_config=model_config,
tools=dataset_retriever_tools,
memory=memory,
max_iterations=10,
max_execution_time=400.0,
early_stopping_method="generate"
)
agent_executor = AgentExecutor(agent_configuration)
should_use_agent = agent_executor.should_use_agent(query)
if not should_use_agent:
return None
result = agent_executor.run(query)
return result.output
def to_dataset_retriever_tool(self, tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler) \
-> Optional[list[BaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tenant_id: tenant id
:param dataset_ids: dataset ids
:param retrieve_config: retrieve config
:param return_resource: return resource
:param invoke_from: invoke from
:param hit_callback: hit callback
"""
tools = []
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
# pass if dataset is not available
if not dataset:
continue
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
continue
available_datasets.append(dataset)
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# get retrieval model config
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
for dataset in available_datasets:
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get score threshold
score_threshold = None
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source()
)
tools.append(tool)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 2,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
)
tools.append(tool)
return tools