mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 10:56:52 +08:00
Feat: conversation variable & variable assigner node (#7222)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -23,10 +23,12 @@ class NodeType(Enum):
|
||||
HTTP_REQUEST = 'http-request'
|
||||
TOOL = 'tool'
|
||||
VARIABLE_AGGREGATOR = 'variable-aggregator'
|
||||
# TODO: merge this into VARIABLE_AGGREGATOR
|
||||
VARIABLE_ASSIGNER = 'variable-assigner'
|
||||
LOOP = 'loop'
|
||||
ITERATION = 'iteration'
|
||||
PARAMETER_EXTRACTOR = 'parameter-extractor'
|
||||
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'NodeType':
|
||||
|
||||
@@ -13,6 +13,7 @@ VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
|
||||
SYSTEM_VARIABLE_NODE_ID = 'sys'
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
|
||||
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
|
||||
|
||||
|
||||
class VariablePool:
|
||||
@@ -21,6 +22,7 @@ class VariablePool:
|
||||
system_variables: Mapping[SystemVariable, Any],
|
||||
user_inputs: Mapping[str, Any],
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable] | None = None,
|
||||
) -> None:
|
||||
# system variables
|
||||
# for example:
|
||||
@@ -44,9 +46,13 @@ class VariablePool:
|
||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
||||
|
||||
# Add environment variables to the variable pool
|
||||
for var in environment_variables or []:
|
||||
for var in environment_variables:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
# Add conversation variables to the variable pool
|
||||
for var in conversation_variables or []:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
Adds a variable to the variable pool.
|
||||
|
||||
@@ -8,6 +8,7 @@ from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
@@ -91,14 +92,19 @@ class BaseNode(ABC):
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
result = self._run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
try:
|
||||
result = self._run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
self.node_run_result = result
|
||||
return result
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
self.node_run_result = result
|
||||
return result
|
||||
|
||||
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
|
||||
def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
:param text: chunk text
|
||||
|
||||
109
api/core/workflow/nodes/variable_assigner/__init__.py
Normal file
109
api/core/workflow/nodes/variable_assigner/__init__.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.segments import SegmentType, Variable, factory
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class VariableAssignerNodeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WriteMode(str, Enum):
|
||||
OVER_WRITE = 'over-write'
|
||||
APPEND = 'append'
|
||||
CLEAR = 'clear'
|
||||
|
||||
|
||||
class VariableAssignerData(BaseNodeData):
|
||||
title: str = 'Variable Assigner'
|
||||
desc: Optional[str] = 'Assign a value to a variable'
|
||||
assigned_variable_selector: Sequence[str]
|
||||
write_mode: WriteMode
|
||||
input_variable_selector: Sequence[str]
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_data_cls: type[BaseNodeData] = VariableAssignerData
|
||||
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
data = cast(VariableAssignerData, self.node_data)
|
||||
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = variable_pool.get(data.assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableAssignerNodeError('assigned variable not found')
|
||||
|
||||
match data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
updated_variable = original_variable.model_copy(update={'value': updated_value})
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = get_zero_value(original_variable.value_type)
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
|
||||
|
||||
# Over write the variable.
|
||||
variable_pool.add(data.assigned_variable_selector, updated_variable)
|
||||
|
||||
# Update conversation variable.
|
||||
# TODO: Find a better way to use the database.
|
||||
conversation_id = variable_pool.get(['sys', 'conversation_id'])
|
||||
if not conversation_id:
|
||||
raise VariableAssignerNodeError('conversation_id not found')
|
||||
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
'value': income_value.to_object(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableAssignerNodeError('conversation variable not found in the database')
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_zero_value(t: SegmentType):
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return factory.build_segment([])
|
||||
case SegmentType.OBJECT:
|
||||
return factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return factory.build_segment('')
|
||||
case SegmentType.NUMBER:
|
||||
return factory.build_segment(0)
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f'unsupported variable type: {t}')
|
||||
@@ -4,12 +4,11 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@@ -30,6 +29,7 @@ from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
@@ -51,7 +51,8 @@ node_classes: Mapping[NodeType, type[BaseNode]] = {
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode,
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -94,10 +95,9 @@ class WorkflowEngineManager:
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
user_inputs: Mapping[str, Any],
|
||||
system_inputs: Mapping[SystemVariable, Any],
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
call_depth: int = 0
|
||||
call_depth: int = 0,
|
||||
variable_pool: VariablePool,
|
||||
) -> None:
|
||||
"""
|
||||
:param workflow: Workflow instance
|
||||
@@ -122,12 +122,6 @@ class WorkflowEngineManager:
|
||||
if not isinstance(graph.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
|
||||
if call_depth > workflow_call_max_depth:
|
||||
@@ -403,6 +397,7 @@ class WorkflowEngineManager:
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=workflow.conversation_variables,
|
||||
)
|
||||
|
||||
if node_cls is None:
|
||||
@@ -468,6 +463,7 @@ class WorkflowEngineManager:
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=workflow.conversation_variables,
|
||||
)
|
||||
|
||||
# variable selector to variable mapping
|
||||
|
||||
Reference in New Issue
Block a user