mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-15 22:06:52 +08:00
feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from .enums import NodeType
|
||||
|
||||
__all__ = ["NodeType"]
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from .answer_node import AnswerNode
|
||||
from .entities import AnswerStreamGenerateRoute
|
||||
|
||||
__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"]
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.variables import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
@@ -9,12 +10,13 @@ from core.workflow.nodes.answer.entities import (
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnswerNode(BaseNode):
|
||||
class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
_node_data_cls = AnswerNodeData
|
||||
_node_type: NodeType = NodeType.ANSWER
|
||||
|
||||
@@ -23,30 +25,35 @@ class AnswerNode(BaseNode):
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
# generate routes
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data)
|
||||
|
||||
answer = ""
|
||||
files = []
|
||||
for part in generate_routes:
|
||||
if part.type == GenerateRouteChunk.ChunkType.VAR:
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = self.graph_runtime_state.variable_pool.get(value_selector)
|
||||
|
||||
if value:
|
||||
answer += value.markdown
|
||||
variable = self.graph_runtime_state.variable_pool.get(value_selector)
|
||||
if variable:
|
||||
if isinstance(variable, FileSegment):
|
||||
files.append(variable.value)
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
files.extend(variable.value)
|
||||
answer += variable.markdown
|
||||
else:
|
||||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer += part.text
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer})
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files})
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AnswerNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -55,9 +62,6 @@ class AnswerNode(BaseNode):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = node_data
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
AnswerStreamGenerateRoute,
|
||||
@@ -7,6 +6,7 @@ from core.workflow.nodes.answer.entities import (
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
@@ -203,7 +203,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]:
|
||||
def _get_file_var_from_value(cls, value: dict | list):
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
@@ -213,9 +213,9 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if "__variant" in value and value["__variant"] == FileVar.__name__:
|
||||
if "dify_model_identity" in value and value["dify_model_identity"] == FILE_MODEL_IDENTITY:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class AnswerNodeData(BaseNodeData):
|
||||
|
||||
4
api/core/workflow/nodes/base/__init__.py
Normal file
4
api/core/workflow/nodes/base/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
from .node import BaseNode
|
||||
|
||||
__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"]
|
||||
24
api/core/workflow/nodes/base/entities.py
Normal file
24
api/core/workflow/nodes/base/entities.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
iteration_node_id: str
|
||||
index: int
|
||||
inputs: dict
|
||||
|
||||
class MetaData(BaseModel):
|
||||
pass
|
||||
|
||||
metadata: MetaData
|
||||
@@ -1,17 +1,27 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import BaseNodeData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
class BaseNode(Generic[GenericNodeData]):
|
||||
_node_data_cls: type[BaseNodeData]
|
||||
_node_type: NodeType
|
||||
|
||||
@@ -19,9 +29,9 @@ class BaseNode(ABC):
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
@@ -45,22 +55,25 @@ class BaseNode(ABC):
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
self.node_id = node_id
|
||||
self.node_data = self._node_data_cls(**config.get("data", {}))
|
||||
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
|
||||
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run node entry
|
||||
:return:
|
||||
"""
|
||||
result = self._run()
|
||||
def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
try:
|
||||
result = self._run()
|
||||
except Exception as e:
|
||||
logger.error(f"Node {self.node_id} failed to run: {e}")
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield RunCompletedEvent(run_result=result)
|
||||
@@ -69,7 +82,10 @@ class BaseNode(ABC):
|
||||
|
||||
@classmethod
|
||||
def extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], config: dict
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
config: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -83,12 +99,16 @@ class BaseNode(ABC):
|
||||
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
return cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=node_data
|
||||
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: GenericNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -0,0 +1,3 @@
|
||||
from .code_node import CodeNode
|
||||
|
||||
__all__ = ["CodeNode"]
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class CodeNode(BaseNode):
|
||||
class CodeNode(BaseNode[CodeNodeData]):
|
||||
_node_data_cls = CodeNodeData
|
||||
_node_type = NodeType.CODE
|
||||
|
||||
@@ -33,20 +34,13 @@ class CodeNode(BaseNode):
|
||||
return code_provider.get_default_config()
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run code
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(CodeNodeData, node_data)
|
||||
|
||||
# Get code language
|
||||
code_language = node_data.code_language
|
||||
code = node_data.code
|
||||
code_language = self.node_data.code_language
|
||||
code = self.node_data.code
|
||||
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in node_data.variables:
|
||||
for variable_selector in self.node_data.variables:
|
||||
variable = variable_selector.variable
|
||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
|
||||
@@ -60,7 +54,7 @@ class CodeNode(BaseNode):
|
||||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result, node_data.outputs)
|
||||
result = self._transform_result(result, self.node_data.outputs)
|
||||
except (CodeExecutionError, ValueError) as e:
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
|
||||
|
||||
@@ -316,7 +310,11 @@ class CodeNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: CodeNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: CodeNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Literal, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class CodeNodeData(BaseNodeData):
|
||||
|
||||
4
api/core/workflow/nodes/document_extractor/__init__.py
Normal file
4
api/core/workflow/nodes/document_extractor/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .node import DocumentExtractorNode
|
||||
|
||||
__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"]
|
||||
7
api/core/workflow/nodes/document_extractor/entities.py
Normal file
7
api/core/workflow/nodes/document_extractor/entities.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class DocumentExtractorNodeData(BaseNodeData):
|
||||
variable_selector: Sequence[str]
|
||||
14
api/core/workflow/nodes/document_extractor/exc.py
Normal file
14
api/core/workflow/nodes/document_extractor/exc.py
Normal file
@@ -0,0 +1,14 @@
|
||||
class DocumentExtractorError(Exception):
|
||||
"""Base exception for errors related to the DocumentExtractorNode."""
|
||||
|
||||
|
||||
class FileDownloadError(DocumentExtractorError):
|
||||
"""Exception raised when there's an error downloading a file."""
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(DocumentExtractorError):
|
||||
"""Exception raised when trying to extract text from an unsupported file type."""
|
||||
|
||||
|
||||
class TextExtractionError(DocumentExtractorError):
|
||||
"""Exception raised when there's an error during text extraction from a file."""
|
||||
244
api/core/workflow/nodes/document_extractor/node.py
Normal file
244
api/core/workflow/nodes/document_extractor/node.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import csv
|
||||
import io
|
||||
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypdfium2
|
||||
from unstructured.partition.email import partition_email
|
||||
from unstructured.partition.epub import partition_epub
|
||||
from unstructured.partition.msg import partition_msg
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
|
||||
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
"""
|
||||
Extracts text content from various file types.
|
||||
Supports plain text, PDF, and DOC/DOCX files.
|
||||
"""
|
||||
|
||||
_node_data_cls = DocumentExtractorNodeData
|
||||
_node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||
|
||||
def _run(self):
|
||||
variable_selector = self.node_data.variable_selector
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||
|
||||
if variable is None:
|
||||
error_message = f"File variable not found for selector: {variable_selector}"
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
|
||||
if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment):
|
||||
error_message = f"Variable {variable_selector} is not an ArrayFileSegment"
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
|
||||
|
||||
value = variable.value
|
||||
inputs = {"variable_selector": variable_selector}
|
||||
process_data = {"documents": value if isinstance(value, list) else [value]}
|
||||
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
extracted_text_list = list(map(_extract_text_from_file, value))
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"text": extracted_text_list},
|
||||
)
|
||||
elif isinstance(value, File):
|
||||
extracted_text = _extract_text_from_file(value)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"text": extracted_text},
|
||||
)
|
||||
else:
|
||||
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
|
||||
except DocumentExtractorError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
|
||||
def _extract_text(*, file_content: bytes, mime_type: str) -> str:
|
||||
"""Extract text from a file based on its MIME type."""
|
||||
if mime_type.startswith("text/plain") or mime_type in {"text/html", "text/htm", "text/markdown", "text/xml"}:
|
||||
return _extract_text_from_plain_text(file_content)
|
||||
elif mime_type == "application/pdf":
|
||||
return _extract_text_from_pdf(file_content)
|
||||
elif mime_type in {
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/msword",
|
||||
}:
|
||||
return _extract_text_from_doc(file_content)
|
||||
elif mime_type == "text/csv":
|
||||
return _extract_text_from_csv(file_content)
|
||||
elif mime_type in {
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.ms-excel",
|
||||
}:
|
||||
return _extract_text_from_excel(file_content)
|
||||
elif mime_type == "application/vnd.ms-powerpoint":
|
||||
return _extract_text_from_ppt(file_content)
|
||||
elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation":
|
||||
return _extract_text_from_pptx(file_content)
|
||||
elif mime_type == "application/epub+zip":
|
||||
return _extract_text_from_epub(file_content)
|
||||
elif mime_type == "message/rfc822":
|
||||
return _extract_text_from_eml(file_content)
|
||||
elif mime_type == "application/vnd.ms-outlook":
|
||||
return _extract_text_from_msg(file_content)
|
||||
else:
|
||||
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
|
||||
|
||||
|
||||
def _extract_text_from_plain_text(file_content: bytes) -> str:
|
||||
try:
|
||||
return file_content.decode("utf-8")
|
||||
except UnicodeDecodeError as e:
|
||||
raise TextExtractionError("Failed to decode plain text file") from e
|
||||
|
||||
|
||||
def _extract_text_from_pdf(file_content: bytes) -> str:
|
||||
try:
|
||||
pdf_file = io.BytesIO(file_content)
|
||||
pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True)
|
||||
text = ""
|
||||
for page in pdf_document:
|
||||
text_page = page.get_textpage()
|
||||
text += text_page.get_text_range()
|
||||
text_page.close()
|
||||
page.close()
|
||||
return text
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
try:
|
||||
doc_file = io.BytesIO(file_content)
|
||||
doc = docx.Document(doc_file)
|
||||
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||
|
||||
|
||||
def _download_file_content(file: File) -> bytes:
|
||||
"""Download the content of a file based on its transfer method."""
|
||||
try:
|
||||
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if file.remote_url is None:
|
||||
raise FileDownloadError("Missing URL for remote file")
|
||||
response = ssrf_proxy.get(file.remote_url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
return file_manager.download(file)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {file.transfer_method}")
|
||||
except Exception as e:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_file(file: File):
|
||||
if file.mime_type is None:
|
||||
raise UnsupportedFileTypeError("Unable to determine file type: MIME type is missing")
|
||||
file_content = _download_file_content(file)
|
||||
extracted_text = _extract_text(file_content=file_content, mime_type=file.mime_type)
|
||||
return extracted_text
|
||||
|
||||
|
||||
def _extract_text_from_csv(file_content: bytes) -> str:
|
||||
try:
|
||||
csv_file = io.StringIO(file_content.decode("utf-8"))
|
||||
csv_reader = csv.reader(csv_file)
|
||||
rows = list(csv_reader)
|
||||
|
||||
if not rows:
|
||||
return ""
|
||||
|
||||
# Create markdown table
|
||||
markdown_table = "| " + " | ".join(rows[0]) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n"
|
||||
for row in rows[1:]:
|
||||
markdown_table += "| " + " | ".join(row) + " |\n"
|
||||
|
||||
return markdown_table.strip()
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
"""Extract text from an Excel file using pandas."""
|
||||
|
||||
try:
|
||||
df = pd.read_excel(io.BytesIO(file_content))
|
||||
|
||||
# Drop rows where all elements are NaN
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
# Convert DataFrame to markdown table
|
||||
markdown_table = df.to_markdown(index=False)
|
||||
return markdown_table
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_ppt(file=file)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from PPT: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_pptx(file=file)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_epub(file=file)
|
||||
return "\n".join([str(element) for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_eml(file_content: bytes) -> str:
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_email(file=file)
|
||||
return "\n".join([str(element) for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_msg(file_content: bytes) -> str:
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_msg(file=file)
|
||||
return "\n".join([str(element) for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e
|
||||
@@ -0,0 +1,4 @@
|
||||
from .end_node import EndNode
|
||||
from .entities import EndStreamParam
|
||||
|
||||
__all__ = ["EndStreamParam", "EndNode"]
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class EndNode(BaseNode):
|
||||
class EndNode(BaseNode[EndNodeData]):
|
||||
_node_data_cls = EndNodeData
|
||||
_node_type = NodeType.END
|
||||
|
||||
@@ -16,20 +17,27 @@ class EndNode(BaseNode):
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(EndNodeData, node_data)
|
||||
output_variables = node_data.outputs
|
||||
output_variables = self.node_data.outputs
|
||||
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
value = variable.to_object() if variable is not None else None
|
||||
outputs[variable_selector.variable] = value
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=outputs,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: EndNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class EndStreamGeneratorRouter:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class EndNodeData(BaseNodeData):
|
||||
|
||||
24
api/core/workflow/nodes/enums.py
Normal file
24
api/core/workflow/nodes/enums.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class NodeType(str, Enum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
ANSWER = "answer"
|
||||
LLM = "llm"
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TEMPLATE_TRANSFORM = "template-transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
10
api/core/workflow/nodes/event/__init__.py
Normal file
10
api/core/workflow/nodes/event/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from .types import NodeEvent
|
||||
|
||||
__all__ = [
|
||||
"RunCompletedEvent",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunStreamChunkEvent",
|
||||
"NodeEvent",
|
||||
"ModelInvokeCompletedEvent",
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
||||
|
||||
@@ -17,4 +18,11 @@ class RunRetrieverResourceEvent(BaseModel):
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent
|
||||
class ModelInvokeCompletedEvent(BaseModel):
|
||||
"""
|
||||
Model invoke completed
|
||||
"""
|
||||
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: str | None = None
|
||||
3
api/core/workflow/nodes/event/types.py
Normal file
3
api/core/workflow/nodes/event/types.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
|
||||
NodeEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent | ModelInvokeCompletedEvent
|
||||
@@ -0,0 +1,4 @@
|
||||
from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData
|
||||
from .node import HttpRequestNode
|
||||
|
||||
__all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"]
|
||||
|
||||
@@ -1,15 +1,25 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
NON_FILE_CONTENT_TYPES = (
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/html",
|
||||
"text/plain",
|
||||
"application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
|
||||
class HttpRequestNodeAuthorizationConfig(BaseModel):
|
||||
type: Literal[None, "basic", "bearer", "custom"]
|
||||
api_key: Union[None, str] = None
|
||||
header: Union[None, str] = None
|
||||
type: Literal["basic", "bearer", "custom"]
|
||||
api_key: str
|
||||
header: str = ""
|
||||
|
||||
|
||||
class HttpRequestNodeAuthorization(BaseModel):
|
||||
@@ -31,9 +41,16 @@ class HttpRequestNodeAuthorization(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
class BodyData(BaseModel):
|
||||
key: str = ""
|
||||
type: Literal["file", "text"]
|
||||
value: str = ""
|
||||
file: Sequence[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class HttpRequestNodeBody(BaseModel):
|
||||
type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json"]
|
||||
data: Union[None, str] = None
|
||||
type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"]
|
||||
data: Sequence[BodyData] = Field(default_factory=list)
|
||||
|
||||
|
||||
class HttpRequestNodeTimeout(BaseModel):
|
||||
@@ -54,3 +71,51 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
params: str
|
||||
body: Optional[HttpRequestNodeBody] = None
|
||||
timeout: Optional[HttpRequestNodeTimeout] = None
|
||||
|
||||
|
||||
class Response:
|
||||
headers: dict[str, str]
|
||||
response: httpx.Response
|
||||
|
||||
def __init__(self, response: httpx.Response):
|
||||
self.response = response
|
||||
self.headers = dict(response.headers)
|
||||
|
||||
@property
|
||||
def is_file(self):
|
||||
content_type = self.content_type
|
||||
content_disposition = self.response.headers.get("Content-Disposition", "")
|
||||
|
||||
return "attachment" in content_disposition or (
|
||||
not any(non_file in content_type for non_file in NON_FILE_CONTENT_TYPES)
|
||||
and any(file_type in content_type for file_type in ("application/", "image/", "audio/", "video/"))
|
||||
)
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
return self.headers.get("Content-Type", "")
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return self.response.text
|
||||
|
||||
@property
|
||||
def content(self) -> bytes:
|
||||
return self.response.content
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
return self.response.status_code
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return len(self.content)
|
||||
|
||||
@property
|
||||
def readable_size(self) -> str:
|
||||
if self.size < 1024:
|
||||
return f"{self.size} bytes"
|
||||
elif self.size < 1024 * 1024:
|
||||
return f"{(self.size / 1024):.2f} KB"
|
||||
else:
|
||||
return f"{(self.size / 1024 / 1024):.2f} MB"
|
||||
|
||||
327
api/core/workflow/nodes/http_request/executor.py
Normal file
327
api/core/workflow/nodes/http_request/executor.py
Normal file
@@ -0,0 +1,327 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from random import randint
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from .entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
Response,
|
||||
)
|
||||
|
||||
BODY_TYPE_TO_CONTENT_TYPE = {
|
||||
"json": "application/json",
|
||||
"x-www-form-urlencoded": "application/x-www-form-urlencoded",
|
||||
"form-data": "multipart/form-data",
|
||||
"raw-text": "text/plain",
|
||||
}
|
||||
|
||||
|
||||
class Executor:
|
||||
method: Literal["get", "head", "post", "put", "delete", "patch"]
|
||||
url: str
|
||||
params: Mapping[str, str] | None
|
||||
content: str | bytes | None
|
||||
data: Mapping[str, Any] | None
|
||||
files: Mapping[str, bytes] | None
|
||||
json: Any
|
||||
headers: dict[str, str]
|
||||
auth: HttpRequestNodeAuthorization
|
||||
timeout: HttpRequestNodeTimeout
|
||||
|
||||
boundary: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
node_data: HttpRequestNodeData,
|
||||
timeout: HttpRequestNodeTimeout,
|
||||
variable_pool: VariablePool,
|
||||
):
|
||||
# If authorization API key is present, convert the API key using the variable pool
|
||||
if node_data.authorization.type == "api-key":
|
||||
if node_data.authorization.config is None:
|
||||
raise ValueError("authorization config is required")
|
||||
node_data.authorization.config.api_key = variable_pool.convert_template(
|
||||
node_data.authorization.config.api_key
|
||||
).text
|
||||
|
||||
self.url: str = node_data.url
|
||||
self.method = node_data.method
|
||||
self.auth = node_data.authorization
|
||||
self.timeout = timeout
|
||||
self.params = None
|
||||
self.headers = {}
|
||||
self.content = None
|
||||
self.files = None
|
||||
self.data = None
|
||||
self.json = None
|
||||
|
||||
# init template
|
||||
self.variable_pool = variable_pool
|
||||
self.node_data = node_data
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
self._init_url()
|
||||
self._init_params()
|
||||
self._init_headers()
|
||||
self._init_body()
|
||||
|
||||
def _init_url(self):
|
||||
self.url = self.variable_pool.convert_template(self.node_data.url).text
|
||||
|
||||
def _init_params(self):
|
||||
params = self.variable_pool.convert_template(self.node_data.params).text
|
||||
self.params = _plain_text_to_dict(params)
|
||||
|
||||
def _init_headers(self):
|
||||
headers = self.variable_pool.convert_template(self.node_data.headers).text
|
||||
self.headers = _plain_text_to_dict(headers)
|
||||
|
||||
body = self.node_data.body
|
||||
if body is None:
|
||||
return
|
||||
if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE:
|
||||
self.headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
|
||||
if body.type == "form-data":
|
||||
self.boundary = f"----WebKitFormBoundary{_generate_random_string(16)}"
|
||||
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
|
||||
|
||||
def _init_body(self):
|
||||
body = self.node_data.body
|
||||
if body is not None:
|
||||
data = body.data
|
||||
match body.type:
|
||||
case "none":
|
||||
self.content = ""
|
||||
case "raw-text":
|
||||
self.content = self.variable_pool.convert_template(data[0].value).text
|
||||
case "json":
|
||||
json_object = json.loads(data[0].value)
|
||||
self.json = self._parse_object_contains_variables(json_object)
|
||||
case "binary":
|
||||
file_selector = data[0].file
|
||||
file_variable = self.variable_pool.get_file(file_selector)
|
||||
if file_variable is None:
|
||||
raise ValueError(f"cannot fetch file with selector {file_selector}")
|
||||
file = file_variable.value
|
||||
self.content = file_manager.download(file)
|
||||
case "x-www-form-urlencoded":
|
||||
form_data = {
|
||||
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
|
||||
item.value
|
||||
).text
|
||||
for item in data
|
||||
}
|
||||
self.data = form_data
|
||||
case "form-data":
|
||||
form_data = {
|
||||
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
|
||||
item.value
|
||||
).text
|
||||
for item in filter(lambda item: item.type == "text", data)
|
||||
}
|
||||
file_selectors = {
|
||||
self.variable_pool.convert_template(item.key).text: item.file
|
||||
for item in filter(lambda item: item.type == "file", data)
|
||||
}
|
||||
files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()}
|
||||
files = {k: v for k, v in files.items() if v is not None}
|
||||
files = {k: variable.value for k, variable in files.items()}
|
||||
files = {k: file_manager.download(v) for k, v in files.items() if v.related_id is not None}
|
||||
|
||||
self.data = form_data
|
||||
self.files = files
|
||||
|
||||
def _assembling_headers(self) -> dict[str, Any]:
|
||||
authorization = deepcopy(self.auth)
|
||||
headers = deepcopy(self.headers) or {}
|
||||
if self.auth.type == "api-key":
|
||||
if self.auth.config is None:
|
||||
raise ValueError("self.authorization config is required")
|
||||
if authorization.config is None:
|
||||
raise ValueError("authorization config is required")
|
||||
|
||||
if self.auth.config.api_key is None:
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
if not authorization.config.header:
|
||||
authorization.config.header = "Authorization"
|
||||
|
||||
if self.auth.config.type == "bearer":
|
||||
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
|
||||
elif self.auth.config.type == "basic":
|
||||
headers[authorization.config.header] = f"Basic {authorization.config.api_key}"
|
||||
elif self.auth.config.type == "custom":
|
||||
headers[authorization.config.header] = authorization.config.api_key or ""
|
||||
|
||||
return headers
|
||||
|
||||
def _validate_and_parse_response(self, response: httpx.Response) -> Response:
|
||||
executor_response = Response(response)
|
||||
|
||||
threshold_size = (
|
||||
dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE
|
||||
if executor_response.is_file
|
||||
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
|
||||
)
|
||||
if executor_response.size > threshold_size:
|
||||
raise ValueError(
|
||||
f'{"File" if executor_response.is_file else "Text"} size is too large,'
|
||||
f' max size is {threshold_size / 1024 / 1024:.2f} MB,'
|
||||
f' but current size is {executor_response.readable_size}.'
|
||||
)
|
||||
|
||||
return executor_response
|
||||
|
||||
def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
raise ValueError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
"url": self.url,
|
||||
"data": self.data,
|
||||
"files": self.files,
|
||||
"json": self.json,
|
||||
"content": self.content,
|
||||
"headers": headers,
|
||||
"params": self.params,
|
||||
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
||||
return response
|
||||
|
||||
def invoke(self) -> Response:
|
||||
# assemble headers
|
||||
headers = self._assembling_headers()
|
||||
# do http request
|
||||
response = self._do_http_request(headers)
|
||||
# validate response
|
||||
return self._validate_and_parse_response(response)
|
||||
|
||||
def to_log(self):
|
||||
url_parts = urlparse(self.url)
|
||||
path = url_parts.path or "/"
|
||||
|
||||
# Add query parameters
|
||||
if self.params:
|
||||
query_string = urlencode(self.params)
|
||||
path += f"?{query_string}"
|
||||
elif url_parts.query:
|
||||
path += f"?{url_parts.query}"
|
||||
|
||||
raw = f"{self.method.upper()} {path} HTTP/1.1\r\n"
|
||||
raw += f"Host: {url_parts.netloc}\r\n"
|
||||
|
||||
headers = self._assembling_headers()
|
||||
for k, v in headers.items():
|
||||
if self.auth.type == "api-key":
|
||||
authorization_header = "Authorization"
|
||||
if self.auth.config and self.auth.config.header:
|
||||
authorization_header = self.auth.config.header
|
||||
if k.lower() == authorization_header.lower():
|
||||
raw += f'{k}: {"*" * len(v)}\r\n'
|
||||
continue
|
||||
raw += f"{k}: {v}\r\n"
|
||||
|
||||
body = ""
|
||||
if self.files:
|
||||
boundary = self.boundary
|
||||
for k, v in self.files.items():
|
||||
body += f"--{boundary}\r\n"
|
||||
body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
|
||||
body += f"{v[1]}\r\n"
|
||||
body += f"--{boundary}--\r\n"
|
||||
elif self.node_data.body:
|
||||
if self.content:
|
||||
if isinstance(self.content, str):
|
||||
body = self.content
|
||||
elif isinstance(self.content, bytes):
|
||||
body = self.content.decode("utf-8", errors="replace")
|
||||
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
|
||||
body = urlencode(self.data)
|
||||
elif self.data and self.node_data.body.type == "form-data":
|
||||
boundary = self.boundary
|
||||
for key, value in self.data.items():
|
||||
body += f"--{boundary}\r\n"
|
||||
body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
body += f"{value}\r\n"
|
||||
body += f"--{boundary}--\r\n"
|
||||
elif self.json:
|
||||
body = json.dumps(self.json)
|
||||
elif self.node_data.body.type == "raw-text":
|
||||
body = self.node_data.body.data[0].value
|
||||
if body:
|
||||
raw += f"Content-Length: {len(body)}\r\n"
|
||||
raw += "\r\n" # Empty line between headers and body
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
def _parse_object_contains_variables(self, obj: str | dict | list, /) -> Mapping[str, Any] | Sequence[Any] | str:
|
||||
if isinstance(obj, dict):
|
||||
return {k: self._parse_object_contains_variables(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self._parse_object_contains_variables(v) for v in obj]
|
||||
elif isinstance(obj, str):
|
||||
return self.variable_pool.convert_template(obj).text
|
||||
|
||||
|
||||
def _plain_text_to_dict(text: str, /) -> dict[str, str]:
|
||||
"""
|
||||
Convert a string of key-value pairs to a dictionary.
|
||||
|
||||
Each line in the input string represents a key-value pair.
|
||||
Keys and values are separated by ':'.
|
||||
Empty values are allowed.
|
||||
|
||||
Examples:
|
||||
'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'}
|
||||
'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'}
|
||||
'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'}
|
||||
|
||||
Args:
|
||||
convert_text (str): The input string to convert.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: A dictionary of key-value pairs.
|
||||
"""
|
||||
return {
|
||||
key.strip(): (value[0].strip() if value else "")
|
||||
for line in text.splitlines()
|
||||
if line.strip()
|
||||
for key, *value in [line.split(":", 1)]
|
||||
}
|
||||
|
||||
|
||||
def _generate_random_string(n: int) -> str:
|
||||
"""
|
||||
Generate a random string of lowercase ASCII letters.
|
||||
|
||||
Args:
|
||||
n (int): The length of the random string to generate.
|
||||
|
||||
Returns:
|
||||
str: A random string of lowercase ASCII letters with length n.
|
||||
|
||||
Example:
|
||||
>>> _generate_random_string(5)
|
||||
'abcde'
|
||||
"""
|
||||
return "".join([chr(randint(97, 122)) for _ in range(n)])
|
||||
@@ -1,343 +0,0 @@
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from random import randint
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeBody,
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class HttpExecutorResponse:
|
||||
headers: dict[str, str]
|
||||
response: httpx.Response
|
||||
|
||||
def __init__(self, response: httpx.Response):
|
||||
self.response = response
|
||||
self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {}
|
||||
|
||||
@property
|
||||
def is_file(self) -> bool:
|
||||
"""
|
||||
check if response is file
|
||||
"""
|
||||
content_type = self.get_content_type()
|
||||
file_content_types = ["image", "audio", "video"]
|
||||
|
||||
return any(v in content_type for v in file_content_types)
|
||||
|
||||
def get_content_type(self) -> str:
|
||||
return self.headers.get("content-type", "")
|
||||
|
||||
def extract_file(self) -> tuple[str, bytes]:
|
||||
"""
|
||||
extract file from response if content type is file related
|
||||
"""
|
||||
if self.is_file:
|
||||
return self.get_content_type(), self.body
|
||||
|
||||
return "", b""
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
if isinstance(self.response, httpx.Response):
|
||||
return self.response.text
|
||||
else:
|
||||
raise ValueError(f"Invalid response type {type(self.response)}")
|
||||
|
||||
@property
|
||||
def body(self) -> bytes:
|
||||
if isinstance(self.response, httpx.Response):
|
||||
return self.response.content
|
||||
else:
|
||||
raise ValueError(f"Invalid response type {type(self.response)}")
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
if isinstance(self.response, httpx.Response):
|
||||
return self.response.status_code
|
||||
else:
|
||||
raise ValueError(f"Invalid response type {type(self.response)}")
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return len(self.body)
|
||||
|
||||
@property
|
||||
def readable_size(self) -> str:
|
||||
if self.size < 1024:
|
||||
return f"{self.size} bytes"
|
||||
elif self.size < 1024 * 1024:
|
||||
return f"{(self.size / 1024):.2f} KB"
|
||||
else:
|
||||
return f"{(self.size / 1024 / 1024):.2f} MB"
|
||||
|
||||
|
||||
class HttpExecutor:
|
||||
server_url: str
|
||||
method: str
|
||||
authorization: HttpRequestNodeAuthorization
|
||||
params: dict[str, Any]
|
||||
headers: dict[str, Any]
|
||||
body: Union[None, str]
|
||||
files: Union[None, dict[str, Any]]
|
||||
boundary: str
|
||||
variable_selectors: list[VariableSelector]
|
||||
timeout: HttpRequestNodeTimeout
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_data: HttpRequestNodeData,
|
||||
timeout: HttpRequestNodeTimeout,
|
||||
variable_pool: Optional[VariablePool] = None,
|
||||
):
|
||||
self.server_url = node_data.url
|
||||
self.method = node_data.method
|
||||
self.authorization = node_data.authorization
|
||||
self.timeout = timeout
|
||||
self.params = {}
|
||||
self.headers = {}
|
||||
self.body = None
|
||||
self.files = None
|
||||
|
||||
# init template
|
||||
self.variable_selectors = []
|
||||
self._init_template(node_data, variable_pool)
|
||||
|
||||
@staticmethod
|
||||
def _is_json_body(body: HttpRequestNodeBody):
|
||||
"""
|
||||
check if body is json
|
||||
"""
|
||||
if body and body.type == "json" and body.data:
|
||||
try:
|
||||
json.loads(body.data)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(convert_text: str):
|
||||
"""
|
||||
Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}`
|
||||
"""
|
||||
kv_paris = convert_text.split("\n")
|
||||
result = {}
|
||||
for kv in kv_paris:
|
||||
if not kv.strip():
|
||||
continue
|
||||
|
||||
kv = kv.split(":", maxsplit=1)
|
||||
if len(kv) == 1:
|
||||
k, v = kv[0], ""
|
||||
else:
|
||||
k, v = kv
|
||||
result[k.strip()] = v
|
||||
return result
|
||||
|
||||
def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
|
||||
# extract all template in url
|
||||
self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool)
|
||||
|
||||
# extract all template in params
|
||||
params, params_variable_selectors = self._format_template(node_data.params, variable_pool)
|
||||
self.params = self._to_dict(params)
|
||||
|
||||
# extract all template in headers
|
||||
headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool)
|
||||
self.headers = self._to_dict(headers)
|
||||
|
||||
# extract all template in body
|
||||
body_data_variable_selectors = []
|
||||
if node_data.body:
|
||||
# check if it's a valid JSON
|
||||
is_valid_json = self._is_json_body(node_data.body)
|
||||
|
||||
body_data = node_data.body.data or ""
|
||||
if body_data:
|
||||
body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json)
|
||||
|
||||
content_type_is_set = any(key.lower() == "content-type" for key in self.headers)
|
||||
if node_data.body.type == "json" and not content_type_is_set:
|
||||
self.headers["Content-Type"] = "application/json"
|
||||
elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set:
|
||||
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
|
||||
if node_data.body.type in {"form-data", "x-www-form-urlencoded"}:
|
||||
body = self._to_dict(body_data)
|
||||
|
||||
if node_data.body.type == "form-data":
|
||||
self.files = {k: ("", v) for k, v in body.items()}
|
||||
random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)])
|
||||
self.boundary = f"----WebKitFormBoundary{random_str(16)}"
|
||||
|
||||
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
|
||||
else:
|
||||
self.body = urlencode(body)
|
||||
elif node_data.body.type in {"json", "raw-text"}:
|
||||
self.body = body_data
|
||||
elif node_data.body.type == "none":
|
||||
self.body = ""
|
||||
|
||||
self.variable_selectors = (
|
||||
server_url_variable_selectors
|
||||
+ params_variable_selectors
|
||||
+ headers_variable_selectors
|
||||
+ body_data_variable_selectors
|
||||
)
|
||||
|
||||
def _assembling_headers(self) -> dict[str, Any]:
|
||||
authorization = deepcopy(self.authorization)
|
||||
headers = deepcopy(self.headers) or {}
|
||||
if self.authorization.type == "api-key":
|
||||
if self.authorization.config is None:
|
||||
raise ValueError("self.authorization config is required")
|
||||
if authorization.config is None:
|
||||
raise ValueError("authorization config is required")
|
||||
|
||||
if self.authorization.config.api_key is None:
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
if not authorization.config.header:
|
||||
authorization.config.header = "Authorization"
|
||||
|
||||
if self.authorization.config.type == "bearer":
|
||||
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
|
||||
elif self.authorization.config.type == "basic":
|
||||
headers[authorization.config.header] = f"Basic {authorization.config.api_key}"
|
||||
elif self.authorization.config.type == "custom":
|
||||
headers[authorization.config.header] = authorization.config.api_key
|
||||
|
||||
return headers
|
||||
|
||||
def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutorResponse:
|
||||
"""
|
||||
validate the response
|
||||
"""
|
||||
if isinstance(response, httpx.Response):
|
||||
executor_response = HttpExecutorResponse(response)
|
||||
else:
|
||||
raise ValueError(f"Invalid response type {type(response)}")
|
||||
|
||||
threshold_size = (
|
||||
dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE
|
||||
if executor_response.is_file
|
||||
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
|
||||
)
|
||||
if executor_response.size > threshold_size:
|
||||
raise ValueError(
|
||||
f'{"File" if executor_response.is_file else "Text"} size is too large,'
|
||||
f' max size is {threshold_size / 1024 / 1024:.2f} MB,'
|
||||
f' but current size is {executor_response.readable_size}.'
|
||||
)
|
||||
|
||||
return executor_response
|
||||
|
||||
def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
kwargs = {
|
||||
"url": self.server_url,
|
||||
"headers": headers,
|
||||
"params": self.params,
|
||||
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
if self.method in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid http method {self.method}")
|
||||
return response
|
||||
|
||||
def invoke(self) -> HttpExecutorResponse:
|
||||
"""
|
||||
invoke http request
|
||||
"""
|
||||
# assemble headers
|
||||
headers = self._assembling_headers()
|
||||
|
||||
# do http request
|
||||
response = self._do_http_request(headers)
|
||||
|
||||
# validate response
|
||||
return self._validate_and_parse_response(response)
|
||||
|
||||
def to_raw_request(self) -> str:
|
||||
"""
|
||||
convert to raw request
|
||||
"""
|
||||
server_url = self.server_url
|
||||
if self.params:
|
||||
server_url += f"?{urlencode(self.params)}"
|
||||
|
||||
raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n"
|
||||
|
||||
headers = self._assembling_headers()
|
||||
for k, v in headers.items():
|
||||
# get authorization header
|
||||
if self.authorization.type == "api-key":
|
||||
authorization_header = "Authorization"
|
||||
if self.authorization.config and self.authorization.config.header:
|
||||
authorization_header = self.authorization.config.header
|
||||
|
||||
if k.lower() == authorization_header.lower():
|
||||
raw_request += f'{k}: {"*" * len(v)}\n'
|
||||
continue
|
||||
|
||||
raw_request += f"{k}: {v}\n"
|
||||
|
||||
raw_request += "\n"
|
||||
|
||||
# if files, use multipart/form-data with boundary
|
||||
if self.files:
|
||||
boundary = self.boundary
|
||||
raw_request += f"--{boundary}"
|
||||
for k, v in self.files.items():
|
||||
raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n'
|
||||
raw_request += f"{v[1]}\n"
|
||||
raw_request += f"--{boundary}"
|
||||
raw_request += "--"
|
||||
else:
|
||||
raw_request += self.body or ""
|
||||
|
||||
return raw_request
|
||||
|
||||
def _format_template(
|
||||
self, template: str, variable_pool: Optional[VariablePool], escape_quotes: bool = False
|
||||
) -> tuple[str, list[VariableSelector]]:
|
||||
"""
|
||||
format template
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=template)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
if variable_pool:
|
||||
variable_value_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable = variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
if escape_quotes and isinstance(variable, str):
|
||||
value = variable.replace('"', '\\"').replace("\n", "\\n")
|
||||
else:
|
||||
value = variable
|
||||
variable_value_mapping[variable_selector.variable] = value
|
||||
|
||||
return variable_template_parser.format(variable_value_mapping), variable_selectors
|
||||
else:
|
||||
return template, variable_selectors
|
||||
@@ -1,165 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_extension
|
||||
from os import path
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.segments import parser
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
)
|
||||
from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
||||
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
|
||||
write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
class HttpRequestNode(BaseNode):
|
||||
_node_data_cls = HttpRequestNodeData
|
||||
_node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict | None = None) -> dict:
|
||||
return {
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
"method": "get",
|
||||
"authorization": {
|
||||
"type": "no-auth",
|
||||
},
|
||||
"body": {"type": "none"},
|
||||
"timeout": {
|
||||
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
|
||||
"max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
"max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
|
||||
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
|
||||
# TODO: Switch to use segment directly
|
||||
if node_data.authorization.config and node_data.authorization.config.api_key:
|
||||
node_data.authorization.config.api_key = parser.convert_template(
|
||||
template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool
|
||||
).text
|
||||
|
||||
# init http executor
|
||||
http_executor = None
|
||||
try:
|
||||
http_executor = HttpExecutor(
|
||||
node_data=node_data,
|
||||
timeout=self._get_request_timeout(node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
)
|
||||
|
||||
# invoke http executor
|
||||
response = http_executor.invoke()
|
||||
except Exception as e:
|
||||
process_data = {}
|
||||
if http_executor:
|
||||
process_data = {
|
||||
"request": http_executor.to_raw_request(),
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
files = self.extract_files(http_executor.server_url, response)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"status_code": response.status_code,
|
||||
"body": response.content if not files else "",
|
||||
"headers": response.headers,
|
||||
"files": files,
|
||||
},
|
||||
process_data={
|
||||
"request": http_executor.to_raw_request(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout:
|
||||
timeout = node_data.timeout
|
||||
if timeout is None:
|
||||
return HTTP_REQUEST_DEFAULT_TIMEOUT
|
||||
|
||||
timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect
|
||||
timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read
|
||||
timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write
|
||||
return timeout
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
|
||||
|
||||
variable_selectors = http_executor.variable_selectors
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to extract variable selector to variable mapping: {e}")
|
||||
return {}
|
||||
|
||||
def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]:
|
||||
"""
|
||||
Extract files from response
|
||||
"""
|
||||
files = []
|
||||
mimetype, file_binary = response.extract_file()
|
||||
|
||||
if mimetype:
|
||||
# extract filename from url
|
||||
filename = path.basename(url)
|
||||
# extract extension if possible
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_binary,
|
||||
mimetype=mimetype,
|
||||
)
|
||||
|
||||
files.append(
|
||||
FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file.id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
)
|
||||
)
|
||||
|
||||
return files
|
||||
174
api/core/workflow/nodes/http_request/node.py
Normal file
174
api/core/workflow/nodes/http_request/node.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_extension
|
||||
from os import path
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import (
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
Response,
|
||||
)
|
||||
|
||||
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
||||
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
|
||||
write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
_node_data_cls = HttpRequestNodeData
|
||||
_node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict | None = None) -> dict:
|
||||
return {
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
"method": "get",
|
||||
"authorization": {
|
||||
"type": "no-auth",
|
||||
},
|
||||
"body": {"type": "none"},
|
||||
"timeout": {
|
||||
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
|
||||
"max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
"max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
|
||||
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
process_data = {}
|
||||
try:
|
||||
http_executor = Executor(
|
||||
node_data=self.node_data,
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
)
|
||||
process_data["request"] = http_executor.to_log()
|
||||
|
||||
response = http_executor.invoke()
|
||||
files = self.extract_files(url=http_executor.url, response=response)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"status_code": response.status_code,
|
||||
"body": response.text if not files else "",
|
||||
"headers": response.headers,
|
||||
"files": files,
|
||||
},
|
||||
process_data={
|
||||
"request": http_executor.to_log(),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"http request node {self.node_id} failed to run: {e}")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout:
|
||||
timeout = node_data.timeout
|
||||
if timeout is None:
|
||||
return HTTP_REQUEST_DEFAULT_TIMEOUT
|
||||
|
||||
timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect
|
||||
timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read
|
||||
timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write
|
||||
return timeout
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: HttpRequestNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
selectors: list[VariableSelector] = []
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
|
||||
if node_data.body:
|
||||
body_type = node_data.body.type
|
||||
data = node_data.body.data
|
||||
match body_type:
|
||||
case "binary":
|
||||
selector = data[0].file
|
||||
selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector))
|
||||
case "json" | "raw-text":
|
||||
selectors += variable_template_parser.extract_selectors_from_template(data[0].key)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(data[0].value)
|
||||
case "x-www-form-urlencoded":
|
||||
for item in data:
|
||||
selectors += variable_template_parser.extract_selectors_from_template(item.key)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(item.value)
|
||||
case "form-data":
|
||||
for item in data:
|
||||
selectors += variable_template_parser.extract_selectors_from_template(item.key)
|
||||
if item.type == "text":
|
||||
selectors += variable_template_parser.extract_selectors_from_template(item.value)
|
||||
elif item.type == "file":
|
||||
selectors.append(
|
||||
VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file)
|
||||
)
|
||||
|
||||
mapping = {}
|
||||
for selector in selectors:
|
||||
mapping[node_id + "." + selector.variable] = selector.value_selector
|
||||
|
||||
return mapping
|
||||
|
||||
def extract_files(self, url: str, response: Response) -> list[File]:
|
||||
"""
|
||||
Extract files from response
|
||||
"""
|
||||
files = []
|
||||
content_type = response.content_type
|
||||
content = response.content
|
||||
|
||||
if content_type:
|
||||
# extract filename from url
|
||||
filename = path.basename(url)
|
||||
# extract extension if possible
|
||||
extension = guess_extension(content_type) or ".bin"
|
||||
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=content,
|
||||
mimetype=content_type,
|
||||
)
|
||||
|
||||
files.append(
|
||||
File(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file.id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=content_type,
|
||||
)
|
||||
)
|
||||
|
||||
return files
|
||||
@@ -0,0 +1,3 @@
|
||||
from .if_else_node import IfElseNode
|
||||
|
||||
__all__ = ["IfElseNode"]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
@@ -21,6 +21,6 @@ class IfElseNodeData(BaseNodeData):
|
||||
conditions: list[Condition]
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = None
|
||||
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||
|
||||
cases: Optional[list[Case]] = None
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class IfElseNode(BaseNode):
|
||||
class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
_node_data_cls = IfElseNodeData
|
||||
_node_type = NodeType.IF_ELSE
|
||||
|
||||
@@ -17,9 +22,6 @@ class IfElseNode(BaseNode):
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(IfElseNodeData, node_data)
|
||||
|
||||
node_inputs: dict[str, list] = {"conditions": []}
|
||||
|
||||
process_datas: dict[str, list] = {"condition_results": []}
|
||||
@@ -30,15 +32,14 @@ class IfElseNode(BaseNode):
|
||||
condition_processor = ConditionProcessor()
|
||||
try:
|
||||
# Check if the new cases structure is used
|
||||
if node_data.cases:
|
||||
for case in node_data.cases:
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions
|
||||
if self.node_data.cases:
|
||||
for case in self.node_data.cases:
|
||||
input_conditions, group_result, final_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=case.conditions,
|
||||
operator=case.logical_operator,
|
||||
)
|
||||
|
||||
# Apply the logical operator for the current case
|
||||
final_result = all(group_result) if case.logical_operator == "and" else any(group_result)
|
||||
|
||||
process_datas["condition_results"].append(
|
||||
{
|
||||
"group": case.model_dump(),
|
||||
@@ -53,13 +54,15 @@ class IfElseNode(BaseNode):
|
||||
break
|
||||
|
||||
else:
|
||||
# TODO: Update database then remove this
|
||||
# Fallback to old structure if cases are not defined
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool, conditions=node_data.conditions
|
||||
input_conditions, group_result, final_result = _should_not_use_old_function(
|
||||
condition_processor=condition_processor,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=self.node_data.conditions or [],
|
||||
operator=self.node_data.logical_operator or "and",
|
||||
)
|
||||
|
||||
final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)
|
||||
|
||||
selected_case_id = "true" if final_result else "false"
|
||||
|
||||
process_datas["condition_results"].append(
|
||||
@@ -87,7 +90,11 @@ class IfElseNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IfElseNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IfElseNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -97,3 +104,18 @@ class IfElseNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
@deprecated("This function is deprecated. You should use the new cases structure.")
|
||||
def _should_not_use_old_function(
|
||||
*,
|
||||
condition_processor: ConditionProcessor,
|
||||
variable_pool: VariablePool,
|
||||
conditions: list[Condition],
|
||||
operator: Literal["and", "or"],
|
||||
):
|
||||
return condition_processor.process_conditions(
|
||||
variable_pool=variable_pool,
|
||||
conditions=conditions,
|
||||
operator=operator,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
from .entities import IterationNodeData
|
||||
from .iteration_node import IterationNode
|
||||
from .iteration_start_node import IterationStartNode
|
||||
|
||||
__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
|
||||
|
||||
class IterationNodeData(BaseIterationNodeData):
|
||||
@@ -26,7 +28,7 @@ class IterationState(BaseIterationState):
|
||||
Iteration State.
|
||||
"""
|
||||
|
||||
outputs: list[Any] = None
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Optional[Any] = None
|
||||
|
||||
class MetaData(BaseIterationState.MetaData):
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
@@ -20,15 +20,16 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IterationNode(BaseNode):
|
||||
class IterationNode(BaseNode[IterationNodeData]):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
@@ -36,11 +37,10 @@ class IterationNode(BaseNode):
|
||||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
self.node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||
|
||||
if not iterator_list_segment:
|
||||
@@ -177,7 +177,7 @@ class IterationNode(BaseNode):
|
||||
|
||||
# remove all nodes outputs from variable pool
|
||||
for node_id in iteration_graph.node_ids:
|
||||
variable_pool.remove_node(node_id)
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
# move to next iteration
|
||||
current_index = variable_pool.get([self.node_id, "index"])
|
||||
@@ -247,7 +247,11 @@ class IterationNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -273,15 +277,13 @@ class IterationNode(BaseNode):
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
|
||||
node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_type_classes_mapping.get(node_type)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
node_cls = cast(BaseNode, node_cls)
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
|
||||
__all__ = ["KnowledgeRetrievalNode"]
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
|
||||
@@ -14,8 +14,9 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@@ -32,15 +33,13 @@ default_retrieval_model = {
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
_node_data_cls = KnowledgeRetrievalNodeData
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
|
||||
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get_any(self.node_data.query_variable_selector)
|
||||
query = variable
|
||||
variables = {"query": query}
|
||||
if not query:
|
||||
@@ -49,7 +48,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
# retrieve knowledge
|
||||
try:
|
||||
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
|
||||
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
|
||||
outputs = {"result": results}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
||||
@@ -244,7 +243,11 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: KnowledgeRetrievalNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
|
||||
3
api/core/workflow/nodes/list_operator/__init__.py
Normal file
3
api/core/workflow/nodes/list_operator/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .node import ListOperatorNode
|
||||
|
||||
__all__ = ["ListOperatorNode"]
|
||||
56
api/core/workflow/nodes/list_operator/entities.py
Normal file
56
api/core/workflow/nodes/list_operator/entities.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
_Condition = Literal[
|
||||
# string conditions
|
||||
"contains",
|
||||
"startswith",
|
||||
"endswith",
|
||||
"is",
|
||||
"in",
|
||||
"empty",
|
||||
"not contains",
|
||||
"not is",
|
||||
"not in",
|
||||
"not empty",
|
||||
# number conditions
|
||||
"=",
|
||||
"!=",
|
||||
"<",
|
||||
">",
|
||||
"≥",
|
||||
"≤",
|
||||
]
|
||||
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
key: str = ""
|
||||
comparison_operator: _Condition = "contains"
|
||||
value: str | Sequence[str] = ""
|
||||
|
||||
|
||||
class FilterBy(BaseModel):
|
||||
enabled: bool = False
|
||||
conditions: Sequence[FilterCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OrderBy(BaseModel):
|
||||
enabled: bool = False
|
||||
key: str = ""
|
||||
value: Literal["asc", "desc"] = "asc"
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
enabled: bool = False
|
||||
size: int = -1
|
||||
|
||||
|
||||
class ListOperatorNodeData(BaseNodeData):
|
||||
variable: Sequence[str] = Field(default_factory=list)
|
||||
filter_by: FilterBy
|
||||
order_by: OrderBy
|
||||
limit: Limit
|
||||
259
api/core/workflow/nodes/list_operator/node.py
Normal file
259
api/core/workflow/nodes/list_operator/node.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Literal
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import ListOperatorNodeData
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
_node_data_cls = ListOperatorNodeData
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
def _run(self):
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
|
||||
if variable is None:
|
||||
error_message = f"Variable not found for selector: {self.node_data.variable}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
|
||||
error_message = (
|
||||
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
|
||||
"or ArrayStringSegment"
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
process_data["variable"] = [item.to_dict() for item in variable.value]
|
||||
else:
|
||||
process_data["variable"] = variable.value
|
||||
|
||||
# Filter
|
||||
if self.node_data.filter_by.enabled:
|
||||
for condition in self.node_data.filter_by.conditions:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
raise ValueError(f"Invalid filter value: {condition.value}")
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
raise ValueError(f"Invalid filter value: {condition.value}")
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
if isinstance(condition.value, str):
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
else:
|
||||
value = condition.value
|
||||
filter_func = _get_file_filter_func(
|
||||
key=condition.key,
|
||||
condition=condition.comparison_operator,
|
||||
value=value,
|
||||
)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
|
||||
# Order
|
||||
if self.node_data.order_by.enabled:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
result = _order_file(
|
||||
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
|
||||
# Slice
|
||||
if self.node_data.limit.enabled:
|
||||
result = variable.value[: self.node_data.limit.size]
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
|
||||
outputs = {
|
||||
"result": variable.value,
|
||||
"first_record": variable.value[0] if variable.value else None,
|
||||
"last_record": variable.value[-1] if variable.value else None,
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
|
||||
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
|
||||
match key:
|
||||
case "size":
|
||||
return lambda x: x.size
|
||||
case _:
|
||||
raise ValueError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
|
||||
match key:
|
||||
case "name":
|
||||
return lambda x: x.filename or ""
|
||||
case "type":
|
||||
return lambda x: x.type
|
||||
case "extension":
|
||||
return lambda x: x.extension or ""
|
||||
case "mimetype":
|
||||
return lambda x: x.mime_type or ""
|
||||
case "transfer_method":
|
||||
return lambda x: x.transfer_method
|
||||
case "urL":
|
||||
return lambda x: x.remote_url or ""
|
||||
case _:
|
||||
raise ValueError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
|
||||
match condition:
|
||||
case "contains":
|
||||
return _contains(value)
|
||||
case "startswith":
|
||||
return _startswith(value)
|
||||
case "endswith":
|
||||
return _endswith(value)
|
||||
case "is":
|
||||
return _is(value)
|
||||
case "in":
|
||||
return _in(value)
|
||||
case "empty":
|
||||
return lambda x: x == ""
|
||||
case "not contains":
|
||||
return lambda x: not _contains(value)(x)
|
||||
case "not is":
|
||||
return lambda x: not _is(value)(x)
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
case "not empty":
|
||||
return lambda x: x != ""
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
|
||||
match condition:
|
||||
case "in":
|
||||
return _in(value)
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
|
||||
match condition:
|
||||
case "=":
|
||||
return _eq(value)
|
||||
case "!=":
|
||||
return _ne(value)
|
||||
case "<":
|
||||
return _lt(value)
|
||||
case "≤":
|
||||
return _le(value)
|
||||
case ">":
|
||||
return _gt(value)
|
||||
case "≥":
|
||||
return _ge(value)
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
if key in {"type", "transfer_method"} and isinstance(value, Sequence):
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
elif key == "size" and isinstance(value, str):
|
||||
extract_func = _get_file_extract_number_func(key=key)
|
||||
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
|
||||
else:
|
||||
raise ValueError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _contains(value: str):
|
||||
return lambda x: value in x
|
||||
|
||||
|
||||
def _startswith(value: str):
|
||||
return lambda x: x.startswith(value)
|
||||
|
||||
|
||||
def _endswith(value: str):
|
||||
return lambda x: x.endswith(value)
|
||||
|
||||
|
||||
def _is(value: str):
|
||||
return lambda x: x is value
|
||||
|
||||
|
||||
def _in(value: str | Sequence[str]):
|
||||
return lambda x: x in value
|
||||
|
||||
|
||||
def _eq(value: int | float):
|
||||
return lambda x: x == value
|
||||
|
||||
|
||||
def _ne(value: int | float):
|
||||
return lambda x: x != value
|
||||
|
||||
|
||||
def _lt(value: int | float):
|
||||
return lambda x: x < value
|
||||
|
||||
|
||||
def _le(value: int | float):
|
||||
return lambda x: x <= value
|
||||
|
||||
|
||||
def _gt(value: int | float):
|
||||
return lambda x: x > value
|
||||
|
||||
|
||||
def _ge(value: int | float):
|
||||
return lambda x: x >= value
|
||||
|
||||
|
||||
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
|
||||
return sorted(array, key=lambda x: x, reverse=order == "desc")
|
||||
|
||||
|
||||
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
|
||||
return sorted(array, key=lambda x: x, reverse=order == "desc")
|
||||
|
||||
|
||||
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
|
||||
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "urL"}:
|
||||
extract_func = _get_file_extract_string_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
elif order_by == "size":
|
||||
extract_func = _get_file_extract_number_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
else:
|
||||
raise ValueError(f"Invalid order key: {order_by}")
|
||||
@@ -0,0 +1,17 @@
|
||||
from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from .node import LLMNode
|
||||
|
||||
__all__ = [
|
||||
"LLMNode",
|
||||
"LLMNodeChatModelMessage",
|
||||
"LLMNodeCompletionModelPromptTemplate",
|
||||
"LLMNodeData",
|
||||
"ModelConfig",
|
||||
"VisionConfig",
|
||||
]
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
@@ -19,62 +17,36 @@ class ModelConfig(BaseModel):
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
"""
|
||||
Context Config.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
variable_selector: Optional[list[str]] = None
|
||||
|
||||
|
||||
class VisionConfigOptions(BaseModel):
|
||||
variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"])
|
||||
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
|
||||
class VisionConfig(BaseModel):
|
||||
"""
|
||||
Vision Config.
|
||||
"""
|
||||
|
||||
class Configs(BaseModel):
|
||||
"""
|
||||
Configs.
|
||||
"""
|
||||
|
||||
detail: Literal["low", "high"]
|
||||
|
||||
enabled: bool
|
||||
configs: Optional[Configs] = None
|
||||
enabled: bool = False
|
||||
configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions)
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""
|
||||
Prompt Config.
|
||||
"""
|
||||
|
||||
jinja2_variables: Optional[list[VariableSelector]] = None
|
||||
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
"""
|
||||
LLM Node Chat Model Message.
|
||||
"""
|
||||
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
|
||||
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
"""
|
||||
LLM Node Chat Model Prompt Template.
|
||||
"""
|
||||
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
"""
|
||||
LLM Node Data.
|
||||
"""
|
||||
|
||||
model: ModelConfig
|
||||
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: Optional[PromptConfig] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@@ -1,39 +1,40 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
NodeEvent,
|
||||
RunCompletedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunStreamChunkEvent,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
@@ -41,44 +42,34 @@ from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class ModelInvokeCompleted(BaseModel):
|
||||
"""
|
||||
Model invoke completed
|
||||
"""
|
||||
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
class LLMNode(BaseNode[LLMNodeData]):
|
||||
_node_data_cls = LLMNodeData
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(LLMNodeData, deepcopy(self.node_data))
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
node_inputs = None
|
||||
process_data = None
|
||||
|
||||
try:
|
||||
# init messages template
|
||||
node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
|
||||
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data, variable_pool)
|
||||
inputs = self._fetch_inputs(node_data=self.node_data)
|
||||
|
||||
# fetch jinja2 inputs
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
|
||||
|
||||
# merge inputs
|
||||
inputs.update(jinja_inputs)
|
||||
@@ -86,13 +77,17 @@ class LLMNode(BaseNode):
|
||||
node_inputs = {}
|
||||
|
||||
# fetch files
|
||||
files = self._fetch_files(node_data, variable_pool)
|
||||
files = (
|
||||
self._fetch_files(selector=self.node_data.vision.configs.variable_selector)
|
||||
if self.node_data.vision.enabled
|
||||
else []
|
||||
)
|
||||
|
||||
if files:
|
||||
node_inputs["#files#"] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
generator = self._fetch_context(node_data, variable_pool)
|
||||
generator = self._fetch_context(node_data=self.node_data)
|
||||
context = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunRetrieverResourceEvent):
|
||||
@@ -103,21 +98,30 @@ class LLMNode(BaseNode):
|
||||
node_inputs["#context#"] = context # type: ignore
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
model_instance, model_config = self._fetch_model_config(self.node_data.model)
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance)
|
||||
|
||||
# fetch prompt messages
|
||||
if self.node_data.memory:
|
||||
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
if not query:
|
||||
raise ValueError("Query not found")
|
||||
query = query.text
|
||||
else:
|
||||
query = None
|
||||
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None,
|
||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||
system_query=query,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
)
|
||||
|
||||
process_data = {
|
||||
@@ -131,7 +135,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
node_data_model=self.node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
@@ -143,7 +147,7 @@ class LLMNode(BaseNode):
|
||||
for event in generator:
|
||||
if isinstance(event, RunStreamChunkEvent):
|
||||
yield event
|
||||
elif isinstance(event, ModelInvokeCompleted):
|
||||
elif isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
@@ -182,15 +186,7 @@ class LLMNode(BaseNode):
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
@@ -207,20 +203,13 @@ class LLMNode(BaseNode):
|
||||
usage = LLMUsage.empty_usage()
|
||||
for event in generator:
|
||||
yield event
|
||||
if isinstance(event, ModelInvokeCompleted):
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
usage = event.usage
|
||||
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
def _handle_invoke_result(
|
||||
self, invoke_result: LLMResult | Generator
|
||||
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:return:
|
||||
"""
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
return
|
||||
|
||||
@@ -250,18 +239,11 @@ class LLMNode(BaseNode):
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason)
|
||||
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
"""
|
||||
Transform chat messages
|
||||
|
||||
:param messages: chat messages
|
||||
:return:
|
||||
"""
|
||||
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
|
||||
if messages.edition_type == "jinja2" and messages.jinja2_text:
|
||||
messages.text = messages.jinja2_text
|
||||
@@ -274,13 +256,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
return messages
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Fetch jinja inputs
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
variables = {}
|
||||
|
||||
if not node_data.prompt_config:
|
||||
@@ -288,7 +264,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
|
||||
def parse_dict(d: dict) -> str:
|
||||
"""
|
||||
@@ -330,13 +306,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
return variables
|
||||
|
||||
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Fetch inputs
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
inputs = {}
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
@@ -350,7 +320,7 @@ class LLMNode(BaseNode):
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
for variable_selector in variable_selectors:
|
||||
variable_value = variable_pool.get_any(variable_selector.value_selector)
|
||||
variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
@@ -362,7 +332,7 @@ class LLMNode(BaseNode):
|
||||
template=memory.query_prompt_template
|
||||
).extract_variable_selectors()
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_value = variable_pool.get_any(variable_selector.value_selector)
|
||||
variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
@@ -370,36 +340,28 @@ class LLMNode(BaseNode):
|
||||
|
||||
return inputs
|
||||
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
|
||||
"""
|
||||
Fetch files
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data.vision.enabled:
|
||||
def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if variable is None:
|
||||
return []
|
||||
|
||||
files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value])
|
||||
if not files:
|
||||
if isinstance(variable, FileSegment):
|
||||
return [variable.value]
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
return variable.value
|
||||
# FIXME: Temporary fix for empty array,
|
||||
# all variables added to variable pool should be a Segment instance.
|
||||
if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0:
|
||||
return []
|
||||
raise ValueError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
return files
|
||||
|
||||
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
|
||||
"""
|
||||
Fetch context
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
def _fetch_context(self, node_data: LLMNodeData):
|
||||
if not node_data.context.enabled:
|
||||
return
|
||||
|
||||
if not node_data.context.variable_selector:
|
||||
return
|
||||
|
||||
context_value = variable_pool.get_any(node_data.context.variable_selector)
|
||||
context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
if isinstance(context_value, str):
|
||||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
|
||||
@@ -424,11 +386,6 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
||||
"""
|
||||
Convert to original retriever resource, temp.
|
||||
:param context_dict: context dict
|
||||
:return:
|
||||
"""
|
||||
if (
|
||||
"metadata" in context_dict
|
||||
and "_source" in context_dict["metadata"]
|
||||
@@ -451,6 +408,7 @@ class LLMNode(BaseNode):
|
||||
"segment_position": metadata.get("segment_position"),
|
||||
"index_node_hash": metadata.get("segment_index_node_hash"),
|
||||
"content": context_dict.get("content"),
|
||||
"page": metadata.get("page"),
|
||||
}
|
||||
|
||||
return source
|
||||
@@ -460,11 +418,6 @@ class LLMNode(BaseNode):
|
||||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data_model: node data model
|
||||
:return:
|
||||
"""
|
||||
model_name = node_data_model.name
|
||||
provider_name = node_data_model.provider
|
||||
|
||||
@@ -523,19 +476,15 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
|
||||
def _fetch_memory(
|
||||
self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance
|
||||
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
|
||||
) -> Optional[TokenBufferMemory]:
|
||||
"""
|
||||
Fetch memory
|
||||
:param node_data_memory: node data memory
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value])
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get_any(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||
)
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
@@ -555,43 +504,31 @@ class LLMNode(BaseNode):
|
||||
|
||||
def _fetch_prompt_messages(
|
||||
self,
|
||||
node_data: LLMNodeData,
|
||||
query: Optional[str],
|
||||
query_prompt_template: Optional[str],
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
*,
|
||||
system_query: str | None = None,
|
||||
inputs: dict[str, str] | None = None,
|
||||
files: Sequence["File"],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Fetch prompt messages
|
||||
:param node_data: node data
|
||||
:param query: query
|
||||
:param query_prompt_template: query prompt template
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param context: context
|
||||
:param memory: memory
|
||||
:param model_config: model config
|
||||
:return:
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=node_data.prompt_template,
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query or "",
|
||||
query=system_query or "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
query_prompt_template=query_prompt_template,
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
vision_enabled = node_data.vision.enabled
|
||||
vision_detail = node_data.vision.configs.detail if node_data.vision.configs else None
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.is_empty():
|
||||
@@ -599,17 +536,13 @@ class LLMNode(BaseNode):
|
||||
|
||||
if not isinstance(prompt_message.content, str):
|
||||
prompt_message_content = []
|
||||
for content_item in prompt_message.content:
|
||||
if (
|
||||
vision_enabled
|
||||
and content_item.type == PromptMessageContentType.IMAGE
|
||||
and isinstance(content_item, ImagePromptMessageContent)
|
||||
):
|
||||
# Override vision config if LLM node has vision config
|
||||
if vision_detail:
|
||||
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
|
||||
for content_item in prompt_message.content or []:
|
||||
if isinstance(content_item, ImagePromptMessageContent):
|
||||
# Override vision config if LLM node has vision config,
|
||||
# cuz vision detail is related to the configuration from FileUpload feature.
|
||||
content_item.detail = vision_detail
|
||||
prompt_message_content.append(content_item)
|
||||
elif content_item.type == PromptMessageContentType.TEXT:
|
||||
elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent):
|
||||
prompt_message_content.append(content_item)
|
||||
|
||||
if len(prompt_message_content) > 1:
|
||||
@@ -631,13 +564,6 @@ class LLMNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
"""
|
||||
Deduct LLM quota
|
||||
:param tenant_id: tenant id
|
||||
:param model_instance: model instance
|
||||
:param usage: usage
|
||||
:return:
|
||||
"""
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
@@ -668,7 +594,7 @@ class LLMNode(BaseNode):
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None:
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_instance.provider,
|
||||
@@ -680,27 +606,28 @@ class LLMNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LLMNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
if isinstance(prompt_template, list) and all(
|
||||
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
|
||||
):
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type != "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
else:
|
||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
if prompt_template.edition_type != "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
else:
|
||||
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}")
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
@@ -745,11 +672,6 @@ class LLMNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"type": "llm",
|
||||
"config": {
|
||||
@@ -1,4 +1,4 @@
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
|
||||
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState
|
||||
|
||||
|
||||
class LoopNodeData(BaseIterationNodeData):
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class LoopNode(BaseNode):
|
||||
class LoopNode(BaseNode[LoopNodeData]):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
|
||||
from core.workflow.nodes.answer import AnswerNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.end import EndNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.if_else import IfElseNode
|
||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from core.workflow.nodes.start import StartNode
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode
|
||||
|
||||
node_classes = {
|
||||
node_type_classes_mapping: dict[NodeType, type[BaseNode]] = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.ANSWER: AnswerNode,
|
||||
@@ -34,4 +36,6 @@ node_classes = {
|
||||
NodeType.ITERATION_START: IterationStartNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNode,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .parameter_extractor_node import ParameterExtractorNode
|
||||
|
||||
__all__ = ["ParameterExtractorNode"]
|
||||
|
||||
@@ -1,20 +1,10 @@
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = {}
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
class ParameterConfig(BaseModel):
|
||||
@@ -49,6 +39,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
reasoning_mode: Literal["function_call", "prompt"]
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@field_validator("reasoning_mode", mode="before")
|
||||
@classmethod
|
||||
@@ -64,7 +55,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
parameters = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
for parameter in self.parameters:
|
||||
parameter_schema = {"description": parameter.description}
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
|
||||
if parameter.type in {"string", "select"}:
|
||||
parameter_schema["type"] = "string"
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
@@ -22,12 +23,16 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from core.workflow.nodes.parameter_extractor.prompts import (
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.llm import LLMNode, ModelConfig
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .prompts import (
|
||||
CHAT_EXAMPLE,
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
|
||||
COMPLETION_GENERATE_JSON_PROMPT,
|
||||
@@ -36,9 +41,6 @@ from core.workflow.nodes.parameter_extractor.prompts import (
|
||||
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT,
|
||||
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class ParameterExtractorNode(LLMNode):
|
||||
@@ -65,33 +67,39 @@ class ParameterExtractorNode(LLMNode):
|
||||
}
|
||||
}
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
def _run(self):
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
node_data = cast(ParameterExtractorNodeData, self.node_data)
|
||||
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query)
|
||||
if not variable:
|
||||
raise ValueError("Input variable content not found or is empty")
|
||||
query = variable
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
||||
query = variable.text if variable else ""
|
||||
|
||||
inputs = {
|
||||
"query": query,
|
||||
"parameters": jsonable_encoder(node_data.parameters),
|
||||
"instruction": jsonable_encoder(node_data.instruction),
|
||||
}
|
||||
files = (
|
||||
self._fetch_files(
|
||||
selector=node_data.vision.configs.variable_selector,
|
||||
)
|
||||
if node_data.vision.enabled
|
||||
else []
|
||||
)
|
||||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise ValueError("Model is not a Large Language Model")
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
|
||||
model_schema = llm_model.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance)
|
||||
memory = self._fetch_memory(
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
if (
|
||||
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
|
||||
@@ -99,15 +107,33 @@ class ParameterExtractorNode(LLMNode):
|
||||
):
|
||||
# use function call
|
||||
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
|
||||
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
model_config=model_config,
|
||||
memory=memory,
|
||||
files=files,
|
||||
)
|
||||
else:
|
||||
# use prompt engineering
|
||||
prompt_messages = self._generate_prompt_engineering_prompt(
|
||||
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
|
||||
data=node_data,
|
||||
query=query,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
model_config=model_config,
|
||||
memory=memory,
|
||||
files=files,
|
||||
)
|
||||
|
||||
prompt_message_tools = []
|
||||
|
||||
inputs = {
|
||||
"query": query,
|
||||
"files": [f.to_dict() for f in files],
|
||||
"parameters": jsonable_encoder(node_data.parameters),
|
||||
"instruction": jsonable_encoder(node_data.instruction),
|
||||
}
|
||||
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
@@ -119,7 +145,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
}
|
||||
|
||||
try:
|
||||
text, usage, tool_call = self._invoke_llm(
|
||||
text, usage, tool_call = self._invoke(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
@@ -150,12 +176,12 @@ class ParameterExtractorNode(LLMNode):
|
||||
error = "Failed to extract result from function call or text response, using empty result."
|
||||
|
||||
try:
|
||||
result = self._validate_result(node_data, result)
|
||||
result = self._validate_result(data=node_data, result=result or {})
|
||||
except Exception as e:
|
||||
error = str(e)
|
||||
|
||||
# transform result into standard format
|
||||
result = self._transform_result(node_data, result)
|
||||
result = self._transform_result(data=node_data, result=result or {})
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -170,7 +196,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
def _invoke_llm(
|
||||
def _invoke(
|
||||
self,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
@@ -178,14 +204,6 @@ class ParameterExtractorNode(LLMNode):
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
@@ -202,6 +220,9 @@ class ParameterExtractorNode(LLMNode):
|
||||
raise ValueError(f"Invalid invoke result: {invoke_result}")
|
||||
|
||||
text = invoke_result.message.content
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Invalid text content type: {type(text)}. Expected str.")
|
||||
|
||||
usage = invoke_result.usage
|
||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||
|
||||
@@ -217,6 +238,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
files: Sequence[File],
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
"""
|
||||
Generate function call prompt.
|
||||
@@ -234,7 +256,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
@@ -296,6 +318,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
files: Sequence[File],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate prompt engineering prompt.
|
||||
@@ -303,9 +326,23 @@ class ParameterExtractorNode(LLMNode):
|
||||
model_mode = ModelMode.value_of(data.model.mode)
|
||||
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
return self._generate_prompt_engineering_completion_prompt(data, query, variable_pool, model_config, memory)
|
||||
return self._generate_prompt_engineering_completion_prompt(
|
||||
node_data=data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_config=model_config,
|
||||
memory=memory,
|
||||
files=files,
|
||||
)
|
||||
elif model_mode == ModelMode.CHAT:
|
||||
return self._generate_prompt_engineering_chat_prompt(data, query, variable_pool, model_config, memory)
|
||||
return self._generate_prompt_engineering_chat_prompt(
|
||||
node_data=data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_config=model_config,
|
||||
memory=memory,
|
||||
files=files,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid model mode: {model_mode}")
|
||||
|
||||
@@ -316,20 +353,23 @@ class ParameterExtractorNode(LLMNode):
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
files: Sequence[File],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate completion prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
|
||||
)
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data, query, variable_pool, memory, rest_token
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
|
||||
)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={"structure": json.dumps(node_data.get_parameter_json_schema())},
|
||||
query="",
|
||||
files=[],
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=memory,
|
||||
@@ -345,27 +385,30 @@ class ParameterExtractorNode(LLMNode):
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
files: Sequence[File],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate chat prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
|
||||
)
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data,
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
||||
node_data=node_data,
|
||||
query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
||||
structure=json.dumps(node_data.get_parameter_json_schema()), text=query
|
||||
),
|
||||
variable_pool,
|
||||
memory,
|
||||
rest_token,
|
||||
variable_pool=variable_pool,
|
||||
memory=memory,
|
||||
max_token_limit=rest_token,
|
||||
)
|
||||
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
@@ -425,10 +468,11 @@ class ParameterExtractorNode(LLMNode):
|
||||
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type.startswith("array"):
|
||||
if not isinstance(result.get(parameter.name), list):
|
||||
parameters = result.get(parameter.name)
|
||||
if not isinstance(parameters, list):
|
||||
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
nested_type = parameter.type[6:-1]
|
||||
for item in result.get(parameter.name):
|
||||
for item in parameters:
|
||||
if nested_type == "number" and not isinstance(item, int | float):
|
||||
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||
if nested_type == "string" and not isinstance(item, str):
|
||||
@@ -565,18 +609,6 @@ class ParameterExtractorNode(LLMNode):
|
||||
|
||||
return result
|
||||
|
||||
def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
|
||||
"""
|
||||
Render instruction.
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(instruction)
|
||||
inputs = {}
|
||||
for selector in variable_template_parser.extract_variable_selectors():
|
||||
variable = variable_pool.get_any(selector.value_selector)
|
||||
inputs[selector.variable] = variable
|
||||
|
||||
return variable_template_parser.format(inputs)
|
||||
|
||||
def _get_function_calling_prompt_template(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
@@ -588,9 +620,9 @@ class ParameterExtractorNode(LLMNode):
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
instruction = self._render_instruction(node_data.instruction or "", variable_pool)
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
|
||||
if memory:
|
||||
if memory and node_data.memory and node_data.memory.window:
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
@@ -611,13 +643,13 @@ class ParameterExtractorNode(LLMNode):
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
):
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
instruction = self._render_instruction(node_data.instruction or "", variable_pool)
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
|
||||
if memory:
|
||||
if memory and node_data.memory and node_data.memory.window:
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
@@ -691,7 +723,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
@@ -712,7 +744,11 @@ class ParameterExtractorNode(LLMNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ParameterExtractorNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -721,11 +757,11 @@ class ParameterExtractorNode(LLMNode):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
variable_mapping = {"query": node_data.query}
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
||||
|
||||
if node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
|
||||
for selector in variable_template_parser.extract_variable_selectors():
|
||||
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
|
||||
for selector in selectors:
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from .entities import QuestionClassifierNodeData
|
||||
from .question_classifier_node import QuestionClassifierNode
|
||||
|
||||
__all__ = ["QuestionClassifierNodeData", "QuestionClassifierNode"]
|
||||
|
||||
@@ -1,39 +1,21 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = {}
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
class ClassConfig(BaseModel):
|
||||
"""
|
||||
Class Config.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class QuestionClassifierNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
|
||||
query_variable_selector: list[str]
|
||||
type: str = "question-classifier"
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@@ -1,25 +1,30 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted
|
||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from core.workflow.nodes.question_classifier.template_prompts import (
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import ModelInvokeCompletedEvent
|
||||
from core.workflow.nodes.llm import (
|
||||
LLMNode,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
from .template_prompts import (
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
|
||||
QUESTION_CLASSIFIER_COMPLETION_PROMPT,
|
||||
@@ -28,46 +33,77 @@ from core.workflow.nodes.question_classifier.template_prompts import (
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_2,
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_3,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file import File
|
||||
|
||||
|
||||
class QuestionClassifierNode(LLMNode):
|
||||
_node_data_cls = QuestionClassifierNodeData
|
||||
node_type = NodeType.QUESTION_CLASSIFIER
|
||||
_node_type = NodeType.QUESTION_CLASSIFIER
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
|
||||
node_data = cast(QuestionClassifierNodeData, node_data)
|
||||
def _run(self):
|
||||
node_data = cast(QuestionClassifierNodeData, self.node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# extract variables
|
||||
variable = variable_pool.get(node_data.query_variable_selector)
|
||||
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
|
||||
query = variable.value if variable else None
|
||||
variables = {"query": query}
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
memory = self._fetch_memory(
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
# fetch instruction
|
||||
instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else ""
|
||||
node_data.instruction = instruction
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
|
||||
|
||||
files: Sequence[File] = (
|
||||
self._fetch_files(
|
||||
selector=node_data.vision.configs.variable_selector,
|
||||
)
|
||||
if node_data.vision.enabled
|
||||
else []
|
||||
)
|
||||
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt(
|
||||
node_data=node_data, context="", query=query, memory=memory, model_config=model_config
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query or "",
|
||||
model_config=model_config,
|
||||
context="",
|
||||
)
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
query=query or "",
|
||||
memory=memory,
|
||||
max_token_limit=rest_token,
|
||||
)
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
system_query=query,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
files=files,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompleted):
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
@@ -129,7 +165,11 @@ class QuestionClassifierNode(LLMNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: QuestionClassifierNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -159,40 +199,6 @@ class QuestionClassifierNode(LLMNode):
|
||||
"""
|
||||
return {"type": "question-classifier", "config": {"instructions": ""}}
|
||||
|
||||
def _fetch_prompt(
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Fetch prompt
|
||||
:param node_data: node data
|
||||
:param query: inputs
|
||||
:param context: context
|
||||
:param memory: memory
|
||||
:param model_config: model config
|
||||
:return:
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, model_config, context)
|
||||
prompt_template = self._get_prompt_template(node_data, query, memory, rest_token)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
@@ -229,7 +235,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
@@ -243,7 +249,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
query: str,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
) -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
|
||||
):
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
classes = node_data.classes
|
||||
categories = []
|
||||
@@ -255,31 +261,32 @@ class QuestionClassifierNode(LLMNode):
|
||||
memory_str = ""
|
||||
if memory:
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
|
||||
)
|
||||
prompt_messages = []
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
|
||||
)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message_1 = ChatModelMessage(
|
||||
user_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_1)
|
||||
assistant_prompt_message_1 = ChatModelMessage(
|
||||
assistant_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_1)
|
||||
user_prompt_message_2 = ChatModelMessage(
|
||||
user_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_2)
|
||||
assistant_prompt_message_2 = ChatModelMessage(
|
||||
assistant_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_2)
|
||||
user_prompt_message_3 = ChatModelMessage(
|
||||
user_prompt_message_3 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(
|
||||
input_text=input_text,
|
||||
@@ -290,7 +297,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
return prompt_messages
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return CompletionModelPromptTemplate(
|
||||
return LLMNodeCompletionModelPromptTemplate(
|
||||
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
|
||||
histories=memory_str,
|
||||
input_text=input_text,
|
||||
@@ -302,23 +309,3 @@ class QuestionClassifierNode(LLMNode):
|
||||
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
|
||||
def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
|
||||
inputs = {}
|
||||
|
||||
variable_selectors = []
|
||||
variable_template_parser = VariableTemplateParser(template=instruction)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
for variable_selector in variable_selectors:
|
||||
variable = variable_pool.get(variable_selector.value_selector)
|
||||
variable_value = variable.value if variable else None
|
||||
if variable_value is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
instruction = prompt_template.format(prompt_inputs)
|
||||
return instruction
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .start_node import StartNode
|
||||
|
||||
__all__ = ["StartNode"]
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Sequence
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class StartNodeData(BaseNodeData):
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class StartNode(BaseNode):
|
||||
class StartNode(BaseNode[StartNodeData]):
|
||||
_node_data_cls = StartNodeData
|
||||
_node_type = NodeType.START
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables
|
||||
|
||||
# TODO: System variables should be directly accessible, no need for special handling
|
||||
# Set system variables as node outputs.
|
||||
for var in system_inputs:
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
|
||||
|
||||
@@ -27,13 +26,10 @@ class StartNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: StartNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: StartNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .template_transform_node import TemplateTransformNode
|
||||
|
||||
__all__ = ["TemplateTransformNode"]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class TemplateTransformNodeData(BaseNodeData):
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
|
||||
|
||||
|
||||
class TemplateTransformNode(BaseNode):
|
||||
class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
|
||||
_node_data_cls = TemplateTransformNodeData
|
||||
_node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
|
||||
@@ -28,22 +29,16 @@ class TemplateTransformNode(BaseNode):
|
||||
}
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data)
|
||||
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in node_data.variables:
|
||||
for variable_selector in self.node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
variables[variable_name] = value
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables
|
||||
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tool_node import ToolNode
|
||||
|
||||
__all__ = ["ToolNode"]
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
@@ -51,7 +51,4 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
raise ValueError("value must be a string, int, float, or bool")
|
||||
return typ
|
||||
|
||||
"""
|
||||
Tool Node Schema
|
||||
"""
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
|
||||
@@ -1,24 +1,28 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from os import path
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.file.models import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from models import ToolFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
class ToolNode(BaseNode[ToolNodeData]):
|
||||
"""
|
||||
Tool Node
|
||||
"""
|
||||
@@ -27,37 +31,38 @@ class ToolNode(BaseNode):
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
|
||||
node_data = cast(ToolNodeData, self.node_data)
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {"provider_type": node_data.provider_type, "provider_id": node_data.provider_id}
|
||||
tool_info = {
|
||||
"provider_type": self.node_data.provider_type,
|
||||
"provider_id": self.node_data.provider_id,
|
||||
}
|
||||
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
},
|
||||
error=f"Failed to get tool runtime: {str(e)}",
|
||||
)
|
||||
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
)
|
||||
|
||||
@@ -74,7 +79,9 @@ class ToolNode(BaseNode):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
},
|
||||
error=f"Failed to invoke tool: {str(e)}",
|
||||
)
|
||||
|
||||
@@ -83,8 +90,14 @@ class ToolNode(BaseNode):
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": plain_text, "files": files, "json": json},
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
outputs={
|
||||
"text": plain_text,
|
||||
"files": files,
|
||||
"json": json,
|
||||
},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
|
||||
@@ -116,29 +129,25 @@ class ToolNode(BaseNode):
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)]
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
raise ValueError(f"variable {tool_input.value} not exists")
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
# TODO: check if the variable exists in the variable pool
|
||||
parameter_value = variable_pool.get(tool_input.value).value
|
||||
else:
|
||||
segment_group = parser.convert_template(
|
||||
template=str(tool_input.value),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
result[parameter_name] = parameter_value
|
||||
raise ValueError(f"unknown tool input type '{tool_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar], list[dict]]:
|
||||
def _convert_tool_messages(
|
||||
self,
|
||||
messages: list[ToolInvokeMessage],
|
||||
):
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
@@ -156,50 +165,86 @@ class ToolNode(BaseNode):
|
||||
|
||||
return plain_text, files, json
|
||||
|
||||
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
|
||||
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[File]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
url = response.message
|
||||
ext = path.splitext(url)[1]
|
||||
mimetype = response.meta.get("mime_type", "image/jpeg")
|
||||
filename = response.save_as or url.split("/")[-1]
|
||||
url = str(response.message) if response.message else None
|
||||
ext = path.splitext(url)[1] if url else ".bin"
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = url.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
result.append(
|
||||
FileVar(
|
||||
File(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
url=url,
|
||||
related_id=tool_file_id,
|
||||
filename=filename,
|
||||
remote_url=url,
|
||||
related_id=tool_file.id,
|
||||
filename=tool_file.name,
|
||||
extension=ext,
|
||||
mime_type=mimetype,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
tool_file_id = response.message.split("/")[-1].split(".")[0]
|
||||
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||
result.append(
|
||||
FileVar(
|
||||
File(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=response.save_as,
|
||||
related_id=tool_file.id,
|
||||
filename=tool_file.name,
|
||||
extension=path.splitext(response.save_as)[1],
|
||||
mime_type=response.meta.get("mime_type", "application/octet-stream"),
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
pass # TODO:
|
||||
url = str(response.message)
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
tool_file_id = url.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||
if "." in url:
|
||||
extension = "." + url.split("/")[-1].split(".")[1]
|
||||
else:
|
||||
extension = ".bin"
|
||||
file = File(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType(response.save_as),
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
filename=tool_file.name,
|
||||
related_id=tool_file.id,
|
||||
extension=extension,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
result.append(file)
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert response.meta is not None
|
||||
result.append(response.meta["file"])
|
||||
|
||||
return result
|
||||
|
||||
@@ -218,12 +263,16 @@ class ToolNode(BaseNode):
|
||||
]
|
||||
)
|
||||
|
||||
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
|
||||
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]):
|
||||
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ToolNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -236,7 +285,7 @@ class ToolNode(BaseNode):
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .variable_aggregator_node import VariableAggregatorNode
|
||||
|
||||
__all__ = ["VariableAggregatorNode"]
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class AdvancedSettings(BaseModel):
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class VariableAggregatorNode(BaseNode):
|
||||
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data = cast(VariableAssignerNodeData, self.node_data)
|
||||
# Get variables
|
||||
outputs = {}
|
||||
inputs = {}
|
||||
|
||||
if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
|
||||
for selector in node_data.variables:
|
||||
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
|
||||
for selector in self.node_data.variables:
|
||||
variable = self.graph_runtime_state.variable_pool.get_any(selector)
|
||||
if variable is not None:
|
||||
outputs = {"output": variable}
|
||||
@@ -26,7 +26,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
inputs = {".".join(selector[1:]): variable}
|
||||
break
|
||||
else:
|
||||
for group in node_data.advanced_settings.groups:
|
||||
for group in self.node_data.advanced_settings.groups:
|
||||
for selector in group.variables:
|
||||
variable = self.graph_runtime_state.variable_pool.get_any(selector)
|
||||
|
||||
|
||||
@@ -1,40 +1,38 @@
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.segments import SegmentType, Variable, factory
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode, BaseNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable, WorkflowNodeExecutionStatus
|
||||
from factories import variable_factory
|
||||
from models import ConversationVariable
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .exc import VariableAssignerNodeError
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode):
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
_node_data_cls: type[BaseNodeData] = VariableAssignerData
|
||||
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
data = cast(VariableAssignerData, self.node_data)
|
||||
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableAssignerNodeError("assigned variable not found")
|
||||
|
||||
match data.write_mode:
|
||||
match self.node_data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError("input value not found")
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError("input value not found")
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
@@ -45,10 +43,10 @@ class VariableAssignerNode(BaseNode):
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f"unsupported write mode: {data.write_mode}")
|
||||
raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
|
||||
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
|
||||
|
||||
# TODO: Move database operation to the pipeline.
|
||||
# Update conversation variable.
|
||||
@@ -80,12 +78,12 @@ def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
def get_zero_value(t: SegmentType):
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return factory.build_segment([])
|
||||
return variable_factory.build_segment([])
|
||||
case SegmentType.OBJECT:
|
||||
return factory.build_segment({})
|
||||
return variable_factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return factory.build_segment("")
|
||||
return variable_factory.build_segment("")
|
||||
case SegmentType.NUMBER:
|
||||
return factory.build_segment(0)
|
||||
return variable_factory.build_segment(0)
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f"unsupported variable type: {t}")
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class WriteMode(str, Enum):
|
||||
|
||||
Reference in New Issue
Block a user