mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
feat: query prompt template support in chatflow (#3791)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
@@ -31,7 +31,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
prompt_messages = []
|
||||
@@ -53,6 +54,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
query_prompt_template=query_prompt_template,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
@@ -121,7 +123,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Get chat model prompt messages.
|
||||
"""
|
||||
@@ -148,6 +151,20 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
||||
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
||||
|
||||
if query and query_prompt_template:
|
||||
prompt_template = PromptTemplateParser(
|
||||
template=query_prompt_template,
|
||||
with_variable_tmpl=self.with_variable_tmpl
|
||||
)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
prompt_inputs['#sys.query#'] = query
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
query = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
|
||||
if memory and memory_config:
|
||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||
|
||||
|
||||
@@ -40,3 +40,4 @@ class MemoryConfig(BaseModel):
|
||||
|
||||
role_prefix: Optional[RolePrefix] = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: Optional[str] = None
|
||||
|
||||
@@ -74,6 +74,7 @@ class LLMNode(BaseNode):
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
|
||||
if node_data.memory else None,
|
||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
context=context,
|
||||
@@ -209,6 +210,17 @@ class LLMNode(BaseNode):
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
memory = node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
||||
.extract_variable_selectors())
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
return inputs
|
||||
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
|
||||
@@ -302,7 +314,8 @@ class LLMNode(BaseNode):
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
|
||||
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data_model: node data model
|
||||
@@ -407,6 +420,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
def _fetch_prompt_messages(self, node_data: LLMNodeData,
|
||||
query: Optional[str],
|
||||
query_prompt_template: Optional[str],
|
||||
inputs: dict[str, str],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
@@ -417,6 +431,7 @@ class LLMNode(BaseNode):
|
||||
Fetch prompt messages
|
||||
:param node_data: node data
|
||||
:param query: query
|
||||
:param query_prompt_template: query prompt template
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param context: context
|
||||
@@ -433,7 +448,8 @@ class LLMNode(BaseNode):
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
query_prompt_template=query_prompt_template,
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
@@ -539,6 +555,13 @@ class LLMNode(BaseNode):
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
memory = node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
||||
.extract_variable_selectors())
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if node_data.context.enabled:
|
||||
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||
|
||||
|
||||
Reference in New Issue
Block a user