mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-12 20:36:53 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -5,7 +5,7 @@ from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState):
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
@@ -310,26 +311,17 @@ class Graph(BaseModel):
|
||||
parallel = None
|
||||
if len(target_node_edges) > 1:
|
||||
# fetch all node ids in current parallels
|
||||
parallel_branch_node_ids = {}
|
||||
condition_edge_mappings = {}
|
||||
parallel_branch_node_ids = defaultdict(list)
|
||||
condition_edge_mappings = defaultdict(list)
|
||||
for graph_edge in target_node_edges:
|
||||
if graph_edge.run_condition is None:
|
||||
if "default" not in parallel_branch_node_ids:
|
||||
parallel_branch_node_ids["default"] = []
|
||||
|
||||
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
|
||||
else:
|
||||
condition_hash = graph_edge.run_condition.hash
|
||||
if condition_hash not in condition_edge_mappings:
|
||||
condition_edge_mappings[condition_hash] = []
|
||||
|
||||
condition_edge_mappings[condition_hash].append(graph_edge)
|
||||
|
||||
for condition_hash, graph_edges in condition_edge_mappings.items():
|
||||
if len(graph_edges) > 1:
|
||||
if condition_hash not in parallel_branch_node_ids:
|
||||
parallel_branch_node_ids[condition_hash] = []
|
||||
|
||||
for graph_edge in graph_edges:
|
||||
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
|
||||
|
||||
@@ -418,7 +410,7 @@ class Graph(BaseModel):
|
||||
if condition_edge_mappings:
|
||||
for condition_hash, graph_edges in condition_edge_mappings.items():
|
||||
for graph_edge in graph_edges:
|
||||
current_parallel: GraphParallel | None = cls._get_current_parallel(
|
||||
current_parallel = cls._get_current_parallel(
|
||||
parallel_mapping=parallel_mapping,
|
||||
graph_edge=graph_edge,
|
||||
parallel=condition_parallels.get(condition_hash),
|
||||
|
||||
@@ -40,6 +40,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
@@ -66,7 +67,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
|
||||
self.max_submit_count = max_submit_count
|
||||
self.submit_count = 0
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
def submit(self, fn, /, *args, **kwargs):
|
||||
self.submit_count += 1
|
||||
self.check_is_full()
|
||||
|
||||
@@ -140,7 +141,8 @@ class GraphEngine:
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
# trigger graph run start event
|
||||
yield GraphRunStartedEvent()
|
||||
handle_exceptions = []
|
||||
handle_exceptions: list[str] = []
|
||||
stream_processor: StreamProcessor
|
||||
|
||||
try:
|
||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||
@@ -168,7 +170,7 @@ class GraphEngine:
|
||||
elif isinstance(item, NodeRunSucceededEvent):
|
||||
if item.node_type == NodeType.END:
|
||||
self.graph_runtime_state.outputs = (
|
||||
item.route_node_state.node_run_result.outputs
|
||||
dict(item.route_node_state.node_run_result.outputs)
|
||||
if item.route_node_state.node_run_result
|
||||
and item.route_node_state.node_run_result.outputs
|
||||
else {}
|
||||
@@ -350,7 +352,7 @@ class GraphEngine:
|
||||
|
||||
if any(edge.run_condition for edge in edge_mappings):
|
||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||
condition_edge_mappings = {}
|
||||
condition_edge_mappings: dict[str, list[GraphEdge]] = {}
|
||||
for edge in edge_mappings:
|
||||
if edge.run_condition:
|
||||
run_condition_hash = edge.run_condition.hash
|
||||
@@ -364,6 +366,9 @@ class GraphEngine:
|
||||
continue
|
||||
|
||||
edge = cast(GraphEdge, sub_edge_mappings[0])
|
||||
if edge.run_condition is None:
|
||||
logger.warning(f"Edge {edge.target_node_id} run condition is None")
|
||||
continue
|
||||
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
@@ -387,11 +392,11 @@ class GraphEngine:
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
final_node_id = item
|
||||
for parallel_result in parallel_generator:
|
||||
if isinstance(parallel_result, str):
|
||||
final_node_id = parallel_result
|
||||
else:
|
||||
yield item
|
||||
yield parallel_result
|
||||
|
||||
break
|
||||
|
||||
@@ -413,11 +418,11 @@ class GraphEngine:
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
final_node_id = item
|
||||
for generated_item in parallel_generator:
|
||||
if isinstance(generated_item, str):
|
||||
final_node_id = generated_item
|
||||
else:
|
||||
yield item
|
||||
yield generated_item
|
||||
|
||||
if not final_node_id:
|
||||
break
|
||||
@@ -653,7 +658,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=run_result.error,
|
||||
error=run_result.error or "Unknown error",
|
||||
retry_index=retries,
|
||||
start_at=retry_start_at,
|
||||
)
|
||||
@@ -732,20 +737,20 @@ class GraphEngine:
|
||||
variable_value=variable_value,
|
||||
)
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
# When setting metadata, convert to dict first
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
|
||||
parallel_start_node_id
|
||||
)
|
||||
if parallel_id and parallel_start_node_id:
|
||||
metadata_dict = dict(run_result.metadata)
|
||||
metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
)
|
||||
run_result.metadata = metadata_dict
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
@@ -869,8 +874,8 @@ class GraphEngine:
|
||||
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
|
||||
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
|
||||
# add error message to handle_exceptions
|
||||
handle_exceptions.append(error_result.error)
|
||||
node_error_args = {
|
||||
handle_exceptions.append(error_result.error or "")
|
||||
node_error_args: dict[str, Any] = {
|
||||
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
"error": error_result.error,
|
||||
"inputs": error_result.inputs,
|
||||
|
||||
Reference in New Issue
Block a user