refactor(api/core/app/segments): Support more kinds of Segments. (#6706)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2024-07-26 15:03:56 +08:00
committed by GitHub
parent 6b50bb0fe6
commit c6996a48a4
11 changed files with 147 additions and 96 deletions

View File

@@ -4,7 +4,7 @@ from typing import Any, Union
from typing_extensions import deprecated
from core.app.segments import Variable, factory
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
@@ -33,7 +33,7 @@ class VariablePool:
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
self._variable_dictionary: dict[str, dict[int, Variable]] = defaultdict(dict)
self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
# TODO: This user inputs is not used for pool.
self.user_inputs = user_inputs
@@ -67,15 +67,15 @@ class VariablePool:
if value is None:
return
if not isinstance(value, Variable):
v = factory.build_anonymous_variable(value)
else:
if isinstance(value, Segment):
v = value
else:
v = factory.build_segment(value)
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Variable | None:
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
Retrieves the value from the variable pool based on the given selector.