mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-11 03:46:52 +08:00
Feat/environment variables in workflow (#6515)
Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -82,9 +83,9 @@ class NodeRunResult(BaseModel):
|
||||
"""
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[dict] = None # node inputs
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict] = None # process data
|
||||
outputs: Optional[dict] = None # node outputs
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
@@ -1,101 +1,141 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
|
||||
|
||||
class ValueType(Enum):
|
||||
"""
|
||||
Value Type Enum
|
||||
"""
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
FILE = "file"
|
||||
SYSTEM_VARIABLE_NODE_ID = 'sys'
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
|
||||
|
||||
|
||||
class VariablePool:
|
||||
|
||||
def __init__(self, system_variables: dict[SystemVariable, Any],
|
||||
user_inputs: dict) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
system_variables: Mapping[SystemVariable, Any],
|
||||
user_inputs: Mapping[str, Any],
|
||||
environment_variables: Sequence[Variable],
|
||||
) -> None:
|
||||
# system variables
|
||||
# for example:
|
||||
# {
|
||||
# 'query': 'abc',
|
||||
# 'files': []
|
||||
# }
|
||||
self.variables_mapping = {}
|
||||
|
||||
# Varaible dictionary is a dictionary for looking up variables by their selector.
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
self._variable_dictionary: dict[str, dict[int, Variable]] = defaultdict(dict)
|
||||
|
||||
# TODO: This user inputs is not used for pool.
|
||||
self.user_inputs = user_inputs
|
||||
|
||||
# Add system variables to the variable pool
|
||||
self.system_variables = system_variables
|
||||
for system_variable, value in system_variables.items():
|
||||
self.append_variable('sys', [system_variable.value], value)
|
||||
for key, value in system_variables.items():
|
||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
||||
|
||||
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
|
||||
# Add environment variables to the variable pool
|
||||
for var in environment_variables or []:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
Append variable
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list, like: ['result', 'text']
|
||||
:param value: value
|
||||
:return:
|
||||
Adds a variable to the variable pool.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): The selector for the variable.
|
||||
value (VariableValue): The value of the variable.
|
||||
|
||||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if node_id not in self.variables_mapping:
|
||||
self.variables_mapping[node_id] = {}
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
|
||||
variable_key_list_hash = hash(tuple(variable_key_list))
|
||||
if value is None:
|
||||
return
|
||||
|
||||
self.variables_mapping[node_id][variable_key_list_hash] = value
|
||||
if not isinstance(value, Variable):
|
||||
v = factory.build_anonymous_variable(value)
|
||||
else:
|
||||
v = value
|
||||
|
||||
def get_variable_value(self, variable_selector: list[str],
|
||||
target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]:
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self._variable_dictionary[selector[0]][hash_key] = v
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Variable | None:
|
||||
"""
|
||||
Get variable
|
||||
:param variable_selector: include node_id and variables
|
||||
:param target_value_type: target value type
|
||||
:return:
|
||||
Retrieves the value from the variable pool based on the given selector.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): The selector used to identify the variable.
|
||||
|
||||
Returns:
|
||||
Any: The value associated with the given selector.
|
||||
|
||||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if len(variable_selector) < 2:
|
||||
raise ValueError('Invalid value selector')
|
||||
|
||||
node_id = variable_selector[0]
|
||||
if node_id not in self.variables_mapping:
|
||||
return None
|
||||
|
||||
# fetch variable keys, pop node_id
|
||||
variable_key_list = variable_selector[1:]
|
||||
|
||||
variable_key_list_hash = hash(tuple(variable_key_list))
|
||||
|
||||
value = self.variables_mapping[node_id].get(variable_key_list_hash)
|
||||
|
||||
if target_value_type:
|
||||
if target_value_type == ValueType.STRING:
|
||||
return str(value)
|
||||
elif target_value_type == ValueType.NUMBER:
|
||||
return int(value)
|
||||
elif target_value_type == ValueType.OBJECT:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError('Invalid value type: object')
|
||||
elif target_value_type in [ValueType.ARRAY_STRING,
|
||||
ValueType.ARRAY_NUMBER,
|
||||
ValueType.ARRAY_OBJECT,
|
||||
ValueType.ARRAY_FILE]:
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f'Invalid value type: {target_value_type.value}')
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
return value
|
||||
|
||||
def clear_node_variables(self, node_id: str) -> None:
|
||||
@deprecated('This method is deprecated, use `get` instead.')
|
||||
def get_any(self, selector: Sequence[str], /) -> Any | None:
|
||||
"""
|
||||
Clear node variables
|
||||
:param node_id: node id
|
||||
:return:
|
||||
Retrieves the value from the variable pool based on the given selector.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): The selector used to identify the variable.
|
||||
|
||||
Returns:
|
||||
Any: The value associated with the given selector.
|
||||
|
||||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if node_id in self.variables_mapping:
|
||||
self.variables_mapping.pop(node_id)
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, ArrayVariable):
|
||||
return [element.value for element in value.value]
|
||||
if isinstance(value, ObjectVariable):
|
||||
return {k: v.value for k, v in value.value.items()}
|
||||
return value.value if value else None
|
||||
|
||||
def remove(self, selector: Sequence[str], /):
|
||||
"""
|
||||
Remove variables from the variable pool based on the given selector.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): A sequence of strings representing the selector.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not selector:
|
||||
return
|
||||
if len(selector) == 1:
|
||||
self._variable_dictionary[selector[0]] = {}
|
||||
return
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self._variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
|
||||
Reference in New Issue
Block a user