mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 11:26:52 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -22,18 +22,22 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
"""
|
||||
Advanced Prompt Transform for Workflow LLM Node.
|
||||
"""
|
||||
|
||||
def __init__(self, with_variable_tmpl: bool = False) -> None:
|
||||
self.with_variable_tmpl = with_variable_tmpl
|
||||
|
||||
def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
|
||||
def get_prompt(
|
||||
self,
|
||||
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
query_prompt_template: Optional[str] = None,
|
||||
) -> list[PromptMessage]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
prompt_messages = []
|
||||
@@ -48,7 +52,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
elif model_mode == ModelMode.CHAT:
|
||||
prompt_messages = self._get_chat_model_prompt_messages(
|
||||
@@ -60,20 +64,22 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_completion_model_prompt_messages(self,
|
||||
prompt_template: CompletionModelPromptTemplate,
|
||||
inputs: dict,
|
||||
query: Optional[str],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||
def _get_completion_model_prompt_messages(
|
||||
self,
|
||||
prompt_template: CompletionModelPromptTemplate,
|
||||
inputs: dict,
|
||||
query: Optional[str],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Get completion model prompt messages.
|
||||
"""
|
||||
@@ -81,7 +87,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
if prompt_template.edition_type == 'basic' or not prompt_template.edition_type:
|
||||
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
@@ -96,15 +102,13 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
role_prefix=role_prefix,
|
||||
prompt_template=prompt_template,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if query:
|
||||
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
prompt = prompt_template.format(prompt_inputs)
|
||||
else:
|
||||
prompt = raw_prompt
|
||||
prompt_inputs = inputs
|
||||
@@ -122,16 +126,18 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_chat_model_prompt_messages(self,
|
||||
prompt_template: list[ChatModelMessage],
|
||||
inputs: dict,
|
||||
query: Optional[str],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
|
||||
def _get_chat_model_prompt_messages(
|
||||
self,
|
||||
prompt_template: list[ChatModelMessage],
|
||||
inputs: dict,
|
||||
query: Optional[str],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
query_prompt_template: Optional[str] = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Get chat model prompt messages.
|
||||
"""
|
||||
@@ -142,22 +148,20 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
for prompt_item in raw_prompt_list:
|
||||
raw_prompt = prompt_item.text
|
||||
|
||||
if prompt_item.edition_type == 'basic' or not prompt_item.edition_type:
|
||||
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
elif prompt_item.edition_type == 'jinja2':
|
||||
prompt = prompt_template.format(prompt_inputs)
|
||||
elif prompt_item.edition_type == "jinja2":
|
||||
prompt = raw_prompt
|
||||
prompt_inputs = inputs
|
||||
|
||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||
else:
|
||||
raise ValueError(f'Invalid edition type: {prompt_item.edition_type}')
|
||||
raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
|
||||
|
||||
if prompt_item.role == PromptMessageRole.USER:
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
@@ -168,17 +172,14 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
if query and query_prompt_template:
|
||||
prompt_template = PromptTemplateParser(
|
||||
template=query_prompt_template,
|
||||
with_variable_tmpl=self.with_variable_tmpl
|
||||
template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
|
||||
)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
prompt_inputs['#sys.query#'] = query
|
||||
prompt_inputs["#sys.query#"] = query
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
query = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
query = prompt_template.format(prompt_inputs)
|
||||
|
||||
if memory and memory_config:
|
||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||
@@ -203,7 +204,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
last_message.content = prompt_message_contents
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data='')] # not for query
|
||||
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
|
||||
@@ -220,38 +221,39 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
return prompt_messages
|
||||
|
||||
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||
if '#context#' in prompt_template.variable_keys:
|
||||
if "#context#" in prompt_template.variable_keys:
|
||||
if context:
|
||||
prompt_inputs['#context#'] = context
|
||||
prompt_inputs["#context#"] = context
|
||||
else:
|
||||
prompt_inputs['#context#'] = ''
|
||||
prompt_inputs["#context#"] = ""
|
||||
|
||||
return prompt_inputs
|
||||
|
||||
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||
if '#query#' in prompt_template.variable_keys:
|
||||
if "#query#" in prompt_template.variable_keys:
|
||||
if query:
|
||||
prompt_inputs['#query#'] = query
|
||||
prompt_inputs["#query#"] = query
|
||||
else:
|
||||
prompt_inputs['#query#'] = ''
|
||||
prompt_inputs["#query#"] = ""
|
||||
|
||||
return prompt_inputs
|
||||
|
||||
def _set_histories_variable(self, memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
raw_prompt: str,
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
prompt_template: PromptTemplateParser,
|
||||
prompt_inputs: dict,
|
||||
model_config: ModelConfigWithCredentialsEntity) -> dict:
|
||||
if '#histories#' in prompt_template.variable_keys:
|
||||
def _set_histories_variable(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
raw_prompt: str,
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
prompt_template: PromptTemplateParser,
|
||||
prompt_inputs: dict,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> dict:
|
||||
if "#histories#" in prompt_template.variable_keys:
|
||||
if memory:
|
||||
inputs = {'#histories#': '', **prompt_inputs}
|
||||
inputs = {"#histories#": "", **prompt_inputs}
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
tmp_human_message = UserPromptMessage(
|
||||
content=prompt_template.format(prompt_inputs)
|
||||
)
|
||||
tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs))
|
||||
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
|
||||
@@ -260,10 +262,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config=memory_config,
|
||||
max_token_limit=rest_tokens,
|
||||
human_prefix=role_prefix.user,
|
||||
ai_prefix=role_prefix.assistant
|
||||
ai_prefix=role_prefix.assistant,
|
||||
)
|
||||
prompt_inputs['#histories#'] = histories
|
||||
prompt_inputs["#histories#"] = histories
|
||||
else:
|
||||
prompt_inputs['#histories#'] = ''
|
||||
prompt_inputs["#histories#"] = ""
|
||||
|
||||
return prompt_inputs
|
||||
|
||||
@@ -17,12 +17,14 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||
"""
|
||||
History Prompt Transform for Agent App
|
||||
"""
|
||||
def __init__(self,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_messages: list[PromptMessage],
|
||||
history_messages: list[PromptMessage],
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_messages: list[PromptMessage],
|
||||
history_messages: list[PromptMessage],
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.prompt_messages = prompt_messages
|
||||
self.history_messages = history_messages
|
||||
@@ -45,9 +47,7 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.memory.model_instance.model,
|
||||
self.memory.model_instance.credentials,
|
||||
self.history_messages
|
||||
self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages
|
||||
)
|
||||
if curr_message_tokens <= max_token_limit:
|
||||
return self.history_messages
|
||||
@@ -63,9 +63,7 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||
# a message is start with UserPromptMessage
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.memory.model_instance.model,
|
||||
self.memory.model_instance.credentials,
|
||||
prompt_messages
|
||||
self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages
|
||||
)
|
||||
# if current message token is overflow, drop all the prompts in current message and break
|
||||
if curr_message_tokens > max_token_limit:
|
||||
|
||||
@@ -9,27 +9,31 @@ class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
Chat Message.
|
||||
"""
|
||||
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
edition_type: Optional[Literal['basic', 'jinja2']] = None
|
||||
edition_type: Optional[Literal["basic", "jinja2"]] = None
|
||||
|
||||
|
||||
class CompletionModelPromptTemplate(BaseModel):
|
||||
"""
|
||||
Completion Model Prompt Template.
|
||||
"""
|
||||
|
||||
text: str
|
||||
edition_type: Optional[Literal['basic', 'jinja2']] = None
|
||||
edition_type: Optional[Literal["basic", "jinja2"]] = None
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""
|
||||
Memory Config.
|
||||
"""
|
||||
|
||||
class RolePrefix(BaseModel):
|
||||
"""
|
||||
Role Prefix.
|
||||
"""
|
||||
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
@@ -37,6 +41,7 @@ class MemoryConfig(BaseModel):
|
||||
"""
|
||||
Window Config.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
size: Optional[int] = None
|
||||
|
||||
|
||||
@@ -7,39 +7,18 @@ CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: "
|
||||
},
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "Human",
|
||||
"assistant_prefix": "Assistant"
|
||||
}
|
||||
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||
},
|
||||
"stop": ["Human:"]
|
||||
"stop": ["Human:"],
|
||||
}
|
||||
|
||||
CHAT_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {
|
||||
"prompt": [{
|
||||
"role": "system",
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}]
|
||||
}
|
||||
}
|
||||
CHAT_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}}
|
||||
|
||||
COMPLETION_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {
|
||||
"prompt": [{
|
||||
"role": "user",
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}]
|
||||
}
|
||||
}
|
||||
COMPLETION_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}}
|
||||
|
||||
COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}
|
||||
},
|
||||
"stop": ["Human:"]
|
||||
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
|
||||
"stop": ["Human:"],
|
||||
}
|
||||
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
@@ -47,37 +26,20 @@ BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}"
|
||||
},
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "用户",
|
||||
"assistant_prefix": "助手"
|
||||
}
|
||||
"conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"},
|
||||
},
|
||||
"stop": ["用户:"]
|
||||
"stop": ["用户:"],
|
||||
}
|
||||
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {
|
||||
"prompt": [{
|
||||
"role": "system",
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}]
|
||||
}
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}
|
||||
}
|
||||
|
||||
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {
|
||||
"prompt": [{
|
||||
"role": "user",
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}]
|
||||
}
|
||||
"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}
|
||||
}
|
||||
|
||||
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}
|
||||
},
|
||||
"stop": ["用户:"]
|
||||
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
|
||||
"stop": ["用户:"],
|
||||
}
|
||||
|
||||
@@ -9,75 +9,78 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
|
||||
|
||||
class PromptTransform:
|
||||
def _append_chat_histories(self, memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||
def _append_chat_histories(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> list[PromptMessage]:
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _calculate_rest_token(self, prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> int:
|
||||
def _calculate_rest_token(
|
||||
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
)
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _get_history_messages_from_memory(self, memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
max_token_limit: int,
|
||||
human_prefix: Optional[str] = None,
|
||||
ai_prefix: Optional[str] = None) -> str:
|
||||
def _get_history_messages_from_memory(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
max_token_limit: int,
|
||||
human_prefix: Optional[str] = None,
|
||||
ai_prefix: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Get memory messages."""
|
||||
kwargs = {
|
||||
"max_token_limit": max_token_limit
|
||||
}
|
||||
kwargs = {"max_token_limit": max_token_limit}
|
||||
|
||||
if human_prefix:
|
||||
kwargs['human_prefix'] = human_prefix
|
||||
kwargs["human_prefix"] = human_prefix
|
||||
|
||||
if ai_prefix:
|
||||
kwargs['ai_prefix'] = ai_prefix
|
||||
kwargs["ai_prefix"] = ai_prefix
|
||||
|
||||
if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
|
||||
kwargs['message_limit'] = memory_config.window.size
|
||||
kwargs["message_limit"] = memory_config.window.size
|
||||
|
||||
return memory.get_history_prompt_text(
|
||||
**kwargs
|
||||
)
|
||||
return memory.get_history_prompt_text(**kwargs)
|
||||
|
||||
def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
max_token_limit: int) -> list[PromptMessage]:
|
||||
def _get_history_messages_list_from_memory(
|
||||
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
) -> list[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
return memory.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=memory_config.window.size
|
||||
if (memory_config.window.enabled
|
||||
and memory_config.window.size is not None
|
||||
and memory_config.window.size > 0)
|
||||
else None
|
||||
if (
|
||||
memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0
|
||||
)
|
||||
else None,
|
||||
)
|
||||
|
||||
@@ -22,11 +22,11 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ModelMode(enum.Enum):
|
||||
COMPLETION = 'completion'
|
||||
CHAT = 'chat'
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'ModelMode':
|
||||
def value_of(cls, value: str) -> "ModelMode":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@@ -36,7 +36,7 @@ class ModelMode(enum.Enum):
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
prompt_file_contents = {}
|
||||
@@ -47,16 +47,17 @@ class SimplePromptTransform(PromptTransform):
|
||||
Simple Prompt Transform for Chatbot App Basic Mode.
|
||||
"""
|
||||
|
||||
def get_prompt(self,
|
||||
app_mode: AppMode,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: list["FileVar"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> \
|
||||
tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
def get_prompt(
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: list["FileVar"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
@@ -69,7 +70,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
||||
@@ -80,19 +81,21 @@ class SimplePromptTransform(PromptTransform):
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
return prompt_messages, stops
|
||||
|
||||
def get_prompt_str_and_rules(self, app_mode: AppMode,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
histories: Optional[str] = None,
|
||||
) -> tuple[str, dict]:
|
||||
def get_prompt_str_and_rules(
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
histories: Optional[str] = None,
|
||||
) -> tuple[str, dict]:
|
||||
# get prompt template
|
||||
prompt_template_config = self.get_prompt_template(
|
||||
app_mode=app_mode,
|
||||
@@ -101,74 +104,75 @@ class SimplePromptTransform(PromptTransform):
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=context is not None,
|
||||
query_in_prompt=query is not None,
|
||||
with_memory_prompt=histories is not None
|
||||
with_memory_prompt=histories is not None,
|
||||
)
|
||||
|
||||
variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs}
|
||||
variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs}
|
||||
|
||||
for v in prompt_template_config['special_variable_keys']:
|
||||
for v in prompt_template_config["special_variable_keys"]:
|
||||
# support #context#, #query# and #histories#
|
||||
if v == '#context#':
|
||||
variables['#context#'] = context if context else ''
|
||||
elif v == '#query#':
|
||||
variables['#query#'] = query if query else ''
|
||||
elif v == '#histories#':
|
||||
variables['#histories#'] = histories if histories else ''
|
||||
if v == "#context#":
|
||||
variables["#context#"] = context if context else ""
|
||||
elif v == "#query#":
|
||||
variables["#query#"] = query if query else ""
|
||||
elif v == "#histories#":
|
||||
variables["#histories#"] = histories if histories else ""
|
||||
|
||||
prompt_template = prompt_template_config['prompt_template']
|
||||
prompt_template = prompt_template_config["prompt_template"]
|
||||
prompt = prompt_template.format(variables)
|
||||
|
||||
return prompt, prompt_template_config['prompt_rules']
|
||||
return prompt, prompt_template_config["prompt_rules"]
|
||||
|
||||
def get_prompt_template(self, app_mode: AppMode,
|
||||
provider: str,
|
||||
model: str,
|
||||
pre_prompt: str,
|
||||
has_context: bool,
|
||||
query_in_prompt: bool,
|
||||
with_memory_prompt: bool = False) -> dict:
|
||||
prompt_rules = self._get_prompt_rule(
|
||||
app_mode=app_mode,
|
||||
provider=provider,
|
||||
model=model
|
||||
)
|
||||
def get_prompt_template(
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
provider: str,
|
||||
model: str,
|
||||
pre_prompt: str,
|
||||
has_context: bool,
|
||||
query_in_prompt: bool,
|
||||
with_memory_prompt: bool = False,
|
||||
) -> dict:
|
||||
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
|
||||
|
||||
custom_variable_keys = []
|
||||
special_variable_keys = []
|
||||
|
||||
prompt = ''
|
||||
for order in prompt_rules['system_prompt_orders']:
|
||||
if order == 'context_prompt' and has_context:
|
||||
prompt += prompt_rules['context_prompt']
|
||||
special_variable_keys.append('#context#')
|
||||
elif order == 'pre_prompt' and pre_prompt:
|
||||
prompt += pre_prompt + '\n'
|
||||
prompt = ""
|
||||
for order in prompt_rules["system_prompt_orders"]:
|
||||
if order == "context_prompt" and has_context:
|
||||
prompt += prompt_rules["context_prompt"]
|
||||
special_variable_keys.append("#context#")
|
||||
elif order == "pre_prompt" and pre_prompt:
|
||||
prompt += pre_prompt + "\n"
|
||||
pre_prompt_template = PromptTemplateParser(template=pre_prompt)
|
||||
custom_variable_keys = pre_prompt_template.variable_keys
|
||||
elif order == 'histories_prompt' and with_memory_prompt:
|
||||
prompt += prompt_rules['histories_prompt']
|
||||
special_variable_keys.append('#histories#')
|
||||
elif order == "histories_prompt" and with_memory_prompt:
|
||||
prompt += prompt_rules["histories_prompt"]
|
||||
special_variable_keys.append("#histories#")
|
||||
|
||||
if query_in_prompt:
|
||||
prompt += prompt_rules.get('query_prompt', '{{#query#}}')
|
||||
special_variable_keys.append('#query#')
|
||||
prompt += prompt_rules.get("query_prompt", "{{#query#}}")
|
||||
special_variable_keys.append("#query#")
|
||||
|
||||
return {
|
||||
"prompt_template": PromptTemplateParser(template=prompt),
|
||||
"custom_variable_keys": custom_variable_keys,
|
||||
"special_variable_keys": special_variable_keys,
|
||||
"prompt_rules": prompt_rules
|
||||
"prompt_rules": prompt_rules,
|
||||
}
|
||||
|
||||
def _get_chat_model_prompt_messages(self, app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list["FileVar"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
def _get_chat_model_prompt_messages(
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list["FileVar"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
prompt_messages = []
|
||||
|
||||
# get prompt
|
||||
@@ -178,7 +182,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
context=context
|
||||
context=context,
|
||||
)
|
||||
|
||||
if prompt and query:
|
||||
@@ -193,7 +197,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
)
|
||||
),
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if query:
|
||||
@@ -203,15 +207,17 @@ class SimplePromptTransform(PromptTransform):
|
||||
|
||||
return prompt_messages, None
|
||||
|
||||
def _get_completion_model_prompt_messages(self, app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list["FileVar"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
def _get_completion_model_prompt_messages(
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list["FileVar"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
# get prompt
|
||||
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
||||
app_mode=app_mode,
|
||||
@@ -219,13 +225,11 @@ class SimplePromptTransform(PromptTransform):
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
context=context
|
||||
context=context,
|
||||
)
|
||||
|
||||
if memory:
|
||||
tmp_human_message = UserPromptMessage(
|
||||
content=prompt
|
||||
)
|
||||
tmp_human_message = UserPromptMessage(content=prompt)
|
||||
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
histories = self._get_history_messages_from_memory(
|
||||
@@ -236,8 +240,8 @@ class SimplePromptTransform(PromptTransform):
|
||||
)
|
||||
),
|
||||
max_token_limit=rest_tokens,
|
||||
human_prefix=prompt_rules.get('human_prefix', 'Human'),
|
||||
ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant')
|
||||
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
||||
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
|
||||
)
|
||||
|
||||
# get prompt
|
||||
@@ -248,10 +252,10 @@ class SimplePromptTransform(PromptTransform):
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
context=context,
|
||||
histories=histories
|
||||
histories=histories,
|
||||
)
|
||||
|
||||
stops = prompt_rules.get('stops')
|
||||
stops = prompt_rules.get("stops")
|
||||
if stops is not None and len(stops) == 0:
|
||||
stops = None
|
||||
|
||||
@@ -277,22 +281,18 @@ class SimplePromptTransform(PromptTransform):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
prompt_file_name = self._prompt_file_name(
|
||||
app_mode=app_mode,
|
||||
provider=provider,
|
||||
model=model
|
||||
)
|
||||
prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model)
|
||||
|
||||
# Check if the prompt file is already loaded
|
||||
if prompt_file_name in prompt_file_contents:
|
||||
return prompt_file_contents[prompt_file_name]
|
||||
|
||||
# Get the absolute path of the subdirectory
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates')
|
||||
json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json')
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
|
||||
json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json")
|
||||
|
||||
# Open the JSON file and read its content
|
||||
with open(json_file_path, encoding='utf-8') as json_file:
|
||||
with open(json_file_path, encoding="utf-8") as json_file:
|
||||
content = json.load(json_file)
|
||||
|
||||
# Store the content of the prompt file
|
||||
@@ -303,21 +303,21 @@ class SimplePromptTransform(PromptTransform):
|
||||
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
|
||||
# baichuan
|
||||
is_baichuan = False
|
||||
if provider == 'baichuan':
|
||||
if provider == "baichuan":
|
||||
is_baichuan = True
|
||||
else:
|
||||
baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]
|
||||
if provider in baichuan_supported_providers and 'baichuan' in model.lower():
|
||||
if provider in baichuan_supported_providers and "baichuan" in model.lower():
|
||||
is_baichuan = True
|
||||
|
||||
if is_baichuan:
|
||||
if app_mode == AppMode.COMPLETION:
|
||||
return 'baichuan_completion'
|
||||
return "baichuan_completion"
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
return "baichuan_chat"
|
||||
|
||||
# common
|
||||
if app_mode == AppMode.COMPLETION:
|
||||
return 'common_completion'
|
||||
return "common_completion"
|
||||
else:
|
||||
return 'common_chat'
|
||||
return "common_chat"
|
||||
|
||||
@@ -25,26 +25,29 @@ class PromptMessageUtil:
|
||||
tool_calls = []
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
role = 'user'
|
||||
role = "user"
|
||||
elif prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||
role = 'assistant'
|
||||
role = "assistant"
|
||||
if isinstance(prompt_message, AssistantPromptMessage):
|
||||
tool_calls = [{
|
||||
'id': tool_call.id,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': tool_call.function.name,
|
||||
'arguments': tool_call.function.arguments,
|
||||
tool_calls = [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
} for tool_call in prompt_message.tool_calls]
|
||||
for tool_call in prompt_message.tool_calls
|
||||
]
|
||||
elif prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
role = 'system'
|
||||
role = "system"
|
||||
elif prompt_message.role == PromptMessageRole.TOOL:
|
||||
role = 'tool'
|
||||
role = "tool"
|
||||
else:
|
||||
continue
|
||||
|
||||
text = ''
|
||||
text = ""
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
@@ -53,27 +56,25 @@ class PromptMessageUtil:
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append({
|
||||
"type": 'image',
|
||||
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
|
||||
"detail": content.detail.value
|
||||
})
|
||||
files.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"detail": content.detail.value,
|
||||
}
|
||||
)
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
prompt = {
|
||||
"role": role,
|
||||
"text": text,
|
||||
"files": files
|
||||
}
|
||||
|
||||
prompt = {"role": role, "text": text, "files": files}
|
||||
|
||||
if tool_calls:
|
||||
prompt['tool_calls'] = tool_calls
|
||||
prompt["tool_calls"] = tool_calls
|
||||
|
||||
prompts.append(prompt)
|
||||
else:
|
||||
prompt_message = prompt_messages[0]
|
||||
text = ''
|
||||
text = ""
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
@@ -82,21 +83,23 @@ class PromptMessageUtil:
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append({
|
||||
"type": 'image',
|
||||
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
|
||||
"detail": content.detail.value
|
||||
})
|
||||
files.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"detail": content.detail.value,
|
||||
}
|
||||
)
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
params = {
|
||||
"role": 'user',
|
||||
"role": "user",
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if files:
|
||||
params['files'] = files
|
||||
params["files"] = files
|
||||
|
||||
prompts.append(params)
|
||||
|
||||
|
||||
@@ -38,8 +38,8 @@ class PromptTemplateParser:
|
||||
return value
|
||||
|
||||
prompt = re.sub(self.regex, replacer, self.template)
|
||||
return re.sub(r'<\|.*?\|>', '', prompt)
|
||||
return re.sub(r"<\|.*?\|>", "", prompt)
|
||||
|
||||
@classmethod
|
||||
def remove_template_variables(cls, text: str, with_variable_tmpl: bool = False):
|
||||
return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r'{\1}', text)
|
||||
return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r"{\1}", text)
|
||||
|
||||
Reference in New Issue
Block a user