feat: record price unit in messages (#919)

This commit is contained in:
takatost
2023-08-19 18:51:40 +08:00
committed by GitHub
parent 920fb6d0e1
commit 0a0d63457d
5 changed files with 88 additions and 4 deletions

View File

@@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
@@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.status = 'llm_end'
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else:
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)]
)
completion_generation = response.generations[0][0]
if isinstance(completion_generation, ChatGeneration):
completion_message = completion_generation.message
@@ -81,6 +86,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)]
)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any