Feat/environment variables in workflow (#6515)

Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
-LAN-
2024-07-22 15:29:39 +08:00
committed by GitHub
parent 87d583f454
commit 5e6fc58db3
146 changed files with 2486 additions and 746 deletions

View File

@@ -1,3 +1,4 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
@@ -82,9 +83,9 @@ class NodeRunResult(BaseModel):
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict] = None # node inputs
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict] = None # process data
outputs: Optional[dict] = None # node outputs
outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches

View File

@@ -1,101 +1,141 @@
from enum import Enum
from typing import Any, Optional, Union
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from typing_extensions import deprecated
from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
VariableValue = Union[str, int, float, dict, list, FileVar]
class ValueType(Enum):
"""
Value Type Enum
"""
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILE = "array[file]"
FILE = "file"
SYSTEM_VARIABLE_NODE_ID = 'sys'
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
class VariablePool:
def __init__(self, system_variables: dict[SystemVariable, Any],
user_inputs: dict) -> None:
def __init__(
self,
system_variables: Mapping[SystemVariable, Any],
user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable],
) -> None:
# system variables
# for example:
# {
# 'query': 'abc',
# 'files': []
# }
self.variables_mapping = {}
# Varaible dictionary is a dictionary for looking up variables by their selector.
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
self._variable_dictionary: dict[str, dict[int, Variable]] = defaultdict(dict)
# TODO: This user inputs is not used for pool.
self.user_inputs = user_inputs
# Add system variables to the variable pool
self.system_variables = system_variables
for system_variable, value in system_variables.items():
self.append_variable('sys', [system_variable.value], value)
for key, value in system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
# Add environment variables to the variable pool
for var in environment_variables or []:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Append variable
:param node_id: node id
:param variable_key_list: variable key list, like: ['result', 'text']
:param value: value
:return:
Adds a variable to the variable pool.
Args:
selector (Sequence[str]): The selector for the variable.
value (VariableValue): The value of the variable.
Raises:
ValueError: If the selector is invalid.
Returns:
None
"""
if node_id not in self.variables_mapping:
self.variables_mapping[node_id] = {}
if len(selector) < 2:
raise ValueError('Invalid selector')
variable_key_list_hash = hash(tuple(variable_key_list))
if value is None:
return
self.variables_mapping[node_id][variable_key_list_hash] = value
if not isinstance(value, Variable):
v = factory.build_anonymous_variable(value)
else:
v = value
def get_variable_value(self, variable_selector: list[str],
target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]:
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Variable | None:
"""
Get variable
:param variable_selector: include node_id and variables
:param target_value_type: target value type
:return:
Retrieves the value from the variable pool based on the given selector.
Args:
selector (Sequence[str]): The selector used to identify the variable.
Returns:
Any: The value associated with the given selector.
Raises:
ValueError: If the selector is invalid.
"""
if len(variable_selector) < 2:
raise ValueError('Invalid value selector')
node_id = variable_selector[0]
if node_id not in self.variables_mapping:
return None
# fetch variable keys, pop node_id
variable_key_list = variable_selector[1:]
variable_key_list_hash = hash(tuple(variable_key_list))
value = self.variables_mapping[node_id].get(variable_key_list_hash)
if target_value_type:
if target_value_type == ValueType.STRING:
return str(value)
elif target_value_type == ValueType.NUMBER:
return int(value)
elif target_value_type == ValueType.OBJECT:
if not isinstance(value, dict):
raise ValueError('Invalid value type: object')
elif target_value_type in [ValueType.ARRAY_STRING,
ValueType.ARRAY_NUMBER,
ValueType.ARRAY_OBJECT,
ValueType.ARRAY_FILE]:
if not isinstance(value, list):
raise ValueError(f'Invalid value type: {target_value_type.value}')
if len(selector) < 2:
raise ValueError('Invalid selector')
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
return value
def clear_node_variables(self, node_id: str) -> None:
@deprecated('This method is deprecated, use `get` instead.')
def get_any(self, selector: Sequence[str], /) -> Any | None:
"""
Clear node variables
:param node_id: node id
:return:
Retrieves the value from the variable pool based on the given selector.
Args:
selector (Sequence[str]): The selector used to identify the variable.
Returns:
Any: The value associated with the given selector.
Raises:
ValueError: If the selector is invalid.
"""
if node_id in self.variables_mapping:
self.variables_mapping.pop(node_id)
if len(selector) < 2:
raise ValueError('Invalid selector')
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
if value is None:
return value
if isinstance(value, ArrayVariable):
return [element.value for element in value.value]
if isinstance(value, ObjectVariable):
return {k: v.value for k, v in value.value.items()}
return value.value if value else None
def remove(self, selector: Sequence[str], /):
"""
Remove variables from the variable pool based on the given selector.
Args:
selector (Sequence[str]): A sequence of strings representing the selector.
Returns:
None
"""
if not selector:
return
if len(selector) == 1:
self._variable_dictionary[selector[0]] = {}
return
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]].pop(hash_key, None)