feat: upgrade langchain (#430)

Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
John Wang
2023-06-25 16:49:14 +08:00
committed by GitHub
parent 1dee5de9b4
commit 3241e4015b
91 changed files with 2703 additions and 3153 deletions

View File

@@ -0,0 +1,87 @@
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
class DatasetTool(BaseTool):
"""Tool for querying a Dataset."""
dataset: Dataset
k: int = 2
def _run(self, tool_input: str) -> str:
if self.dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
dataset=self.dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=5
)
)
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
tool_input,
search_type='similarity',
search_kwargs={
'k': self.k
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = await vector_index.asearch(
tool_input,
search_type='similarity',
search_kwargs={
'k': 10
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))

View File

@@ -1,73 +0,0 @@
from typing import Optional
from langchain.callbacks import CallbackManager
from llama_index.langchain_helpers.agents import IndexToolConfig
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
from models.dataset import Dataset
class DatasetToolBuilder:
@classmethod
def build_dataset_tool(cls, dataset: Dataset,
response_mode: str = "no_synthesizer",
callback_handler: Optional[DatasetToolCallbackHandler] = None):
if dataset.indexing_technique == "economy":
# use keyword table query
index = KeywordTableIndex(dataset=dataset).query_index
if not index:
return None
query_kwargs = {
"mode": "default",
"response_mode": response_mode,
"query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE,
"max_keywords_per_query": 5,
# If num_chunks_per_query is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"num_chunks_per_query": 2
}
else:
index = VectorIndex(dataset=dataset).query_index
if not index:
return None
query_kwargs = {
"mode": "default",
"response_mode": response_mode,
# If top_k is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"similarity_top_k": 2
}
# fulfill description when it is empty
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
index_tool_config = IndexToolConfig(
index=index,
name=f"dataset-{dataset.id}",
description=description,
index_query_kwargs=query_kwargs,
tool_kwargs={
"callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()])
},
# tool_kwargs={"return_direct": True},
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
)
index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id)
return EnhanceLlamaIndexTool.from_tool_config(
tool_config=index_tool_config,
callback_handler=index_callback_handler
)

View File

@@ -1,43 +0,0 @@
from typing import Dict
from langchain.tools import BaseTool
from llama_index.indices.base import BaseGPTIndex
from llama_index.langchain_helpers.agents import IndexToolConfig
from pydantic import Field
from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
class EnhanceLlamaIndexTool(BaseTool):
"""Tool for querying a LlamaIndex."""
# NOTE: name/description still needs to be set
index: BaseGPTIndex
query_kwargs: Dict = Field(default_factory=dict)
return_sources: bool = False
callback_handler: IndexToolCallbackHandler
@classmethod
def from_tool_config(cls, tool_config: IndexToolConfig,
callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
"""Create a tool from a tool config."""
return_sources = tool_config.tool_kwargs.pop("return_sources", False)
return cls(
index=tool_config.index,
callback_handler=callback_handler,
name=tool_config.name,
description=tool_config.description,
return_sources=return_sources,
query_kwargs=tool_config.index_query_kwargs,
**tool_config.tool_kwargs,
)
def _run(self, tool_input: str) -> str:
response = self.index.query(tool_input, **self.query_kwargs)
self.callback_handler.on_tool_end(response)
return str(response)
async def _arun(self, tool_input: str) -> str:
response = await self.index.aquery(tool_input, **self.query_kwargs)
self.callback_handler.on_tool_end(response)
return str(response)