From d339403e89df80a9d84e64c0c67b386166cb8e74 Mon Sep 17 00:00:00 2001 From: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com> Date: Wed, 19 Mar 2025 11:24:57 +0800 Subject: [PATCH] Chore: optimize the code of PromptTransform (#16143) --- api/core/prompt/simple_prompt_transform.py | 16 ++++++++-------- .../test_agent_history_prompt_transform.py | 2 -- .../core/prompt/test_simple_prompt_transform.py | 1 - 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 421b14e0d..ad56d84cb 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -93,7 +93,7 @@ class SimplePromptTransform(PromptTransform): return prompt_messages, stops - def get_prompt_str_and_rules( + def _get_prompt_str_and_rules( self, app_mode: AppMode, model_config: ModelConfigWithCredentialsEntity, @@ -184,7 +184,7 @@ class SimplePromptTransform(PromptTransform): prompt_messages: list[PromptMessage] = [] # get prompt - prompt, _ = self.get_prompt_str_and_rules( + prompt, _ = self._get_prompt_str_and_rules( app_mode=app_mode, model_config=model_config, pre_prompt=pre_prompt, @@ -209,9 +209,9 @@ class SimplePromptTransform(PromptTransform): ) if query: - prompt_messages.append(self.get_last_user_message(query, files, image_detail_config)) + prompt_messages.append(self._get_last_user_message(query, files, image_detail_config)) else: - prompt_messages.append(self.get_last_user_message(prompt, files, image_detail_config)) + prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config)) return prompt_messages, None @@ -228,7 +228,7 @@ class SimplePromptTransform(PromptTransform): image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, ) -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt - prompt, prompt_rules = self.get_prompt_str_and_rules( + prompt, prompt_rules = self._get_prompt_str_and_rules( app_mode=app_mode, model_config=model_config, pre_prompt=pre_prompt, @@ -254,7 +254,7 @@ class SimplePromptTransform(PromptTransform): ) # get prompt - prompt, prompt_rules = self.get_prompt_str_and_rules( + prompt, prompt_rules = self._get_prompt_str_and_rules( app_mode=app_mode, model_config=model_config, pre_prompt=pre_prompt, @@ -268,9 +268,9 @@ class SimplePromptTransform(PromptTransform): if stops is not None and len(stops) == 0: stops = None - return [self.get_last_user_message(prompt, files, image_detail_config)], stops + return [self._get_last_user_message(prompt, files, image_detail_config)], stops - def get_last_user_message( + def _get_last_user_message( self, prompt: str, files: Sequence["File"], diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 0fd176e65..d157a41d2 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -64,12 +64,10 @@ def test_get_prompt(): transform._calculate_rest_token = MagicMock(return_value=max_token_limit) result = transform.get_prompt() - assert len(result) <= max_token_limit assert len(result) == 4 max_token_limit = 20 transform._calculate_rest_token = MagicMock(return_value=max_token_limit) result = transform.get_prompt() - assert len(result) <= max_token_limit assert len(result) == 12 diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index c32fc2bc3..c822ecbe7 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -84,7 +84,6 @@ def test_get_baichuan_completion_app_prompt_template_with_pcq(): query_in_prompt=True, with_memory_prompt=False, ) - print(prompt_template["prompt_template"].template) prompt_rules = prompt_template["prompt_rules"] assert prompt_template["prompt_template"].template == ( prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]