mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-11 03:46:52 +08:00
feat/enhance the multi-modal support (#8818)
This commit is contained in:
4
api/core/workflow/nodes/base/__init__.py
Normal file
4
api/core/workflow/nodes/base/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
from .node import BaseNode
|
||||
|
||||
__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"]
|
||||
24
api/core/workflow/nodes/base/entities.py
Normal file
24
api/core/workflow/nodes/base/entities.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
iteration_node_id: str
|
||||
index: int
|
||||
inputs: dict
|
||||
|
||||
class MetaData(BaseModel):
|
||||
pass
|
||||
|
||||
metadata: MetaData
|
||||
137
api/core/workflow/nodes/base/node.py
Normal file
137
api/core/workflow/nodes/base/node.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import BaseNodeData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
|
||||
|
||||
|
||||
class BaseNode(Generic[GenericNodeData]):
|
||||
_node_data_cls: type[BaseNodeData]
|
||||
_node_type: NodeType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
self.workflow_type = graph_init_params.workflow_type
|
||||
self.workflow_id = graph_init_params.workflow_id
|
||||
self.graph_config = graph_init_params.graph_config
|
||||
self.user_id = graph_init_params.user_id
|
||||
self.user_from = graph_init_params.user_from
|
||||
self.invoke_from = graph_init_params.invoke_from
|
||||
self.workflow_call_depth = graph_init_params.call_depth
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.previous_node_id = previous_node_id
|
||||
self.thread_pool_id = thread_pool_id
|
||||
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
self.node_id = node_id
|
||||
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
try:
|
||||
result = self._run()
|
||||
except Exception as e:
|
||||
logger.error(f"Node {self.node_id} failed to run: {e}")
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield RunCompletedEvent(run_result=result)
|
||||
else:
|
||||
yield from result
|
||||
|
||||
@classmethod
|
||||
def extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
config: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
return cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: GenericNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def node_type(self) -> NodeType:
|
||||
"""
|
||||
Get node type
|
||||
:return:
|
||||
"""
|
||||
return self._node_type
|
||||
Reference in New Issue
Block a user