mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-15 13:56:53 +08:00
Feat/environment variables in workflow (#6515)
Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
@@ -19,7 +17,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
class AnswerNode(BaseNode):
|
||||
_node_data_cls = AnswerNodeData
|
||||
node_type = NodeType.ANSWER
|
||||
_node_type: NodeType = NodeType.ANSWER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
@@ -28,7 +26,7 @@ class AnswerNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
# generate routes
|
||||
generate_routes = self.extract_generate_route_from_node_data(node_data)
|
||||
@@ -38,31 +36,9 @@ class AnswerNode(BaseNode):
|
||||
if part.type == "var":
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=value_selector
|
||||
)
|
||||
|
||||
text = ''
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
elif isinstance(value, dict):
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
elif isinstance(value, FileVar):
|
||||
# convert file to markdown
|
||||
text = value.to_markdown()
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, FileVar):
|
||||
text += item.to_markdown() + ' '
|
||||
|
||||
text = text.strip()
|
||||
|
||||
if not text and value:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
answer += text
|
||||
value = variable_pool.get(value_selector)
|
||||
if value:
|
||||
answer += value.markdown
|
||||
else:
|
||||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer += part.text
|
||||
@@ -82,7 +58,7 @@ class AnswerNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@@ -143,7 +119,7 @@ class AnswerNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@@ -46,7 +47,7 @@ class BaseNode(ABC):
|
||||
node_data: BaseNodeData
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
|
||||
callbacks: list[BaseWorkflowCallback]
|
||||
callbacks: Sequence[WorkflowCallback]
|
||||
|
||||
def __init__(self, tenant_id: str,
|
||||
app_id: str,
|
||||
@@ -54,8 +55,8 @@ class BaseNode(ABC):
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
config: dict,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
config: Mapping[str, Any],
|
||||
callbacks: Sequence[WorkflowCallback] | None = None,
|
||||
workflow_call_depth: int = 0) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
@@ -65,7 +66,8 @@ class BaseNode(ABC):
|
||||
self.invoke_from = invoke_from
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
|
||||
self.node_id = config.get("id")
|
||||
# TODO: May need to check if key exists.
|
||||
self.node_id = config["id"]
|
||||
if not self.node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
|
||||
@@ -59,11 +59,8 @@ class CodeNode(BaseNode):
|
||||
variables = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
variables[variable] = value
|
||||
value = variable_pool.get(variable_selector.value_selector)
|
||||
variables[variable] = value.value if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
|
||||
@@ -10,7 +10,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
class EndNode(BaseNode):
|
||||
_node_data_cls = EndNodeData
|
||||
node_type = NodeType.END
|
||||
_node_type = NodeType.END
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
@@ -19,16 +19,13 @@ class EndNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(EndNodeData, node_data)
|
||||
output_variables = node_data.outputs
|
||||
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
outputs[variable_selector.variable] = value
|
||||
value = variable_pool.get(variable_selector.value_selector)
|
||||
outputs[variable_selector.variable] = value.value if value else None
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -45,7 +42,7 @@ class EndNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
node_data = cast(EndNodeData, node_data)
|
||||
|
||||
return cls.extract_generate_nodes_from_node_data(graph, node_data)
|
||||
|
||||
@@ -57,7 +54,7 @@ class EndNode(BaseNode):
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
nodes = graph.get('nodes', [])
|
||||
node_mapping = {node.get('id'): node for node in nodes}
|
||||
|
||||
variable_selectors = node_data.outputs
|
||||
|
||||
@@ -9,7 +9,7 @@ import httpx
|
||||
import core.helper.ssrf_proxy as ssrf_proxy
|
||||
from configs import dify_config
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeBody,
|
||||
@@ -212,13 +212,11 @@ class HttpExecutor:
|
||||
raise ValueError('self.authorization config is required')
|
||||
if authorization.config is None:
|
||||
raise ValueError('authorization config is required')
|
||||
if authorization.config.type != 'bearer' and authorization.config.header is None:
|
||||
raise ValueError('authorization config header is required')
|
||||
|
||||
if self.authorization.config.api_key is None:
|
||||
raise ValueError('api_key is required')
|
||||
|
||||
if not self.authorization.config.header:
|
||||
if not authorization.config.header:
|
||||
authorization.config.header = 'Authorization'
|
||||
|
||||
if self.authorization.config.type == 'bearer':
|
||||
@@ -335,16 +333,13 @@ class HttpExecutor:
|
||||
if variable_pool:
|
||||
variable_value_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector, target_value_type=ValueType.STRING
|
||||
)
|
||||
|
||||
if value is None:
|
||||
variable = variable_pool.get(variable_selector.value_selector)
|
||||
if variable is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
if escape_quotes and isinstance(value, str):
|
||||
value = value.replace('"', '\\"')
|
||||
|
||||
if escape_quotes and isinstance(variable.value, str):
|
||||
value = variable.value.replace('"', '\\"')
|
||||
else:
|
||||
value = variable.value
|
||||
variable_value_mapping[variable_selector.variable] = value
|
||||
|
||||
return variable_template_parser.format(variable_value_mapping), variable_selectors
|
||||
|
||||
@@ -3,6 +3,7 @@ from mimetypes import guess_extension
|
||||
from os import path
|
||||
from typing import cast
|
||||
|
||||
from core.app.segments import parser
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
@@ -51,6 +52,9 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
|
||||
# TODO: Switch to use segment directly
|
||||
if node_data.authorization.config and node_data.authorization.config.api_key:
|
||||
node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text
|
||||
|
||||
# init http executor
|
||||
http_executor = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
@@ -11,7 +12,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
class IfElseNode(BaseNode):
|
||||
_node_data_cls = IfElseNodeData
|
||||
node_type = NodeType.IF_ELSE
|
||||
_node_type = NodeType.IF_ELSE
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
@@ -20,7 +21,7 @@ class IfElseNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(IfElseNodeData, node_data)
|
||||
|
||||
node_inputs = {
|
||||
"conditions": []
|
||||
@@ -138,14 +139,12 @@ class IfElseNode(BaseNode):
|
||||
else:
|
||||
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
|
||||
|
||||
def process_conditions(self, variable_pool: VariablePool, conditions: list[Condition]):
|
||||
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
|
||||
input_conditions = []
|
||||
group_result = []
|
||||
|
||||
for condition in conditions:
|
||||
actual_value = variable_pool.get_variable_value(
|
||||
variable_selector=condition.variable_selector
|
||||
)
|
||||
actual_variable = variable_pool.get_any(condition.variable_selector)
|
||||
|
||||
if condition.value is not None:
|
||||
variable_template_parser = VariableTemplateParser(template=condition.value)
|
||||
@@ -153,9 +152,7 @@ class IfElseNode(BaseNode):
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
if variable_selectors:
|
||||
for variable_selector in variable_selectors:
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
expected_value = variable_template_parser.format({variable_selector.variable: value})
|
||||
else:
|
||||
expected_value = condition.value
|
||||
@@ -165,13 +162,13 @@ class IfElseNode(BaseNode):
|
||||
comparison_operator = condition.comparison_operator
|
||||
input_conditions.append(
|
||||
{
|
||||
"actual_value": actual_value,
|
||||
"actual_value": actual_variable,
|
||||
"expected_value": expected_value,
|
||||
"comparison_operator": comparison_operator
|
||||
}
|
||||
)
|
||||
|
||||
result = self.evaluate_condition(actual_value, expected_value, comparison_operator)
|
||||
result = self.evaluate_condition(actual_variable, expected_value, comparison_operator)
|
||||
group_result.append(result)
|
||||
|
||||
return input_conditions, group_result
|
||||
|
||||
@@ -20,7 +20,8 @@ class IterationNode(BaseIterationNode):
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
iterator = variable_pool.get_variable_value(cast(IterationNodeData, self.node_data).iterator_selector)
|
||||
self.node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_any(self.node_data.iterator_selector)
|
||||
|
||||
if not isinstance(iterator, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
|
||||
@@ -63,15 +64,15 @@ class IterationNode(BaseIterationNode):
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
|
||||
variable_pool.append_variable(self.node_id, ['index'], state.index)
|
||||
variable_pool.add((self.node_id, 'index'), state.index)
|
||||
# get the iterator value
|
||||
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
|
||||
iterator = variable_pool.get_any(node_data.iterator_selector)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return
|
||||
|
||||
if state.index < len(iterator):
|
||||
variable_pool.append_variable(self.node_id, ['item'], iterator[state.index])
|
||||
variable_pool.add((self.node_id, 'item'), iterator[state.index])
|
||||
|
||||
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
@@ -87,7 +88,7 @@ class IterationNode(BaseIterationNode):
|
||||
:return: True if iteration limit is reached, False otherwise
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
|
||||
iterator = variable_pool.get_any(node_data.iterator_selector)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return True
|
||||
@@ -100,9 +101,9 @@ class IterationNode(BaseIterationNode):
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
output_selector = cast(IterationNodeData, self.node_data).output_selector
|
||||
output = variable_pool.get_variable_value(output_selector)
|
||||
output = variable_pool.get_any(output_selector)
|
||||
# clear the output for this iteration
|
||||
variable_pool.append_variable(self.node_id, output_selector[1:], None)
|
||||
variable_pool.remove([self.node_id] + output_selector[1:])
|
||||
state.current_output = output
|
||||
if output is not None:
|
||||
state.outputs.append(output)
|
||||
|
||||
@@ -41,7 +41,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
|
||||
|
||||
# extract variables
|
||||
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
|
||||
variable = variable_pool.get(node_data.query_variable_selector)
|
||||
query = variable.value if variable else None
|
||||
variables = {
|
||||
'query': query
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
_node_data_cls = LLMNodeData
|
||||
node_type = NodeType.LLM
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
@@ -90,7 +90,7 @@ class LLMNode(BaseNode):
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
|
||||
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value])
|
||||
if node_data.memory else None,
|
||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||
inputs=inputs,
|
||||
@@ -238,8 +238,8 @@ class LLMNode(BaseNode):
|
||||
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
value = variable_pool.get_any(
|
||||
variable_selector.value_selector
|
||||
)
|
||||
|
||||
def parse_dict(d: dict) -> str:
|
||||
@@ -302,7 +302,7 @@ class LLMNode(BaseNode):
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
for variable_selector in variable_selectors:
|
||||
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||
variable_value = variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
@@ -313,7 +313,7 @@ class LLMNode(BaseNode):
|
||||
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
||||
.extract_variable_selectors())
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||
variable_value = variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
@@ -331,7 +331,7 @@ class LLMNode(BaseNode):
|
||||
if not node_data.vision.enabled:
|
||||
return []
|
||||
|
||||
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
|
||||
files = variable_pool.get_any(['sys', SystemVariable.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
@@ -350,7 +350,7 @@ class LLMNode(BaseNode):
|
||||
if not node_data.context.variable_selector:
|
||||
return None
|
||||
|
||||
context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
|
||||
context_value = variable_pool.get_any(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
if isinstance(context_value, str):
|
||||
return context_value
|
||||
@@ -496,7 +496,7 @@ class LLMNode(BaseNode):
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION_ID.value])
|
||||
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
|
||||
@@ -71,9 +71,10 @@ class ParameterExtractorNode(LLMNode):
|
||||
Run the node.
|
||||
"""
|
||||
node_data = cast(ParameterExtractorNodeData, self.node_data)
|
||||
query = variable_pool.get_variable_value(node_data.query)
|
||||
if not query:
|
||||
variable = variable_pool.get(node_data.query)
|
||||
if not variable:
|
||||
raise ValueError("Input variable content not found or is empty")
|
||||
query = variable.value
|
||||
|
||||
inputs = {
|
||||
'query': query,
|
||||
@@ -564,7 +565,8 @@ class ParameterExtractorNode(LLMNode):
|
||||
variable_template_parser = VariableTemplateParser(instruction)
|
||||
inputs = {}
|
||||
for selector in variable_template_parser.extract_variable_selectors():
|
||||
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
|
||||
variable = variable_pool.get(selector.value_selector)
|
||||
inputs[selector.variable] = variable.value if variable else None
|
||||
|
||||
return variable_template_parser.format(inputs)
|
||||
|
||||
|
||||
@@ -41,7 +41,8 @@ class QuestionClassifierNode(LLMNode):
|
||||
node_data = cast(QuestionClassifierNodeData, node_data)
|
||||
|
||||
# extract variables
|
||||
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
|
||||
variable = variable_pool.get(node_data.query_variable_selector)
|
||||
query = variable.value if variable else None
|
||||
variables = {
|
||||
'query': query
|
||||
}
|
||||
@@ -294,7 +295,8 @@ class QuestionClassifierNode(LLMNode):
|
||||
variable_template_parser = VariableTemplateParser(template=instruction)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
for variable_selector in variable_selectors:
|
||||
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||
variable = variable_pool.get(variable_selector.value_selector)
|
||||
variable_value = variable.value if variable else None
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
class StartNode(BaseNode):
|
||||
_node_data_cls = StartNodeData
|
||||
node_type = NodeType.START
|
||||
_node_type = NodeType.START
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
@@ -18,7 +18,7 @@ class StartNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
# Get cleaned inputs
|
||||
cleaned_inputs = variable_pool.user_inputs
|
||||
cleaned_inputs = dict(variable_pool.user_inputs)
|
||||
|
||||
for var in variable_pool.system_variables:
|
||||
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
|
||||
|
||||
@@ -44,12 +44,9 @@ class TemplateTransformNode(BaseNode):
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
variables[variable] = value
|
||||
variable_name = variable_selector.variable
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
variables[variable_name] = value
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
|
||||
@@ -29,6 +29,7 @@ class ToolEntity(BaseModel):
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal['mixed', 'variable', 'constant']
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from os import path
|
||||
from typing import Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.segments import parser
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
@@ -20,6 +21,7 @@ class ToolNode(BaseNode):
|
||||
"""
|
||||
Tool Node
|
||||
"""
|
||||
|
||||
_node_data_cls = ToolNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
@@ -50,23 +52,24 @@ class ToolNode(BaseNode):
|
||||
},
|
||||
error=f'Failed to get tool runtime: {str(e)}'
|
||||
)
|
||||
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_parameters(variable_pool, node_data, tool_runtime)
|
||||
tool_parameters = tool_runtime.get_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data)
|
||||
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data, for_log=True)
|
||||
|
||||
try:
|
||||
messages = ToolEngine.workflow_invoke(
|
||||
tool=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
workflow_id=self.workflow_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters,
|
||||
inputs=parameters_for_log,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
@@ -86,21 +89,34 @@ class ToolNode(BaseNode):
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
inputs=parameters
|
||||
inputs=parameters_for_log
|
||||
)
|
||||
|
||||
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData, tool_runtime: Tool) -> dict:
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
tool_parameters: Sequence[ToolParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: ToolNodeData,
|
||||
for_log: bool = False,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate parameters
|
||||
"""
|
||||
tool_parameters = tool_runtime.get_all_runtime_parameters()
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
def fetch_parameter(name: str) -> Optional[ToolParameter]:
|
||||
return next((parameter for parameter in tool_parameters if parameter.name == name), None)
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
||||
|
||||
result = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
parameter = fetch_parameter(parameter_name)
|
||||
parameter = tool_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
continue
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
@@ -108,35 +124,21 @@ class ToolNode(BaseNode):
|
||||
v.to_dict() for v in self._fetch_files(variable_pool)
|
||||
]
|
||||
else:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
|
||||
elif input.type == 'variable':
|
||||
result[parameter_name] = variable_pool.get_variable_value(input.value)
|
||||
elif input.type == 'constant':
|
||||
result[parameter_name] = input.value
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
segment_group = parser.convert_template(
|
||||
template=str(tool_input.value),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
result[parameter_name] = segment_group.log if for_log else segment_group.text
|
||||
|
||||
return result
|
||||
|
||||
def _format_variable_template(self, template: str, variable_pool: VariablePool) -> str:
|
||||
"""
|
||||
Format variable template
|
||||
"""
|
||||
inputs = {}
|
||||
template_parser = VariableTemplateParser(template)
|
||||
for selector in template_parser.extract_variable_selectors():
|
||||
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
|
||||
|
||||
return template_parser.format(inputs)
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
return files
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
# FIXME: ensure this is a ArrayVariable contains FileVariable.
|
||||
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
|
||||
return [file_var.value for file_var in variable.value] if variable else []
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
|
||||
@@ -19,28 +19,27 @@ class VariableAggregatorNode(BaseNode):
|
||||
inputs = {}
|
||||
|
||||
if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
|
||||
for variable in node_data.variables:
|
||||
value = variable_pool.get_variable_value(variable)
|
||||
|
||||
if value is not None:
|
||||
for selector in node_data.variables:
|
||||
variable = variable_pool.get(selector)
|
||||
if variable is not None:
|
||||
outputs = {
|
||||
"output": value
|
||||
"output": variable.value
|
||||
}
|
||||
|
||||
inputs = {
|
||||
'.'.join(variable[1:]): value
|
||||
'.'.join(selector[1:]): variable.value
|
||||
}
|
||||
break
|
||||
else:
|
||||
for group in node_data.advanced_settings.groups:
|
||||
for variable in group.variables:
|
||||
value = variable_pool.get_variable_value(variable)
|
||||
for selector in group.variables:
|
||||
variable = variable_pool.get(selector)
|
||||
|
||||
if value is not None:
|
||||
if variable is not None:
|
||||
outputs[group.group_name] = {
|
||||
'output': value
|
||||
'output': variable.value
|
||||
}
|
||||
inputs['.'.join(variable[1:])] = value
|
||||
inputs['.'.join(selector[1:])] = variable.value
|
||||
break
|
||||
|
||||
return NodeRunResult(
|
||||
|
||||
Reference in New Issue
Block a user