feat/enhance the multi-modal support (#8818)

This commit is contained in:
-LAN-
2024-10-21 10:43:49 +08:00
committed by GitHub
parent 7a1d6fe509
commit e61752bd3a
267 changed files with 6263 additions and 3523 deletions

View File

@@ -0,0 +1,3 @@
from .enums import NodeType
__all__ = ["NodeType"]

View File

@@ -0,0 +1,4 @@
from .answer_node import AnswerNode
from .entities import AnswerStreamGenerateRoute
__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"]

View File

@@ -1,7 +1,8 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
@@ -9,12 +10,13 @@ from core.workflow.nodes.answer.entities import (
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
class AnswerNode(BaseNode):
class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER
@@ -23,30 +25,35 @@ class AnswerNode(BaseNode):
Run node
:return:
"""
node_data = self.node_data
node_data = cast(AnswerNodeData, node_data)
# generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data)
answer = ""
files = []
for part in generate_routes:
if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = self.graph_runtime_state.variable_pool.get(value_selector)
if value:
answer += value.markdown
variable = self.graph_runtime_state.variable_pool.get(value_selector)
if variable:
if isinstance(variable, FileSegment):
files.append(variable.value)
elif isinstance(variable, ArrayFileSegment):
files.extend(variable.value)
answer += variable.markdown
else:
part = cast(TextGenerateRouteChunk, part)
answer += part.text
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer})
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files})
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AnswerNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -55,9 +62,6 @@ class AnswerNode(BaseNode):
:param node_data: node data
:return:
"""
node_data = node_data
node_data = cast(AnswerNodeData, node_data)
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()

View File

@@ -1,5 +1,4 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
AnswerStreamGenerateRoute,
@@ -7,6 +6,7 @@ from core.workflow.nodes.answer.entities import (
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser

View File

@@ -1,8 +1,8 @@
import logging
from collections.abc import Generator
from typing import Optional, cast
from typing import cast
from core.file.file_obj import FileVar
from core.file import FILE_MODEL_IDENTITY, File
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
@@ -203,7 +203,7 @@ class AnswerStreamProcessor(StreamProcessor):
return files
@classmethod
def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]:
def _get_file_var_from_value(cls, value: dict | list):
"""
Get file var from value
:param value: variable value
@@ -213,9 +213,9 @@ class AnswerStreamProcessor(StreamProcessor):
return None
if isinstance(value, dict):
if "__variant" in value and value["__variant"] == FileVar.__name__:
if "dify_model_identity" in value and value["dify_model_identity"] == FILE_MODEL_IDENTITY:
return value
elif isinstance(value, FileVar):
elif isinstance(value, File):
return value.to_dict()
return None

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.nodes.base import BaseNodeData
class AnswerNodeData(BaseNodeData):

View File

@@ -0,0 +1,4 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
from .node import BaseNode
__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"]

View File

@@ -0,0 +1,24 @@
from abc import ABC
from typing import Optional
from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
class BaseIterationNodeData(BaseNodeData):
start_node_id: Optional[str] = None
class BaseIterationState(BaseModel):
iteration_node_id: str
index: int
inputs: dict
class MetaData(BaseModel):
pass
metadata: MetaData

View File

@@ -1,17 +1,27 @@
from abc import ABC, abstractmethod
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus
from .entities import BaseNodeData
if TYPE_CHECKING:
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__)
GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(ABC):
class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType
@@ -19,9 +29,9 @@ class BaseNode(ABC):
self,
id: str,
config: Mapping[str, Any],
graph_init_params: GraphInitParams,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
) -> None:
@@ -45,22 +55,25 @@ class BaseNode(ABC):
raise ValueError("Node ID is required.")
self.node_id = node_id
self.node_data = self._node_data_cls(**config.get("data", {}))
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))
@abstractmethod
def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
"""
Run node
:return:
"""
raise NotImplementedError
def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node entry
:return:
"""
result = self._run()
def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
try:
result = self._run()
except Exception as e:
logger.error(f"Node {self.node_id} failed to run: {e}")
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
@@ -69,7 +82,10 @@ class BaseNode(ABC):
@classmethod
def extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], config: dict
cls,
*,
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -83,12 +99,16 @@ class BaseNode(ABC):
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=node_data
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: GenericNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@@ -0,0 +1,3 @@
from .code_node import CodeNode
__all__ = ["CodeNode"]

View File

@@ -1,18 +1,19 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast
from typing import Any, Optional, Union
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
class CodeNode(BaseNode):
class CodeNode(BaseNode[CodeNodeData]):
_node_data_cls = CodeNodeData
_node_type = NodeType.CODE
@@ -33,20 +34,13 @@ class CodeNode(BaseNode):
return code_provider.get_default_config()
def _run(self) -> NodeRunResult:
"""
Run code
:return:
"""
node_data = self.node_data
node_data = cast(CodeNodeData, node_data)
# Get code language
code_language = node_data.code_language
code = node_data.code
code_language = self.node_data.code_language
code = self.node_data.code
# Get variables
variables = {}
for variable_selector in node_data.variables:
for variable_selector in self.node_data.variables:
variable = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
@@ -60,7 +54,7 @@ class CodeNode(BaseNode):
)
# Transform result
result = self._transform_result(result, node_data.outputs)
result = self._transform_result(result, self.node_data.outputs)
except (CodeExecutionError, ValueError) as e:
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
@@ -316,7 +310,11 @@ class CodeNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: CodeNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: CodeNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@@ -3,8 +3,8 @@ from typing import Literal, Optional
from pydantic import BaseModel
from core.helper.code_executor.code_executor import CodeLanguage
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
class CodeNodeData(BaseNodeData):

View File

@@ -0,0 +1,4 @@
from .entities import DocumentExtractorNodeData
from .node import DocumentExtractorNode
__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"]

View File

@@ -0,0 +1,7 @@
from collections.abc import Sequence
from core.workflow.nodes.base import BaseNodeData
class DocumentExtractorNodeData(BaseNodeData):
variable_selector: Sequence[str]

View File

@@ -0,0 +1,14 @@
class DocumentExtractorError(Exception):
"""Base exception for errors related to the DocumentExtractorNode."""
class FileDownloadError(DocumentExtractorError):
"""Exception raised when there's an error downloading a file."""
class UnsupportedFileTypeError(DocumentExtractorError):
"""Exception raised when trying to extract text from an unsupported file type."""
class TextExtractionError(DocumentExtractorError):
"""Exception raised when there's an error during text extraction from a file."""

View File

@@ -0,0 +1,244 @@
import csv
import io
import docx
import pandas as pd
import pypdfium2
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
from unstructured.partition.msg import partition_msg
from unstructured.partition.ppt import partition_ppt
from unstructured.partition.pptx import partition_pptx
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
"""
_node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR
def _run(self):
variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None:
error_message = f"File variable not found for selector: {variable_selector}"
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment):
error_message = f"Variable {variable_selector} is not an ArrayFileSegment"
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
value = variable.value
inputs = {"variable_selector": variable_selector}
process_data = {"documents": value if isinstance(value, list) else [value]}
try:
if isinstance(value, list):
extracted_text_list = list(map(_extract_text_from_file, value))
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": extracted_text_list},
)
elif isinstance(value, File):
extracted_text = _extract_text_from_file(value)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": extracted_text},
)
else:
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
except DocumentExtractorError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=inputs,
process_data=process_data,
)
def _extract_text(*, file_content: bytes, mime_type: str) -> str:
"""Extract text from a file based on its MIME type."""
if mime_type.startswith("text/plain") or mime_type in {"text/html", "text/htm", "text/markdown", "text/xml"}:
return _extract_text_from_plain_text(file_content)
elif mime_type == "application/pdf":
return _extract_text_from_pdf(file_content)
elif mime_type in {
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/msword",
}:
return _extract_text_from_doc(file_content)
elif mime_type == "text/csv":
return _extract_text_from_csv(file_content)
elif mime_type in {
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.ms-excel",
}:
return _extract_text_from_excel(file_content)
elif mime_type == "application/vnd.ms-powerpoint":
return _extract_text_from_ppt(file_content)
elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation":
return _extract_text_from_pptx(file_content)
elif mime_type == "application/epub+zip":
return _extract_text_from_epub(file_content)
elif mime_type == "message/rfc822":
return _extract_text_from_eml(file_content)
elif mime_type == "application/vnd.ms-outlook":
return _extract_text_from_msg(file_content)
else:
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
def _extract_text_from_plain_text(file_content: bytes) -> str:
try:
return file_content.decode("utf-8")
except UnicodeDecodeError as e:
raise TextExtractionError("Failed to decode plain text file") from e
def _extract_text_from_pdf(file_content: bytes) -> str:
try:
pdf_file = io.BytesIO(file_content)
pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True)
text = ""
for page in pdf_document:
text_page = page.get_textpage()
text += text_page.get_text_range()
text_page.close()
page.close()
return text
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e
def _extract_text_from_doc(file_content: bytes) -> str:
try:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
def _download_file_content(file: File) -> bytes:
"""Download the content of a file based on its transfer method."""
try:
if file.transfer_method == FileTransferMethod.REMOTE_URL:
if file.remote_url is None:
raise FileDownloadError("Missing URL for remote file")
response = ssrf_proxy.get(file.remote_url)
response.raise_for_status()
return response.content
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
return file_manager.download(file)
else:
raise ValueError(f"Unsupported transfer method: {file.transfer_method}")
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
def _extract_text_from_file(file: File):
if file.mime_type is None:
raise UnsupportedFileTypeError("Unable to determine file type: MIME type is missing")
file_content = _download_file_content(file)
extracted_text = _extract_text(file_content=file_content, mime_type=file.mime_type)
return extracted_text
def _extract_text_from_csv(file_content: bytes) -> str:
try:
csv_file = io.StringIO(file_content.decode("utf-8"))
csv_reader = csv.reader(csv_file)
rows = list(csv_reader)
if not rows:
return ""
# Create markdown table
markdown_table = "| " + " | ".join(rows[0]) + " |\n"
markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n"
for row in rows[1:]:
markdown_table += "| " + " | ".join(row) + " |\n"
return markdown_table.strip()
except Exception as e:
raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e
def _extract_text_from_excel(file_content: bytes) -> str:
"""Extract text from an Excel file using pandas."""
try:
df = pd.read_excel(io.BytesIO(file_content))
# Drop rows where all elements are NaN
df.dropna(how="all", inplace=True)
# Convert DataFrame to markdown table
markdown_table = df.to_markdown(index=False)
return markdown_table
except Exception as e:
raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e
def _extract_text_from_ppt(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PPT: {str(e)}") from e
def _extract_text_from_pptx(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
elements = partition_pptx(file=file)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
def _extract_text_from_epub(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
elements = partition_epub(file=file)
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e
def _extract_text_from_eml(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
elements = partition_email(file=file)
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e
def _extract_text_from_msg(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
elements = partition_msg(file=file)
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e

View File

@@ -0,0 +1,4 @@
from .end_node import EndNode
from .entities import EndStreamParam
__all__ = ["EndStreamParam", "EndNode"]

View File

@@ -1,13 +1,14 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
class EndNode(BaseNode):
class EndNode(BaseNode[EndNodeData]):
_node_data_cls = EndNodeData
_node_type = NodeType.END
@@ -16,20 +17,27 @@ class EndNode(BaseNode):
Run node
:return:
"""
node_data = self.node_data
node_data = cast(EndNodeData, node_data)
output_variables = node_data.outputs
output_variables = self.node_data.outputs
outputs = {}
for variable_selector in output_variables:
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
value = variable.to_object() if variable is not None else None
outputs[variable_selector.variable] = value
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=outputs,
outputs=outputs,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: EndNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@@ -1,5 +1,5 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
from core.workflow.nodes.enums import NodeType
class EndStreamGeneratorRouter:

View File

@@ -1,7 +1,7 @@
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
class EndNodeData(BaseNodeData):

View File

@@ -0,0 +1,24 @@
from enum import Enum
class NodeType(str, Enum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
VARIABLE_AGGREGATOR = "variable-aggregator"
VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"

View File

@@ -0,0 +1,10 @@
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from .types import NodeEvent
__all__ = [
"RunCompletedEvent",
"RunRetrieverResourceEvent",
"RunStreamChunkEvent",
"NodeEvent",
"ModelInvokeCompletedEvent",
]

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.node_entities import NodeRunResult
@@ -17,4 +18,11 @@ class RunRetrieverResourceEvent(BaseModel):
context: str = Field(..., description="context")
RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent
class ModelInvokeCompletedEvent(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
finish_reason: str | None = None

View File

@@ -0,0 +1,3 @@
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
NodeEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent | ModelInvokeCompletedEvent

View File

@@ -0,0 +1,4 @@
from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData
from .node import HttpRequestNode
__all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"]

View File

@@ -1,15 +1,25 @@
from typing import Literal, Optional, Union
from collections.abc import Sequence
from typing import Literal, Optional
from pydantic import BaseModel, ValidationInfo, field_validator
import httpx
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from configs import dify_config
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.nodes.base import BaseNodeData
NON_FILE_CONTENT_TYPES = (
"application/json",
"application/xml",
"text/html",
"text/plain",
"application/x-www-form-urlencoded",
)
class HttpRequestNodeAuthorizationConfig(BaseModel):
type: Literal[None, "basic", "bearer", "custom"]
api_key: Union[None, str] = None
header: Union[None, str] = None
type: Literal["basic", "bearer", "custom"]
api_key: str
header: str = ""
class HttpRequestNodeAuthorization(BaseModel):
@@ -31,9 +41,16 @@ class HttpRequestNodeAuthorization(BaseModel):
return v
class BodyData(BaseModel):
key: str = ""
type: Literal["file", "text"]
value: str = ""
file: Sequence[str] = Field(default_factory=list)
class HttpRequestNodeBody(BaseModel):
type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json"]
data: Union[None, str] = None
type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"]
data: Sequence[BodyData] = Field(default_factory=list)
class HttpRequestNodeTimeout(BaseModel):
@@ -54,3 +71,51 @@ class HttpRequestNodeData(BaseNodeData):
params: str
body: Optional[HttpRequestNodeBody] = None
timeout: Optional[HttpRequestNodeTimeout] = None
class Response:
headers: dict[str, str]
response: httpx.Response
def __init__(self, response: httpx.Response):
self.response = response
self.headers = dict(response.headers)
@property
def is_file(self):
content_type = self.content_type
content_disposition = self.response.headers.get("Content-Disposition", "")
return "attachment" in content_disposition or (
not any(non_file in content_type for non_file in NON_FILE_CONTENT_TYPES)
and any(file_type in content_type for file_type in ("application/", "image/", "audio/", "video/"))
)
@property
def content_type(self) -> str:
return self.headers.get("Content-Type", "")
@property
def text(self) -> str:
return self.response.text
@property
def content(self) -> bytes:
return self.response.content
@property
def status_code(self) -> int:
return self.response.status_code
@property
def size(self) -> int:
return len(self.content)
@property
def readable_size(self) -> str:
if self.size < 1024:
return f"{self.size} bytes"
elif self.size < 1024 * 1024:
return f"{(self.size / 1024):.2f} KB"
else:
return f"{(self.size / 1024 / 1024):.2f} MB"

View File

@@ -0,0 +1,327 @@
import json
from collections.abc import Mapping, Sequence
from copy import deepcopy
from random import randint
from typing import Any, Literal
from urllib.parse import urlencode, urlparse
import httpx
from configs import dify_config
from core.file import file_manager
from core.helper import ssrf_proxy
from core.workflow.entities.variable_pool import VariablePool
from .entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
HttpRequestNodeTimeout,
Response,
)
BODY_TYPE_TO_CONTENT_TYPE = {
"json": "application/json",
"x-www-form-urlencoded": "application/x-www-form-urlencoded",
"form-data": "multipart/form-data",
"raw-text": "text/plain",
}
class Executor:
method: Literal["get", "head", "post", "put", "delete", "patch"]
url: str
params: Mapping[str, str] | None
content: str | bytes | None
data: Mapping[str, Any] | None
files: Mapping[str, bytes] | None
json: Any
headers: dict[str, str]
auth: HttpRequestNodeAuthorization
timeout: HttpRequestNodeTimeout
boundary: str
def __init__(
self,
*,
node_data: HttpRequestNodeData,
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
if node_data.authorization.config is None:
raise ValueError("authorization config is required")
node_data.authorization.config.api_key = variable_pool.convert_template(
node_data.authorization.config.api_key
).text
self.url: str = node_data.url
self.method = node_data.method
self.auth = node_data.authorization
self.timeout = timeout
self.params = None
self.headers = {}
self.content = None
self.files = None
self.data = None
self.json = None
# init template
self.variable_pool = variable_pool
self.node_data = node_data
self._initialize()
def _initialize(self):
self._init_url()
self._init_params()
self._init_headers()
self._init_body()
def _init_url(self):
self.url = self.variable_pool.convert_template(self.node_data.url).text
def _init_params(self):
params = self.variable_pool.convert_template(self.node_data.params).text
self.params = _plain_text_to_dict(params)
def _init_headers(self):
headers = self.variable_pool.convert_template(self.node_data.headers).text
self.headers = _plain_text_to_dict(headers)
body = self.node_data.body
if body is None:
return
if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE:
self.headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
if body.type == "form-data":
self.boundary = f"----WebKitFormBoundary{_generate_random_string(16)}"
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
def _init_body(self):
body = self.node_data.body
if body is not None:
data = body.data
match body.type:
case "none":
self.content = ""
case "raw-text":
self.content = self.variable_pool.convert_template(data[0].value).text
case "json":
json_object = json.loads(data[0].value)
self.json = self._parse_object_contains_variables(json_object)
case "binary":
file_selector = data[0].file
file_variable = self.variable_pool.get_file(file_selector)
if file_variable is None:
raise ValueError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
self.content = file_manager.download(file)
case "x-www-form-urlencoded":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
item.value
).text
for item in data
}
self.data = form_data
case "form-data":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
item.value
).text
for item in filter(lambda item: item.type == "text", data)
}
file_selectors = {
self.variable_pool.convert_template(item.key).text: item.file
for item in filter(lambda item: item.type == "file", data)
}
files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()}
files = {k: v for k, v in files.items() if v is not None}
files = {k: variable.value for k, variable in files.items()}
files = {k: file_manager.download(v) for k, v in files.items() if v.related_id is not None}
self.data = form_data
self.files = files
def _assembling_headers(self) -> dict[str, Any]:
authorization = deepcopy(self.auth)
headers = deepcopy(self.headers) or {}
if self.auth.type == "api-key":
if self.auth.config is None:
raise ValueError("self.authorization config is required")
if authorization.config is None:
raise ValueError("authorization config is required")
if self.auth.config.api_key is None:
raise ValueError("api_key is required")
if not authorization.config.header:
authorization.config.header = "Authorization"
if self.auth.config.type == "bearer":
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
elif self.auth.config.type == "basic":
headers[authorization.config.header] = f"Basic {authorization.config.api_key}"
elif self.auth.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key or ""
return headers
def _validate_and_parse_response(self, response: httpx.Response) -> Response:
executor_response = Response(response)
threshold_size = (
dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE
if executor_response.is_file
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
)
if executor_response.size > threshold_size:
raise ValueError(
f'{"File" if executor_response.is_file else "Text"} size is too large,'
f' max size is {threshold_size / 1024 / 1024:.2f} MB,'
f' but current size is {executor_response.readable_size}.'
)
return executor_response
def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
"""
do http request depending on api bundle
"""
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
raise ValueError(f"Invalid http method {self.method}")
request_args = {
"url": self.url,
"data": self.data,
"files": self.files,
"json": self.json,
"content": self.content,
"headers": headers,
"params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"follow_redirects": True,
}
response = getattr(ssrf_proxy, self.method)(**request_args)
return response
def invoke(self) -> Response:
# assemble headers
headers = self._assembling_headers()
# do http request
response = self._do_http_request(headers)
# validate response
return self._validate_and_parse_response(response)
def to_log(self):
url_parts = urlparse(self.url)
path = url_parts.path or "/"
# Add query parameters
if self.params:
query_string = urlencode(self.params)
path += f"?{query_string}"
elif url_parts.query:
path += f"?{url_parts.query}"
raw = f"{self.method.upper()} {path} HTTP/1.1\r\n"
raw += f"Host: {url_parts.netloc}\r\n"
headers = self._assembling_headers()
for k, v in headers.items():
if self.auth.type == "api-key":
authorization_header = "Authorization"
if self.auth.config and self.auth.config.header:
authorization_header = self.auth.config.header
if k.lower() == authorization_header.lower():
raw += f'{k}: {"*" * len(v)}\r\n'
continue
raw += f"{k}: {v}\r\n"
body = ""
if self.files:
boundary = self.boundary
for k, v in self.files.items():
body += f"--{boundary}\r\n"
body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
body += f"{v[1]}\r\n"
body += f"--{boundary}--\r\n"
elif self.node_data.body:
if self.content:
if isinstance(self.content, str):
body = self.content
elif isinstance(self.content, bytes):
body = self.content.decode("utf-8", errors="replace")
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
body = urlencode(self.data)
elif self.data and self.node_data.body.type == "form-data":
boundary = self.boundary
for key, value in self.data.items():
body += f"--{boundary}\r\n"
body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
body += f"{value}\r\n"
body += f"--{boundary}--\r\n"
elif self.json:
body = json.dumps(self.json)
elif self.node_data.body.type == "raw-text":
body = self.node_data.body.data[0].value
if body:
raw += f"Content-Length: {len(body)}\r\n"
raw += "\r\n" # Empty line between headers and body
raw += body
return raw
def _parse_object_contains_variables(self, obj: str | dict | list, /) -> Mapping[str, Any] | Sequence[Any] | str:
if isinstance(obj, dict):
return {k: self._parse_object_contains_variables(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._parse_object_contains_variables(v) for v in obj]
elif isinstance(obj, str):
return self.variable_pool.convert_template(obj).text
def _plain_text_to_dict(text: str, /) -> dict[str, str]:
"""
Convert a string of key-value pairs to a dictionary.
Each line in the input string represents a key-value pair.
Keys and values are separated by ':'.
Empty values are allowed.
Examples:
'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'}
'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'}
'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'}
Args:
convert_text (str): The input string to convert.
Returns:
dict[str, str]: A dictionary of key-value pairs.
"""
return {
key.strip(): (value[0].strip() if value else "")
for line in text.splitlines()
if line.strip()
for key, *value in [line.split(":", 1)]
}
def _generate_random_string(n: int) -> str:
"""
Generate a random string of lowercase ASCII letters.
Args:
n (int): The length of the random string to generate.
Returns:
str: A random string of lowercase ASCII letters with length n.
Example:
>>> _generate_random_string(5)
'abcde'
"""
return "".join([chr(randint(97, 122)) for _ in range(n)])

View File

@@ -1,343 +0,0 @@
import json
from copy import deepcopy
from random import randint
from typing import Any, Optional, Union
from urllib.parse import urlencode
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeBody,
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class HttpExecutorResponse:
headers: dict[str, str]
response: httpx.Response
def __init__(self, response: httpx.Response):
self.response = response
self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {}
@property
def is_file(self) -> bool:
"""
check if response is file
"""
content_type = self.get_content_type()
file_content_types = ["image", "audio", "video"]
return any(v in content_type for v in file_content_types)
def get_content_type(self) -> str:
return self.headers.get("content-type", "")
def extract_file(self) -> tuple[str, bytes]:
"""
extract file from response if content type is file related
"""
if self.is_file:
return self.get_content_type(), self.body
return "", b""
@property
def content(self) -> str:
if isinstance(self.response, httpx.Response):
return self.response.text
else:
raise ValueError(f"Invalid response type {type(self.response)}")
@property
def body(self) -> bytes:
if isinstance(self.response, httpx.Response):
return self.response.content
else:
raise ValueError(f"Invalid response type {type(self.response)}")
@property
def status_code(self) -> int:
if isinstance(self.response, httpx.Response):
return self.response.status_code
else:
raise ValueError(f"Invalid response type {type(self.response)}")
@property
def size(self) -> int:
return len(self.body)
@property
def readable_size(self) -> str:
if self.size < 1024:
return f"{self.size} bytes"
elif self.size < 1024 * 1024:
return f"{(self.size / 1024):.2f} KB"
else:
return f"{(self.size / 1024 / 1024):.2f} MB"
class HttpExecutor:
server_url: str
method: str
authorization: HttpRequestNodeAuthorization
params: dict[str, Any]
headers: dict[str, Any]
body: Union[None, str]
files: Union[None, dict[str, Any]]
boundary: str
variable_selectors: list[VariableSelector]
timeout: HttpRequestNodeTimeout
def __init__(
self,
node_data: HttpRequestNodeData,
timeout: HttpRequestNodeTimeout,
variable_pool: Optional[VariablePool] = None,
):
self.server_url = node_data.url
self.method = node_data.method
self.authorization = node_data.authorization
self.timeout = timeout
self.params = {}
self.headers = {}
self.body = None
self.files = None
# init template
self.variable_selectors = []
self._init_template(node_data, variable_pool)
@staticmethod
def _is_json_body(body: HttpRequestNodeBody):
"""
check if body is json
"""
if body and body.type == "json" and body.data:
try:
json.loads(body.data)
return True
except:
return False
return False
@staticmethod
def _to_dict(convert_text: str):
"""
Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}`
"""
kv_paris = convert_text.split("\n")
result = {}
for kv in kv_paris:
if not kv.strip():
continue
kv = kv.split(":", maxsplit=1)
if len(kv) == 1:
k, v = kv[0], ""
else:
k, v = kv
result[k.strip()] = v
return result
def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
# extract all template in url
self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool)
# extract all template in params
params, params_variable_selectors = self._format_template(node_data.params, variable_pool)
self.params = self._to_dict(params)
# extract all template in headers
headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool)
self.headers = self._to_dict(headers)
# extract all template in body
body_data_variable_selectors = []
if node_data.body:
# check if it's a valid JSON
is_valid_json = self._is_json_body(node_data.body)
body_data = node_data.body.data or ""
if body_data:
body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json)
content_type_is_set = any(key.lower() == "content-type" for key in self.headers)
if node_data.body.type == "json" and not content_type_is_set:
self.headers["Content-Type"] = "application/json"
elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set:
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
if node_data.body.type in {"form-data", "x-www-form-urlencoded"}:
body = self._to_dict(body_data)
if node_data.body.type == "form-data":
self.files = {k: ("", v) for k, v in body.items()}
random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)])
self.boundary = f"----WebKitFormBoundary{random_str(16)}"
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
else:
self.body = urlencode(body)
elif node_data.body.type in {"json", "raw-text"}:
self.body = body_data
elif node_data.body.type == "none":
self.body = ""
self.variable_selectors = (
server_url_variable_selectors
+ params_variable_selectors
+ headers_variable_selectors
+ body_data_variable_selectors
)
def _assembling_headers(self) -> dict[str, Any]:
authorization = deepcopy(self.authorization)
headers = deepcopy(self.headers) or {}
if self.authorization.type == "api-key":
if self.authorization.config is None:
raise ValueError("self.authorization config is required")
if authorization.config is None:
raise ValueError("authorization config is required")
if self.authorization.config.api_key is None:
raise ValueError("api_key is required")
if not authorization.config.header:
authorization.config.header = "Authorization"
if self.authorization.config.type == "bearer":
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
elif self.authorization.config.type == "basic":
headers[authorization.config.header] = f"Basic {authorization.config.api_key}"
elif self.authorization.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key
return headers
def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutorResponse:
"""
validate the response
"""
if isinstance(response, httpx.Response):
executor_response = HttpExecutorResponse(response)
else:
raise ValueError(f"Invalid response type {type(response)}")
threshold_size = (
dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE
if executor_response.is_file
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
)
if executor_response.size > threshold_size:
raise ValueError(
f'{"File" if executor_response.is_file else "Text"} size is too large,'
f' max size is {threshold_size / 1024 / 1024:.2f} MB,'
f' but current size is {executor_response.readable_size}.'
)
return executor_response
def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
"""
do http request depending on api bundle
"""
kwargs = {
"url": self.server_url,
"headers": headers,
"params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"follow_redirects": True,
}
if self.method in {"get", "head", "post", "put", "delete", "patch"}:
response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
else:
raise ValueError(f"Invalid http method {self.method}")
return response
def invoke(self) -> HttpExecutorResponse:
"""
invoke http request
"""
# assemble headers
headers = self._assembling_headers()
# do http request
response = self._do_http_request(headers)
# validate response
return self._validate_and_parse_response(response)
def to_raw_request(self) -> str:
"""
convert to raw request
"""
server_url = self.server_url
if self.params:
server_url += f"?{urlencode(self.params)}"
raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n"
headers = self._assembling_headers()
for k, v in headers.items():
# get authorization header
if self.authorization.type == "api-key":
authorization_header = "Authorization"
if self.authorization.config and self.authorization.config.header:
authorization_header = self.authorization.config.header
if k.lower() == authorization_header.lower():
raw_request += f'{k}: {"*" * len(v)}\n'
continue
raw_request += f"{k}: {v}\n"
raw_request += "\n"
# if files, use multipart/form-data with boundary
if self.files:
boundary = self.boundary
raw_request += f"--{boundary}"
for k, v in self.files.items():
raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n'
raw_request += f"{v[1]}\n"
raw_request += f"--{boundary}"
raw_request += "--"
else:
raw_request += self.body or ""
return raw_request
def _format_template(
self, template: str, variable_pool: Optional[VariablePool], escape_quotes: bool = False
) -> tuple[str, list[VariableSelector]]:
"""
format template
"""
variable_template_parser = VariableTemplateParser(template=template)
variable_selectors = variable_template_parser.extract_variable_selectors()
if variable_pool:
variable_value_mapping = {}
for variable_selector in variable_selectors:
variable = variable_pool.get_any(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
if escape_quotes and isinstance(variable, str):
value = variable.replace('"', '\\"').replace("\n", "\\n")
else:
value = variable
variable_value_mapping[variable_selector.variable] = value
return variable_template_parser.format(variable_value_mapping), variable_selectors
else:
return template, variable_selectors

View File

@@ -1,165 +0,0 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_extension
from os import path
from typing import Any, cast
from configs import dify_config
from core.app.segments import parser
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse
from models.workflow import WorkflowNodeExecutionStatus
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
)
class HttpRequestNode(BaseNode):
_node_data_cls = HttpRequestNodeData
_node_type = NodeType.HTTP_REQUEST
@classmethod
def get_default_config(cls, filters: dict | None = None) -> dict:
return {
"type": "http-request",
"config": {
"method": "get",
"authorization": {
"type": "no-auth",
},
"body": {"type": "none"},
"timeout": {
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
"max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
"max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
},
},
}
def _run(self) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
# TODO: Switch to use segment directly
if node_data.authorization.config and node_data.authorization.config.api_key:
node_data.authorization.config.api_key = parser.convert_template(
template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool
).text
# init http executor
http_executor = None
try:
http_executor = HttpExecutor(
node_data=node_data,
timeout=self._get_request_timeout(node_data),
variable_pool=self.graph_runtime_state.variable_pool,
)
# invoke http executor
response = http_executor.invoke()
except Exception as e:
process_data = {}
if http_executor:
process_data = {
"request": http_executor.to_raw_request(),
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
process_data=process_data,
)
files = self.extract_files(http_executor.server_url, response)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"status_code": response.status_code,
"body": response.content if not files else "",
"headers": response.headers,
"files": files,
},
process_data={
"request": http_executor.to_raw_request(),
},
)
@staticmethod
def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout:
timeout = node_data.timeout
if timeout is None:
return HTTP_REQUEST_DEFAULT_TIMEOUT
timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect
timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read
timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write
return timeout
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
try:
http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
variable_selectors = http_executor.variable_selectors
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
return variable_mapping
except Exception as e:
logging.exception(f"Failed to extract variable selector to variable mapping: {e}")
return {}
def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]:
"""
Extract files from response
"""
files = []
mimetype, file_binary = response.extract_file()
if mimetype:
# extract filename from url
filename = path.basename(url)
# extract extension if possible
extension = guess_extension(mimetype) or ".bin"
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mimetype,
)
files.append(
FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file.id,
filename=filename,
extension=extension,
mime_type=mimetype,
)
)
return files

View File

@@ -0,0 +1,174 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_extension
from os import path
from typing import Any
from configs import dify_config
from core.file import File, FileTransferMethod, FileType
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.utils import variable_template_parser
from models.workflow import WorkflowNodeExecutionStatus
from .entities import (
HttpRequestNodeData,
HttpRequestNodeTimeout,
Response,
)
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
)
logger = logging.getLogger(__name__)
class HttpRequestNode(BaseNode[HttpRequestNodeData]):
_node_data_cls = HttpRequestNodeData
_node_type = NodeType.HTTP_REQUEST
@classmethod
def get_default_config(cls, filters: dict | None = None) -> dict:
return {
"type": "http-request",
"config": {
"method": "get",
"authorization": {
"type": "no-auth",
},
"body": {"type": "none"},
"timeout": {
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
"max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
"max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
},
},
}
def _run(self) -> NodeRunResult:
process_data = {}
try:
http_executor = Executor(
node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
)
process_data["request"] = http_executor.to_log()
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"status_code": response.status_code,
"body": response.text if not files else "",
"headers": response.headers,
"files": files,
},
process_data={
"request": http_executor.to_log(),
},
)
except Exception as e:
logger.warning(f"http request node {self.node_id} failed to run: {e}")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
process_data=process_data,
)
@staticmethod
def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout:
timeout = node_data.timeout
if timeout is None:
return HTTP_REQUEST_DEFAULT_TIMEOUT
timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect
timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read
timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write
return timeout
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: HttpRequestNodeData,
) -> Mapping[str, Sequence[str]]:
selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
if node_data.body:
body_type = node_data.body.type
data = node_data.body.data
match body_type:
case "binary":
selector = data[0].file
selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector))
case "json" | "raw-text":
selectors += variable_template_parser.extract_selectors_from_template(data[0].key)
selectors += variable_template_parser.extract_selectors_from_template(data[0].value)
case "x-www-form-urlencoded":
for item in data:
selectors += variable_template_parser.extract_selectors_from_template(item.key)
selectors += variable_template_parser.extract_selectors_from_template(item.value)
case "form-data":
for item in data:
selectors += variable_template_parser.extract_selectors_from_template(item.key)
if item.type == "text":
selectors += variable_template_parser.extract_selectors_from_template(item.value)
elif item.type == "file":
selectors.append(
VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file)
)
mapping = {}
for selector in selectors:
mapping[node_id + "." + selector.variable] = selector.value_selector
return mapping
def extract_files(self, url: str, response: Response) -> list[File]:
"""
Extract files from response
"""
files = []
content_type = response.content_type
content = response.content
if content_type:
# extract filename from url
filename = path.basename(url)
# extract extension if possible
extension = guess_extension(content_type) or ".bin"
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=content,
mimetype=content_type,
)
files.append(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file.id,
filename=filename,
extension=extension,
mime_type=content_type,
)
)
return files

View File

@@ -0,0 +1,3 @@
from .if_else_node import IfElseNode
__all__ = ["IfElseNode"]

View File

@@ -1,8 +1,8 @@
from typing import Literal, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.nodes.base import BaseNodeData
from core.workflow.utils.condition.entities import Condition
@@ -21,6 +21,6 @@ class IfElseNodeData(BaseNodeData):
conditions: list[Condition]
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = None
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
cases: Optional[list[Case]] = None

View File

@@ -1,14 +1,19 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any, Literal
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from typing_extensions import deprecated
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
from models.workflow import WorkflowNodeExecutionStatus
class IfElseNode(BaseNode):
class IfElseNode(BaseNode[IfElseNodeData]):
_node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE
@@ -17,9 +22,6 @@ class IfElseNode(BaseNode):
Run node
:return:
"""
node_data = self.node_data
node_data = cast(IfElseNodeData, node_data)
node_inputs: dict[str, list] = {"conditions": []}
process_datas: dict[str, list] = {"condition_results": []}
@@ -30,15 +32,14 @@ class IfElseNode(BaseNode):
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
if node_data.cases:
for case in node_data.cases:
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions
if self.node_data.cases:
for case in self.node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions,
operator=case.logical_operator,
)
# Apply the logical operator for the current case
final_result = all(group_result) if case.logical_operator == "and" else any(group_result)
process_datas["condition_results"].append(
{
"group": case.model_dump(),
@@ -53,13 +54,15 @@ class IfElseNode(BaseNode):
break
else:
# TODO: Update database then remove this
# Fallback to old structure if cases are not defined
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool, conditions=node_data.conditions
input_conditions, group_result, final_result = _should_not_use_old_function(
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
conditions=self.node_data.conditions or [],
operator=self.node_data.logical_operator or "and",
)
final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)
selected_case_id = "true" if final_result else "false"
process_datas["condition_results"].append(
@@ -87,7 +90,11 @@ class IfElseNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IfElseNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -97,3 +104,18 @@ class IfElseNode(BaseNode):
:return:
"""
return {}
@deprecated("This function is deprecated. You should use the new cases structure.")
def _should_not_use_old_function(
*,
condition_processor: ConditionProcessor,
variable_pool: VariablePool,
conditions: list[Condition],
operator: Literal["and", "or"],
):
return condition_processor.process_conditions(
variable_pool=variable_pool,
conditions=conditions,
operator=operator,
)

View File

@@ -0,0 +1,5 @@
from .entities import IterationNodeData
from .iteration_node import IterationNode
from .iteration_start_node import IterationStartNode
__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"]

View File

@@ -1,6 +1,8 @@
from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
from pydantic import Field
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
class IterationNodeData(BaseIterationNodeData):
@@ -26,7 +28,7 @@ class IterationState(BaseIterationState):
Iteration State.
"""
outputs: list[Any] = None
outputs: list[Any] = Field(default_factory=list)
current_output: Optional[Any] = None
class MetaData(BaseIterationState.MetaData):

View File

@@ -5,7 +5,7 @@ from typing import Any, cast
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
@@ -20,15 +20,16 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
class IterationNode(BaseNode):
class IterationNode(BaseNode[IterationNodeData]):
"""
Iteration Node.
"""
@@ -36,11 +37,10 @@ class IterationNode(BaseNode):
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.
"""
self.node_data = cast(IterationNodeData, self.node_data)
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not iterator_list_segment:
@@ -177,7 +177,7 @@ class IterationNode(BaseNode):
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove_node(node_id)
variable_pool.remove([node_id])
# move to next iteration
current_index = variable_pool.get([self.node_id, "index"])
@@ -247,7 +247,11 @@ class IterationNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -273,15 +277,13 @@ class IterationNode(BaseNode):
# variable selector to variable mapping
try:
# Get node class
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.nodes.node_mapping import node_type_classes_mapping
node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type)
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping.get(node_type)
if not node_cls:
continue
node_cls = cast(BaseNode, node_cls)
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
)

View File

@@ -1,8 +1,9 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
from models.workflow import WorkflowNodeExecutionStatus

View File

@@ -0,0 +1,3 @@
from .knowledge_retrieval_node import KnowledgeRetrievalNode
__all__ = ["KnowledgeRetrievalNode"]

View File

@@ -2,7 +2,7 @@ from typing import Any, Literal, Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.nodes.base import BaseNodeData
class RerankingModelConfig(BaseModel):

View File

@@ -14,8 +14,9 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@@ -32,15 +33,13 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(BaseNode):
class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
_node_data_cls = KnowledgeRetrievalNodeData
node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
def _run(self) -> NodeRunResult:
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector)
variable = self.graph_runtime_state.variable_pool.get_any(self.node_data.query_variable_selector)
query = variable
variables = {"query": query}
if not query:
@@ -49,7 +48,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
# retrieve knowledge
try:
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
outputs = {"result": results}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
@@ -244,7 +243,11 @@ class KnowledgeRetrievalNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: KnowledgeRetrievalNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: KnowledgeRetrievalNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@@ -0,0 +1,3 @@
from .node import ListOperatorNode
__all__ = ["ListOperatorNode"]

View File

@@ -0,0 +1,56 @@
from collections.abc import Sequence
from typing import Literal
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
_Condition = Literal[
# string conditions
"contains",
"startswith",
"endswith",
"is",
"in",
"empty",
"not contains",
"not is",
"not in",
"not empty",
# number conditions
"=",
"!=",
"<",
">",
"",
"",
]
class FilterCondition(BaseModel):
key: str = ""
comparison_operator: _Condition = "contains"
value: str | Sequence[str] = ""
class FilterBy(BaseModel):
enabled: bool = False
conditions: Sequence[FilterCondition] = Field(default_factory=list)
class OrderBy(BaseModel):
enabled: bool = False
key: str = ""
value: Literal["asc", "desc"] = "asc"
class Limit(BaseModel):
enabled: bool = False
size: int = -1
class ListOperatorNodeData(BaseNodeData):
variable: Sequence[str] = Field(default_factory=list)
filter_by: FilterBy
order_by: OrderBy
limit: Limit

View File

@@ -0,0 +1,259 @@
from collections.abc import Callable, Sequence
from typing import Literal
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ListOperatorNodeData
class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_data_cls = ListOperatorNodeData
_node_type = NodeType.LIST_OPERATOR
def _run(self):
inputs = {}
process_data = {}
outputs = {}
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
if variable is None:
error_message = f"Variable not found for selector: {self.node_data.variable}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if isinstance(variable, ArrayFileSegment):
process_data["variable"] = [item.to_dict() for item in variable.value]
else:
process_data["variable"] = variable.value
# Filter
if self.node_data.filter_by.enabled:
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
if isinstance(condition.value, str):
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
else:
value = condition.value
filter_func = _get_file_filter_func(
key=condition.key,
condition=condition.comparison_operator,
value=value,
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
# Order
if self.node_data.order_by.enabled:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
# Slice
if self.node_data.limit.enabled:
result = variable.value[: self.node_data.limit.size]
variable = variable.model_copy(update={"value": result})
outputs = {
"result": variable.value,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
match key:
case "size":
return lambda x: x.size
case _:
raise ValueError(f"Invalid key: {key}")
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
match key:
case "name":
return lambda x: x.filename or ""
case "type":
return lambda x: x.type
case "extension":
return lambda x: x.extension or ""
case "mimetype":
return lambda x: x.mime_type or ""
case "transfer_method":
return lambda x: x.transfer_method
case "urL":
return lambda x: x.remote_url or ""
case _:
raise ValueError(f"Invalid key: {key}")
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
match condition:
case "contains":
return _contains(value)
case "startswith":
return _startswith(value)
case "endswith":
return _endswith(value)
case "is":
return _is(value)
case "in":
return _in(value)
case "empty":
return lambda x: x == ""
case "not contains":
return lambda x: not _contains(value)(x)
case "not is":
return lambda x: not _is(value)(x)
case "not in":
return lambda x: not _in(value)(x)
case "not empty":
return lambda x: x != ""
case _:
raise ValueError(f"Invalid condition: {condition}")
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
match condition:
case "in":
return _in(value)
case "not in":
return lambda x: not _in(value)(x)
case _:
raise ValueError(f"Invalid condition: {condition}")
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
match condition:
case "=":
return _eq(value)
case "!=":
return _ne(value)
case "<":
return _lt(value)
case "":
return _le(value)
case ">":
return _gt(value)
case "":
return _ge(value)
case _:
raise ValueError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
if key in {"type", "transfer_method"} and isinstance(value, Sequence):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
elif key == "size" and isinstance(value, str):
extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
else:
raise ValueError(f"Invalid key: {key}")
def _contains(value: str):
return lambda x: value in x
def _startswith(value: str):
return lambda x: x.startswith(value)
def _endswith(value: str):
return lambda x: x.endswith(value)
def _is(value: str):
return lambda x: x is value
def _in(value: str | Sequence[str]):
return lambda x: x in value
def _eq(value: int | float):
return lambda x: x == value
def _ne(value: int | float):
return lambda x: x != value
def _lt(value: int | float):
return lambda x: x < value
def _le(value: int | float):
return lambda x: x <= value
def _gt(value: int | float):
return lambda x: x > value
def _ge(value: int | float):
return lambda x: x >= value
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "urL"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
elif order_by == "size":
extract_func = _get_file_extract_number_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
else:
raise ValueError(f"Invalid order key: {order_by}")

View File

@@ -0,0 +1,17 @@
from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from .node import LLMNode
__all__ = [
"LLMNode",
"LLMNodeChatModelMessage",
"LLMNodeCompletionModelPromptTemplate",
"LLMNodeData",
"ModelConfig",
"VisionConfig",
]

View File

@@ -1,17 +1,15 @@
from typing import Any, Literal, Optional, Union
from collections.abc import Sequence
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.model_runtime.entities import ImagePromptMessageContent
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
@@ -19,62 +17,36 @@ class ModelConfig(BaseModel):
class ContextConfig(BaseModel):
"""
Context Config.
"""
enabled: bool
variable_selector: Optional[list[str]] = None
class VisionConfigOptions(BaseModel):
variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"])
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH
class VisionConfig(BaseModel):
"""
Vision Config.
"""
class Configs(BaseModel):
"""
Configs.
"""
detail: Literal["low", "high"]
enabled: bool
configs: Optional[Configs] = None
enabled: bool = False
configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions)
class PromptConfig(BaseModel):
"""
Prompt Config.
"""
jinja2_variables: Optional[list[VariableSelector]] = None
class LLMNodeChatModelMessage(ChatModelMessage):
"""
LLM Node Chat Model Message.
"""
jinja2_text: Optional[str] = None
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
"""
LLM Node Chat Model Prompt Template.
"""
jinja2_text: Optional[str] = None
class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
model: ModelConfig
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: Optional[PromptConfig] = None
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig
vision: VisionConfig = Field(default_factory=VisionConfig)

View File

@@ -1,39 +1,40 @@
import json
from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional, cast
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities import (
AudioPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
NodeEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunStreamChunkEvent,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
@@ -41,44 +42,34 @@ from models.model import Conversation
from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus
from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
)
if TYPE_CHECKING:
from core.file.file_obj import FileVar
from core.file.models import File
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
finish_reason: Optional[str] = None
class LLMNode(BaseNode):
class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:return:
"""
node_data = cast(LLMNodeData, deepcopy(self.node_data))
variable_pool = self.graph_runtime_state.variable_pool
def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs = None
process_data = None
try:
# init messages template
node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool)
inputs = self._fetch_inputs(node_data=self.node_data)
# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
# merge inputs
inputs.update(jinja_inputs)
@@ -86,13 +77,17 @@ class LLMNode(BaseNode):
node_inputs = {}
# fetch files
files = self._fetch_files(node_data, variable_pool)
files = (
self._fetch_files(selector=self.node_data.vision.configs.variable_selector)
if self.node_data.vision.enabled
else []
)
if files:
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
generator = self._fetch_context(node_data, variable_pool)
generator = self._fetch_context(node_data=self.node_data)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
@@ -103,21 +98,30 @@ class LLMNode(BaseNode):
node_inputs["#context#"] = context # type: ignore
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
model_instance, model_config = self._fetch_model_config(self.node_data.model)
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance)
# fetch prompt messages
if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not query:
raise ValueError("Query not found")
query = query.text
else:
query = None
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
system_query=query,
inputs=inputs,
files=files,
context=context,
memory=memory,
model_config=model_config,
vision_detail=self.node_data.vision.configs.detail,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
)
process_data = {
@@ -131,7 +135,7 @@ class LLMNode(BaseNode):
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
node_data_model=self.node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
@@ -143,7 +147,7 @@ class LLMNode(BaseNode):
for event in generator:
if isinstance(event, RunStreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompleted):
elif isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
@@ -182,15 +186,7 @@ class LLMNode(BaseNode):
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Invoke large language model
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
) -> Generator[NodeEvent, None, None]:
db.session.close()
invoke_result = model_instance.invoke_llm(
@@ -207,20 +203,13 @@ class LLMNode(BaseNode):
usage = LLMUsage.empty_usage()
for event in generator:
yield event
if isinstance(event, ModelInvokeCompleted):
if isinstance(event, ModelInvokeCompletedEvent):
usage = event.usage
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
if isinstance(invoke_result, LLMResult):
return
@@ -250,18 +239,11 @@ class LLMNode(BaseNode):
if not usage:
usage = LLMUsage.empty_usage()
yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason)
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
def _transform_chat_messages(
self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
Transform chat messages
:param messages: chat messages
:return:
"""
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == "jinja2" and messages.jinja2_text:
messages.text = messages.jinja2_text
@@ -274,13 +256,7 @@ class LLMNode(BaseNode):
return messages
def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch jinja inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables = {}
if not node_data.prompt_config:
@@ -288,7 +264,7 @@ class LLMNode(BaseNode):
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable = variable_selector.variable
value = variable_pool.get_any(variable_selector.value_selector)
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
def parse_dict(d: dict) -> str:
"""
@@ -330,13 +306,7 @@ class LLMNode(BaseNode):
return variables
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
inputs = {}
prompt_template = node_data.prompt_template
@@ -350,7 +320,7 @@ class LLMNode(BaseNode):
variable_selectors = variable_template_parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable_value = variable_pool.get_any(variable_selector.value_selector)
variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
@@ -362,7 +332,7 @@ class LLMNode(BaseNode):
template=memory.query_prompt_template
).extract_variable_selectors()
for variable_selector in query_variable_selectors:
variable_value = variable_pool.get_any(variable_selector.value_selector)
variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
@@ -370,36 +340,28 @@ class LLMNode(BaseNode):
return inputs
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
"""
Fetch files
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if not node_data.vision.enabled:
def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is None:
return []
files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value])
if not files:
if isinstance(variable, FileSegment):
return [variable.value]
if isinstance(variable, ArrayFileSegment):
return variable.value
# FIXME: Temporary fix for empty array,
# all variables added to variable pool should be a Segment instance.
if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0:
return []
raise ValueError(f"Invalid variable type: {type(variable)}")
return files
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
"""
Fetch context
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled:
return
if not node_data.context.variable_selector:
return
context_value = variable_pool.get_any(node_data.context.variable_selector)
context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
@@ -424,11 +386,6 @@ class LLMNode(BaseNode):
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
"""
Convert to original retriever resource, temp.
:param context_dict: context dict
:return:
"""
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
@@ -451,6 +408,7 @@ class LLMNode(BaseNode):
"segment_position": metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
"page": metadata.get("page"),
}
return source
@@ -460,11 +418,6 @@ class LLMNode(BaseNode):
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data_model: node data model
:return:
"""
model_name = node_data_model.name
provider_name = node_data_model.provider
@@ -523,19 +476,15 @@ class LLMNode(BaseNode):
)
def _fetch_memory(
self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
"""
Fetch memory
:param node_data_memory: node data memory
:param variable_pool: variable pool
:return:
"""
if not node_data_memory:
return None
# get conversation id
conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value])
conversation_id = self.graph_runtime_state.variable_pool.get_any(
["sys", SystemVariableKey.CONVERSATION_ID.value]
)
if conversation_id is None:
return None
@@ -555,43 +504,31 @@ class LLMNode(BaseNode):
def _fetch_prompt_messages(
self,
node_data: LLMNodeData,
query: Optional[str],
query_prompt_template: Optional[str],
inputs: dict[str, str],
files: list["FileVar"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
*,
system_query: str | None = None,
inputs: dict[str, str] | None = None,
files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
memory_config: MemoryConfig | None = None,
vision_detail: ImagePromptMessageContent.DETAIL,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt messages
:param node_data: node data
:param query: query
:param query_prompt_template: query prompt template
:param inputs: inputs
:param files: files
:param context: context
:param memory: memory
:param model_config: model config
:return:
"""
inputs = inputs or {}
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_messages = prompt_transform.get_prompt(
prompt_template=node_data.prompt_template,
prompt_template=prompt_template,
inputs=inputs,
query=query or "",
query=system_query or "",
files=files,
context=context,
memory_config=node_data.memory,
memory_config=memory_config,
memory=memory,
model_config=model_config,
query_prompt_template=query_prompt_template,
)
stop = model_config.stop
vision_enabled = node_data.vision.enabled
vision_detail = node_data.vision.configs.detail if node_data.vision.configs else None
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if prompt_message.is_empty():
@@ -599,17 +536,13 @@ class LLMNode(BaseNode):
if not isinstance(prompt_message.content, str):
prompt_message_content = []
for content_item in prompt_message.content:
if (
vision_enabled
and content_item.type == PromptMessageContentType.IMAGE
and isinstance(content_item, ImagePromptMessageContent)
):
# Override vision config if LLM node has vision config
if vision_detail:
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
for content_item in prompt_message.content or []:
if isinstance(content_item, ImagePromptMessageContent):
# Override vision config if LLM node has vision config,
# cuz vision detail is related to the configuration from FileUpload feature.
content_item.detail = vision_detail
prompt_message_content.append(content_item)
elif content_item.type == PromptMessageContentType.TEXT:
elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent):
prompt_message_content.append(content_item)
if len(prompt_message_content) > 1:
@@ -631,13 +564,6 @@ class LLMNode(BaseNode):
@classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
"""
Deduct LLM quota
:param tenant_id: tenant id
:param model_instance: model instance
:param usage: usage
:return:
"""
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
@@ -668,7 +594,7 @@ class LLMNode(BaseNode):
else:
used_quota = 1
if used_quota is not None:
if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_instance.provider,
@@ -680,27 +606,28 @@ class LLMNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
if isinstance(prompt_template, list) and all(
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
):
for prompt in prompt_template:
if prompt.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
else:
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
else:
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}")
variable_mapping = {}
for variable_selector in variable_selectors:
@@ -745,11 +672,6 @@ class LLMNode(BaseNode):
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {
"type": "llm",
"config": {

View File

@@ -1,4 +1,4 @@
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState
class LoopNodeData(BaseIterationNodeData):

View File

@@ -1,12 +1,12 @@
from typing import Any
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
from core.workflow.utils.condition.entities import Condition
class LoopNode(BaseNode):
class LoopNode(BaseNode[LoopNodeData]):
"""
Loop Node.
"""

View File

@@ -1,22 +1,24 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code import CodeNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end import EndNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.list_operator import ListOperatorNode
from core.workflow.nodes.llm import LLMNode
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from core.workflow.nodes.question_classifier import QuestionClassifierNode
from core.workflow.nodes.start import StartNode
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from core.workflow.nodes.variable_assigner import VariableAssignerNode
node_classes = {
node_type_classes_mapping: dict[NodeType, type[BaseNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
@@ -34,4 +36,6 @@ node_classes = {
NodeType.ITERATION_START: IterationStartNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode,
NodeType.LIST_OPERATOR: ListOperatorNode,
}

View File

@@ -0,0 +1,3 @@
from .parameter_extractor_node import ParameterExtractorNode
__all__ = ["ParameterExtractorNode"]

View File

@@ -1,20 +1,10 @@
from typing import Any, Literal, Optional
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, Field, field_validator
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
class ParameterConfig(BaseModel):
@@ -49,6 +39,7 @@ class ParameterExtractorNodeData(BaseNodeData):
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
reasoning_mode: Literal["function_call", "prompt"]
vision: VisionConfig = Field(default_factory=VisionConfig)
@field_validator("reasoning_mode", mode="before")
@classmethod
@@ -64,7 +55,7 @@ class ParameterExtractorNodeData(BaseNodeData):
parameters = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema = {"description": parameter.description}
parameter_schema: dict[str, Any] = {"description": parameter.description}
if parameter.type in {"string", "select"}:
parameter_schema["type"] = "string"

View File

@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@@ -22,12 +23,16 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.parameter_extractor.prompts import (
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import LLMNode, ModelConfig
from core.workflow.utils import variable_template_parser
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ParameterExtractorNodeData
from .prompts import (
CHAT_EXAMPLE,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
COMPLETION_GENERATE_JSON_PROMPT,
@@ -36,9 +41,6 @@ from core.workflow.nodes.parameter_extractor.prompts import (
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT,
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus
class ParameterExtractorNode(LLMNode):
@@ -65,33 +67,39 @@ class ParameterExtractorNode(LLMNode):
}
}
def _run(self) -> NodeRunResult:
def _run(self):
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self.node_data)
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query)
if not variable:
raise ValueError("Input variable content not found or is empty")
query = variable
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""
inputs = {
"query": query,
"parameters": jsonable_encoder(node_data.parameters),
"instruction": jsonable_encoder(node_data.instruction),
}
files = (
self._fetch_files(
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled
else []
)
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
model_schema = llm_model.get_model_schema(
model=model_config.model,
credentials=model_config.credentials,
)
if not model_schema:
raise ValueError("Model schema not found")
# fetch memory
memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance)
memory = self._fetch_memory(
node_data_memory=node_data.memory,
model_instance=model_instance,
)
if (
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
@@ -99,15 +107,33 @@ class ParameterExtractorNode(LLMNode):
):
# use function call
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
node_data=node_data,
query=query,
variable_pool=self.graph_runtime_state.variable_pool,
model_config=model_config,
memory=memory,
files=files,
)
else:
# use prompt engineering
prompt_messages = self._generate_prompt_engineering_prompt(
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
data=node_data,
query=query,
variable_pool=self.graph_runtime_state.variable_pool,
model_config=model_config,
memory=memory,
files=files,
)
prompt_message_tools = []
inputs = {
"query": query,
"files": [f.to_dict() for f in files],
"parameters": jsonable_encoder(node_data.parameters),
"instruction": jsonable_encoder(node_data.instruction),
}
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
@@ -119,7 +145,7 @@ class ParameterExtractorNode(LLMNode):
}
try:
text, usage, tool_call = self._invoke_llm(
text, usage, tool_call = self._invoke(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
@@ -150,12 +176,12 @@ class ParameterExtractorNode(LLMNode):
error = "Failed to extract result from function call or text response, using empty result."
try:
result = self._validate_result(node_data, result)
result = self._validate_result(data=node_data, result=result or {})
except Exception as e:
error = str(e)
# transform result into standard format
result = self._transform_result(node_data, result)
result = self._transform_result(data=node_data, result=result or {})
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -170,7 +196,7 @@ class ParameterExtractorNode(LLMNode):
llm_usage=usage,
)
def _invoke_llm(
def _invoke(
self,
node_data_model: ModelConfig,
model_instance: ModelInstance,
@@ -178,14 +204,6 @@ class ParameterExtractorNode(LLMNode):
tools: list[PromptMessageTool],
stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
"""
Invoke large language model
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
db.session.close()
invoke_result = model_instance.invoke_llm(
@@ -202,6 +220,9 @@ class ParameterExtractorNode(LLMNode):
raise ValueError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content
if not isinstance(text, str):
raise ValueError(f"Invalid text content type: {type(text)}. Expected str.")
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
@@ -217,6 +238,7 @@ class ParameterExtractorNode(LLMNode):
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
files: Sequence[File],
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
"""
Generate function call prompt.
@@ -234,7 +256,7 @@ class ParameterExtractorNode(LLMNode):
prompt_template=prompt_template,
inputs={},
query="",
files=[],
files=files,
context="",
memory_config=node_data.memory,
memory=None,
@@ -296,6 +318,7 @@ class ParameterExtractorNode(LLMNode):
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
files: Sequence[File],
) -> list[PromptMessage]:
"""
Generate prompt engineering prompt.
@@ -303,9 +326,23 @@ class ParameterExtractorNode(LLMNode):
model_mode = ModelMode.value_of(data.model.mode)
if model_mode == ModelMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(data, query, variable_pool, model_config, memory)
return self._generate_prompt_engineering_completion_prompt(
node_data=data,
query=query,
variable_pool=variable_pool,
model_config=model_config,
memory=memory,
files=files,
)
elif model_mode == ModelMode.CHAT:
return self._generate_prompt_engineering_chat_prompt(data, query, variable_pool, model_config, memory)
return self._generate_prompt_engineering_chat_prompt(
node_data=data,
query=query,
variable_pool=variable_pool,
model_config=model_config,
memory=memory,
files=files,
)
else:
raise ValueError(f"Invalid model mode: {model_mode}")
@@ -316,20 +353,23 @@ class ParameterExtractorNode(LLMNode):
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
files: Sequence[File],
) -> list[PromptMessage]:
"""
Generate completion prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
rest_token = self._calculate_rest_token(
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data, query, variable_pool, memory, rest_token
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={"structure": json.dumps(node_data.get_parameter_json_schema())},
query="",
files=[],
files=files,
context="",
memory_config=node_data.memory,
memory=memory,
@@ -345,27 +385,30 @@ class ParameterExtractorNode(LLMNode):
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
files: Sequence[File],
) -> list[PromptMessage]:
"""
Generate chat prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
rest_token = self._calculate_rest_token(
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
node_data=node_data,
query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(node_data.get_parameter_json_schema()), text=query
),
variable_pool,
memory,
rest_token,
variable_pool=variable_pool,
memory=memory,
max_token_limit=rest_token,
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
files=files,
context="",
memory_config=node_data.memory,
memory=None,
@@ -425,10 +468,11 @@ class ParameterExtractorNode(LLMNode):
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith("array"):
if not isinstance(result.get(parameter.name), list):
parameters = result.get(parameter.name)
if not isinstance(parameters, list):
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1]
for item in result.get(parameter.name):
for item in parameters:
if nested_type == "number" and not isinstance(item, int | float):
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == "string" and not isinstance(item, str):
@@ -565,18 +609,6 @@ class ParameterExtractorNode(LLMNode):
return result
def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
"""
Render instruction.
"""
variable_template_parser = VariableTemplateParser(instruction)
inputs = {}
for selector in variable_template_parser.extract_variable_selectors():
variable = variable_pool.get_any(selector.value_selector)
inputs[selector.variable] = variable
return variable_template_parser.format(inputs)
def _get_function_calling_prompt_template(
self,
node_data: ParameterExtractorNodeData,
@@ -588,9 +620,9 @@ class ParameterExtractorNode(LLMNode):
model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query
memory_str = ""
instruction = self._render_instruction(node_data.instruction or "", variable_pool)
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory:
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
@@ -611,13 +643,13 @@ class ParameterExtractorNode(LLMNode):
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
):
model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query
memory_str = ""
instruction = self._render_instruction(node_data.instruction or "", variable_pool)
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory:
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
@@ -691,7 +723,7 @@ class ParameterExtractorNode(LLMNode):
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
@@ -712,7 +744,11 @@ class ParameterExtractorNode(LLMNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ParameterExtractorNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ParameterExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -721,11 +757,11 @@ class ParameterExtractorNode(LLMNode):
:param node_data: node data
:return:
"""
variable_mapping = {"query": node_data.query}
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
for selector in variable_template_parser.extract_variable_selectors():
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}

View File

@@ -0,0 +1,4 @@
from .entities import QuestionClassifierNodeData
from .question_classifier_node import QuestionClassifierNode
__all__ = ["QuestionClassifierNodeData", "QuestionClassifierNode"]

View File

@@ -1,39 +1,21 @@
from typing import Any, Optional
from typing import Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
class ClassConfig(BaseModel):
"""
Class Config.
"""
id: str
name: str
class QuestionClassifierNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
query_variable_selector: list[str]
type: str = "question-classifier"
model: ModelConfig
classes: list[ClassConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
vision: VisionConfig = Field(default_factory=VisionConfig)

View File

@@ -1,25 +1,30 @@
import json
import logging
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.question_classifier.template_prompts import (
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import (
LLMNode,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.workflow import WorkflowNodeExecutionStatus
from .entities import QuestionClassifierNodeData
from .template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
QUESTION_CLASSIFIER_COMPLETION_PROMPT,
@@ -28,46 +33,77 @@ from core.workflow.nodes.question_classifier.template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_2,
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING:
from core.file import File
class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData
node_type = NodeType.QUESTION_CLASSIFIER
_node_type = NodeType.QUESTION_CLASSIFIER
def _run(self) -> NodeRunResult:
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
node_data = cast(QuestionClassifierNodeData, node_data)
def _run(self):
node_data = cast(QuestionClassifierNodeData, self.node_data)
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
variable = variable_pool.get(node_data.query_variable_selector)
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
query = variable.value if variable else None
variables = {"query": query}
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
memory = self._fetch_memory(
node_data_memory=node_data.memory,
model_instance=model_instance,
)
# fetch instruction
instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else ""
node_data.instruction = instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
files: Sequence[File] = (
self._fetch_files(
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled
else []
)
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt(
node_data=node_data, context="", query=query, memory=memory, model_config=model_config
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query or "",
model_config=model_config,
context="",
)
prompt_template = self._get_prompt_template(
node_data=node_data,
query=query or "",
memory=memory,
max_token_limit=rest_token,
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
system_query=query,
memory=memory,
model_config=model_config,
files=files,
vision_detail=node_data.vision.configs.detail,
)
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
if isinstance(event, ModelInvokeCompleted):
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
@@ -129,7 +165,11 @@ class QuestionClassifierNode(LLMNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: QuestionClassifierNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: QuestionClassifierNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -159,40 +199,6 @@ class QuestionClassifierNode(LLMNode):
"""
return {"type": "question-classifier", "config": {"instructions": ""}}
def _fetch_prompt(
self,
node_data: QuestionClassifierNodeData,
query: str,
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt
:param node_data: node data
:param query: inputs
:param context: context
:param memory: memory
:param model_config: model config
:return:
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, model_config, context)
prompt_template = self._get_prompt_template(node_data, query, memory, rest_token)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config,
)
stop = model_config.stop
return prompt_messages, stop
def _calculate_rest_token(
self,
node_data: QuestionClassifierNodeData,
@@ -229,7 +235,7 @@ class QuestionClassifierNode(LLMNode):
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
@@ -243,7 +249,7 @@ class QuestionClassifierNode(LLMNode):
query: str,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
) -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
):
model_mode = ModelMode.value_of(node_data.model.mode)
classes = node_data.classes
categories = []
@@ -255,31 +261,32 @@ class QuestionClassifierNode(LLMNode):
memory_str = ""
if memory:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)
prompt_messages = []
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = ChatModelMessage(
user_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = ChatModelMessage(
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = ChatModelMessage(
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = ChatModelMessage(
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = ChatModelMessage(
user_prompt_message_3 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(
input_text=input_text,
@@ -290,7 +297,7 @@ class QuestionClassifierNode(LLMNode):
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return CompletionModelPromptTemplate(
return LLMNodeCompletionModelPromptTemplate(
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
histories=memory_str,
input_text=input_text,
@@ -302,23 +309,3 @@ class QuestionClassifierNode(LLMNode):
else:
raise ValueError(f"Model mode {model_mode} not support.")
def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
inputs = {}
variable_selectors = []
variable_template_parser = VariableTemplateParser(template=instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable = variable_pool.get(variable_selector.value_selector)
variable_value = variable.value if variable else None
if variable_value is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
inputs[variable_selector.variable] = variable_value
prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
instruction = prompt_template.format(prompt_inputs)
return instruction

View File

@@ -0,0 +1,3 @@
from .start_node import StartNode
__all__ = ["StartNode"]

View File

@@ -3,7 +3,7 @@ from collections.abc import Sequence
from pydantic import Field
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.nodes.base import BaseNodeData
class StartNodeData(BaseNodeData):

View File

@@ -1,25 +1,24 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.base_node import BaseNode
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class StartNode(BaseNode):
class StartNode(BaseNode[StartNodeData]):
_node_data_cls = StartNodeData
_node_type = NodeType.START
def _run(self) -> NodeRunResult:
"""
Run node
:return:
"""
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
@@ -27,13 +26,10 @@ class StartNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: StartNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: StartNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}

View File

@@ -0,0 +1,3 @@
from .template_transform_node import TemplateTransformNode
__all__ = ["TemplateTransformNode"]

View File

@@ -1,5 +1,5 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
class TemplateTransformNodeData(BaseNodeData):

View File

@@ -1,17 +1,18 @@
import os
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from models.workflow import WorkflowNodeExecutionStatus
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
class TemplateTransformNode(BaseNode):
class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
_node_data_cls = TemplateTransformNodeData
_node_type = NodeType.TEMPLATE_TRANSFORM
@@ -28,22 +29,16 @@ class TemplateTransformNode(BaseNode):
}
def _run(self) -> NodeRunResult:
"""
Run node
"""
node_data = self.node_data
node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data)
# Get variables
variables = {}
for variable_selector in node_data.variables:
for variable_selector in self.node_data.variables:
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
variables[variable_name] = value
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
)
except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))

View File

@@ -0,0 +1,3 @@
from .tool_node import ToolNode
__all__ = ["ToolNode"]

View File

@@ -3,7 +3,7 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.nodes.base import BaseNodeData
class ToolEntity(BaseModel):
@@ -51,7 +51,4 @@ class ToolNodeData(BaseNodeData, ToolEntity):
raise ValueError("value must be a string, int, float, or bool")
return typ
"""
Tool Node Schema
"""
tool_parameters: dict[str, ToolInput]

View File

@@ -1,24 +1,28 @@
from collections.abc import Mapping, Sequence
from os import path
from typing import Any, cast
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.file.models import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models import WorkflowNodeExecutionStatus
from extensions.ext_database import db
from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus
class ToolNode(BaseNode):
class ToolNode(BaseNode[ToolNodeData]):
"""
Tool Node
"""
@@ -27,37 +31,38 @@ class ToolNode(BaseNode):
_node_type = NodeType.TOOL
def _run(self) -> NodeRunResult:
"""
Run the tool node
"""
node_data = cast(ToolNodeData, self.node_data)
# fetch tool icon
tool_info = {"provider_type": node_data.provider_type, "provider_id": node_data.provider_id}
tool_info = {
"provider_type": self.node_data.provider_type,
"provider_id": self.node_data.provider_id,
}
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to get tool runtime: {str(e)}",
)
# get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
node_data=self.node_data,
for_log=True,
)
@@ -74,7 +79,9 @@ class ToolNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to invoke tool: {str(e)}",
)
@@ -83,8 +90,14 @@ class ToolNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": plain_text, "files": files, "json": json},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
outputs={
"text": plain_text,
"files": files,
"json": json,
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
inputs=parameters_for_log,
)
@@ -116,29 +129,25 @@ class ToolNode(BaseNode):
if not parameter:
result[parameter_name] = None
continue
if parameter.type == ToolParameter.ToolParameterType.FILE:
result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)]
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ValueError(f"variable {tool_input.value} not exists")
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
# TODO: check if the variable exists in the variable pool
parameter_value = variable_pool.get(tool_input.value).value
else:
segment_group = parser.convert_template(
template=str(tool_input.value),
variable_pool=variable_pool,
)
parameter_value = segment_group.log if for_log else segment_group.text
result[parameter_name] = parameter_value
raise ValueError(f"unknown tool input type '{tool_input.type}'")
result[parameter_name] = parameter_value
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar], list[dict]]:
def _convert_tool_messages(
self,
messages: list[ToolInvokeMessage],
):
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
@@ -156,50 +165,86 @@ class ToolNode(BaseNode):
return plain_text, files, json
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[File]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
url = response.message
ext = path.splitext(url)[1]
mimetype = response.meta.get("mime_type", "image/jpeg")
filename = response.save_as or url.split("/")[-1]
url = str(response.message) if response.message else None
ext = path.splitext(url)[1] if url else ".bin"
tool_file_id = str(url).split("/")[-1].split(".")[0]
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
# get tool file id
tool_file_id = url.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
result.append(
FileVar(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=url,
related_id=tool_file_id,
filename=filename,
remote_url=url,
related_id=tool_file.id,
filename=tool_file.name,
extension=ext,
mime_type=mimetype,
mime_type=tool_file.mimetype,
size=tool_file.size,
)
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
tool_file_id = response.message.split("/")[-1].split(".")[0]
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
result.append(
FileVar(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=response.save_as,
related_id=tool_file.id,
filename=tool_file.name,
extension=path.splitext(response.save_as)[1],
mime_type=response.meta.get("mime_type", "application/octet-stream"),
mime_type=tool_file.mimetype,
size=tool_file.size,
)
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
pass # TODO:
url = str(response.message)
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = url.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
if "." in url:
extension = "." + url.split("/")[-1].split(".")[1]
else:
extension = ".bin"
file = File(
tenant_id=self.tenant_id,
type=FileType(response.save_as),
transfer_method=transfer_method,
remote_url=url,
filename=tool_file.name,
related_id=tool_file.id,
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
)
result.append(file)
elif response.type == ToolInvokeMessage.MessageType.FILE:
assert response.meta is not None
result.append(response.meta["file"])
return result
@@ -218,12 +263,16 @@ class ToolNode(BaseNode):
]
)
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]):
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -236,7 +285,7 @@ class ToolNode(BaseNode):
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == "mixed":
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":

View File

@@ -0,0 +1,3 @@
from .variable_aggregator_node import VariableAggregatorNode
__all__ = ["VariableAggregatorNode"]

View File

@@ -2,7 +2,7 @@ from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.nodes.base import BaseNodeData
class AdvancedSettings(BaseModel):

View File

@@ -1,24 +1,24 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
from models.workflow import WorkflowNodeExecutionStatus
class VariableAggregatorNode(BaseNode):
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR
def _run(self) -> NodeRunResult:
node_data = cast(VariableAssignerNodeData, self.node_data)
# Get variables
outputs = {}
inputs = {}
if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
for selector in node_data.variables:
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get_any(selector)
if variable is not None:
outputs = {"output": variable}
@@ -26,7 +26,7 @@ class VariableAggregatorNode(BaseNode):
inputs = {".".join(selector[1:]): variable}
break
else:
for group in node_data.advanced_settings.groups:
for group in self.node_data.advanced_settings.groups:
for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get_any(selector)

View File

@@ -1,40 +1,38 @@
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.segments import SegmentType, Variable, factory
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.variables import SegmentType, Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode, BaseNodeData
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus
from factories import variable_factory
from models import ConversationVariable
from models.workflow import WorkflowNodeExecutionStatus
from .exc import VariableAssignerNodeError
from .node_data import VariableAssignerData, WriteMode
class VariableAssignerNode(BaseNode):
class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError("assigned variable not found")
match data.write_mode:
match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]
@@ -45,10 +43,10 @@ class VariableAssignerNode(BaseNode):
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableAssignerNodeError(f"unsupported write mode: {data.write_mode}")
raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
@@ -80,12 +78,12 @@ def update_conversation_variable(conversation_id: str, variable: Variable):
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return factory.build_segment([])
return variable_factory.build_segment([])
case SegmentType.OBJECT:
return factory.build_segment({})
return variable_factory.build_segment({})
case SegmentType.STRING:
return factory.build_segment("")
return variable_factory.build_segment("")
case SegmentType.NUMBER:
return factory.build_segment(0)
return variable_factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f"unsupported variable type: {t}")

View File

@@ -2,7 +2,7 @@ from collections.abc import Sequence
from enum import Enum
from typing import Optional
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.nodes.base import BaseNodeData
class WriteMode(str, Enum):