Feat: conversation variable & variable assigner node (#7222)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
KVOJJJin
2024-08-13 14:44:10 +08:00
committed by GitHub
parent 8b55bd5828
commit 935e72d449
128 changed files with 3354 additions and 683 deletions

View File

@@ -3,8 +3,9 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.file.file_obj import FileExtraConfig
from core.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode
from models import AppMode
class ModelConfigEntity(BaseModel):
@@ -200,11 +201,6 @@ class TracingConfigEntity(BaseModel):
tracing_provider: str
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class AppAdditionalFeatures(BaseModel):

View File

@@ -1,7 +1,7 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.app.app_config.entities import FileExtraConfig
from core.file.file_obj import FileExtraConfig
class FileUploadConfigManager:

View File

@@ -113,7 +113,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=invoke_from,
@@ -180,7 +179,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
@@ -189,12 +187,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream
)
def _generate(self, app_model: App,
def _generate(self, *,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation = None,
conversation: Conversation | None = None,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
is_first_conversation = False

View File

@@ -4,6 +4,9 @@ import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -17,11 +20,12 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueSto
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
from models.workflow import ConversationVariable, Workflow
logger = logging.getLogger(__name__)
@@ -31,10 +35,13 @@ class AdvancedChatAppRunner(AppRunner):
AdvancedChat Application Runner
"""
def run(self, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message) -> None:
def run(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
"""
Run application
:param application_generate_entity: application generate entity
@@ -48,11 +55,11 @@ class AdvancedChatAppRunner(AppRunner):
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
query = application_generate_entity.query
@@ -68,35 +75,66 @@ class AdvancedChatAppRunner(AppRunner):
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id,
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
):
return
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
session.commit()
# Convert database entities to variables
conversation_variables = [item.to_variable() for item in conversation_variables]
# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
@@ -106,43 +144,30 @@ class AdvancedChatAppRunner(AppRunner):
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth
call_depth=application_generate_entity.call_depth,
variable_pool=variable_pool,
)
def single_iteration_run(self, app_id: str, workflow_id: str,
queue_manager: AppQueueManager,
inputs: dict, node_id: str, user_id: str) -> None:
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
workflow_callbacks = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow,
node_id=node_id,
user_id=user_id,
user_inputs=inputs,
callbacks=workflow_callbacks
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
@@ -150,22 +175,25 @@ class AdvancedChatAppRunner(AppRunner):
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def handle_input_moderation(
self, queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
self,
queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> bool:
"""
Handle input moderation
@@ -192,17 +220,20 @@ class AdvancedChatAppRunner(AppRunner):
queue_manager=queue_manager,
text=str(e),
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
)
return True
return False
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
def handle_annotation_reply(
self,
app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity,
) -> bool:
"""
Handle annotation reply
:param app_record: app record
@@ -217,29 +248,27 @@ class AdvancedChatAppRunner(AppRunner):
message=message,
query=query,
user_id=app_generate_entity.user_id,
invoke_from=app_generate_entity.invoke_from
invoke_from=app_generate_entity.invoke_from,
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
)
self._stream_output(
queue_manager=queue_manager,
text=annotation_reply.content,
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
)
return True
return False
def _stream_output(self, queue_manager: AppQueueManager,
text: str,
stream: bool,
stopped_by: QueueStopEvent.StopBy) -> None:
def _stream_output(
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
) -> None:
"""
Direct output
:param queue_manager: application queue manager
@@ -250,21 +279,10 @@ class AdvancedChatAppRunner(AppRunner):
if stream:
index = 0
for token in text:
queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueStopEvent(stopped_by=stopped_by),
PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)

View File

@@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
@@ -26,8 +27,7 @@ class WorkflowAppRunner:
Workflow Application Runner
"""
def run(self, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager) -> None:
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
"""
Run application
:param application_generate_entity: application generate entity
@@ -47,25 +47,36 @@ class WorkflowAppRunner:
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
files = application_generate_entity.files
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# Create a variable pool.
system_inputs = {
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
@@ -75,44 +86,33 @@ class WorkflowAppRunner:
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth
call_depth=application_generate_entity.call_depth,
variable_pool=variable_pool,
)
def single_iteration_run(self, app_id: str, workflow_id: str,
queue_manager: AppQueueManager,
inputs: dict, node_id: str, user_id: str) -> None:
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
if not app_record.workflow_id:
raise ValueError("Workflow not initialized")
raise ValueError('Workflow not initialized')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
workflow_callbacks = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow,
node_id=node_id,
user_id=user_id,
user_inputs=inputs,
callbacks=workflow_callbacks
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
@@ -120,11 +120,13 @@ class WorkflowAppRunner:
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow

View File

@@ -1,6 +1,7 @@
from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
@@ -50,4 +51,5 @@ __all__ = [
'ArrayNumberVariable',
'ArrayObjectVariable',
'ArrayFileVariable',
'ArraySegment',
]

View File

@@ -0,0 +1,2 @@
class VariableError(Exception):
pass

View File

@@ -1,8 +1,10 @@
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from core.file.file_obj import FileVar
from .exc import VariableError
from .segments import (
ArrayAnySegment,
FileSegment,
@@ -29,39 +31,43 @@ from .variables import (
)
def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
if (value_type := m.get('value_type')) is None:
raise ValueError('missing value type')
if not m.get('name'):
raise ValueError('missing name')
if (value := m.get('value')) is None:
raise ValueError('missing value')
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if (value_type := mapping.get('value_type')) is None:
raise VariableError('missing value type')
if not mapping.get('name'):
raise VariableError('missing name')
if (value := mapping.get('value')) is None:
raise VariableError('missing value')
match value_type:
case SegmentType.STRING:
return StringVariable.model_validate(m)
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
return SecretVariable.model_validate(m)
result = SecretVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, int):
return IntegerVariable.model_validate(m)
result = IntegerVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, float):
return FloatVariable.model_validate(m)
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise ValueError(f'invalid number value {value}')
raise VariableError(f'invalid number value {value}')
case SegmentType.FILE:
return FileVariable.model_validate(m)
result = FileVariable.model_validate(mapping)
case SegmentType.OBJECT if isinstance(value, dict):
return ObjectVariable.model_validate(
{**m, 'value': {k: build_variable_from_mapping(v) for k, v in value.items()}}
)
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
return ArrayStringVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
result = ArrayStringVariable.model_validate(mapping)
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
return ArrayNumberVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
return ArrayObjectVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
result = ArrayObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_FILE if isinstance(value, list):
return ArrayFileVariable.model_validate({**m, 'value': [build_variable_from_mapping(v) for v in value]})
raise ValueError(f'not supported value type {value_type}')
mapping = dict(mapping)
mapping['value'] = [{'value': v} for v in value]
result = ArrayFileVariable.model_validate(mapping)
case _:
raise VariableError(f'not supported value type {value_type}')
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}')
return result
def build_segment(value: Any, /) -> Segment:
@@ -74,12 +80,9 @@ def build_segment(value: Any, /) -> Segment:
if isinstance(value, float):
return FloatSegment(value=value)
if isinstance(value, dict):
# TODO: Limit the depth of the object
return ObjectSegment(value=value)
if isinstance(value, list):
# TODO: Limit the depth of the array
elements = [build_segment(v) for v in value]
return ArrayAnySegment(value=elements)
return ArrayAnySegment(value=value)
if isinstance(value, FileVar):
return FileSegment(value=value)
raise ValueError(f'not supported value {value}')

View File

@@ -1,4 +1,5 @@
import json
import sys
from collections.abc import Mapping, Sequence
from typing import Any
@@ -37,6 +38,10 @@ class Segment(BaseModel):
def markdown(self) -> str:
return str(self.value)
@property
def size(self) -> int:
return sys.getsizeof(self.value)
def to_object(self) -> Any:
return self.value
@@ -105,28 +110,25 @@ class ArraySegment(Segment):
def markdown(self) -> str:
return '\n'.join(['- ' + item.markdown for item in self.value])
def to_object(self):
return [v.to_object() for v in self.value]
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Segment]
value: Sequence[Any]
class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[StringSegment]
value: Sequence[str]
class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[FloatSegment | IntegerSegment]
value: Sequence[float | int]
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[ObjectSegment]
value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment):

View File

@@ -1,14 +1,19 @@
import enum
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel
from core.app.app_config.entities import FileExtraConfig
from core.file.tool_file_parser import ToolFileParser
from core.file.upload_file_parser import UploadFileParser
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from extensions.ext_database import db
from models.model import UploadFile
class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class FileType(enum.Enum):
@@ -114,6 +119,7 @@ class FileVar(BaseModel):
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
from models.model import UploadFile
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url

View File

@@ -5,8 +5,7 @@ from urllib.parse import parse_qs, urlparse
import requests
from core.app.app_config.entities import FileExtraConfig
from core.file.file_obj import FileBelongsTo, FileTransferMethod, FileType, FileVar
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser, MessageFile, UploadFile

View File

@@ -2,7 +2,6 @@ import base64
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
def obfuscated_token(token: str):
@@ -14,6 +13,7 @@ def obfuscated_token(token: str):
def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
raise ValueError(f'Tenant with id {tenant_id} not found')
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)

View File

@@ -23,10 +23,12 @@ class NodeType(Enum):
HTTP_REQUEST = 'http-request'
TOOL = 'tool'
VARIABLE_AGGREGATOR = 'variable-aggregator'
# TODO: merge this into VARIABLE_AGGREGATOR
VARIABLE_ASSIGNER = 'variable-assigner'
LOOP = 'loop'
ITERATION = 'iteration'
PARAMETER_EXTRACTOR = 'parameter-extractor'
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
@classmethod
def value_of(cls, value: str) -> 'NodeType':

View File

@@ -13,6 +13,7 @@ VariableValue = Union[str, int, float, dict, list, FileVar]
SYSTEM_VARIABLE_NODE_ID = 'sys'
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
class VariablePool:
@@ -21,6 +22,7 @@ class VariablePool:
system_variables: Mapping[SystemVariable, Any],
user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable] | None = None,
) -> None:
# system variables
# for example:
@@ -44,9 +46,13 @@ class VariablePool:
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool
for var in environment_variables or []:
for var in environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
for var in conversation_variables or []:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.

View File

@@ -8,6 +8,7 @@ from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from models import WorkflowNodeExecutionStatus
class UserFrom(Enum):
@@ -91,14 +92,19 @@ class BaseNode(ABC):
:param variable_pool: variable pool
:return:
"""
result = self._run(
variable_pool=variable_pool
)
try:
result = self._run(
variable_pool=variable_pool
)
self.node_run_result = result
return result
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
self.node_run_result = result
return result
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None:
"""
Publish text chunk
:param text: chunk text

View File

@@ -0,0 +1,109 @@
from collections.abc import Sequence
from enum import Enum
from typing import Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.segments import SegmentType, Variable, factory
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus
class VariableAssignerNodeError(Exception):
pass
class WriteMode(str, Enum):
OVER_WRITE = 'over-write'
APPEND = 'append'
CLEAR = 'clear'
class VariableAssignerData(BaseNodeData):
title: str = 'Variable Assigner'
desc: Optional[str] = 'Assign a value to a variable'
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]
class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')
match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})
case WriteMode.APPEND:
income_value = variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={'value': updated_value})
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
case _:
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
# Over write the variable.
variable_pool.add(data.assigned_variable_selector, updated_variable)
# Update conversation variable.
# TODO: Find a better way to use the database.
conversation_id = variable_pool.get(['sys', 'conversation_id'])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
'value': income_value.to_object(),
},
)
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError('conversation variable not found in the database')
row.data = variable.model_dump_json()
session.commit()
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return factory.build_segment([])
case SegmentType.OBJECT:
return factory.build_segment({})
case SegmentType.STRING:
return factory.build_segment('')
case SegmentType.NUMBER:
return factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f'unsupported variable type: {t}')

View File

@@ -4,12 +4,11 @@ from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from configs import dify_config
from core.app.app_config.entities import FileExtraConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.errors import WorkflowNodeRunFailedError
@@ -30,6 +29,7 @@ from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from core.workflow.nodes.variable_assigner import VariableAssignerNode
from extensions.ext_database import db
from models.workflow import (
Workflow,
@@ -51,7 +51,8 @@ node_classes: Mapping[NodeType, type[BaseNode]] = {
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode,
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
}
logger = logging.getLogger(__name__)
@@ -94,10 +95,9 @@ class WorkflowEngineManager:
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
user_inputs: Mapping[str, Any],
system_inputs: Mapping[SystemVariable, Any],
callbacks: Sequence[WorkflowCallback],
call_depth: int = 0
call_depth: int = 0,
variable_pool: VariablePool,
) -> None:
"""
:param workflow: Workflow instance
@@ -122,12 +122,6 @@ class WorkflowEngineManager:
if not isinstance(graph.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
)
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
if call_depth > workflow_call_max_depth:
@@ -403,6 +397,7 @@ class WorkflowEngineManager:
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
conversation_variables=workflow.conversation_variables,
)
if node_cls is None:
@@ -468,6 +463,7 @@ class WorkflowEngineManager:
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
conversation_variables=workflow.conversation_variables,
)
# variable selector to variable mapping