diff --git a/api/core/completion.py b/api/core/completion.py
index 4614f8084..192b16f4f 100644
--- a/api/core/completion.py
+++ b/api/core/completion.py
@@ -130,13 +130,12 @@ class Completion:
fake_response = agent_execute_result.output
# get llm prompt
- prompt_messages, stop_words = cls.get_main_llm_prompt(
+ prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
- model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
- query=query,
inputs=inputs,
- agent_execute_result=agent_execute_result,
+ query=query,
+ context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
@@ -154,113 +153,6 @@ class Completion:
return response
- @classmethod
- def get_main_llm_prompt(cls, mode: str, model: dict,
- pre_prompt: str, query: str, inputs: dict,
- agent_execute_result: Optional[AgentExecuteResult],
- memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
- Tuple[List[PromptMessage], Optional[List[str]]]:
- if mode == 'completion':
- prompt_template = JinjaPromptTemplate.from_template(
- template=("""Use the following context as your learned knowledge, inside XML tags.
-
-
-{{context}}
-
-
-When answer to user:
-- If you don't know, just say that you don't know.
-- If you don't know when you are not sure, ask for clarification.
-Avoid mentioning that you obtained the information from the context.
-And answer according to the language of the user's question.
-""" if agent_execute_result else "")
- + (pre_prompt + "\n" if pre_prompt else "")
- + "{{query}}\n"
- )
-
- if agent_execute_result:
- inputs['context'] = agent_execute_result.output
-
- prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
- prompt_content = prompt_template.format(
- query=query,
- **prompt_inputs
- )
-
- return [PromptMessage(content=prompt_content)], None
- else:
- messages: List[BaseMessage] = []
-
- human_inputs = {
- "query": query
- }
-
- human_message_prompt = ""
-
- if pre_prompt:
- pre_prompt_inputs = {k: inputs[k] for k in
- JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
- if k in inputs}
-
- if pre_prompt_inputs:
- human_inputs.update(pre_prompt_inputs)
-
- if agent_execute_result:
- human_inputs['context'] = agent_execute_result.output
- human_message_prompt += """Use the following context as your learned knowledge, inside XML tags.
-
-
-{{context}}
-
-
-When answer to user:
-- If you don't know, just say that you don't know.
-- If you don't know when you are not sure, ask for clarification.
-Avoid mentioning that you obtained the information from the context.
-And answer according to the language of the user's question.
-"""
-
- if pre_prompt:
- human_message_prompt += pre_prompt
-
- query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
-
- if memory:
- # append chat histories
- tmp_human_message = PromptBuilder.to_human_message(
- prompt_content=human_message_prompt + query_prompt,
- inputs=human_inputs
- )
-
- if memory.model_instance.model_rules.max_tokens.max:
- curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
- max_tokens = model.get("completion_params").get('max_tokens')
- rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
- rest_tokens = max(rest_tokens, 0)
- else:
- rest_tokens = 2000
-
- histories = cls.get_history_messages_from_memory(memory, rest_tokens)
- human_message_prompt += "\n\n" if human_message_prompt else ""
- human_message_prompt += "Here is the chat histories between human and assistant, " \
- "inside XML tags.\n\n\n"
- human_message_prompt += histories + "\n"
-
- human_message_prompt += query_prompt
-
- # construct main prompt
- human_message = PromptBuilder.to_human_message(
- prompt_content=human_message_prompt,
- inputs=human_inputs
- )
-
- messages.append(human_message)
-
- for message in messages:
- message.content = re.sub(r'<\|.*?\|>', '', message.content)
-
- return to_prompt_messages(messages), ['\nHuman:', '']
-
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> str:
@@ -307,13 +199,12 @@ And answer according to the language of the user's question.
max_tokens = 0
# get prompt without memory and context
- prompt_messages, _ = cls.get_main_llm_prompt(
+ prompt_messages, _ = model_instance.get_prompt(
mode=mode,
- model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
- query=query,
inputs=inputs,
- agent_execute_result=None,
+ query=query,
+ context=None,
memory=None
)
@@ -358,13 +249,12 @@ And answer according to the language of the user's question.
)
# get llm prompt
- old_prompt_messages, _ = cls.get_main_llm_prompt(
- mode="completion",
- model=app_model_config.model_dict,
+ old_prompt_messages, _ = final_model_instance.get_prompt(
+ mode='completion',
pre_prompt=pre_prompt,
- query=message.query,
inputs=message.inputs,
- agent_execute_result=None,
+ query=message.query,
+ context=None,
memory=None
)
diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py
index fe9240706..8662e7327 100644
--- a/api/core/model_providers/models/llm/base.py
+++ b/api/core/model_providers/models/llm/base.py
@@ -1,17 +1,24 @@
+import json
+import os
+import re
from abc import abstractmethod
-from typing import List, Optional, Any, Union
+from typing import List, Optional, Any, Union, Tuple
import decimal
from langchain.callbacks.manager import Callbacks
+from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
+from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
+from core.prompt.prompt_builder import PromptBuilder
+from core.prompt.prompt_template import JinjaPromptTemplate
from core.third_party.langchain.llms.fake import FakeLLM
import logging
+
logger = logging.getLogger(__name__)
@@ -76,13 +83,14 @@ class BaseLLM(BaseProviderModel):
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
- 'prompt': decimal.Decimal('0'),
- 'completion': decimal.Decimal('0'),
- 'unit': decimal.Decimal('0'),
- 'currency': 'USD'
- }
+ 'prompt': decimal.Decimal('0'),
+ 'completion': decimal.Decimal('0'),
+ 'unit': decimal.Decimal('0'),
+ 'currency': 'USD'
+ }
rules = self.model_provider.get_rules()
- price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
+ price_config = rules['price_config'][
+ self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'prompt': decimal.Decimal(price_config['prompt']),
'completion': decimal.Decimal(price_config['completion']),
@@ -90,7 +98,7 @@ class BaseLLM(BaseProviderModel):
'currency': price_config['currency']
}
return price_config
-
+
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
logger.debug(f"model: {self.name} price_config: {self._price_config}")
@@ -158,7 +166,8 @@ class BaseLLM(BaseProviderModel):
total_tokens = result.llm_output['token_usage']['total_tokens']
else:
prompt_tokens = self.get_num_tokens(messages)
- completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
+ completion_tokens = self.get_num_tokens(
+ [PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
total_tokens = prompt_tokens + completion_tokens
self.model_provider.update_last_used()
@@ -293,6 +302,119 @@ class BaseLLM(BaseProviderModel):
def support_streaming(cls):
return False
+ def get_prompt(self, mode: str,
+ pre_prompt: str, inputs: dict,
+ query: str,
+ context: Optional[str],
+ memory: Optional[BaseChatMemory]) -> \
+ Tuple[List[PromptMessage], Optional[List[str]]]:
+ prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
+ prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
+ return [PromptMessage(content=prompt)], stops
+
+ def prompt_file_name(self, mode: str) -> str:
+ if mode == 'completion':
+ return 'common_completion'
+ else:
+ return 'common_chat'
+
+ def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
+ query: str,
+ context: Optional[str],
+ memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
+ context_prompt_content = ''
+ if context and 'context_prompt' in prompt_rules:
+ prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
+ context_prompt_content = prompt_template.format(
+ context=context
+ )
+
+ pre_prompt_content = ''
+ if pre_prompt:
+ prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
+ prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
+ pre_prompt_content = prompt_template.format(
+ **prompt_inputs
+ )
+
+ prompt = ''
+ for order in prompt_rules['system_prompt_orders']:
+ if order == 'context_prompt':
+ prompt += context_prompt_content
+ elif order == 'pre_prompt':
+ prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
+
+ query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
+
+ if memory and 'histories_prompt' in prompt_rules:
+ # append chat histories
+ tmp_human_message = PromptBuilder.to_human_message(
+ prompt_content=prompt + query_prompt,
+ inputs={
+ 'query': query
+ }
+ )
+
+ if self.model_rules.max_tokens.max:
+ curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
+ max_tokens = self.model_kwargs.max_tokens
+ rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
+ rest_tokens = max(rest_tokens, 0)
+ else:
+ rest_tokens = 2000
+
+ memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
+ memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
+
+ histories = self._get_history_messages_from_memory(memory, rest_tokens)
+ prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
+ histories_prompt_content = prompt_template.format(
+ histories=histories
+ )
+
+ prompt = ''
+ for order in prompt_rules['system_prompt_orders']:
+ if order == 'context_prompt':
+ prompt += context_prompt_content
+ elif order == 'pre_prompt':
+ prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
+ elif order == 'histories_prompt':
+ prompt += histories_prompt_content
+
+ prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
+ query_prompt_content = prompt_template.format(
+ query=query
+ )
+
+ prompt += query_prompt_content
+
+ prompt = re.sub(r'<\|.*?\|>', '', prompt)
+
+ stops = prompt_rules.get('stops')
+ if stops is not None and len(stops) == 0:
+ stops = None
+
+ return prompt, stops
+
+ def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
+ # Get the absolute path of the subdirectory
+ prompt_path = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
+ 'prompt/generate_prompts')
+
+ json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
+ # Open the JSON file and read its content
+ with open(json_file_path, 'r') as json_file:
+ return json.load(json_file)
+
+ def _get_history_messages_from_memory(self, memory: BaseChatMemory,
+ max_token_limit: int) -> str:
+ """Get memory messages."""
+ memory.max_token_limit = max_token_limit
+ memory_key = memory.memory_variables[0]
+ external_context = memory.load_memory_variables({})
+ return external_context[memory_key]
+
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if not model_mode:
diff --git a/api/core/model_providers/models/llm/huggingface_hub_model.py b/api/core/model_providers/models/llm/huggingface_hub_model.py
index 7e800e3fe..fb381bf64 100644
--- a/api/core/model_providers/models/llm/huggingface_hub_model.py
+++ b/api/core/model_providers/models/llm/huggingface_hub_model.py
@@ -60,6 +60,15 @@ class HuggingfaceHubModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts)
+ def prompt_file_name(self, mode: str) -> str:
+ if 'baichuan' in self.name.lower():
+ if mode == 'completion':
+ return 'baichuan_completion'
+ else:
+ return 'baichuan_chat'
+ else:
+ return super().prompt_file_name(mode)
+
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs
diff --git a/api/core/model_providers/models/llm/openllm_model.py b/api/core/model_providers/models/llm/openllm_model.py
index 2f9876a92..217d893c4 100644
--- a/api/core/model_providers/models/llm/openllm_model.py
+++ b/api/core/model_providers/models/llm/openllm_model.py
@@ -49,6 +49,15 @@ class OpenLLMModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
+ def prompt_file_name(self, mode: str) -> str:
+ if 'baichuan' in self.name.lower():
+ if mode == 'completion':
+ return 'baichuan_completion'
+ else:
+ return 'baichuan_chat'
+ else:
+ return super().prompt_file_name(mode)
+
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass
diff --git a/api/core/model_providers/models/llm/xinference_model.py b/api/core/model_providers/models/llm/xinference_model.py
index 8af6356c9..a058a601b 100644
--- a/api/core/model_providers/models/llm/xinference_model.py
+++ b/api/core/model_providers/models/llm/xinference_model.py
@@ -59,6 +59,15 @@ class XinferenceModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
+ def prompt_file_name(self, mode: str) -> str:
+ if 'baichuan' in self.name.lower():
+ if mode == 'completion':
+ return 'baichuan_completion'
+ else:
+ return 'baichuan_chat'
+ else:
+ return super().prompt_file_name(mode)
+
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass
diff --git a/api/core/prompt/generate_prompts/baichuan_chat.json b/api/core/prompt/generate_prompts/baichuan_chat.json
new file mode 100644
index 000000000..93b9b7254
--- /dev/null
+++ b/api/core/prompt/generate_prompts/baichuan_chat.json
@@ -0,0 +1,13 @@
+{
+ "human_prefix": "用户",
+ "assistant_prefix": "助手",
+ "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n引用材料\n{{context}}\n```\n\n",
+ "histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{histories}}\n```\n\n",
+ "system_prompt_orders": [
+ "context_prompt",
+ "pre_prompt",
+ "histories_prompt"
+ ],
+ "query_prompt": "用户:{{query}}\n助手:",
+ "stops": ["用户:"]
+}
\ No newline at end of file
diff --git a/api/core/prompt/generate_prompts/baichuan_completion.json b/api/core/prompt/generate_prompts/baichuan_completion.json
new file mode 100644
index 000000000..bdc7cf976
--- /dev/null
+++ b/api/core/prompt/generate_prompts/baichuan_completion.json
@@ -0,0 +1,9 @@
+{
+ "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n引用材料\n{{context}}\n```\n",
+ "system_prompt_orders": [
+ "context_prompt",
+ "pre_prompt"
+ ],
+ "query_prompt": "{{query}}",
+ "stops": null
+}
\ No newline at end of file
diff --git a/api/core/prompt/generate_prompts/common_chat.json b/api/core/prompt/generate_prompts/common_chat.json
new file mode 100644
index 000000000..baa000a7a
--- /dev/null
+++ b/api/core/prompt/generate_prompts/common_chat.json
@@ -0,0 +1,13 @@
+{
+ "human_prefix": "Human",
+ "assistant_prefix": "Assistant",
+ "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{context}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
+ "histories_prompt": "Here is the chat histories between human and assistant, inside XML tags.\n\n\n{{histories}}\n\n\n",
+ "system_prompt_orders": [
+ "context_prompt",
+ "pre_prompt",
+ "histories_prompt"
+ ],
+ "query_prompt": "Human: {{query}}\n\nAssistant: ",
+ "stops": ["\nHuman:", ""]
+}
\ No newline at end of file
diff --git a/api/core/prompt/generate_prompts/common_completion.json b/api/core/prompt/generate_prompts/common_completion.json
new file mode 100644
index 000000000..9e7e8d68e
--- /dev/null
+++ b/api/core/prompt/generate_prompts/common_completion.json
@@ -0,0 +1,9 @@
+{
+ "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{context}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
+ "system_prompt_orders": [
+ "context_prompt",
+ "pre_prompt"
+ ],
+ "query_prompt": "{{query}}",
+ "stops": null
+}
\ No newline at end of file