feat: Parallel Execution of Nodes in Workflows (#8192)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
takatost
2024-09-10 15:23:16 +08:00
committed by GitHub
parent 5da0182800
commit dabfd74622
156 changed files with 11158 additions and 5605 deletions

View File

@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class RunConditionHandler(ABC):
def __init__(self,
init_params: GraphInitParams,
graph: Graph,
condition: RunCondition):
self.init_params = init_params
self.graph = graph
self.condition = condition
@abstractmethod
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState
) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
raise NotImplementedError

View File

@@ -0,0 +1,28 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
if not self.condition.branch_identify:
raise Exception("Branch identify is required")
run_result = previous_route_node_state.node_run_result
if not run_result:
return False
if not run_result.edge_source_handle:
return False
return self.condition.branch_identify == run_result.edge_source_handle

View File

@@ -0,0 +1,32 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState
) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
if not self.condition.conditions:
return True
# process condition
condition_processor = ConditionProcessor()
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=graph_runtime_state.variable_pool,
conditions=self.condition.conditions
)
# Apply the logical operator for the current case
compare_result = all(group_result)
return compare_result

View File

@@ -0,0 +1,35 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.run_condition import RunCondition
class ConditionManager:
@staticmethod
def get_condition_handler(
init_params: GraphInitParams,
graph: Graph,
run_condition: RunCondition
) -> RunConditionHandler:
"""
Get condition handler
:param init_params: init params
:param graph: graph
:param run_condition: run condition
:return: condition handler
"""
if run_condition.type == "branch_identify":
return BranchIdentifyRunConditionHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)
else:
return ConditionRunConditionHandlerHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)

View File

@@ -0,0 +1,163 @@
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class GraphEngineEvent(BaseModel):
pass
###########################################
# Graph Events
###########################################
class BaseGraphEvent(GraphEngineEvent):
pass
class GraphRunStartedEvent(BaseGraphEvent):
pass
class GraphRunSucceededEvent(BaseGraphEvent):
outputs: Optional[dict[str, Any]] = None
"""outputs"""
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
###########################################
# Node Events
###########################################
class BaseNodeEvent(GraphEngineEvent):
id: str = Field(..., description="node execution id")
node_id: str = Field(..., description="node id")
node_type: NodeType = Field(..., description="node type")
node_data: BaseNodeData = Field(..., description="node data")
route_node_state: RouteNodeState = Field(..., description="route node state")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class NodeRunStartedEvent(BaseNodeEvent):
predecessor_node_id: Optional[str] = None
"""predecessor node id"""
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: Optional[list[str]] = None
"""from variable selector"""
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class NodeRunSucceededEvent(BaseNodeEvent):
pass
class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
###########################################
# Parallel Branch Events
###########################################
class BaseParallelBranchEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id")
"""parallel id"""
parallel_start_node_id: str = Field(..., description="parallel start node id")
"""parallel start node id"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
pass
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
pass
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
error: str = Field(..., description="failed reason")
###########################################
# Iteration Events
###########################################
class BaseIterationEvent(GraphEngineEvent):
iteration_id: str = Field(..., description="iteration node execution id")
iteration_node_id: str = Field(..., description="iteration node id")
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
iteration_node_data: BaseNodeData = Field(..., description="node data")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
class IterationRunStartedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
predecessor_node_id: Optional[str] = None
class IterationRunNextEvent(BaseIterationEvent):
index: int = Field(..., description="index")
pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output")
class IterationRunSucceededEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
class IterationRunFailedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent

View File

@@ -0,0 +1,692 @@
import uuid
from collections.abc import Mapping
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
from core.workflow.nodes.end.entities import EndStreamParam
class GraphEdge(BaseModel):
source_node_id: str = Field(..., description="source node id")
target_node_id: str = Field(..., description="target node id")
run_condition: Optional[RunCondition] = None
"""run condition"""
class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
start_from_node_id: str = Field(..., description="start from node id")
parent_parallel_id: Optional[str] = None
"""parent parallel id"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id"""
end_to_node_id: Optional[str] = None
"""end to node id"""
class Graph(BaseModel):
root_node_id: str = Field(..., description="root node id of the graph")
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field(
default_factory=list,
description="node configs mapping (node id: node config)"
)
edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict,
description="graph edge mapping (source node id: edges)"
)
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict,
description="reverse graph edge mapping (target node id: edges)"
)
parallel_mapping: dict[str, GraphParallel] = Field(
default_factory=dict,
description="graph parallel mapping (parallel id: parallel)"
)
node_parallel_mapping: dict[str, str] = Field(
default_factory=dict,
description="graph node parallel mapping (node id: parallel id)"
)
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
...,
description="answer stream generate routes"
)
end_stream_param: EndStreamParam = Field(
...,
description="end stream param"
)
@classmethod
def init(cls,
graph_config: Mapping[str, Any],
root_node_id: Optional[str] = None) -> "Graph":
"""
Init graph
:param graph_config: graph config
:param root_node_id: root node id
:return: graph
"""
# edge configs
edge_configs = graph_config.get('edges')
if edge_configs is None:
edge_configs = []
edge_configs = cast(list, edge_configs)
# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
for edge_config in edge_configs:
source_node_id = edge_config.get('source')
if not source_node_id:
continue
if source_node_id not in edge_mapping:
edge_mapping[source_node_id] = []
target_node_id = edge_config.get('target')
if not target_node_id:
continue
if target_node_id not in reverse_edge_mapping:
reverse_edge_mapping[target_node_id] = []
# is target node id in source node id edge mapping
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
continue
target_edge_ids.add(target_node_id)
# parse run condition
run_condition = None
if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source':
run_condition = RunCondition(
type='branch_identify',
branch_identify=edge_config.get('sourceHandle')
)
graph_edge = GraphEdge(
source_node_id=source_node_id,
target_node_id=target_node_id,
run_condition=run_condition
)
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)
# node configs
node_configs = graph_config.get('nodes')
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = cast(list, node_configs)
# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get('id')
if not node_id:
continue
if node_id not in target_edge_ids:
root_node_configs.append(node_config)
all_node_id_config_mapping[node_id] = node_config
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
# fetch root node
if not root_node_id:
# if no root node id, use the START type node as root node
root_node_id = next((node_config.get("id") for node_config in root_node_configs
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
if not root_node_id or root_node_id not in root_node_ids:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# Check whether it is connected to the previous node
cls._check_connected_to_previous_node(
route=[root_node_id],
edge_mapping=edge_mapping
)
# fetch all node ids from root node
node_ids = [root_node_id]
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=root_node_id
)
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
# init parallel mapping
parallel_mapping: dict[str, GraphParallel] = {}
node_parallel_mapping: dict[str, str] = {}
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=root_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
)
# Check if it exceeds N layers of parallel
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=3,
parent_parallel_id=parallel.parent_parallel_id
)
# init answer stream generate routes
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping
)
# init end stream param
end_stream_param = EndStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
node_parallel_mapping=node_parallel_mapping
)
# init graph
graph = cls(
root_node_id=root_node_id,
node_ids=node_ids,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
answer_stream_generate_routes=answer_stream_generate_routes,
end_stream_param=end_stream_param
)
return graph
def add_extra_edge(self, source_node_id: str,
target_node_id: str,
run_condition: Optional[RunCondition] = None) -> None:
"""
Add extra edge to the graph
:param source_node_id: source node id
:param target_node_id: target node id
:param run_condition: run condition
"""
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
return
if source_node_id not in self.edge_mapping:
self.edge_mapping[source_node_id] = []
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
return
graph_edge = GraphEdge(
source_node_id=source_node_id,
target_node_id=target_node_id,
run_condition=run_condition
)
self.edge_mapping[source_node_id].append(graph_edge)
def get_leaf_node_ids(self) -> list[str]:
"""
Get leaf node ids of the graph
:return: leaf node ids
"""
leaf_node_ids = []
for node_id in self.node_ids:
if node_id not in self.edge_mapping:
leaf_node_ids.append(node_id)
elif (len(self.edge_mapping[node_id]) == 1
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id):
leaf_node_ids.append(node_id)
return leaf_node_ids
@classmethod
def _recursively_add_node_ids(cls,
node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
node_id: str) -> None:
"""
Recursively add node ids
:param node_ids: node ids
:param edge_mapping: edge mapping
:param node_id: node id
"""
for graph_edge in edge_mapping.get(node_id, []):
if graph_edge.target_node_id in node_ids:
continue
node_ids.append(graph_edge.target_node_id)
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=graph_edge.target_node_id
)
@classmethod
def _check_connected_to_previous_node(
cls,
route: list[str],
edge_mapping: dict[str, list[GraphEdge]]
) -> None:
"""
Check whether it is connected to the previous node
"""
last_node_id = route[-1]
for graph_edge in edge_mapping.get(last_node_id, []):
if not graph_edge.target_node_id:
continue
if graph_edge.target_node_id in route:
raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.")
new_route = route[:]
new_route.append(graph_edge.target_node_id)
cls._check_connected_to_previous_node(
route=new_route,
edge_mapping=edge_mapping,
)
@classmethod
def _recursively_add_parallels(
cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str],
parent_parallel: Optional[GraphParallel] = None
) -> None:
"""
Recursively add parallel ids
:param edge_mapping: edge mapping
:param start_node_id: start from node id
:param parallel_mapping: parallel mapping
:param node_parallel_mapping: node parallel mapping
:param parent_parallel: parent parallel
"""
target_node_edges = edge_mapping.get(start_node_id, [])
parallel = None
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_branch_node_ids = []
condition_edge_mappings = {}
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
parallel_branch_node_ids.append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
if not condition_hash in condition_edge_mappings:
condition_edge_mappings[condition_hash] = []
condition_edge_mappings[condition_hash].append(graph_edge)
for _, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
for graph_edge in graph_edges:
parallel_branch_node_ids.append(graph_edge.target_node_id)
# any target node id in node_parallel_mapping
if parallel_branch_node_ids:
parent_parallel_id = parent_parallel.id if parent_parallel else None
parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel.id if parent_parallel else None,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None
)
parallel_mapping[parallel.id] = parallel
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=parallel_branch_node_ids
)
# collect all branches node ids
parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
in_parent_parallel = True
if parent_parallel_id:
in_parent_parallel = False
for parallel_node_id, parallel_id in node_parallel_mapping.items():
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
in_parent_parallel = True
break
if in_parent_parallel:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id
outside_parallel_target_node_ids = set()
for node_id in parallel_node_ids:
if node_id == parallel.start_from_node_id:
continue
node_edges = edge_mapping.get(node_id)
if not node_edges:
continue
if len(node_edges) > 1:
continue
target_node_id = node_edges[0].target_node_id
if target_node_id in parallel_node_ids:
continue
if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
continue
if (
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)
if len(outside_parallel_target_node_ids) == 1:
if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id:
parallel.end_to_node_id = None
else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
for graph_edge in target_node_edges:
current_parallel = None
if parallel:
current_parallel = parallel
elif parent_parallel:
if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id):
current_parallel = parent_parallel
else:
# fetch parent parallel's parent parallel
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
if parent_parallel_parent_parallel_id:
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
if (
parent_parallel_parent_parallel
and (
not parent_parallel_parent_parallel.end_to_node_id
or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id)
)
):
current_parallel = parent_parallel_parent_parallel
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel
)
@classmethod
def _check_exceed_parallel_limit(
cls,
parallel_mapping: dict[str, GraphParallel],
level_limit: int,
parent_parallel_id: str,
current_level: int = 1
) -> None:
"""
Check if it exceeds N layers of parallel
"""
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
return
current_level += 1
if current_level > level_limit:
raise ValueError(f"Exceeds {level_limit} layers of parallel")
if parent_parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=level_limit,
parent_parallel_id=parent_parallel.parent_parallel_id,
current_level=current_level
)
@classmethod
def _recursively_add_parallel_node_ids(cls,
branch_node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
merge_node_id: str,
start_node_id: str) -> None:
"""
Recursively add node ids
:param branch_node_ids: in branch node ids
:param edge_mapping: edge mapping
:param merge_node_id: merge node id
:param start_node_id: start node id
"""
for graph_edge in edge_mapping.get(start_node_id, []):
if (graph_edge.target_node_id != merge_node_id
and graph_edge.target_node_id not in branch_node_ids):
branch_node_ids.append(graph_edge.target_node_id)
cls._recursively_add_parallel_node_ids(
branch_node_ids=branch_node_ids,
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=graph_edge.target_node_id
)
@classmethod
def _fetch_all_node_ids_in_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
"""
Fetch all node ids in parallels
"""
routes_node_ids: dict[str, list[str]] = {}
for parallel_branch_node_id in parallel_branch_node_ids:
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
# fetch routes node ids
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=parallel_branch_node_id,
routes_node_ids=routes_node_ids[parallel_branch_node_id]
)
# fetch leaf node ids from routes node ids
leaf_node_ids: dict[str, list[str]] = {}
merge_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
if branch_node_id not in leaf_node_ids:
leaf_node_ids[branch_node_id] = []
leaf_node_ids[branch_node_id].append(node_id)
for branch_node_id2, inner_route2 in routes_node_ids.items():
if (
branch_node_id != branch_node_id2
and node_id in inner_route2
and len(reverse_edge_mapping.get(node_id, [])) > 1
and cls._is_node_in_routes(
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=node_id,
routes_node_ids=routes_node_ids
)
):
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []
if branch_node_id2 not in merge_branch_node_ids[node_id]:
merge_branch_node_ids[node_id].append(branch_node_id2)
# sorted merge_branch_node_ids by branch_node_ids length desc
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
duplicate_end_node_ids = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids:
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
# check which node is after
if cls._is_node2_after_node1(
node1_id=node_id,
node2_id=node_id2,
edge_mapping=edge_mapping
):
if node_id in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]
elif cls._is_node2_after_node1(
node1_id=node_id2,
node2_id=node_id,
edge_mapping=edge_mapping
):
if node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
branches_merge_node_ids: dict[str, str] = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
if len(branch_node_ids) <= 1:
continue
for branch_node_id in branch_node_ids:
if branch_node_id in branches_merge_node_ids:
continue
branches_merge_node_ids[branch_node_id] = node_id
in_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
in_branch_node_ids[branch_node_id] = []
if branch_node_id not in branches_merge_node_ids:
# all node ids in current branch is in this thread
in_branch_node_ids[branch_node_id].append(branch_node_id)
in_branch_node_ids[branch_node_id].extend(node_ids)
else:
merge_node_id = branches_merge_node_ids[branch_node_id]
if merge_node_id != branch_node_id:
in_branch_node_ids[branch_node_id].append(branch_node_id)
# fetch all node ids from branch_node_id and merge_node_id
cls._recursively_add_parallel_node_ids(
branch_node_ids=in_branch_node_ids[branch_node_id],
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=branch_node_id
)
return in_branch_node_ids
@classmethod
def _recursively_fetch_routes(cls,
edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
routes_node_ids: list[str]) -> None:
"""
Recursively fetch route
"""
if start_node_id not in edge_mapping:
return
for graph_edge in edge_mapping[start_node_id]:
# find next node ids
if graph_edge.target_node_id not in routes_node_ids:
routes_node_ids.append(graph_edge.target_node_id)
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=graph_edge.target_node_id,
routes_node_ids=routes_node_ids
)
@classmethod
def _is_node_in_routes(cls,
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
routes_node_ids: dict[str, list[str]]) -> bool:
"""
Recursively check if the node is in the routes
"""
if start_node_id not in reverse_edge_mapping:
return False
all_routes_node_ids = set()
parallel_start_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
all_routes_node_ids.add(node_id)
if branch_node_id in reverse_edge_mapping:
for graph_edge in reverse_edge_mapping[branch_node_id]:
if graph_edge.source_node_id not in parallel_start_node_ids:
parallel_start_node_ids[graph_edge.source_node_id] = []
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
parallel_start_node_id = None
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
if set(branch_node_ids) == set(routes_node_ids.keys()):
parallel_start_node_id = p_start_node_id
return True
if not parallel_start_node_id:
raise Exception("Parallel start node id not found")
for graph_edge in reverse_edge_mapping[start_node_id]:
if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id:
return False
return True
@classmethod
def _is_node2_after_node1(
cls,
node1_id: str,
node2_id: str,
edge_mapping: dict[str, list[GraphEdge]]
) -> bool:
"""
is node2 after node1
"""
if node1_id not in edge_mapping:
return False
for graph_edge in edge_mapping[node1_id]:
if graph_edge.target_node_id == node2_id:
return True
if cls._is_node2_after_node1(
node1_id=graph_edge.target_node_id,
node2_id=node2_id,
edge_mapping=edge_mapping
):
return True
return False

View File

@@ -0,0 +1,21 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from models.workflow import WorkflowType
class GraphInitParams(BaseModel):
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_type: WorkflowType = Field(..., description="workflow type")
workflow_id: str = Field(..., description="workflow id")
graph_config: Mapping[str, Any] = Field(..., description="graph config")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
call_depth: int = Field(..., description="call depth")

View File

@@ -0,0 +1,27 @@
from typing import Any
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
class GraphRuntimeState(BaseModel):
variable_pool: VariablePool = Field(..., description="variable pool")
"""variable pool"""
start_at: float = Field(..., description="start time")
"""start time"""
total_tokens: int = 0
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info"""
outputs: dict[str, Any] = {}
"""outputs"""
node_run_steps: int = 0
"""node run steps"""
node_run_state: RuntimeRouteState = RuntimeRouteState()
"""node run state"""

View File

@@ -0,0 +1,13 @@
from typing import Optional
from pydantic import BaseModel
from core.workflow.graph_engine.entities.graph import GraphParallel
class NextGraphNode(BaseModel):
node_id: str
"""next node id"""
parallel: Optional[GraphParallel] = None
"""parallel"""

View File

@@ -0,0 +1,21 @@
import hashlib
from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.utils.condition.entities import Condition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: Optional[str] = None
"""branch identify like: sourceHandle, required when type is branch_identify"""
conditions: Optional[list[Condition]] = None
"""conditions to run the node, required when type is condition"""
@property
def hash(self) -> str:
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()

View File

@@ -0,0 +1,111 @@
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
class RouteNodeState(BaseModel):
class Status(Enum):
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""
node_id: str
"""node id"""
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.RUNNING
"""node status"""
start_at: datetime
"""start time"""
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
"""failed reason"""
paused_by: Optional[str] = None
"""paused by"""
index: int = 1
def set_finished(self, run_result: NodeRunResult) -> None:
"""
Node finished
:param run_result: run result
"""
if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
raise Exception(f"Route state {self.id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
self.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
self.status = RouteNodeState.Status.FAILED
self.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
self.node_run_result = run_result
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
class RuntimeRouteState(BaseModel):
routes: dict[str, list[str]] = Field(
default_factory=dict,
description="graph state routes (source_node_state_id: target_node_state_id)"
)
node_state_mapping: dict[str, RouteNodeState] = Field(
default_factory=dict,
description="node state mapping (route_node_state_id: route_node_state)"
)
def create_node_state(self, node_id: str) -> RouteNodeState:
"""
Create node state
:param node_id: node id
"""
state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
self.node_state_mapping[state.id] = state
return state
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
"""
Add route to the graph state
:param source_node_state_id: source node state id
:param target_node_state_id: target node state id
"""
if source_node_state_id not in self.routes:
self.routes[source_node_state_id] = []
self.routes[source_node_state_id].append(target_node_state_id)
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \
-> list[RouteNodeState]:
"""
Get routes with node state by source node id
:param source_node_state_id: source node state id
:return: routes with node state
"""
return [self.node_state_mapping[target_state_id]
for target_state_id in self.routes.get(source_node_state_id, [])]

View File

@@ -0,0 +1,716 @@
import logging
import queue
import time
import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Optional
from flask import Flask, current_app
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeType,
UserFrom,
)
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
BaseIterationEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph, GraphEdge
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_classes
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
logger = logging.getLogger(__name__)
class GraphEngineThreadPool(ThreadPoolExecutor):
def __init__(self, max_workers=None, thread_name_prefix='',
initializer=None, initargs=(), max_submit_count=100) -> None:
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
self.max_submit_count = max_submit_count
self.submit_count = 0
def submit(self, fn, *args, **kwargs):
self.submit_count += 1
self.check_is_full()
return super().submit(fn, *args, **kwargs)
def check_is_full(self) -> None:
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
if self.submit_count > self.max_submit_count:
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
class GraphEngine:
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
def __init__(
self,
tenant_id: str,
app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
graph_config: Mapping[str, Any],
variable_pool: VariablePool,
max_execution_steps: int,
max_execution_time: int,
thread_pool_id: Optional[str] = None
) -> None:
thread_pool_max_submit_count = 100
thread_pool_max_workers = 10
## init thread pool
if thread_pool_id:
if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.")
self.thread_pool_id = thread_pool_id
self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id]
self.is_main_thread_pool = False
else:
self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count)
self.thread_pool_id = str(uuid.uuid4())
self.is_main_thread_pool = True
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
self.graph = graph
self.init_params = GraphInitParams(
tenant_id=tenant_id,
app_id=app_id,
workflow_type=workflow_type,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth
)
self.graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter()
)
self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
yield GraphRunStartedEvent()
try:
stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
if self.init_params.workflow_type == WorkflowType.CHAT:
stream_processor_cls = AnswerStreamProcessor
else:
stream_processor_cls = EndStreamProcessor
stream_processor = stream_processor_cls(
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool
)
# run graph
generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id)
)
for item in generator:
try:
yield item
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
return
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else {})
elif item.node_type == NodeType.ANSWER:
if "answer" not in self.graph_runtime_state.outputs:
self.graph_runtime_state.outputs["answer"] = ""
self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "")
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else "")
self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip()
except Exception as e:
logger.exception(f"Graph run failed: {str(e)}")
yield GraphRunFailedEvent(error=str(e))
return
# trigger graph run success event
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
except GraphRunFailedError as e:
yield GraphRunFailedEvent(error=e.error)
return
except Exception as e:
logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(error=str(e))
raise e
finally:
if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id]
def _run(
self,
start_node_id: str,
in_parallel_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None
) -> Generator[GraphEngineEvent, None, None]:
parallel_start_node_id = None
if in_parallel_id:
parallel_start_node_id = start_node_id
next_node_id = start_node_id
previous_route_node_state: Optional[RouteNodeState] = None
while True:
# max steps reached
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps))
# or max execution time reached
if self._is_timed_out(
start_at=self.graph_runtime_state.start_at,
max_execution_time=self.max_execution_time
):
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
# init route node state
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
node_id=next_node_id
)
# get node config
node_id = route_node_state.node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} config not found.')
# convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
# init workflow run state
node_instance = node_cls( # type: ignore
id=route_node_state.id,
config=node_config,
graph_init_params=self.init_params,
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id
)
try:
# run node
generator = self._run_node(
node_instance=node_instance,
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
for item in generator:
if isinstance(item, NodeRunStartedEvent):
self.graph_runtime_state.node_run_steps += 1
item.route_node_state.index = self.graph_runtime_state.node_run_steps
yield item
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
# append route
if previous_route_node_state:
self.graph_runtime_state.node_run_state.add_route(
source_node_state_id=previous_route_node_state.id,
target_node_state_id=route_node_state.id
)
except Exception as e:
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent(
error=str(e),
id=node_instance.id,
node_id=next_node_id,
node_type=node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
raise e
# It may not be necessary, but it is necessary. :)
if (self.graph.node_id_config_mapping[next_node_id]
.get("data", {}).get("type", "").lower() == NodeType.END.value):
break
previous_route_node_state = route_node_state
# get next node ids
edge_mappings = self.graph.edge_mapping.get(next_node_id)
if not edge_mappings:
break
if len(edge_mappings) == 1:
edge = edge_mappings[0]
if edge.run_condition:
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
graph=self.graph,
run_condition=edge.run_condition,
).check(
graph_runtime_state=self.graph_runtime_state,
previous_route_node_state=previous_route_node_state
)
if not result:
break
next_node_id = edge.target_node_id
else:
final_node_id = None
if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results
condition_edge_mappings = {}
for edge in edge_mappings:
if edge.run_condition:
run_condition_hash = edge.run_condition.hash
if run_condition_hash not in condition_edge_mappings:
condition_edge_mappings[run_condition_hash] = []
condition_edge_mappings[run_condition_hash].append(edge)
for _, sub_edge_mappings in condition_edge_mappings.items():
if len(sub_edge_mappings) == 0:
continue
edge = sub_edge_mappings[0]
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
graph=self.graph,
run_condition=edge.run_condition,
).check(
graph_runtime_state=self.graph_runtime_state,
previous_route_node_state=previous_route_node_state,
)
if not result:
continue
if len(sub_edge_mappings) == 1:
final_node_id = edge.target_node_id
else:
parallel_generator = self._run_parallel_branches(
edge_mappings=sub_edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
for item in parallel_generator:
if isinstance(item, str):
final_node_id = item
else:
yield item
break
if not final_node_id:
break
next_node_id = final_node_id
else:
parallel_generator = self._run_parallel_branches(
edge_mappings=edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
for item in parallel_generator:
if isinstance(item, str):
final_node_id = item
else:
yield item
if not final_node_id:
break
next_node_id = final_node_id
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
break
def _run_parallel_branches(
self,
edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id:
node_id = edge_mappings[0].target_node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.')
node_title = node_config.get('data', {}).get('title')
raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.')
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
# run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue()
# Create a list to store the threads
futures = []
# new thread
for edge in edge_mappings:
if (
edge.target_node_id not in self.graph.node_parallel_mapping
or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id
):
continue
futures.append(
self.thread_pool.submit(self._run_parallel_node, **{
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'q': q,
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'parent_parallel_id': in_parallel_id,
'parent_parallel_start_node_id': parallel_start_node_id,
})
)
succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
yield event
if event.parallel_id == parallel_id:
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(futures):
q.put(None)
continue
elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.error)
except queue.Empty:
continue
# wait all threads
wait(futures)
# get final node id
final_node_id = parallel.end_to_node_id
if final_node_id:
yield final_node_id
def _run_parallel_node(
self,
flask_app: Flask,
q: queue.Queue,
parallel_id: str,
parallel_start_node_id: str,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
) -> None:
"""
Run parallel nodes
"""
with flask_app.app_context():
try:
q.put(ParallelBranchRunStartedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
))
# run node
generator = self._run(
start_node_id=parallel_start_node_id,
in_parallel_id=parallel_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
for item in generator:
q.put(item)
# trigger graph run success event
q.put(ParallelBranchRunSucceededEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
))
except GraphRunFailedError as e:
q.put(ParallelBranchRunFailedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
error=e.error
))
except Exception as e:
logger.exception("Unknown Error when generating in parallel")
q.put(ParallelBranchRunFailedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
error=str(e)
))
finally:
db.session.remove()
def _run_node(
self,
node_instance: BaseNode,
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
) -> Generator[GraphEngineEvent, None, None]:
"""
Run node
"""
# trigger node run start event
yield NodeRunStartedEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
predecessor_node_id=node_instance.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
db.session.close()
try:
# run node
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or 'Unknown error.',
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value
)
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id
yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
except GenerateTaskStoppedException:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}")
raise e
finally:
db.session.close()
def _append_variables_recursively(self,
node_id: str,
variable_key_list: list[str],
variable_value: VariableValue):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
self.graph_runtime_state.variable_pool.add(
[node_id] + variable_key_list,
variable_value
)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
node_id=node_id,
variable_key_list=new_key_list,
variable_value=value
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
Check timeout
:param start_at: start time
:param max_execution_time: max execution time
:return:
"""
return time.perf_counter() - start_at > max_execution_time
class GraphRunFailedError(Exception):
def __init__(self, error: str):
self.error = error