mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-11 03:46:52 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -25,17 +25,19 @@ from models.model import Message
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ['wenxin']
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage] = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] = None
|
||||
_instruction: str = None
|
||||
_query: str = None
|
||||
_prompt_messages_tools: list[PromptMessage] = None
|
||||
|
||||
def run(self, message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
def run(
|
||||
self,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
@@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
# check model mode
|
||||
if 'Observation' not in app_generate_entity.model_conf.stop:
|
||||
if "Observation" not in app_generate_entity.model_conf.stop:
|
||||
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_conf.stop.append('Observation')
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(
|
||||
instruction, inputs)
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
@@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
llm_usage = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
@@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
message_file_ids = []
|
||||
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
@@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(
|
||||
chunks, usage_dict)
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response='',
|
||||
thought='',
|
||||
action_str='',
|
||||
observation='',
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
@@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=chunk
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
scratchpad.thought = scratchpad.thought.strip(
|
||||
) or 'I am thinking about how to help you'
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if 'usage' in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict['usage'])
|
||||
if "usage" in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
usage_dict['usage'] = LLMUsage.empty_usage()
|
||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input
|
||||
} if scratchpad.action else {},
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else "",
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
observation="",
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict['usage']
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
if not scratchpad.is_final():
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = ''
|
||||
final_answer = ""
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_answer = json.dumps(
|
||||
scratchpad.action.action_input)
|
||||
final_answer = json.dumps(scratchpad.action.action_input)
|
||||
elif isinstance(scratchpad.action.action_input, str):
|
||||
final_answer = scratchpad.action.action_input
|
||||
else:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
except json.JSONDecodeError:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
else:
|
||||
function_call_state = True
|
||||
# action is tool call, invoke tool
|
||||
@@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought,
|
||||
observation={
|
||||
scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={
|
||||
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
llm_usage=usage_dict['usage']
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in self._prompt_messages_tools:
|
||||
@@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage']
|
||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||
),
|
||||
system_fingerprint=''
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_name="",
|
||||
tool_input={},
|
||||
tool_invoke_meta={},
|
||||
thought=final_answer,
|
||||
observation={},
|
||||
answer=final_answer,
|
||||
messages_ids=[]
|
||||
messages_ids=[],
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
"""
|
||||
handle invoke action
|
||||
:param action: action
|
||||
@@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(
|
||||
tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file_id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
@@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(
|
||||
action_name=action['action'],
|
||||
action_input=action['action_input']
|
||||
)
|
||||
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
||||
"""
|
||||
@@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
@@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
@abstractmethod
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
organize prompt messages
|
||||
organize prompt messages
|
||||
"""
|
||||
|
||||
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
format assistant message
|
||||
format assistant message
|
||||
"""
|
||||
message = ''
|
||||
message = ""
|
||||
for scratchpad in agent_scratchpad:
|
||||
if scratchpad.is_final():
|
||||
message += f"Final Answer: {scratchpad.agent_response}"
|
||||
@@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
@@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if not current_scratchpad:
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or 'I am thinking about how to help you',
|
||||
action_str='',
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
@@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(
|
||||
message.tool_calls[0].function.arguments)
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(
|
||||
current_scratchpad.action.to_dict()
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except:
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
@@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
current_scratchpad.observation = message.content
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(
|
||||
content=self._format_assistant_message(scratchpads)
|
||||
))
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
|
||||
result.append(message)
|
||||
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(
|
||||
content=self._format_assistant_message(scratchpads)
|
||||
))
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
historic_prompts = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=current_session_messages or [],
|
||||
history_messages=result,
|
||||
memory=self.memory
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
return historic_prompts
|
||||
|
||||
@@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = first_prompt \
|
||||
.replace("{{instruction}}", self._instruction) \
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
@@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
Organize
|
||||
"""
|
||||
# organize system prompt
|
||||
system_message = self._organize_system_prompt()
|
||||
@@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content='')
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
@@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
if assistant_messages:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages([
|
||||
system_message,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content='continue')
|
||||
])
|
||||
historic_messages = self._organize_historic_prompt_messages(
|
||||
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
|
||||
)
|
||||
messages = [
|
||||
system_message,
|
||||
*historic_messages,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content='continue')
|
||||
UserPromptMessage(content="continue"),
|
||||
]
|
||||
else:
|
||||
# organize historic prompt messages
|
||||
|
||||
@@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
|
||||
@@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ''
|
||||
assistant_prompt = ""
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
@@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
query_prompt = f"Question: {self._query}"
|
||||
|
||||
# join all messages
|
||||
prompt = system_prompt \
|
||||
.replace("{{historic_messages}}", historic_prompt) \
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt) \
|
||||
prompt = (
|
||||
system_prompt.replace("{{historic_messages}}", historic_prompt)
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt)
|
||||
.replace("{{query}}", query_prompt)
|
||||
)
|
||||
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
|
||||
@@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
|
||||
provider_type: Literal["builtin", "api", "workflow"]
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
@@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel):
|
||||
"""
|
||||
Agent Prompt Entity.
|
||||
"""
|
||||
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
@@ -31,6 +33,7 @@ class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Action Entity.
|
||||
"""
|
||||
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
|
||||
@@ -39,8 +42,8 @@ class AgentScratchpadUnit(BaseModel):
|
||||
Convert to dictionary.
|
||||
"""
|
||||
return {
|
||||
'action': self.action_name,
|
||||
'action_input': self.action_input,
|
||||
"action": self.action_name,
|
||||
"action_input": self.action_input,
|
||||
}
|
||||
|
||||
agent_response: Optional[str] = None
|
||||
@@ -54,10 +57,10 @@ class AgentScratchpadUnit(BaseModel):
|
||||
Check if the scratchpad unit is final.
|
||||
"""
|
||||
return self.action is None or (
|
||||
'final' in self.action.action_name.lower() and
|
||||
'answer' in self.action.action_name.lower()
|
||||
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
|
||||
)
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Entity.
|
||||
@@ -67,8 +70,9 @@ class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Strategy.
|
||||
"""
|
||||
CHAIN_OF_THOUGHT = 'chain-of-thought'
|
||||
FUNCTION_CALLING = 'function-calling'
|
||||
|
||||
CHAIN_OF_THOUGHT = "chain-of-thought"
|
||||
FUNCTION_CALLING = "function-calling"
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
@@ -24,11 +24,9 @@ from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
def run(self,
|
||||
message: Message, query: str, **kwargs: Any
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
@@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
llm_usage = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
@@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
message_file_ids = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
@@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
response = ""
|
||||
|
||||
# save tool call names and inputs
|
||||
tool_call_names = ''
|
||||
tool_call_inputs = ''
|
||||
tool_call_names = ""
|
||||
tool_call_inputs = ""
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
@@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
is_first_chunk = False
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
@@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if result.usage:
|
||||
increase_usage(llm_usage, result.usage)
|
||||
@@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
response += result.message.content
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ''
|
||||
result.message.content = ""
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
@@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(
|
||||
content='',
|
||||
tool_calls=[]
|
||||
)
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
assistant_message.tool_calls=[
|
||||
assistant_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call[0],
|
||||
type='function',
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call[1],
|
||||
arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
)
|
||||
) for tool_call in tool_calls
|
||||
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
else:
|
||||
assistant_message.content = response
|
||||
|
||||
|
||||
self._current_thoughts.append(assistant_message)
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
@@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
llm_usage=current_llm_usage
|
||||
llm_usage=current_llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
final_answer += response + '\n'
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
final_answer += response + "\n"
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
@@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}",
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
||||
}
|
||||
else:
|
||||
# invoke tool
|
||||
@@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file_id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": tool_invoke_response,
|
||||
"meta": tool_invoke_meta.to_dict()
|
||||
"meta": tool_invoke_meta.to_dict(),
|
||||
}
|
||||
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
if tool_response['tool_response'] is not None:
|
||||
if tool_response["tool_response"] is not None:
|
||||
self._current_thoughts.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response['tool_response'],
|
||||
content=tool_response["tool_response"],
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
agent_thought=agent_thought,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
thought=None,
|
||||
tool_invoke_meta={
|
||||
tool_response['tool_call_name']: tool_response['meta']
|
||||
for tool_response in tool_responses
|
||||
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
|
||||
},
|
||||
observation={
|
||||
tool_response['tool_call_name']: tool_response['tool_response']
|
||||
tool_response["tool_call_name"]: tool_response["tool_response"]
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
answer=None,
|
||||
messages_ids=message_file_ids
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
@@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
@@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
if llm_result_chunk.delta.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
||||
"""
|
||||
Check if there is any blocking tool call in llm result
|
||||
@@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_tool_calls(
|
||||
self, llm_result_chunk: LLMResultChunk
|
||||
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != '':
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
))
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
@@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != '':
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
))
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _init_system_message(
|
||||
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
@@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
]
|
||||
|
||||
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
@@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
@@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = '\n'.join([
|
||||
content.data if content.type == PromptMessageContentType.TEXT else
|
||||
'[image]' if content.type == PromptMessageContentType.IMAGE else
|
||||
'[file]'
|
||||
for content in prompt_message.content
|
||||
])
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query, [])
|
||||
|
||||
@@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [
|
||||
*self.history_prompt_messages,
|
||||
*query_prompt_messages,
|
||||
*self._current_thoughts
|
||||
]
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
|
||||
@@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
|
||||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
|
||||
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(json_str):
|
||||
try:
|
||||
action = json.loads(json_str)
|
||||
@@ -22,7 +23,7 @@ class CotAgentOutputParser:
|
||||
action = action[0]
|
||||
|
||||
for key, value in action.items():
|
||||
if 'input' in key.lower():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
else:
|
||||
action_name = value
|
||||
@@ -33,37 +34,37 @@ class CotAgentOutputParser:
|
||||
action_input=action_input,
|
||||
)
|
||||
else:
|
||||
return json_str or ''
|
||||
return json_str or ""
|
||||
except:
|
||||
return json_str or ''
|
||||
|
||||
return json_str or ""
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
||||
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
for block in code_blocks:
|
||||
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
||||
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
||||
yield parse_action(json_text)
|
||||
|
||||
code_block_cache = ''
|
||||
|
||||
code_block_cache = ""
|
||||
code_block_delimiter_count = 0
|
||||
in_code_block = False
|
||||
json_cache = ''
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
got_json = False
|
||||
|
||||
action_cache = ''
|
||||
action_str = 'action:'
|
||||
action_cache = ""
|
||||
action_str = "action:"
|
||||
action_idx = 0
|
||||
|
||||
thought_cache = ''
|
||||
thought_str = 'thought:'
|
||||
thought_cache = ""
|
||||
thought_str = "thought:"
|
||||
thought_idx = 0
|
||||
|
||||
for response in llm_response:
|
||||
if response.delta.usage:
|
||||
usage_dict['usage'] = response.delta.usage
|
||||
usage_dict["usage"] = response.delta.usage
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
continue
|
||||
@@ -72,24 +73,24 @@ class CotAgentOutputParser:
|
||||
index = 0
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index:index+steps]
|
||||
last_character = response[index-1] if index > 0 else ''
|
||||
delta = response[index : index + steps]
|
||||
last_character = response[index - 1] if index > 0 else ""
|
||||
|
||||
if delta == '`':
|
||||
if delta == "`":
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
yield code_block_cache
|
||||
code_block_cache = ''
|
||||
code_block_cache = ""
|
||||
else:
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
||||
if last_character not in ['\n', ' ', '']:
|
||||
if last_character not in ["\n", " ", ""]:
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
@@ -97,7 +98,7 @@ class CotAgentOutputParser:
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
@@ -105,18 +106,18 @@ class CotAgentOutputParser:
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if action_cache:
|
||||
yield action_cache
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
|
||||
|
||||
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
||||
if last_character not in ['\n', ' ', '']:
|
||||
if last_character not in ["\n", " ", ""]:
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
@@ -124,7 +125,7 @@ class CotAgentOutputParser:
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
@@ -132,31 +133,31 @@ class CotAgentOutputParser:
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if thought_cache:
|
||||
yield thought_cache
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ''
|
||||
|
||||
code_block_cache = ""
|
||||
|
||||
in_code_block = not in_code_block
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block:
|
||||
# handle single json
|
||||
if delta == '{':
|
||||
if delta == "{":
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
json_cache += delta
|
||||
elif delta == '}':
|
||||
elif delta == "}":
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
@@ -172,12 +173,12 @@ class CotAgentOutputParser:
|
||||
if got_json:
|
||||
got_json = False
|
||||
yield parse_action(json_cache)
|
||||
json_cache = ''
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
yield delta.replace('`', '')
|
||||
yield delta.replace("`", "")
|
||||
|
||||
index += steps
|
||||
|
||||
@@ -186,4 +187,3 @@ class CotAgentOutputParser:
|
||||
|
||||
if json_cache:
|
||||
yield parse_action(json_cache)
|
||||
|
||||
|
||||
@@ -91,14 +91,14 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
|
||||
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
|
||||
|
||||
REACT_PROMPT_TEMPLATES = {
|
||||
'english': {
|
||||
'chat': {
|
||||
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
|
||||
"english": {
|
||||
"chat": {
|
||||
"prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
"agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
|
||||
},
|
||||
"completion": {
|
||||
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
|
||||
},
|
||||
'completion': {
|
||||
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user