Feat: conversation variable & variable assigner node (#7222)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
KVOJJJin
2024-08-13 14:44:10 +08:00
committed by GitHub
parent 8b55bd5828
commit 935e72d449
128 changed files with 3354 additions and 683 deletions

View File

@@ -8,6 +8,7 @@ from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from models import WorkflowNodeExecutionStatus
class UserFrom(Enum):
@@ -91,14 +92,19 @@ class BaseNode(ABC):
:param variable_pool: variable pool
:return:
"""
result = self._run(
variable_pool=variable_pool
)
try:
result = self._run(
variable_pool=variable_pool
)
self.node_run_result = result
return result
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
self.node_run_result = result
return result
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None:
"""
Publish text chunk
:param text: chunk text

View File

@@ -0,0 +1,109 @@
from collections.abc import Sequence
from enum import Enum
from typing import Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.segments import SegmentType, Variable, factory
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus
class VariableAssignerNodeError(Exception):
pass
class WriteMode(str, Enum):
OVER_WRITE = 'over-write'
APPEND = 'append'
CLEAR = 'clear'
class VariableAssignerData(BaseNodeData):
title: str = 'Variable Assigner'
desc: Optional[str] = 'Assign a value to a variable'
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]
class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')
match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})
case WriteMode.APPEND:
income_value = variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={'value': updated_value})
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
case _:
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
# Over write the variable.
variable_pool.add(data.assigned_variable_selector, updated_variable)
# Update conversation variable.
# TODO: Find a better way to use the database.
conversation_id = variable_pool.get(['sys', 'conversation_id'])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
'value': income_value.to_object(),
},
)
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError('conversation variable not found in the database')
row.data = variable.model_dump_json()
session.commit()
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return factory.build_segment([])
case SegmentType.OBJECT:
return factory.build_segment({})
case SegmentType.STRING:
return factory.build_segment('')
case SegmentType.NUMBER:
return factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f'unsupported variable type: {t}')