chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -25,17 +25,19 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ['wenxin']
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None
_query: str = None
_prompt_messages_tools: list[PromptMessage] = None
def run(self, message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
def run(
self,
message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
"""
Run Cot agent application
"""
@@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
trace_manager = app_generate_entity.trace_manager
# check model mode
if 'Observation' not in app_generate_entity.model_conf.stop:
if "Observation" not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append('Observation')
app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(
instruction, inputs)
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True
llm_usage = {
'usage': None
}
final_answer = ''
llm_usage = {"usage": None}
final_answer = ""
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict['usage']
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
@@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
if iteration_step > 1:
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
@@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC):
raise ValueError("failed to invoke llm")
usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(
chunks, usage_dict)
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response='',
thought='',
action_str='',
observation='',
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
@@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC):
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=chunk
),
usage=None
)
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
scratchpad.thought = scratchpad.thought.strip(
) or 'I am thinking about how to help you'
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
# get llm usage
if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict['usage'])
if "usage" in usage_dict:
increase_usage(llm_usage, usage_dict["usage"])
else:
usage_dict['usage'] = LLMUsage.empty_usage()
usage_dict["usage"] = LLMUsage.empty_usage()
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else {},
tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought,
observation='',
observation="",
answer=scratchpad.agent_response,
messages_ids=[],
llm_usage=usage_dict['usage']
llm_usage=usage_dict["usage"],
)
if not scratchpad.is_final():
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = ''
final_answer = ""
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(
scratchpad.action.action_input)
final_answer = json.dumps(scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f'{scratchpad.action.action_input}'
final_answer = f"{scratchpad.action.action_input}"
except json.JSONDecodeError:
final_answer = f'{scratchpad.action.action_input}'
final_answer = f"{scratchpad.action.action_input}"
else:
function_call_state = True
# action is tool call, invoke tool
@@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input},
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought,
observation={
scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict['usage']
llm_usage=usage_dict["usage"],
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
@@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC):
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage']
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
),
system_fingerprint=''
system_fingerprint="",
)
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name='',
tool_name="",
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation={},
answer=final_answer,
messages_ids=[]
messages_ids=[],
)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
system_fingerprint="",
)
),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None
) -> tuple[str, ToolInvokeMeta]:
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
:param action: action
@@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files
for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file_id, name=save_as)
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file_id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
@@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
convert dict to action
"""
return AgentScratchpadUnit.Action(
action_name=action['action'],
action_input=action['action_input']
)
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
"""
@@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
for key, value in inputs.items():
try:
instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception as e:
continue
@@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
organize prompt messages
organize prompt messages
"""
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
format assistant message
format assistant message
"""
message = ''
message = ""
for scratchpad in agent_scratchpad:
if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}"
@@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
"""
organize historic prompt messages
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
@@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if not current_scratchpad:
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or 'I am thinking about how to help you',
action_str='',
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
@@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(
message.tool_calls[0].function.arguments)
)
current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict()
action_input=json.loads(message.tool_calls[0].function.arguments),
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except:
pass
elif isinstance(message, ToolPromptMessage):
@@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
current_scratchpad.observation = message.content
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory
memory=self.memory,
).get_prompt()
return historic_prompts

View File

@@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
system_prompt = first_prompt \
.replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
@@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner):
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize
Organize
"""
# organize system prompt
system_message = self._organize_system_prompt()
@@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad:
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content='')
assistant_message = AssistantPromptMessage(content="")
for unit in agent_scratchpad:
if unit.is_final():
assistant_message.content += f"Final Answer: {unit.agent_response}"
@@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner):
if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([
system_message,
*query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
])
historic_messages = self._organize_historic_prompt_messages(
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
)
messages = [
system_message,
*historic_messages,
*query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
UserPromptMessage(content="continue"),
]
else:
# organize historic prompt messages

View File

@@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return system_prompt
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
@@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
assistant_prompt = ''
assistant_prompt = ""
for unit in agent_scratchpad:
if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}"
@@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
query_prompt = f"Question: {self._query}"
# join all messages
prompt = system_prompt \
.replace("{{historic_messages}}", historic_prompt) \
.replace("{{agent_scratchpad}}", assistant_prompt) \
prompt = (
system_prompt.replace("{{historic_messages}}", historic_prompt)
.replace("{{agent_scratchpad}}", assistant_prompt)
.replace("{{query}}", query_prompt)
)
return [UserPromptMessage(content=prompt)]
return [UserPromptMessage(content=prompt)]

View File

@@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api", "workflow"]
provider_id: str
tool_name: str
@@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
"""
first_prompt: str
next_iteration: str
@@ -31,6 +33,7 @@ class AgentScratchpadUnit(BaseModel):
"""
Action Entity.
"""
action_name: str
action_input: Union[dict, str]
@@ -39,8 +42,8 @@ class AgentScratchpadUnit(BaseModel):
Convert to dictionary.
"""
return {
'action': self.action_name,
'action_input': self.action_input,
"action": self.action_name,
"action_input": self.action_input,
}
agent_response: Optional[str] = None
@@ -54,10 +57,10 @@ class AgentScratchpadUnit(BaseModel):
Check if the scratchpad unit is final.
"""
return self.action is None or (
'final' in self.action.action_name.lower() and
'answer' in self.action.action_name.lower()
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
)
class AgentEntity(BaseModel):
"""
Agent Entity.
@@ -67,8 +70,9 @@ class AgentEntity(BaseModel):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling'
CHAIN_OF_THOUGHT = "chain-of-thought"
FUNCTION_CALLING = "function-calling"
provider: str
model: str

View File

@@ -24,11 +24,9 @@ from models.model import Message
logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self,
message: Message, query: str, **kwargs: Any
) -> Generator[LLMResultChunk, None, None]:
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
@@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call
function_call_state = True
llm_usage = {
'usage': None
}
final_answer = ''
llm_usage = {"usage": None}
final_answer = ""
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict['usage']
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
@@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
# recalc llm max tokens
@@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response
response = ''
response = ""
# save tool call names and inputs
tool_call_names = ''
tool_call_inputs = ''
tool_call_names = ""
tool_call_inputs = ""
current_llm_usage = None
@@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
is_first_chunk = False
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
@@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if result.usage:
increase_usage(llm_usage, result.usage)
@@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
response += result.message.content
if not result.message.content:
result.message.content = ''
result.message.content = ""
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
@@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner):
index=0,
message=result.message,
usage=result.usage,
)
),
)
assistant_message = AssistantPromptMessage(
content='',
tool_calls=[]
)
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
assistant_message.tool_calls=[
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=tool_call[0],
type='function',
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1],
arguments=json.dumps(tool_call[2], ensure_ascii=False)
)
) for tool_call in tool_calls
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
),
)
for tool_call in tool_calls
]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message)
# save thought
self.save_agent_thought(
agent_thought=agent_thought,
agent_thought=agent_thought,
tool_name=tool_call_names,
tool_input=tool_call_inputs,
thought=response,
@@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
observation=None,
answer=response,
messages_ids=[],
llm_usage=current_llm_usage
llm_usage=current_llm_usage,
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
final_answer += response + '\n'
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
final_answer += response + "\n"
# call tools
tool_responses = []
@@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
}
else:
# invoke tool
@@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file_id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict()
"meta": tool_invoke_meta.to_dict(),
}
tool_responses.append(tool_response)
if tool_response['tool_response'] is not None:
if tool_response["tool_response"] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=tool_response['tool_response'],
content=tool_response["tool_response"],
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
)
if len(tool_responses) > 0:
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
agent_thought=agent_thought,
tool_name=None,
tool_input=None,
thought=None,
thought=None,
tool_invoke_meta={
tool_response['tool_call_name']: tool_response['meta']
for tool_response in tool_responses
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
},
observation={
tool_response['tool_call_name']: tool_response['tool_response']
tool_response["tool_call_name"]: tool_response["tool_response"]
for tool_response in tool_responses
},
answer=None,
messages_ids=message_file_ids
messages_ids=message_file_ids,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool
for prompt_tool in prompt_messages_tools:
@@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
system_fingerprint="",
)
),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
PublishFrom.APPLICATION_MANAGER,
)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
"""
@@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if llm_result_chunk.delta.message.tool_calls:
return True
return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
@@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_tool_calls(
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract tool calls from llm result chunk
@@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls:
args = {}
if prompt_message.function.arguments != '':
if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments)
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
args,
))
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result
@@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
args = {}
if prompt_message.function.arguments != '':
if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments)
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
args,
))
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
return tool_calls
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _init_system_message(
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
"""
Initialize system message
"""
@@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return [
SystemPromptMessage(content=prompt_template),
]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
@@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
@@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = '\n'.join([
content.data if content.type == PromptMessageContentType.TEXT else
'[image]' if content.type == PromptMessageContentType.IMAGE else
'[file]'
for content in prompt_message.content
])
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages
def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, [])
@@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory
memory=self.memory,
).get_prompt()
prompt_messages = [
*self.history_prompt_messages,
*query_prompt_messages,
*self._current_thoughts
]
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)

View File

@@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
class CotAgentOutputParser:
@classmethod
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def handle_react_stream_output(
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(json_str):
try:
action = json.loads(json_str)
@@ -22,7 +23,7 @@ class CotAgentOutputParser:
action = action[0]
for key, value in action.items():
if 'input' in key.lower():
if "input" in key.lower():
action_input = value
else:
action_name = value
@@ -33,37 +34,37 @@ class CotAgentOutputParser:
action_input=action_input,
)
else:
return json_str or ''
return json_str or ""
except:
return json_str or ''
return json_str or ""
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
if not code_blocks:
return
for block in code_blocks:
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
yield parse_action(json_text)
code_block_cache = ''
code_block_cache = ""
code_block_delimiter_count = 0
in_code_block = False
json_cache = ''
json_cache = ""
json_quote_count = 0
in_json = False
got_json = False
action_cache = ''
action_str = 'action:'
action_cache = ""
action_str = "action:"
action_idx = 0
thought_cache = ''
thought_str = 'thought:'
thought_cache = ""
thought_str = "thought:"
thought_idx = 0
for response in llm_response:
if response.delta.usage:
usage_dict['usage'] = response.delta.usage
usage_dict["usage"] = response.delta.usage
response = response.delta.message.content
if not isinstance(response, str):
continue
@@ -72,24 +73,24 @@ class CotAgentOutputParser:
index = 0
while index < len(response):
steps = 1
delta = response[index:index+steps]
last_character = response[index-1] if index > 0 else ''
delta = response[index : index + steps]
last_character = response[index - 1] if index > 0 else ""
if delta == '`':
if delta == "`":
code_block_cache += delta
code_block_delimiter_count += 1
else:
if not in_code_block:
if code_block_delimiter_count > 0:
yield code_block_cache
code_block_cache = ''
code_block_cache = ""
else:
code_block_cache += delta
code_block_delimiter_count = 0
if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in ['\n', ' ', '']:
if last_character not in ["\n", " ", ""]:
index += steps
yield delta
continue
@@ -97,7 +98,7 @@ class CotAgentOutputParser:
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ''
action_cache = ""
action_idx = 0
index += steps
continue
@@ -105,18 +106,18 @@ class CotAgentOutputParser:
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ''
action_cache = ""
action_idx = 0
index += steps
continue
else:
if action_cache:
yield action_cache
action_cache = ''
action_cache = ""
action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in ['\n', ' ', '']:
if last_character not in ["\n", " ", ""]:
index += steps
yield delta
continue
@@ -124,7 +125,7 @@ class CotAgentOutputParser:
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ''
thought_cache = ""
thought_idx = 0
index += steps
continue
@@ -132,31 +133,31 @@ class CotAgentOutputParser:
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ''
thought_cache = ""
thought_idx = 0
index += steps
continue
else:
if thought_cache:
yield thought_cache
thought_cache = ''
thought_cache = ""
thought_idx = 0
if code_block_delimiter_count == 3:
if in_code_block:
yield from extra_json_from_code_block(code_block_cache)
code_block_cache = ''
code_block_cache = ""
in_code_block = not in_code_block
code_block_delimiter_count = 0
if not in_code_block:
# handle single json
if delta == '{':
if delta == "{":
json_quote_count += 1
in_json = True
json_cache += delta
elif delta == '}':
elif delta == "}":
json_cache += delta
if json_quote_count > 0:
json_quote_count -= 1
@@ -172,12 +173,12 @@ class CotAgentOutputParser:
if got_json:
got_json = False
yield parse_action(json_cache)
json_cache = ''
json_cache = ""
json_quote_count = 0
in_json = False
if not in_code_block and not in_json:
yield delta.replace('`', '')
yield delta.replace("`", "")
index += steps
@@ -186,4 +187,3 @@ class CotAgentOutputParser:
if json_cache:
yield parse_action(json_cache)

View File

@@ -91,14 +91,14 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
REACT_PROMPT_TEMPLATES = {
'english': {
'chat': {
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
"english": {
"chat": {
"prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
},
"completion": {
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
},
'completion': {
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
}
}
}
}