mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-23 01:36:55 +08:00
feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -1,24 +0,0 @@
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
iteration_node_id: str
|
||||
index: int
|
||||
inputs: dict
|
||||
|
||||
class MetaData(BaseModel):
|
||||
pass
|
||||
|
||||
metadata: MetaData
|
||||
@@ -1,52 +1,14 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
"""
|
||||
Node Types.
|
||||
"""
|
||||
|
||||
START = "start"
|
||||
END = "end"
|
||||
ANSWER = "answer"
|
||||
LLM = "llm"
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TEMPLATE_TRANSFORM = "template-transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
# TODO: merge this into VARIABLE_AGGREGATOR
|
||||
VARIABLE_ASSIGNER = "variable-assigner"
|
||||
LOOP = "loop"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # fake start node for iteration
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "NodeType":
|
||||
"""
|
||||
Get value of given node type.
|
||||
|
||||
:param value: node type value
|
||||
:return: node type
|
||||
"""
|
||||
for node_type in cls:
|
||||
if node_type.value == value:
|
||||
return node_type
|
||||
raise ValueError(f"invalid node type value {value}")
|
||||
|
||||
|
||||
class NodeRunMetadataKey(Enum):
|
||||
class NodeRunMetadataKey(str, Enum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
@@ -70,7 +32,7 @@ class NodeRunResult(BaseModel):
|
||||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None # node inputs
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict[str, Any]] = None # process data
|
||||
outputs: Optional[dict[str, Any]] = None # node outputs
|
||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
@@ -79,24 +41,3 @@ class NodeRunResult(BaseModel):
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
"""
|
||||
User from
|
||||
"""
|
||||
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "UserFrom":
|
||||
"""
|
||||
Value of
|
||||
:param value: value
|
||||
:return:
|
||||
"""
|
||||
for item in cls:
|
||||
if item.value == value:
|
||||
return item
|
||||
raise ValueError(f"Invalid value: {value}")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -7,4 +9,4 @@ class VariableSelector(BaseModel):
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_selector: list[str]
|
||||
value_selector: Sequence[str]
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import Segment, Variable, factory
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.segments import FileSegment
|
||||
from factories import variable_factory
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from ..enums import SystemVariableKey
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, File]
|
||||
|
||||
|
||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
||||
|
||||
class VariablePool(BaseModel):
|
||||
@@ -23,46 +26,63 @@ class VariablePool(BaseModel):
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: dict[str, dict[int, Segment]] = Field(
|
||||
description="Variables mapping", default=defaultdict(dict)
|
||||
description="Variables mapping",
|
||||
default=defaultdict(dict),
|
||||
)
|
||||
|
||||
# TODO: This user inputs is not used for pool.
|
||||
user_inputs: Mapping[str, Any] = Field(
|
||||
description="User inputs",
|
||||
)
|
||||
|
||||
system_variables: Mapping[SystemVariableKey, Any] = Field(
|
||||
description="System variables",
|
||||
)
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
conversation_variables: Sequence[Variable] = Field(
|
||||
description="Conversation variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list)
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
system_variables: Mapping[SystemVariableKey, Any] | None = None,
|
||||
user_inputs: Mapping[str, Any] | None = None,
|
||||
environment_variables: Sequence[Variable] | None = None,
|
||||
conversation_variables: Sequence[Variable] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
environment_variables = environment_variables or []
|
||||
conversation_variables = conversation_variables or []
|
||||
user_inputs = user_inputs or {}
|
||||
system_variables = system_variables or {}
|
||||
|
||||
conversation_variables: Sequence[Variable] | None = None
|
||||
super().__init__(
|
||||
system_variables=system_variables,
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def val_model_after(self):
|
||||
"""
|
||||
Append system variables
|
||||
:return:
|
||||
"""
|
||||
# Add system variables to the variable pool
|
||||
for key, value in self.system_variables.items():
|
||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
||||
|
||||
# Add environment variables to the variable pool
|
||||
for var in self.environment_variables or []:
|
||||
for var in self.environment_variables:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
# Add conversation variables to the variable pool
|
||||
for var in self.conversation_variables or []:
|
||||
for var in self.conversation_variables:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
return self
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
Adds a variable to the variable pool.
|
||||
|
||||
NOTE: You should not add a non-Segment value to the variable pool
|
||||
even if it is allowed now.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): The selector for the variable.
|
||||
value (VariableValue): The value of the variable.
|
||||
@@ -82,7 +102,7 @@ class VariablePool(BaseModel):
|
||||
if isinstance(value, Segment):
|
||||
v = value
|
||||
else:
|
||||
v = factory.build_segment(value)
|
||||
v = variable_factory.build_segment(value)
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self.variable_dictionary[selector[0]][hash_key] = v
|
||||
@@ -101,10 +121,19 @@ class VariablePool(BaseModel):
|
||||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
raise ValueError("Invalid selector")
|
||||
return None
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
if value is None:
|
||||
selector, attr = selector[:-1], selector[-1]
|
||||
value = self.get(selector)
|
||||
if isinstance(value, FileSegment):
|
||||
attr = FileAttribute(attr)
|
||||
attr_value = file_manager.get_attr(file=value.value, attr=attr)
|
||||
return variable_factory.build_segment(attr_value)
|
||||
|
||||
return value
|
||||
|
||||
@deprecated("This method is deprecated, use `get` instead.")
|
||||
@@ -145,14 +174,18 @@ class VariablePool(BaseModel):
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self.variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
|
||||
def remove_node(self, node_id: str, /):
|
||||
"""
|
||||
Remove all variables associated with a given node id.
|
||||
def convert_template(self, template: str, /):
|
||||
parts = VARIABLE_PATTERN.split(template)
|
||||
segments = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if "." in part and (variable := self.get(part.split("."))):
|
||||
segments.append(variable)
|
||||
else:
|
||||
segments.append(variable_factory.build_segment(part))
|
||||
return SegmentGroup(value=segments)
|
||||
|
||||
Args:
|
||||
node_id (str): The node id to remove.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.variable_dictionary.pop(node_id, None)
|
||||
def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
|
||||
segment = self.get(selector)
|
||||
if isinstance(segment, FileSegment):
|
||||
return segment
|
||||
return None
|
||||
|
||||
@@ -3,12 +3,13 @@ from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode, UserFrom
|
||||
from core.workflow.nodes.base import BaseIterationState, BaseNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
|
||||
from .node_entities import NodeRunResult
|
||||
from .variable_pool import VariablePool
|
||||
|
||||
|
||||
class WorkflowNodeAndResult:
|
||||
node: BaseNode
|
||||
|
||||
Reference in New Issue
Block a user