mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 10:56:52 +08:00
feat: Parallel Execution of Nodes in Workflows (#8192)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
0
api/core/workflow/graph_engine/__init__.py
Normal file
0
api/core/workflow/graph_engine/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class RunConditionHandler(ABC):
|
||||
def __init__(self,
|
||||
init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
condition: RunCondition):
|
||||
self.init_params = init_params
|
||||
self.graph = graph
|
||||
self.condition = condition
|
||||
|
||||
@abstractmethod
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,28 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.branch_identify:
|
||||
raise Exception("Branch identify is required")
|
||||
|
||||
run_result = previous_route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return False
|
||||
|
||||
if not run_result.edge_source_handle:
|
||||
return False
|
||||
|
||||
return self.condition.branch_identify == run_result.edge_source_handle
|
||||
@@ -0,0 +1,32 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.conditions:
|
||||
return True
|
||||
|
||||
# process condition
|
||||
condition_processor = ConditionProcessor()
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=graph_runtime_state.variable_pool,
|
||||
conditions=self.condition.conditions
|
||||
)
|
||||
|
||||
# Apply the logical operator for the current case
|
||||
compare_result = all(group_result)
|
||||
|
||||
return compare_result
|
||||
@@ -0,0 +1,35 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class ConditionManager:
|
||||
@staticmethod
|
||||
def get_condition_handler(
|
||||
init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
run_condition: RunCondition
|
||||
) -> RunConditionHandler:
|
||||
"""
|
||||
Get condition handler
|
||||
|
||||
:param init_params: init params
|
||||
:param graph: graph
|
||||
:param run_condition: run condition
|
||||
:return: condition handler
|
||||
"""
|
||||
if run_condition.type == "branch_identify":
|
||||
return BranchIdentifyRunConditionHandler(
|
||||
init_params=init_params,
|
||||
graph=graph,
|
||||
condition=run_condition
|
||||
)
|
||||
else:
|
||||
return ConditionRunConditionHandlerHandler(
|
||||
init_params=init_params,
|
||||
graph=graph,
|
||||
condition=run_condition
|
||||
)
|
||||
0
api/core/workflow/graph_engine/entities/__init__.py
Normal file
0
api/core/workflow/graph_engine/entities/__init__.py
Normal file
163
api/core/workflow/graph_engine/entities/event.py
Normal file
163
api/core/workflow/graph_engine/entities/event.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
###########################################
|
||||
# Graph Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseGraphEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
"""outputs"""
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Node Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseNodeEvent(GraphEngineEvent):
|
||||
id: str = Field(..., description="node execution id")
|
||||
node_id: str = Field(..., description="node id")
|
||||
node_type: NodeType = Field(..., description="node type")
|
||||
node_data: BaseNodeData = Field(..., description="node data")
|
||||
route_node_state: RouteNodeState = Field(..., description="route node state")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
predecessor_node_id: Optional[str] = None
|
||||
"""predecessor node id"""
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
"""from variable selector"""
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(BaseNodeEvent):
|
||||
pass
|
||||
|
||||
|
||||
class NodeRunFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseParallelBranchEvent(GraphEngineEvent):
|
||||
parallel_id: str = Field(..., description="parallel id")
|
||||
"""parallel id"""
|
||||
parallel_start_node_id: str = Field(..., description="parallel start node id")
|
||||
"""parallel start node id"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Iteration Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseIterationEvent(GraphEngineEvent):
|
||||
iteration_id: str = Field(..., description="iteration node execution id")
|
||||
iteration_node_id: str = Field(..., description="iteration node id")
|
||||
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
|
||||
iteration_node_data: BaseNodeData = Field(..., description="node data")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class IterationRunNextEvent(BaseIterationEvent):
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output")
|
||||
|
||||
|
||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class IterationRunFailedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent
|
||||
692
api/core/workflow/graph_engine/entities/graph.py
Normal file
692
api/core/workflow/graph_engine/entities/graph.py
Normal file
@@ -0,0 +1,692 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
|
||||
from core.workflow.nodes.end.entities import EndStreamParam
|
||||
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
source_node_id: str = Field(..., description="source node id")
|
||||
target_node_id: str = Field(..., description="target node id")
|
||||
run_condition: Optional[RunCondition] = None
|
||||
"""run condition"""
|
||||
|
||||
|
||||
class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
|
||||
start_from_node_id: str = Field(..., description="start from node id")
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id"""
|
||||
end_to_node_id: Optional[str] = None
|
||||
"""end to node id"""
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
root_node_id: str = Field(..., description="root node id of the graph")
|
||||
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
|
||||
node_id_config_mapping: dict[str, dict] = Field(
|
||||
default_factory=list,
|
||||
description="node configs mapping (node id: node config)"
|
||||
)
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="graph edge mapping (source node id: edges)"
|
||||
)
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="reverse graph edge mapping (target node id: edges)"
|
||||
)
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(
|
||||
default_factory=dict,
|
||||
description="graph parallel mapping (parallel id: parallel)"
|
||||
)
|
||||
node_parallel_mapping: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
|
||||
...,
|
||||
description="answer stream generate routes"
|
||||
)
|
||||
end_stream_param: EndStreamParam = Field(
|
||||
...,
|
||||
description="end stream param"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: Optional[str] = None) -> "Graph":
|
||||
"""
|
||||
Init graph
|
||||
|
||||
:param graph_config: graph config
|
||||
:param root_node_id: root node id
|
||||
:return: graph
|
||||
"""
|
||||
# edge configs
|
||||
edge_configs = graph_config.get('edges')
|
||||
if edge_configs is None:
|
||||
edge_configs = []
|
||||
|
||||
edge_configs = cast(list, edge_configs)
|
||||
|
||||
# reorganize edges mapping
|
||||
edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
target_edge_ids = set()
|
||||
for edge_config in edge_configs:
|
||||
source_node_id = edge_config.get('source')
|
||||
if not source_node_id:
|
||||
continue
|
||||
|
||||
if source_node_id not in edge_mapping:
|
||||
edge_mapping[source_node_id] = []
|
||||
|
||||
target_node_id = edge_config.get('target')
|
||||
if not target_node_id:
|
||||
continue
|
||||
|
||||
if target_node_id not in reverse_edge_mapping:
|
||||
reverse_edge_mapping[target_node_id] = []
|
||||
|
||||
# is target node id in source node id edge mapping
|
||||
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
|
||||
continue
|
||||
|
||||
target_edge_ids.add(target_node_id)
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source':
|
||||
run_condition = RunCondition(
|
||||
type='branch_identify',
|
||||
branch_identify=edge_config.get('sourceHandle')
|
||||
)
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
edge_mapping[source_node_id].append(graph_edge)
|
||||
reverse_edge_mapping[target_node_id].append(graph_edge)
|
||||
|
||||
# node configs
|
||||
node_configs = graph_config.get('nodes')
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = cast(list, node_configs)
|
||||
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get('id')
|
||||
if not node_id:
|
||||
continue
|
||||
|
||||
if node_id not in target_edge_ids:
|
||||
root_node_configs.append(node_config)
|
||||
|
||||
all_node_id_config_mapping[node_id] = node_config
|
||||
|
||||
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
|
||||
|
||||
# fetch root node
|
||||
if not root_node_id:
|
||||
# if no root node id, use the START type node as root node
|
||||
root_node_id = next((node_config.get("id") for node_config in root_node_configs
|
||||
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
|
||||
|
||||
if not root_node_id or root_node_id not in root_node_ids:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
|
||||
# Check whether it is connected to the previous node
|
||||
cls._check_connected_to_previous_node(
|
||||
route=[root_node_id],
|
||||
edge_mapping=edge_mapping
|
||||
)
|
||||
|
||||
# fetch all node ids from root node
|
||||
node_ids = [root_node_id]
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=root_node_id
|
||||
)
|
||||
|
||||
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
|
||||
|
||||
# init parallel mapping
|
||||
parallel_mapping: dict[str, GraphParallel] = {}
|
||||
node_parallel_mapping: dict[str, str] = {}
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=root_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
)
|
||||
|
||||
# Check if it exceeds N layers of parallel
|
||||
for parallel in parallel_mapping.values():
|
||||
if parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=3,
|
||||
parent_parallel_id=parallel.parent_parallel_id
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping
|
||||
)
|
||||
|
||||
# init end stream param
|
||||
end_stream_param = EndStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = cls(
|
||||
root_node_id=root_node_id,
|
||||
node_ids=node_ids,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
answer_stream_generate_routes=answer_stream_generate_routes,
|
||||
end_stream_param=end_stream_param
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
def add_extra_edge(self, source_node_id: str,
|
||||
target_node_id: str,
|
||||
run_condition: Optional[RunCondition] = None) -> None:
|
||||
"""
|
||||
Add extra edge to the graph
|
||||
|
||||
:param source_node_id: source node id
|
||||
:param target_node_id: target node id
|
||||
:param run_condition: run condition
|
||||
"""
|
||||
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
|
||||
return
|
||||
|
||||
if source_node_id not in self.edge_mapping:
|
||||
self.edge_mapping[source_node_id] = []
|
||||
|
||||
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
|
||||
return
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
self.edge_mapping[source_node_id].append(graph_edge)
|
||||
|
||||
def get_leaf_node_ids(self) -> list[str]:
|
||||
"""
|
||||
Get leaf node ids of the graph
|
||||
|
||||
:return: leaf node ids
|
||||
"""
|
||||
leaf_node_ids = []
|
||||
for node_id in self.node_ids:
|
||||
if node_id not in self.edge_mapping:
|
||||
leaf_node_ids.append(node_id)
|
||||
elif (len(self.edge_mapping[node_id]) == 1
|
||||
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id):
|
||||
leaf_node_ids.append(node_id)
|
||||
|
||||
return leaf_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_node_ids(cls,
|
||||
node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
node_id: str) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param node_ids: node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param node_id: node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(node_id, []):
|
||||
if graph_edge.target_node_id in node_ids:
|
||||
continue
|
||||
|
||||
node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_connected_to_previous_node(
|
||||
cls,
|
||||
route: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]]
|
||||
) -> None:
|
||||
"""
|
||||
Check whether it is connected to the previous node
|
||||
"""
|
||||
last_node_id = route[-1]
|
||||
|
||||
for graph_edge in edge_mapping.get(last_node_id, []):
|
||||
if not graph_edge.target_node_id:
|
||||
continue
|
||||
|
||||
if graph_edge.target_node_id in route:
|
||||
raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.")
|
||||
|
||||
new_route = route[:]
|
||||
new_route.append(graph_edge.target_node_id)
|
||||
cls._check_connected_to_previous_node(
|
||||
route=new_route,
|
||||
edge_mapping=edge_mapping,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallels(
|
||||
cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
node_parallel_mapping: dict[str, str],
|
||||
parent_parallel: Optional[GraphParallel] = None
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add parallel ids
|
||||
|
||||
:param edge_mapping: edge mapping
|
||||
:param start_node_id: start from node id
|
||||
:param parallel_mapping: parallel mapping
|
||||
:param node_parallel_mapping: node parallel mapping
|
||||
:param parent_parallel: parent parallel
|
||||
"""
|
||||
target_node_edges = edge_mapping.get(start_node_id, [])
|
||||
parallel = None
|
||||
if len(target_node_edges) > 1:
|
||||
# fetch all node ids in current parallels
|
||||
parallel_branch_node_ids = []
|
||||
condition_edge_mappings = {}
|
||||
for graph_edge in target_node_edges:
|
||||
if graph_edge.run_condition is None:
|
||||
parallel_branch_node_ids.append(graph_edge.target_node_id)
|
||||
else:
|
||||
condition_hash = graph_edge.run_condition.hash
|
||||
if not condition_hash in condition_edge_mappings:
|
||||
condition_edge_mappings[condition_hash] = []
|
||||
|
||||
condition_edge_mappings[condition_hash].append(graph_edge)
|
||||
|
||||
for _, graph_edges in condition_edge_mappings.items():
|
||||
if len(graph_edges) > 1:
|
||||
for graph_edge in graph_edges:
|
||||
parallel_branch_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
# any target node id in node_parallel_mapping
|
||||
if parallel_branch_node_ids:
|
||||
parent_parallel_id = parent_parallel.id if parent_parallel else None
|
||||
|
||||
parallel = GraphParallel(
|
||||
start_from_node_id=start_node_id,
|
||||
parent_parallel_id=parent_parallel.id if parent_parallel else None,
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None
|
||||
)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
|
||||
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_branch_node_ids=parallel_branch_node_ids
|
||||
)
|
||||
|
||||
# collect all branches node ids
|
||||
parallel_node_ids = []
|
||||
for _, node_ids in in_branch_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
in_parent_parallel = True
|
||||
if parent_parallel_id:
|
||||
in_parent_parallel = False
|
||||
for parallel_node_id, parallel_id in node_parallel_mapping.items():
|
||||
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
|
||||
in_parent_parallel = True
|
||||
break
|
||||
|
||||
if in_parent_parallel:
|
||||
parallel_node_ids.append(node_id)
|
||||
node_parallel_mapping[node_id] = parallel.id
|
||||
|
||||
outside_parallel_target_node_ids = set()
|
||||
for node_id in parallel_node_ids:
|
||||
if node_id == parallel.start_from_node_id:
|
||||
continue
|
||||
|
||||
node_edges = edge_mapping.get(node_id)
|
||||
if not node_edges:
|
||||
continue
|
||||
|
||||
if len(node_edges) > 1:
|
||||
continue
|
||||
|
||||
target_node_id = node_edges[0].target_node_id
|
||||
if target_node_id in parallel_node_ids:
|
||||
continue
|
||||
|
||||
if parent_parallel_id:
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
continue
|
||||
|
||||
if (
|
||||
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
|
||||
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
|
||||
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
|
||||
):
|
||||
outside_parallel_target_node_ids.add(target_node_id)
|
||||
|
||||
if len(outside_parallel_target_node_ids) == 1:
|
||||
if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id:
|
||||
parallel.end_to_node_id = None
|
||||
else:
|
||||
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||
|
||||
for graph_edge in target_node_edges:
|
||||
current_parallel = None
|
||||
if parallel:
|
||||
current_parallel = parallel
|
||||
elif parent_parallel:
|
||||
if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id):
|
||||
current_parallel = parent_parallel
|
||||
else:
|
||||
# fetch parent parallel's parent parallel
|
||||
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
|
||||
if parent_parallel_parent_parallel_id:
|
||||
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
|
||||
if (
|
||||
parent_parallel_parent_parallel
|
||||
and (
|
||||
not parent_parallel_parent_parallel.end_to_node_id
|
||||
or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id)
|
||||
)
|
||||
):
|
||||
current_parallel = parent_parallel_parent_parallel
|
||||
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_exceed_parallel_limit(
|
||||
cls,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
level_limit: int,
|
||||
parent_parallel_id: str,
|
||||
current_level: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Check if it exceeds N layers of parallel
|
||||
"""
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
return
|
||||
|
||||
current_level += 1
|
||||
if current_level > level_limit:
|
||||
raise ValueError(f"Exceeds {level_limit} layers of parallel")
|
||||
|
||||
if parent_parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=level_limit,
|
||||
parent_parallel_id=parent_parallel.parent_parallel_id,
|
||||
current_level=current_level
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallel_node_ids(cls,
|
||||
branch_node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
merge_node_id: str,
|
||||
start_node_id: str) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param branch_node_ids: in branch node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param merge_node_id: merge node id
|
||||
:param start_node_id: start node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(start_node_id, []):
|
||||
if (graph_edge.target_node_id != merge_node_id
|
||||
and graph_edge.target_node_id not in branch_node_ids):
|
||||
branch_node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=branch_node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _fetch_all_node_ids_in_parallels(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch all node ids in parallels
|
||||
"""
|
||||
routes_node_ids: dict[str, list[str]] = {}
|
||||
for parallel_branch_node_id in parallel_branch_node_ids:
|
||||
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
|
||||
|
||||
# fetch routes node ids
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=parallel_branch_node_id,
|
||||
routes_node_ids=routes_node_ids[parallel_branch_node_id]
|
||||
)
|
||||
|
||||
# fetch leaf node ids from routes node ids
|
||||
leaf_node_ids: dict[str, list[str]] = {}
|
||||
merge_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
|
||||
if branch_node_id not in leaf_node_ids:
|
||||
leaf_node_ids[branch_node_id] = []
|
||||
|
||||
leaf_node_ids[branch_node_id].append(node_id)
|
||||
|
||||
for branch_node_id2, inner_route2 in routes_node_ids.items():
|
||||
if (
|
||||
branch_node_id != branch_node_id2
|
||||
and node_id in inner_route2
|
||||
and len(reverse_edge_mapping.get(node_id, [])) > 1
|
||||
and cls._is_node_in_routes(
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=node_id,
|
||||
routes_node_ids=routes_node_ids
|
||||
)
|
||||
):
|
||||
if node_id not in merge_branch_node_ids:
|
||||
merge_branch_node_ids[node_id] = []
|
||||
|
||||
if branch_node_id2 not in merge_branch_node_ids[node_id]:
|
||||
merge_branch_node_ids[node_id].append(branch_node_id2)
|
||||
|
||||
# sorted merge_branch_node_ids by branch_node_ids length desc
|
||||
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
|
||||
|
||||
duplicate_end_node_ids = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
|
||||
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
|
||||
if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids:
|
||||
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
|
||||
|
||||
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
|
||||
# check which node is after
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=node_id,
|
||||
node2_id=node_id2,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
if node_id in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id2]
|
||||
elif cls._is_node2_after_node1(
|
||||
node1_id=node_id2,
|
||||
node2_id=node_id,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
if node_id2 in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id]
|
||||
|
||||
branches_merge_node_ids: dict[str, str] = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
if len(branch_node_ids) <= 1:
|
||||
continue
|
||||
|
||||
for branch_node_id in branch_node_ids:
|
||||
if branch_node_id in branches_merge_node_ids:
|
||||
continue
|
||||
|
||||
branches_merge_node_ids[branch_node_id] = node_id
|
||||
|
||||
in_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
in_branch_node_ids[branch_node_id] = []
|
||||
if branch_node_id not in branches_merge_node_ids:
|
||||
# all node ids in current branch is in this thread
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
in_branch_node_ids[branch_node_id].extend(node_ids)
|
||||
else:
|
||||
merge_node_id = branches_merge_node_ids[branch_node_id]
|
||||
if merge_node_id != branch_node_id:
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
|
||||
# fetch all node ids from branch_node_id and merge_node_id
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=in_branch_node_ids[branch_node_id],
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=branch_node_id
|
||||
)
|
||||
|
||||
return in_branch_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_fetch_routes(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
routes_node_ids: list[str]) -> None:
|
||||
"""
|
||||
Recursively fetch route
|
||||
"""
|
||||
if start_node_id not in edge_mapping:
|
||||
return
|
||||
|
||||
for graph_edge in edge_mapping[start_node_id]:
|
||||
# find next node ids
|
||||
if graph_edge.target_node_id not in routes_node_ids:
|
||||
routes_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
routes_node_ids=routes_node_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _is_node_in_routes(cls,
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
routes_node_ids: dict[str, list[str]]) -> bool:
|
||||
"""
|
||||
Recursively check if the node is in the routes
|
||||
"""
|
||||
if start_node_id not in reverse_edge_mapping:
|
||||
return False
|
||||
|
||||
all_routes_node_ids = set()
|
||||
parallel_start_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
all_routes_node_ids.add(node_id)
|
||||
|
||||
if branch_node_id in reverse_edge_mapping:
|
||||
for graph_edge in reverse_edge_mapping[branch_node_id]:
|
||||
if graph_edge.source_node_id not in parallel_start_node_ids:
|
||||
parallel_start_node_ids[graph_edge.source_node_id] = []
|
||||
|
||||
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
|
||||
|
||||
parallel_start_node_id = None
|
||||
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
|
||||
if set(branch_node_ids) == set(routes_node_ids.keys()):
|
||||
parallel_start_node_id = p_start_node_id
|
||||
return True
|
||||
|
||||
if not parallel_start_node_id:
|
||||
raise Exception("Parallel start node id not found")
|
||||
|
||||
for graph_edge in reverse_edge_mapping[start_node_id]:
|
||||
if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_node2_after_node1(
|
||||
cls,
|
||||
node1_id: str,
|
||||
node2_id: str,
|
||||
edge_mapping: dict[str, list[GraphEdge]]
|
||||
) -> bool:
|
||||
"""
|
||||
is node2 after node1
|
||||
"""
|
||||
if node1_id not in edge_mapping:
|
||||
return False
|
||||
|
||||
for graph_edge in edge_mapping[node1_id]:
|
||||
if graph_edge.target_node_id == node2_id:
|
||||
return True
|
||||
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=graph_edge.target_node_id,
|
||||
node2_id=node2_id,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
21
api/core/workflow/graph_engine/entities/graph_init_params.py
Normal file
21
api/core/workflow/graph_engine/entities/graph_init_params.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
@@ -0,0 +1,27 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
"""variable pool"""
|
||||
|
||||
start_at: float = Field(..., description="start time")
|
||||
"""start time"""
|
||||
total_tokens: int = 0
|
||||
"""total tokens"""
|
||||
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
||||
"""llm usage info"""
|
||||
outputs: dict[str, Any] = {}
|
||||
"""outputs"""
|
||||
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
|
||||
node_run_state: RuntimeRouteState = RuntimeRouteState()
|
||||
"""node run state"""
|
||||
13
api/core/workflow/graph_engine/entities/next_graph_node.py
Normal file
13
api/core/workflow/graph_engine/entities/next_graph_node.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import GraphParallel
|
||||
|
||||
|
||||
class NextGraphNode(BaseModel):
|
||||
node_id: str
|
||||
"""next node id"""
|
||||
|
||||
parallel: Optional[GraphParallel] = None
|
||||
"""parallel"""
|
||||
21
api/core/workflow/graph_engine/entities/run_condition.py
Normal file
21
api/core/workflow/graph_engine/entities/run_condition.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import hashlib
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class RunCondition(BaseModel):
|
||||
type: Literal["branch_identify", "condition"]
|
||||
"""condition type"""
|
||||
|
||||
branch_identify: Optional[str] = None
|
||||
"""branch identify like: sourceHandle, required when type is branch_identify"""
|
||||
|
||||
conditions: Optional[list[Condition]] = None
|
||||
"""conditions to run the node, required when type is condition"""
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()
|
||||
111
api/core/workflow/graph_engine/entities/runtime_route_state.py
Normal file
111
api/core/workflow/graph_engine/entities/runtime_route_state.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RouteNodeState(BaseModel):
|
||||
class Status(Enum):
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""node state id"""
|
||||
|
||||
node_id: str
|
||||
"""node id"""
|
||||
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
"""node run result"""
|
||||
|
||||
status: Status = Status.RUNNING
|
||||
"""node status"""
|
||||
|
||||
start_at: datetime
|
||||
"""start time"""
|
||||
|
||||
paused_at: Optional[datetime] = None
|
||||
"""paused time"""
|
||||
|
||||
finished_at: Optional[datetime] = None
|
||||
"""finished time"""
|
||||
|
||||
failed_reason: Optional[str] = None
|
||||
"""failed reason"""
|
||||
|
||||
paused_by: Optional[str] = None
|
||||
"""paused by"""
|
||||
|
||||
index: int = 1
|
||||
|
||||
def set_finished(self, run_result: NodeRunResult) -> None:
|
||||
"""
|
||||
Node finished
|
||||
|
||||
:param run_result: run result
|
||||
"""
|
||||
if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
|
||||
raise Exception(f"Route state {self.id} already finished")
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
self.status = RouteNodeState.Status.SUCCESS
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
self.status = RouteNodeState.Status.FAILED
|
||||
self.failed_reason = run_result.error
|
||||
else:
|
||||
raise Exception(f"Invalid route status {run_result.status}")
|
||||
|
||||
self.node_run_result = run_result
|
||||
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
|
||||
class RuntimeRouteState(BaseModel):
|
||||
routes: dict[str, list[str]] = Field(
|
||||
default_factory=dict,
|
||||
description="graph state routes (source_node_state_id: target_node_state_id)"
|
||||
)
|
||||
|
||||
node_state_mapping: dict[str, RouteNodeState] = Field(
|
||||
default_factory=dict,
|
||||
description="node state mapping (route_node_state_id: route_node_state)"
|
||||
)
|
||||
|
||||
def create_node_state(self, node_id: str) -> RouteNodeState:
|
||||
"""
|
||||
Create node state
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
|
||||
self.node_state_mapping[state.id] = state
|
||||
return state
|
||||
|
||||
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
|
||||
"""
|
||||
Add route to the graph state
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:param target_node_state_id: target node state id
|
||||
"""
|
||||
if source_node_state_id not in self.routes:
|
||||
self.routes[source_node_state_id] = []
|
||||
|
||||
self.routes[source_node_state_id].append(target_node_state_id)
|
||||
|
||||
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \
|
||||
-> list[RouteNodeState]:
|
||||
"""
|
||||
Get routes with node state by source node id
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:return: routes with node state
|
||||
"""
|
||||
return [self.node_state_mapping[target_state_id]
|
||||
for target_state_id in self.routes.get(source_node_state_id, [])]
|
||||
716
api/core/workflow/graph_engine/graph_engine.py
Normal file
716
api/core/workflow/graph_engine/graph_engine.py
Normal file
@@ -0,0 +1,716 @@
|
||||
import logging
|
||||
import queue
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunMetadataKey,
|
||||
NodeType,
|
||||
UserFrom,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseIterationEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph, GraphEdge
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphEngineThreadPool(ThreadPoolExecutor):
|
||||
def __init__(self, max_workers=None, thread_name_prefix='',
|
||||
initializer=None, initargs=(), max_submit_count=100) -> None:
|
||||
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
|
||||
self.max_submit_count = max_submit_count
|
||||
self.submit_count = 0
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
self.submit_count += 1
|
||||
self.check_is_full()
|
||||
|
||||
return super().submit(fn, *args, **kwargs)
|
||||
|
||||
def check_is_full(self) -> None:
|
||||
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
|
||||
if self.submit_count > self.max_submit_count:
|
||||
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
|
||||
|
||||
|
||||
class GraphEngine:
|
||||
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
graph_config: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int,
|
||||
thread_pool_id: Optional[str] = None
|
||||
) -> None:
|
||||
thread_pool_max_submit_count = 100
|
||||
thread_pool_max_workers = 10
|
||||
|
||||
## init thread pool
|
||||
if thread_pool_id:
|
||||
if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
|
||||
raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.")
|
||||
|
||||
self.thread_pool_id = thread_pool_id
|
||||
self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id]
|
||||
self.is_main_thread_pool = False
|
||||
else:
|
||||
self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count)
|
||||
self.thread_pool_id = str(uuid.uuid4())
|
||||
self.is_main_thread_pool = True
|
||||
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
|
||||
|
||||
self.graph = graph
|
||||
self.init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_type=workflow_type,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth
|
||||
)
|
||||
|
||||
self.graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self.max_execution_steps = max_execution_steps
|
||||
self.max_execution_time = max_execution_time
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
# trigger graph run start event
|
||||
yield GraphRunStartedEvent()
|
||||
|
||||
try:
|
||||
stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
|
||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||
stream_processor_cls = AnswerStreamProcessor
|
||||
else:
|
||||
stream_processor_cls = EndStreamProcessor
|
||||
|
||||
stream_processor = stream_processor_cls(
|
||||
graph=self.graph,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
# run graph
|
||||
generator = stream_processor.process(
|
||||
self._run(start_node_id=self.graph.root_node_id)
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
try:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
|
||||
return
|
||||
elif isinstance(item, NodeRunSucceededEvent):
|
||||
if item.node_type == NodeType.END:
|
||||
self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs
|
||||
if item.route_node_state.node_run_result
|
||||
and item.route_node_state.node_run_result.outputs
|
||||
else {})
|
||||
elif item.node_type == NodeType.ANSWER:
|
||||
if "answer" not in self.graph_runtime_state.outputs:
|
||||
self.graph_runtime_state.outputs["answer"] = ""
|
||||
|
||||
self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "")
|
||||
if item.route_node_state.node_run_result
|
||||
and item.route_node_state.node_run_result.outputs
|
||||
else "")
|
||||
|
||||
self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip()
|
||||
except Exception as e:
|
||||
logger.exception(f"Graph run failed: {str(e)}")
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
return
|
||||
|
||||
# trigger graph run success event
|
||||
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
|
||||
except GraphRunFailedError as e:
|
||||
yield GraphRunFailedEvent(error=e.error)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when graph running")
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
raise e
|
||||
finally:
|
||||
if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
|
||||
del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
start_node_id: str,
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
parallel_start_node_id = None
|
||||
if in_parallel_id:
|
||||
parallel_start_node_id = start_node_id
|
||||
|
||||
next_node_id = start_node_id
|
||||
previous_route_node_state: Optional[RouteNodeState] = None
|
||||
while True:
|
||||
# max steps reached
|
||||
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
|
||||
raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps))
|
||||
|
||||
# or max execution time reached
|
||||
if self._is_timed_out(
|
||||
start_at=self.graph_runtime_state.start_at,
|
||||
max_execution_time=self.max_execution_time
|
||||
):
|
||||
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
|
||||
|
||||
# init route node state
|
||||
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
|
||||
node_id=next_node_id
|
||||
)
|
||||
|
||||
# get node config
|
||||
node_id = route_node_state.node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError(f'Node {node_id} config not found.')
|
||||
|
||||
# convert to specific node
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
||||
|
||||
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls( # type: ignore
|
||||
id=route_node_state.id,
|
||||
config=node_config,
|
||||
graph_init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=self.thread_pool_id
|
||||
)
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = self._run_node(
|
||||
node_instance=node_instance,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
if isinstance(item, NodeRunStartedEvent):
|
||||
self.graph_runtime_state.node_run_steps += 1
|
||||
item.route_node_state.index = self.graph_runtime_state.node_run_steps
|
||||
|
||||
yield item
|
||||
|
||||
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
|
||||
|
||||
# append route
|
||||
if previous_route_node_state:
|
||||
self.graph_runtime_state.node_run_state.add_route(
|
||||
source_node_state_id=previous_route_node_state.id,
|
||||
target_node_state_id=route_node_state.id
|
||||
)
|
||||
except Exception as e:
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = str(e)
|
||||
yield NodeRunFailedEvent(
|
||||
error=str(e),
|
||||
id=node_instance.id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
raise e
|
||||
|
||||
# It may not be necessary, but it is necessary. :)
|
||||
if (self.graph.node_id_config_mapping[next_node_id]
|
||||
.get("data", {}).get("type", "").lower() == NodeType.END.value):
|
||||
break
|
||||
|
||||
previous_route_node_state = route_node_state
|
||||
|
||||
# get next node ids
|
||||
edge_mappings = self.graph.edge_mapping.get(next_node_id)
|
||||
if not edge_mappings:
|
||||
break
|
||||
|
||||
if len(edge_mappings) == 1:
|
||||
edge = edge_mappings[0]
|
||||
|
||||
if edge.run_condition:
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_route_node_state=previous_route_node_state
|
||||
)
|
||||
|
||||
if not result:
|
||||
break
|
||||
|
||||
next_node_id = edge.target_node_id
|
||||
else:
|
||||
final_node_id = None
|
||||
|
||||
if any(edge.run_condition for edge in edge_mappings):
|
||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||
condition_edge_mappings = {}
|
||||
for edge in edge_mappings:
|
||||
if edge.run_condition:
|
||||
run_condition_hash = edge.run_condition.hash
|
||||
if run_condition_hash not in condition_edge_mappings:
|
||||
condition_edge_mappings[run_condition_hash] = []
|
||||
|
||||
condition_edge_mappings[run_condition_hash].append(edge)
|
||||
|
||||
for _, sub_edge_mappings in condition_edge_mappings.items():
|
||||
if len(sub_edge_mappings) == 0:
|
||||
continue
|
||||
|
||||
edge = sub_edge_mappings[0]
|
||||
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_route_node_state=previous_route_node_state,
|
||||
)
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
if len(sub_edge_mappings) == 1:
|
||||
final_node_id = edge.target_node_id
|
||||
else:
|
||||
parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=sub_edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
final_node_id = item
|
||||
else:
|
||||
yield item
|
||||
|
||||
break
|
||||
|
||||
if not final_node_id:
|
||||
break
|
||||
|
||||
next_node_id = final_node_id
|
||||
else:
|
||||
parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
final_node_id = item
|
||||
else:
|
||||
yield item
|
||||
|
||||
if not final_node_id:
|
||||
break
|
||||
|
||||
next_node_id = final_node_id
|
||||
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
||||
break
|
||||
|
||||
def _run_parallel_branches(
|
||||
self,
|
||||
edge_mappings: list[GraphEdge],
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
) -> Generator[GraphEngineEvent | str, None, None]:
|
||||
# if nodes has no run conditions, parallel run all nodes
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||
if not parallel_id:
|
||||
node_id = edge_mappings[0].target_node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.')
|
||||
|
||||
node_title = node_config.get('data', {}).get('title')
|
||||
raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.')
|
||||
|
||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||
if not parallel:
|
||||
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
|
||||
|
||||
# run parallel nodes, run in new thread and use queue to get results
|
||||
q: queue.Queue = queue.Queue()
|
||||
|
||||
# Create a list to store the threads
|
||||
futures = []
|
||||
|
||||
# new thread
|
||||
for edge in edge_mappings:
|
||||
if (
|
||||
edge.target_node_id not in self.graph.node_parallel_mapping
|
||||
or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id
|
||||
):
|
||||
continue
|
||||
|
||||
futures.append(
|
||||
self.thread_pool.submit(self._run_parallel_node, **{
|
||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
'q': q,
|
||||
'parallel_id': parallel_id,
|
||||
'parallel_start_node_id': edge.target_node_id,
|
||||
'parent_parallel_id': in_parallel_id,
|
||||
'parent_parallel_start_node_id': parallel_start_node_id,
|
||||
})
|
||||
)
|
||||
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
if event is None:
|
||||
break
|
||||
|
||||
yield event
|
||||
if event.parallel_id == parallel_id:
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(futures):
|
||||
q.put(None)
|
||||
|
||||
continue
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
raise GraphRunFailedError(event.error)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# wait all threads
|
||||
wait(futures)
|
||||
|
||||
# get final node id
|
||||
final_node_id = parallel.end_to_node_id
|
||||
if final_node_id:
|
||||
yield final_node_id
|
||||
|
||||
def _run_parallel_node(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
q: queue.Queue,
|
||||
parallel_id: str,
|
||||
parallel_start_node_id: str,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run parallel nodes
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
q.put(ParallelBranchRunStartedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
))
|
||||
|
||||
# run node
|
||||
generator = self._run(
|
||||
start_node_id=parallel_start_node_id,
|
||||
in_parallel_id=parallel_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
q.put(item)
|
||||
|
||||
# trigger graph run success event
|
||||
q.put(ParallelBranchRunSucceededEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
))
|
||||
except GraphRunFailedError as e:
|
||||
q.put(ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=e.error
|
||||
))
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating in parallel")
|
||||
q.put(ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=str(e)
|
||||
))
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
node_instance: BaseNode,
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
# trigger node run start event
|
||||
yield NodeRunStartedEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or 'Unknown error.',
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value
|
||||
)
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _append_variables_recursively(self,
|
||||
node_id: str,
|
||||
variable_key_list: list[str],
|
||||
variable_value: VariableValue):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
self.graph_runtime_state.variable_pool.add(
|
||||
[node_id] + variable_key_list,
|
||||
variable_value
|
||||
)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
for key, value in variable_value.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
node_id=node_id,
|
||||
variable_key_list=new_key_list,
|
||||
variable_value=value
|
||||
)
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
Check timeout
|
||||
:param start_at: start time
|
||||
:param max_execution_time: max execution time
|
||||
:return:
|
||||
"""
|
||||
return time.perf_counter() - start_at > max_execution_time
|
||||
|
||||
|
||||
class GraphRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
Reference in New Issue
Block a user