feat: support LLM jinja2 template prompt (#3968)

Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
Yeuoly
2024-05-10 18:08:32 +08:00
committed by GitHub
parent 897e07f639
commit 8578ee0864
8 changed files with 325 additions and 41 deletions

View File

@@ -4,6 +4,7 @@ from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class ModelConfig(BaseModel):
@@ -37,13 +38,31 @@ class VisionConfig(BaseModel):
enabled: bool
configs: Optional[Configs] = None
class PromptConfig(BaseModel):
"""
Prompt Config.
"""
jinja2_variables: Optional[list[VariableSelector]] = None
class LLMNodeChatModelMessage(ChatModelMessage):
"""
LLM Node Chat Model Message.
"""
jinja2_text: Optional[str] = None
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
"""
LLM Node Chat Model Prompt Template.
"""
jinja2_text: Optional[str] = None
class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
model: ModelConfig
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
prompt_config: Optional[PromptConfig] = None
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig

View File

@@ -1,4 +1,6 @@
import json
from collections.abc import Generator
from copy import deepcopy
from typing import Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@@ -17,11 +19,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
@@ -39,16 +45,24 @@ class LLMNode(BaseNode):
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
node_data = cast(LLMNodeData, deepcopy(self.node_data))
node_inputs = None
process_data = None
try:
# init messages template
node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool)
# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
# merge inputs
inputs.update(jinja_inputs)
node_inputs = {}
# fetch files
@@ -183,6 +197,86 @@ class LLMNode(BaseNode):
usage = LLMUsage.empty_usage()
return full_text, usage
def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
Transform chat messages
:param messages: chat messages
:return:
"""
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == 'jinja2':
messages.text = messages.jinja2_text
return messages
for message in messages:
if message.edition_type == 'jinja2':
message.text = message.jinja2_text
return messages
def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch jinja inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
variables = {}
if not node_data.prompt_config:
return variables
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable = variable_selector.variable
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
def parse_dict(d: dict) -> str:
"""
Parse dict into string
"""
# check if it's a context structure
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
return d['content']
# else, parse the dict
try:
return json.dumps(d, ensure_ascii=False)
except Exception:
return str(d)
if isinstance(value, str):
value = value
elif isinstance(value, list):
result = ''
for item in value:
if isinstance(item, dict):
result += parse_dict(item)
elif isinstance(item, str):
result += item
elif isinstance(item, int | float):
result += str(item)
else:
result += str(item)
result += '\n'
value = result.strip()
elif isinstance(value, dict):
value = parse_dict(value)
elif isinstance(value, int | float):
value = str(value)
else:
value = str(value)
variables[variable] = value
return variables
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
@@ -531,25 +625,25 @@ class LLMNode(BaseNode):
db.session.commit()
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
if prompt.edition_type != 'jinja2':
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
else:
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
if prompt_template.edition_type != 'jinja2':
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}
for variable_selector in variable_selectors:
@@ -571,6 +665,22 @@ class LLMNode(BaseNode):
if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
if node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
for prompt in prompt_template:
if prompt.edition_type == 'jinja2':
enable_jinja = True
break
else:
if prompt_template.edition_type == 'jinja2':
enable_jinja = True
if enable_jinja:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
return variable_mapping
@classmethod
@@ -588,7 +698,8 @@ class LLMNode(BaseNode):
"prompts": [
{
"role": "system",
"text": "You are a helpful AI assistant."
"text": "You are a helpful AI assistant.",
"edition_type": "basic"
}
]
},
@@ -600,7 +711,8 @@ class LLMNode(BaseNode):
"prompt": {
"text": "Here is the chat histories between human and assistant, inside "
"<histories></histories> XML tags.\n\n<histories>\n{{"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
"edition_type": "basic"
},
"stop": ["Human:"]
}