feat: universal chat in explore (#649)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
John Wang
2023-07-27 13:08:57 +08:00
committed by GitHub
parent 94b54b7ca9
commit 4fdb37771a
64 changed files with 3186 additions and 858 deletions

View File

@@ -1,10 +1,12 @@
import json
import logging
import time
from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
@@ -20,6 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self.conversation_message_task = conversation_message_task
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
self.current_chain = None
@property
@@ -29,6 +32,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
def clear_agent_loops(self) -> None:
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
@property
def always_verbose(self) -> bool:
@@ -61,9 +65,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
# kwargs={}
if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end'
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
self._current_loop.completion = response.generations[0][0].text
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
completion_generation = response.generations[0][0]
if isinstance(completion_generation, ChatGeneration):
completion_message = completion_generation.message
if 'function_call' in completion_message.additional_kwargs:
self._current_loop.completion \
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
else:
self._current_loop.completion = response.generations[0][0].text
else:
self._current_loop.completion = completion_generation.text
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@@ -71,6 +87,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
logging.error(error)
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
def on_tool_start(
self,
@@ -89,15 +106,29 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
) -> Any:
"""Run on agent action."""
tool = action.tool
tool_input = action.tool_input
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
tool_input = json.dumps({"query": action.tool_input}
if isinstance(action.tool_input, str) else action.tool_input)
completion = None
if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \
or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction):
thought = action.log.strip()
completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']})
else:
action_name_position = action.log.index("Action:") if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
if self._current_loop and self._current_loop.status == 'llm_end':
self._current_loop.status = 'agent_action'
self._current_loop.thought = thought
self._current_loop.tool_name = tool
self._current_loop.tool_input = tool_input
if completion is not None:
self._current_loop.completion = completion
self._message_agent_thought = self.conversation_message_task.on_agent_start(
self.current_chain,
self._current_loop
)
def on_tool_end(
self,
@@ -120,10 +151,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
)
self._agent_loops.append(self._current_loop)
self._current_loop = None
self._message_agent_thought = None
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@@ -132,6 +166,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
logging.error(error)
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
@@ -141,10 +176,18 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed = True
self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self._current_loop.thought = '[DONE]'
self._message_agent_thought = self.conversation_message_task.on_agent_start(
self.current_chain,
self._current_loop
)
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
)
self._agent_loops.append(self._current_loop)
self._current_loop = None
self._message_agent_thought = None
elif not self._current_loop and self._agent_loops:
self._agent_loops[-1].status = 'agent_finish'

View File

@@ -1,3 +1,4 @@
import json
import logging
from typing import Any, Dict, List, Union, Optional
@@ -43,9 +44,11 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
input_str: str,
**kwargs: Any,
) -> None:
tool_name = serialized.get('name')
dataset_id = tool_name[len("dataset-"):]
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str))
# tool_name = serialized.get('name')
input_dict = json.loads(input_str.replace("'", "\""))
dataset_id = input_dict.get('dataset_id')
query = input_dict.get('query')
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
def on_tool_end(
self,

View File

@@ -10,9 +10,9 @@ class AgentLoop(BaseModel):
tool_output: str = None
prompt: str = None
prompt_tokens: int = None
prompt_tokens: int = 0
completion: str = None
completion_tokens: int = None
completion_tokens: int = 0
latency: float = None

View File

@@ -1,20 +1,18 @@
import logging
import time
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage
from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
def __init__(self, llm: BaseLanguageModel,
conversation_message_task: ConversationMessageTask):
self.llm = llm
self.llm_message = LLMMessage()

View File

@@ -20,15 +20,13 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
self._current_chain_result = None
self._current_chain_message = None
self.conversation_message_task = conversation_message_task
self.agent_loop_gather_callback_handler = AgentLoopGatherCallbackHandler(
llm_constant.agent_model_name,
conversation_message_task
)
self.agent_callback = None
def clear_chain_results(self) -> None:
self._current_chain_result = None
self._current_chain_message = None
self.agent_loop_gather_callback_handler.current_chain = None
if self.agent_callback:
self.agent_callback.current_chain = None
@property
def always_verbose(self) -> bool:
@@ -58,7 +56,8 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
started_at=time.perf_counter()
)
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
if self.agent_callback:
self.agent_callback.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""