feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -5,7 +5,7 @@ from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState):
"""
Check if the condition can be executed

View File

@@ -1,4 +1,5 @@
import uuid
from collections import defaultdict
from collections.abc import Mapping
from typing import Any, Optional, cast
@@ -310,26 +311,17 @@ class Graph(BaseModel):
parallel = None
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_branch_node_ids = {}
condition_edge_mappings = {}
parallel_branch_node_ids = defaultdict(list)
condition_edge_mappings = defaultdict(list)
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
if "default" not in parallel_branch_node_ids:
parallel_branch_node_ids["default"] = []
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
if condition_hash not in condition_edge_mappings:
condition_edge_mappings[condition_hash] = []
condition_edge_mappings[condition_hash].append(graph_edge)
for condition_hash, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
if condition_hash not in parallel_branch_node_ids:
parallel_branch_node_ids[condition_hash] = []
for graph_edge in graph_edges:
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
@@ -418,7 +410,7 @@ class Graph(BaseModel):
if condition_edge_mappings:
for condition_hash, graph_edges in condition_edge_mappings.items():
for graph_edge in graph_edges:
current_parallel: GraphParallel | None = cls._get_current_parallel(
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=condition_parallels.get(condition_hash),

View File

@@ -40,6 +40,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
@@ -66,7 +67,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
self.max_submit_count = max_submit_count
self.submit_count = 0
def submit(self, fn, *args, **kwargs):
def submit(self, fn, /, *args, **kwargs):
self.submit_count += 1
self.check_is_full()
@@ -140,7 +141,8 @@ class GraphEngine:
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
yield GraphRunStartedEvent()
handle_exceptions = []
handle_exceptions: list[str] = []
stream_processor: StreamProcessor
try:
if self.init_params.workflow_type == WorkflowType.CHAT:
@@ -168,7 +170,7 @@ class GraphEngine:
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
self.graph_runtime_state.outputs = (
item.route_node_state.node_run_result.outputs
dict(item.route_node_state.node_run_result.outputs)
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else {}
@@ -350,7 +352,7 @@ class GraphEngine:
if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results
condition_edge_mappings = {}
condition_edge_mappings: dict[str, list[GraphEdge]] = {}
for edge in edge_mappings:
if edge.run_condition:
run_condition_hash = edge.run_condition.hash
@@ -364,6 +366,9 @@ class GraphEngine:
continue
edge = cast(GraphEdge, sub_edge_mappings[0])
if edge.run_condition is None:
logger.warning(f"Edge {edge.target_node_id} run condition is None")
continue
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@@ -387,11 +392,11 @@ class GraphEngine:
handle_exceptions=handle_exceptions,
)
for item in parallel_generator:
if isinstance(item, str):
final_node_id = item
for parallel_result in parallel_generator:
if isinstance(parallel_result, str):
final_node_id = parallel_result
else:
yield item
yield parallel_result
break
@@ -413,11 +418,11 @@ class GraphEngine:
handle_exceptions=handle_exceptions,
)
for item in parallel_generator:
if isinstance(item, str):
final_node_id = item
for generated_item in parallel_generator:
if isinstance(generated_item, str):
final_node_id = generated_item
else:
yield item
yield generated_item
if not final_node_id:
break
@@ -653,7 +658,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
error=run_result.error,
error=run_result.error or "Unknown error",
retry_index=retries,
start_at=retry_start_at,
)
@@ -732,20 +737,20 @@ class GraphEngine:
variable_value=variable_value,
)
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
# When setting metadata, convert to dict first
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
parallel_start_node_id
)
if parallel_id and parallel_start_node_id:
metadata_dict = dict(run_result.metadata)
metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
run_result.metadata = metadata_dict
yield NodeRunSucceededEvent(
id=node_instance.id,
@@ -869,8 +874,8 @@ class GraphEngine:
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
# add error message to handle_exceptions
handle_exceptions.append(error_result.error)
node_error_args = {
handle_exceptions.append(error_result.error or "")
node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": error_result.error,
"inputs": error_result.inputs,