Feat/blocking function call (#2247)

This commit is contained in:
Yeuoly
2024-01-30 15:25:37 +08:00
committed by GitHub
parent 1ea18a2922
commit 6d5b386394
33 changed files with 429 additions and 94 deletions

View File

@@ -5,7 +5,7 @@ from typing import Union, Generator, Dict, Any, Tuple, List
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage, LLMResultChunkDelta
from core.model_manager import ModelInstance
from core.application_queue_manager import PublishFrom
@@ -20,8 +20,7 @@ from models.model import Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance,
conversation: Conversation,
def run(self, conversation: Conversation,
message: Message,
query: str,
) -> Generator[LLMResultChunk, None, None]:
@@ -81,6 +80,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
function_call_state = False
@@ -101,12 +102,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
tools=prompt_messages_tools,
stop=app_orchestration_config.model_config.stop,
stream=True,
stream=self.stream_tool_call,
user=self.user_id,
callbacks=[],
)
@@ -122,11 +123,41 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
current_llm_usage = None
for chunk in chunks:
if self.stream_tool_call:
for chunk in chunks:
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
response += content.data
else:
response += chunk.delta.message.content
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
yield chunk
else:
result: LLMResult = chunks
# check if there is any tool call
if self.check_tool_calls(chunk):
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
@@ -138,18 +169,44 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
if result.usage:
increase_usage(llm_usage, result.usage)
current_llm_usage = result.usage
if result.message and result.message.content:
if isinstance(result.message.content, list):
for content in result.message.content:
response += content.data
else:
response += chunk.delta.message.content
response += result.message.content
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
if not result.message.content:
result.message.content = ''
yield chunk
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=result.message,
usage=result.usage,
)
)
if tool_calls:
prompt_messages.append(AssistantPromptMessage(
content='',
name='',
tool_calls=[AssistantPromptMessage.ToolCall(
id=tool_call[0],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1],
arguments=json.dumps(tool_call[2], ensure_ascii=False)
)
) for tool_call in tool_calls]
))
# save thought
self.save_agent_thought(
@@ -167,6 +224,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
final_answer += response + '\n'
# update prompt messages
if response.strip():
prompt_messages.append(AssistantPromptMessage(
content=response,
))
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
@@ -256,12 +319,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# update prompt messages
if response.strip():
prompt_messages.append(AssistantPromptMessage(
content=response,
))
# update prompt tool
for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
@@ -287,6 +344,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
if llm_result_chunk.delta.message.tool_calls:
return True
return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
"""
if llm_result.message.tool_calls:
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
"""
@@ -304,6 +369,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
))
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
json.loads(prompt_message.function.arguments),
))
return tool_calls
def organize_prompt_messages(self, prompt_template: str,
query: str = None,