mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
feat: support LLM jinja2 template prompt (#3968)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import Optional, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file.file_obj import FileVar
|
||||
from core.helper.code_executor.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@@ -80,29 +81,35 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
if prompt_template.edition_type == 'basic' or not prompt_template.edition_type:
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
if memory and memory_config:
|
||||
role_prefix = memory_config.role_prefix
|
||||
prompt_inputs = self._set_histories_variable(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
raw_prompt=raw_prompt,
|
||||
role_prefix=role_prefix,
|
||||
prompt_template=prompt_template,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config
|
||||
if memory and memory_config:
|
||||
role_prefix = memory_config.role_prefix
|
||||
prompt_inputs = self._set_histories_variable(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
raw_prompt=raw_prompt,
|
||||
role_prefix=role_prefix,
|
||||
prompt_template=prompt_template,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
if query:
|
||||
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
else:
|
||||
prompt = raw_prompt
|
||||
prompt_inputs = inputs
|
||||
|
||||
if query:
|
||||
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||
@@ -135,14 +142,22 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
for prompt_item in raw_prompt_list:
|
||||
raw_prompt = prompt_item.text
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
if prompt_item.edition_type == 'basic' or not prompt_item.edition_type:
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
elif prompt_item.edition_type == 'jinja2':
|
||||
prompt = raw_prompt
|
||||
prompt_inputs = inputs
|
||||
|
||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||
else:
|
||||
raise ValueError(f'Invalid edition type: {prompt_item.edition_type}')
|
||||
|
||||
if prompt_item.role == PromptMessageRole.USER:
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -11,6 +11,7 @@ class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
edition_type: Optional[Literal['basic', 'jinja2']]
|
||||
|
||||
|
||||
class CompletionModelPromptTemplate(BaseModel):
|
||||
@@ -18,6 +19,7 @@ class CompletionModelPromptTemplate(BaseModel):
|
||||
Completion Model Prompt Template.
|
||||
"""
|
||||
text: str
|
||||
edition_type: Optional[Literal['basic', 'jinja2']]
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user