feat: query prompt template support in chatflow (#3791)

Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
takatost
2024-04-25 18:01:53 +08:00
committed by GitHub
parent 80b9507e7a
commit 12435774ca
12 changed files with 113 additions and 18 deletions

View File

@@ -74,6 +74,7 @@ class LLMNode(BaseNode):
node_data=node_data,
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs,
files=files,
context=context,
@@ -209,6 +210,17 @@ class LLMNode(BaseNode):
inputs[variable_selector.variable] = variable_value
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
.extract_variable_selectors())
for variable_selector in query_variable_selectors:
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
inputs[variable_selector.variable] = variable_value
return inputs
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
@@ -302,7 +314,8 @@ class LLMNode(BaseNode):
return None
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data_model: node data model
@@ -407,6 +420,7 @@ class LLMNode(BaseNode):
def _fetch_prompt_messages(self, node_data: LLMNodeData,
query: Optional[str],
query_prompt_template: Optional[str],
inputs: dict[str, str],
files: list[FileVar],
context: Optional[str],
@@ -417,6 +431,7 @@ class LLMNode(BaseNode):
Fetch prompt messages
:param node_data: node data
:param query: query
:param query_prompt_template: query prompt template
:param inputs: inputs
:param files: files
:param context: context
@@ -433,7 +448,8 @@ class LLMNode(BaseNode):
context=context,
memory_config=node_data.memory,
memory=memory,
model_config=model_config
model_config=model_config,
query_prompt_template=query_prompt_template,
)
stop = model_config.stop
@@ -539,6 +555,13 @@ class LLMNode(BaseNode):
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
.extract_variable_selectors())
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled:
variable_mapping['#context#'] = node_data.context.variable_selector