Feat: continue on error (#11458)

Co-authored-by: Novice Lee <novicelee@NovicedeMacBook-Pro.local>
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
This commit is contained in:
Novice
2024-12-11 14:22:42 +08:00
committed by GitHub
parent bec5451f12
commit 79a710ce98
31 changed files with 1211 additions and 80 deletions

View File

@@ -6,7 +6,7 @@ from core.workflow.nodes.answer.entities import (
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@@ -148,13 +148,18 @@ class AnswerStreamGeneratorRouter:
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.VARIABLE_ASSIGNER,
}:
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (
source_node_type
in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.VARIABLE_ASSIGNER,
}
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(

View File

@@ -6,6 +6,7 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunExceptionEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@@ -50,7 +51,7 @@ class AnswerStreamProcessor(StreamProcessor):
for _ in stream_out_answer_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent):
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished

View File

@@ -1,14 +1,124 @@
import json
from abc import ABC
from typing import Optional
from enum import StrEnum
from typing import Any, Optional, Union
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from core.workflow.nodes.base.exc import DefaultValueTypeError
from core.workflow.nodes.enums import ErrorStrategy
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
NumberType = Union[int, float]
class DefaultValue(BaseModel):
value: Any
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str) -> Any:
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> "DefaultValue":
if self.type is None:
raise DefaultValueTypeError("type field is required")
# Type validation configuration
type_validators = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator = type_validators.get(self.type)
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
version: str = "1"
@property
def default_value_dict(self):
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
class BaseIterationNodeData(BaseNodeData):
start_node_id: Optional[str] = None

View File

@@ -0,0 +1,10 @@
class BaseNodeError(Exception):
"""Base class for node errors."""
pass
class DefaultValueTypeError(BaseNodeError):
"""Raised when the default value type is invalid."""
pass

View File

@@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus
@@ -72,10 +72,7 @@ class BaseNode(Generic[GenericNodeData]):
result = self._run()
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run")
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
@@ -137,3 +134,12 @@ class BaseNode(Generic[GenericNodeData]):
:return:
"""
return self._node_type
@property
def should_continue_on_error(self) -> bool:
"""judge if should continue on error
Returns:
bool: if should continue on error
"""
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE

View File

@@ -61,7 +61,9 @@ class CodeNode(BaseNode[CodeNodeData]):
# Transform result
result = self._transform_result(result, self.node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)

View File

@@ -22,3 +22,16 @@ class NodeType(StrEnum):
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
class ErrorStrategy(StrEnum):
FAIL_BRANCH = "fail-branch"
DEFAULT_VALUE = "default-value"
class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch"
SUCCESS = "success-branch"
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]

View File

@@ -21,6 +21,7 @@ from .entities import (
from .exc import (
AuthorizationConfigError,
FileFetchError,
HttpRequestNodeError,
InvalidHttpMethodError,
ResponseSizeError,
)
@@ -208,8 +209,10 @@ class Executor:
"follow_redirects": True,
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
response = getattr(ssrf_proxy, self.method)(**request_args)
try:
response = getattr(ssrf_proxy, self.method)(**request_args)
except ssrf_proxy.MaxRetriesExceededError as e:
raise HttpRequestNodeError(str(e))
return response
def invoke(self) -> Response:

View File

@@ -65,6 +65,21 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and self.should_continue_on_error:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
"status_code": response.status_code,
"body": response.text if not files else "",
"headers": response.headers,
"files": files,
},
process_data={
"request": http_executor.to_log(),
},
error=f"Request failed with status code {response.status_code}",
error_type="HTTPResponseCodeError",
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
@@ -83,6 +98,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
process_data=process_data,
error_type=type(e).__name__,
)
@staticmethod

View File

@@ -193,6 +193,7 @@ class LLMNode(BaseNode[LLMNodeData]):
error=str(e),
inputs=node_inputs,
process_data=process_data,
error_type=type(e).__name__,
)
)
return

View File

@@ -139,7 +139,7 @@ class QuestionClassifierNode(LLMNode):
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
}
outputs = {"class_name": category_name}
outputs = {"class_name": category_name, "class_id": category_id}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,

View File

@@ -56,6 +56,7 @@ class ToolNode(BaseNode[ToolNodeData]):
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to get tool runtime: {str(e)}",
error_type=type(e).__name__,
)
# get parameters
@@ -89,6 +90,7 @@ class ToolNode(BaseNode[ToolNodeData]):
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to invoke tool: {str(e)}",
error_type=type(e).__name__,
)
# convert tool messages