feat(api/workflow): Add Conversation.dialogue_count (#7275)

This commit is contained in:
-LAN-
2024-08-15 10:53:05 +08:00
committed by GitHub
parent 8f5d8397f9
commit 32dc963556
29 changed files with 205 additions and 259 deletions

View File

@@ -4,13 +4,14 @@ from typing import Any, Optional
from pydantic import BaseModel
from models.workflow import WorkflowNodeExecutionStatus
from models import WorkflowNodeExecutionStatus
class NodeType(Enum):
"""
Node Types.
"""
START = 'start'
END = 'end'
ANSWER = 'answer'
@@ -44,33 +45,11 @@ class NodeType(Enum):
raise ValueError(f'invalid node type value {value}')
class SystemVariable(Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
@classmethod
def value_of(cls, value: str) -> 'SystemVariable':
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')
class NodeRunMetadataKey(Enum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = 'total_tokens'
TOTAL_PRICE = 'total_price'
CURRENCY = 'currency'
@@ -83,6 +62,7 @@ class NodeRunResult(BaseModel):
"""
Node Run Result.
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[Mapping[str, Any]] = None # node inputs

View File

@@ -6,7 +6,7 @@ from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.enums import SystemVariable
VariableValue = Union[str, int, float, dict, list, FileVar]

View File

@@ -0,0 +1,25 @@
from enum import Enum
class SystemVariable(str, Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'
@classmethod
def value_of(cls, value: str):
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')

View File

@@ -23,8 +23,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
@@ -201,8 +202,8 @@ class LLMNode(BaseNode):
usage = LLMUsage.empty_usage()
return full_text, usage
def _transform_chat_messages(self,
def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
@@ -249,13 +250,13 @@ class LLMNode(BaseNode):
# check if it's a context structure
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
return d['content']
# else, parse the dict
try:
return json.dumps(d, ensure_ascii=False)
except Exception:
return str(d)
if isinstance(value, str):
value = value
elif isinstance(value, list):

View File

@@ -2,19 +2,20 @@ from collections.abc import Mapping, Sequence
from os import path
from typing import Any, cast
from core.app.segments import parser
from core.app.segments import ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
from models import WorkflowNodeExecutionStatus
class ToolNode(BaseNode):
@@ -140,9 +141,9 @@ class ToolNode(BaseNode):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
# FIXME: ensure this is a ArrayVariable contains FileVariable.
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
return [file_var.value for file_var in variable.value] if variable else []
assert isinstance(variable, ArrayAnyVariable)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
"""

View File

@@ -3,6 +3,7 @@ import time
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
import contexts
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -97,7 +98,7 @@ class WorkflowEngineManager:
invoke_from: InvokeFrom,
callbacks: Sequence[WorkflowCallback],
call_depth: int = 0,
variable_pool: VariablePool,
variable_pool: VariablePool | None = None,
) -> None:
"""
:param workflow: Workflow instance
@@ -128,6 +129,8 @@ class WorkflowEngineManager:
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
# init workflow run state
if not variable_pool:
variable_pool = contexts.workflow_variable_pool.get()
workflow_run_state = WorkflowRunState(
workflow=workflow,
start_at=time.perf_counter(),