mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-25 02:33:00 +08:00
chore: remove Langchain tools import (#3407)
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
@@ -10,6 +8,7 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
@@ -29,20 +28,15 @@ class DatasetMultiRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="dataset multi retriever and rerank")
|
||||
|
||||
|
||||
class DatasetMultiRetrieverTool(BaseTool):
|
||||
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"""Tool for querying multi dataset."""
|
||||
name: str = "dataset_"
|
||||
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
|
||||
description: str = "dataset multi retriever and rerank. "
|
||||
tenant_id: str
|
||||
dataset_ids: list[str]
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
|
||||
@@ -149,9 +143,6 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list,
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler]):
|
||||
with flask_app.app_context():
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from msal_extensions.persistence import ABC
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
|
||||
|
||||
class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
"""Tool for querying a Dataset."""
|
||||
name: str = "dataset"
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
tenant_id: str
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool.
|
||||
|
||||
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
"""
|
||||
@@ -1,10 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
@@ -24,19 +22,13 @@ class DatasetRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
||||
|
||||
|
||||
class DatasetRetrieverTool(BaseTool):
|
||||
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"""Tool for querying a Dataset."""
|
||||
name: str = "dataset"
|
||||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
@@ -153,7 +145,4 @@ class DatasetRetrieverTool(BaseTool):
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
return str("\n".join(document_context_list))
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
@@ -14,11 +12,12 @@ from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
|
||||
class DatasetRetrieverTool(Tool):
|
||||
langchain_tool: BaseTool
|
||||
retrival_tool: DatasetRetrieverBaseTool
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_tools(tenant_id: str,
|
||||
@@ -43,7 +42,7 @@ class DatasetRetrieverTool(Tool):
|
||||
# Agent only support SINGLE mode
|
||||
original_retriever_mode = retrieve_config.retrieve_strategy
|
||||
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
||||
langchain_tools = feature.to_dataset_retriever_tool(
|
||||
retrival_tools = feature.to_dataset_retriever_tool(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=retrieve_config,
|
||||
@@ -54,17 +53,17 @@ class DatasetRetrieverTool(Tool):
|
||||
# restore retrieve strategy
|
||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
||||
|
||||
# convert langchain tools to Tools
|
||||
# convert retrival tools to Tools
|
||||
tools = []
|
||||
for langchain_tool in langchain_tools:
|
||||
for retrival_tool in retrival_tools:
|
||||
tool = DatasetRetrieverTool(
|
||||
langchain_tool=langchain_tool,
|
||||
identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
||||
retrival_tool=retrival_tool,
|
||||
identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
||||
parameters=[],
|
||||
is_team_authorization=True,
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US='', zh_Hans=''),
|
||||
llm=langchain_tool.description),
|
||||
llm=retrival_tool.description),
|
||||
runtime=DatasetRetrieverTool.Runtime()
|
||||
)
|
||||
|
||||
@@ -96,7 +95,7 @@ class DatasetRetrieverTool(Tool):
|
||||
return self.create_text_message(text='please input query')
|
||||
|
||||
# invoke dataset retriever tool
|
||||
result = self.langchain_tool._run(query=query)
|
||||
result = self.retrival_tool._run(query=query)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user