feat: optimize error record in agent (#869)

This commit is contained in:
takatost
2023-08-16 15:55:42 +08:00
committed by GitHub
parent f207e180df
commit 2dfb3e95f6
8 changed files with 35 additions and 14 deletions

View File

@@ -59,7 +59,11 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
return super().plan(intermediate_steps, callbacks, **kwargs)
try:
return super().plan(intermediate_steps, callbacks, **kwargs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
async def aplan(
self,

View File

@@ -50,9 +50,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {})

View File

@@ -50,9 +50,13 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {})

View File

@@ -94,7 +94,12 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
return AgentFinish(return_values={"output": rst}, log=rst)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
try:
return self.output_parser.parse(full_output)

View File

@@ -89,8 +89,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
messages = []
if prompts:
messages = prompts[0].to_messages()
@@ -99,7 +99,11 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
try:
return self.output_parser.parse(full_output)