refactor: remove unused codes, move core/agent module into dataset retrieval feature (#2614)

This commit is contained in:
takatost
2024-02-28 23:32:47 +08:00
committed by GitHub
parent d44b05a9e5
commit dd961985f0
29 changed files with 41 additions and 2016 deletions

View File

@@ -4,7 +4,7 @@ from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
from core.tools.tool.tool import Tool
@@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool):
@staticmethod
def get_dataset_tools(tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler
) -> list['DatasetRetrieverTool']:
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler
) -> list['DatasetRetrieverTool']:
"""
get dataset tool
"""
@@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool):
)
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
# convert langchain tools to Tools
tools = []
for langchain_tool in langchain_tools:
@@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool):
llm=langchain_tool.description),
runtime=DatasetRetrieverTool.Runtime()
)
tools.append(tool)
return tools
@@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool):
def get_runtime_parameters(self) -> list[ToolParameter]:
return [
ToolParameter(name='query',
label=I18nObject(en_US='', zh_Hans=''),
human_description=I18nObject(en_US='', zh_Hans=''),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description='Query for the dataset to be used to retrieve the dataset.',
required=True,
default=''),
label=I18nObject(en_US='', zh_Hans=''),
human_description=I18nObject(en_US='', zh_Hans=''),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description='Query for the dataset to be used to retrieve the dataset.',
required=True,
default=''),
]
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
@@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool):
query = tool_parameters.get('query', None)
if not query:
return self.create_text_message(text='please input query')
# invoke dataset retriever tool
result = self.langchain_tool._run(query=query)
@@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool):
"""
validate the credentials for dataset retriever tool
"""
pass
pass