feat/enhance the multi-modal support (#8818)

This commit is contained in:
-LAN-
2024-10-21 10:43:49 +08:00
committed by GitHub
parent 7a1d6fe509
commit e61752bd3a
267 changed files with 6263 additions and 3523 deletions

View File

@@ -1,24 +0,0 @@
from abc import ABC
from typing import Optional
from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
class BaseIterationNodeData(BaseNodeData):
start_node_id: Optional[str] = None
class BaseIterationState(BaseModel):
iteration_node_id: str
index: int
inputs: dict
class MetaData(BaseModel):
pass
metadata: MetaData

View File

@@ -1,52 +1,14 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMUsage
from models import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionStatus
class NodeType(Enum):
"""
Node Types.
"""
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
VARIABLE_AGGREGATOR = "variable-aggregator"
# TODO: merge this into VARIABLE_AGGREGATOR
VARIABLE_ASSIGNER = "variable-assigner"
LOOP = "loop"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # fake start node for iteration
PARAMETER_EXTRACTOR = "parameter-extractor"
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
@classmethod
def value_of(cls, value: str) -> "NodeType":
"""
Get value of given node type.
:param value: node type value
:return: node type
"""
for node_type in cls:
if node_type.value == value:
return node_type
raise ValueError(f"invalid node type value {value}")
class NodeRunMetadataKey(Enum):
class NodeRunMetadataKey(str, Enum):
"""
Node Run Metadata Key.
"""
@@ -70,7 +32,7 @@ class NodeRunResult(BaseModel):
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict[str, Any]] = None # node inputs
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict[str, Any]] = None # process data
outputs: Optional[dict[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
@@ -79,24 +41,3 @@ class NodeRunResult(BaseModel):
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")

View File

@@ -1,3 +1,5 @@
from collections.abc import Sequence
from pydantic import BaseModel
@@ -7,4 +9,4 @@ class VariableSelector(BaseModel):
"""
variable: str
value_selector: list[str]
value_selector: Sequence[str]

View File

@@ -1,20 +1,23 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariableKey
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.segments import FileSegment
from factories import variable_factory
VariableValue = Union[str, int, float, dict, list, FileVar]
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from ..enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, File]
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
class VariablePool(BaseModel):
@@ -23,46 +26,63 @@ class VariablePool(BaseModel):
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, Segment]] = Field(
description="Variables mapping", default=defaultdict(dict)
description="Variables mapping",
default=defaultdict(dict),
)
# TODO: This user inputs is not used for pool.
user_inputs: Mapping[str, Any] = Field(
description="User inputs",
)
system_variables: Mapping[SystemVariableKey, Any] = Field(
description="System variables",
)
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
default_factory=list,
)
conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.",
default_factory=list,
)
environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list)
def __init__(
self,
*,
system_variables: Mapping[SystemVariableKey, Any] | None = None,
user_inputs: Mapping[str, Any] | None = None,
environment_variables: Sequence[Variable] | None = None,
conversation_variables: Sequence[Variable] | None = None,
**kwargs,
):
environment_variables = environment_variables or []
conversation_variables = conversation_variables or []
user_inputs = user_inputs or {}
system_variables = system_variables or {}
conversation_variables: Sequence[Variable] | None = None
super().__init__(
system_variables=system_variables,
user_inputs=user_inputs,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
**kwargs,
)
@model_validator(mode="after")
def val_model_after(self):
"""
Append system variables
:return:
"""
# Add system variables to the variable pool
for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool
for var in self.environment_variables or []:
for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
for var in self.conversation_variables or []:
for var in self.conversation_variables:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
return self
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.
NOTE: You should not add a non-Segment value to the variable pool
even if it is allowed now.
Args:
selector (Sequence[str]): The selector for the variable.
value (VariableValue): The value of the variable.
@@ -82,7 +102,7 @@ class VariablePool(BaseModel):
if isinstance(value, Segment):
v = value
else:
v = factory.build_segment(value)
v = variable_factory.build_segment(value)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = v
@@ -101,10 +121,19 @@ class VariablePool(BaseModel):
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError("Invalid selector")
return None
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
if value is None:
selector, attr = selector[:-1], selector[-1]
value = self.get(selector)
if isinstance(value, FileSegment):
attr = FileAttribute(attr)
attr_value = file_manager.get_attr(file=value.value, attr=attr)
return variable_factory.build_segment(attr_value)
return value
@deprecated("This method is deprecated, use `get` instead.")
@@ -145,14 +174,18 @@ class VariablePool(BaseModel):
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]].pop(hash_key, None)
def remove_node(self, node_id: str, /):
"""
Remove all variables associated with a given node id.
def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
segments = []
for part in filter(lambda x: x, parts):
if "." in part and (variable := self.get(part.split("."))):
segments.append(variable)
else:
segments.append(variable_factory.build_segment(part))
return SegmentGroup(value=segments)
Args:
node_id (str): The node id to remove.
Returns:
None
"""
self.variable_dictionary.pop(node_id, None)
def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
segment = self.get(selector)
if isinstance(segment, FileSegment):
return segment
return None

View File

@@ -3,12 +3,13 @@ from typing import Optional
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode, UserFrom
from core.workflow.nodes.base import BaseIterationState, BaseNode
from models.enums import UserFrom
from models.workflow import Workflow, WorkflowType
from .node_entities import NodeRunResult
from .variable_pool import VariablePool
class WorkflowNodeAndResult:
node: BaseNode