mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-08 10:26:50 +08:00
feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -6,6 +6,8 @@ import httpx
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
|
||||
|
||||
|
||||
@@ -24,10 +26,16 @@ class MockedHttp:
|
||||
# get data, files
|
||||
data = kwargs.get("data")
|
||||
files = kwargs.get("files")
|
||||
json = kwargs.get("json")
|
||||
content = kwargs.get("content")
|
||||
if data is not None:
|
||||
resp = dumps(data).encode("utf-8")
|
||||
elif files is not None:
|
||||
resp = dumps(files).encode("utf-8")
|
||||
elif json is not None:
|
||||
resp = dumps(json).encode("utf-8")
|
||||
elif content is not None:
|
||||
resp = content
|
||||
else:
|
||||
resp = b"OK"
|
||||
|
||||
@@ -43,6 +51,6 @@ def setup_http_mock(request, monkeypatch: MonkeyPatch):
|
||||
yield
|
||||
return
|
||||
|
||||
monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request)
|
||||
monkeypatch.setattr(ssrf_proxy, "make_request", MockedHttp.httpx_request)
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import cast
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, UserFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
@@ -14,6 +14,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
@@ -5,13 +5,13 @@ from urllib.parse import urlencode
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
|
||||
|
||||
@@ -211,7 +211,16 @@ def test_json(setup_http_mock):
|
||||
},
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'},
|
||||
"body": {
|
||||
"type": "json",
|
||||
"data": [
|
||||
{
|
||||
"key": "",
|
||||
"type": "text",
|
||||
"value": '{"a": "{{#a.b123.args1#}}"}',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -243,7 +252,21 @@ def test_x_www_form_urlencoded(setup_http_mock):
|
||||
},
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"},
|
||||
"body": {
|
||||
"type": "x-www-form-urlencoded",
|
||||
"data": [
|
||||
{
|
||||
"key": "a",
|
||||
"type": "text",
|
||||
"value": "{{#a.b123.args1#}}",
|
||||
},
|
||||
{
|
||||
"key": "b",
|
||||
"type": "text",
|
||||
"value": "{{#a.b123.args2#}}",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -275,7 +298,21 @@ def test_form_data(setup_http_mock):
|
||||
},
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"},
|
||||
"body": {
|
||||
"type": "form-data",
|
||||
"data": [
|
||||
{
|
||||
"key": "a",
|
||||
"type": "text",
|
||||
"value": "{{#a.b123.args1#}}",
|
||||
},
|
||||
{
|
||||
"key": "b",
|
||||
"type": "text",
|
||||
"value": "{{#a.b123.args2#}}",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -310,7 +347,7 @@ def test_none_data(setup_http_mock):
|
||||
},
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": {"type": "none", "data": "123123123"},
|
||||
"body": {"type": "none", "data": []},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -366,7 +403,21 @@ def test_multi_colons_parse(setup_http_mock):
|
||||
},
|
||||
"params": "Referer:http://example1.com\nRedirect:http://example2.com",
|
||||
"headers": "Referer:http://example3.com\nRedirect:http://example4.com",
|
||||
"body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"},
|
||||
"body": {
|
||||
"type": "form-data",
|
||||
"data": [
|
||||
{
|
||||
"key": "Referer",
|
||||
"type": "text",
|
||||
"value": "http://example5.com",
|
||||
},
|
||||
{
|
||||
"key": "Redirect",
|
||||
"type": "text",
|
||||
"value": "http://example6.com",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -377,5 +428,5 @@ def test_multi_colons_parse(setup_http_mock):
|
||||
resp = result.outputs
|
||||
|
||||
assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "")
|
||||
assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request", "")
|
||||
assert "http://example3.com" == resp.get("headers", {}).get("referer")
|
||||
assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "")
|
||||
# assert "http://example3.com" == resp.get("headers", {}).get("referer")
|
||||
|
||||
@@ -13,15 +13,15 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers import ModelProviderFactory
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.provider import ProviderType
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
@@ -20,6 +19,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.provider import ProviderType
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
|
||||
@@ -4,13 +4,13 @@ import uuid
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
@@ -2,13 +2,14 @@ import time
|
||||
import uuid
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, UserFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
|
||||
57
api/tests/integration_tests/workflow/test_sync_workflow.py
Normal file
57
api/tests/integration_tests/workflow/test_sync_workflow.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
This test file is used to verify the compatibility of Workflow before and after supporting multiple file types.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from models import Workflow
|
||||
|
||||
OLD_VERSION_WORKFLOW_FEATURES = {
|
||||
"file_upload": {
|
||||
"image": {
|
||||
"enabled": True,
|
||||
"number_limits": 6,
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
"opening_statement": "",
|
||||
"retriever_resource": {"enabled": True},
|
||||
"sensitive_word_avoidance": {"enabled": False},
|
||||
"speech_to_text": {"enabled": False},
|
||||
"suggested_questions": [],
|
||||
"suggested_questions_after_answer": {"enabled": False},
|
||||
"text_to_speech": {"enabled": False, "language": "", "voice": ""},
|
||||
}
|
||||
|
||||
NEW_VERSION_WORKFLOW_FEATURES = {
|
||||
"file_upload": {
|
||||
"enabled": True,
|
||||
"allowed_file_types": ["image"],
|
||||
"allowed_extensions": [],
|
||||
"allowed_upload_methods": ["remote_url", "local_file"],
|
||||
"number_limits": 6,
|
||||
},
|
||||
"opening_statement": "",
|
||||
"retriever_resource": {"enabled": True},
|
||||
"sensitive_word_avoidance": {"enabled": False},
|
||||
"speech_to_text": {"enabled": False},
|
||||
"suggested_questions": [],
|
||||
"suggested_questions_after_answer": {"enabled": False},
|
||||
"text_to_speech": {"enabled": False, "language": "", "voice": ""},
|
||||
}
|
||||
|
||||
|
||||
def test_workflow_features():
|
||||
workflow = Workflow(
|
||||
tenant_id="",
|
||||
app_id="",
|
||||
type="",
|
||||
version="",
|
||||
graph="",
|
||||
features=json.dumps(OLD_VERSION_WORKFLOW_FEATURES),
|
||||
created_by="",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
assert workflow.features_dict == NEW_VERSION_WORKFLOW_FEATURES
|
||||
@@ -2,7 +2,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.segments import (
|
||||
from core.variables import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
@@ -11,43 +11,43 @@ from core.app.segments import (
|
||||
ObjectSegment,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
factory,
|
||||
)
|
||||
from core.app.segments.exc import VariableError
|
||||
from core.variables.exc import VariableError
|
||||
from factories import variable_factory
|
||||
|
||||
|
||||
def test_string_variable():
|
||||
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
result = variable_factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, StringVariable)
|
||||
|
||||
|
||||
def test_integer_variable():
|
||||
test_data = {"value_type": "number", "name": "test_int", "value": 42}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
result = variable_factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, IntegerVariable)
|
||||
|
||||
|
||||
def test_float_variable():
|
||||
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
result = variable_factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, FloatVariable)
|
||||
|
||||
|
||||
def test_secret_variable():
|
||||
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
result = variable_factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, SecretVariable)
|
||||
|
||||
|
||||
def test_invalid_value_type():
|
||||
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
|
||||
with pytest.raises(VariableError):
|
||||
factory.build_variable_from_mapping(test_data)
|
||||
variable_factory.build_variable_from_mapping(test_data)
|
||||
|
||||
|
||||
def test_build_a_blank_string():
|
||||
result = factory.build_variable_from_mapping(
|
||||
result = variable_factory.build_variable_from_mapping(
|
||||
{
|
||||
"value_type": "string",
|
||||
"name": "blank",
|
||||
@@ -59,7 +59,7 @@ def test_build_a_blank_string():
|
||||
|
||||
|
||||
def test_build_a_object_variable_with_none_value():
|
||||
var = factory.build_segment(
|
||||
var = variable_factory.build_segment(
|
||||
{
|
||||
"key1": None,
|
||||
}
|
||||
@@ -79,7 +79,7 @@ def test_object_variable():
|
||||
"key2": 2,
|
||||
},
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
variable = variable_factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ObjectSegment)
|
||||
assert isinstance(variable.value["key1"], str)
|
||||
assert isinstance(variable.value["key2"], int)
|
||||
@@ -96,7 +96,7 @@ def test_array_string_variable():
|
||||
"text",
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
variable = variable_factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayStringVariable)
|
||||
assert isinstance(variable.value[0], str)
|
||||
assert isinstance(variable.value[1], str)
|
||||
@@ -113,7 +113,7 @@ def test_array_number_variable():
|
||||
2.0,
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
variable = variable_factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayNumberVariable)
|
||||
assert isinstance(variable.value[0], int)
|
||||
assert isinstance(variable.value[1], float)
|
||||
@@ -136,7 +136,7 @@ def test_array_object_variable():
|
||||
},
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
variable = variable_factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayObjectVariable)
|
||||
assert isinstance(variable.value[0], dict)
|
||||
assert isinstance(variable.value[1], dict)
|
||||
@@ -146,13 +146,13 @@ def test_array_object_variable():
|
||||
assert isinstance(variable.value[1]["key2"], int)
|
||||
|
||||
|
||||
def test_variable_cannot_large_than_5_kb():
|
||||
def test_variable_cannot_large_than_200_kb():
|
||||
with pytest.raises(VariableError):
|
||||
factory.build_variable_from_mapping(
|
||||
variable_factory.build_variable_from_mapping(
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"value_type": "string",
|
||||
"name": "test_text",
|
||||
"value": "a" * 1024 * 6,
|
||||
"value": "a" * 1024 * 201,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from core.app.segments import SecretVariable, StringSegment, parser
|
||||
from core.helper import encrypter
|
||||
from core.variables import SecretVariable, StringSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
||||
@@ -13,12 +13,13 @@ def test_segment_group_to_text():
|
||||
environment_variables=[
|
||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
||||
],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
|
||||
template = (
|
||||
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
|
||||
)
|
||||
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
|
||||
segments_group = variable_pool.convert_template(template)
|
||||
|
||||
assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key."
|
||||
assert segments_group.log == (
|
||||
@@ -32,9 +33,10 @@ def test_convert_constant_to_segment_group():
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
template = "Hello, world!"
|
||||
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
|
||||
segments_group = variable_pool.convert_template(template)
|
||||
assert segments_group.text == "Hello, world!"
|
||||
assert segments_group.log == "Hello, world!"
|
||||
|
||||
@@ -46,9 +48,10 @@ def test_convert_variable_to_segment_group():
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
template = "{{#sys.user_id#}}"
|
||||
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
|
||||
segments_group = variable_pool.convert_template(template)
|
||||
assert segments_group.text == "fake-user-id"
|
||||
assert segments_group.log == "fake-user-id"
|
||||
assert segments_group.value == [StringSegment(value="fake-user-id")]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.segments import (
|
||||
from core.variables import (
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
|
||||
|
||||
|
||||
@patch("httpx.request")
|
||||
@patch("httpx.Client.request")
|
||||
def test_successful_request(mock_request):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
@@ -16,7 +16,7 @@ def test_successful_request(mock_request):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@patch("httpx.request")
|
||||
@patch("httpx.Client.request")
|
||||
def test_retry_exceed_max_retries(mock_request):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
@@ -29,7 +29,7 @@ def test_retry_exceed_max_retries(mock_request):
|
||||
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
|
||||
|
||||
|
||||
@patch("httpx.request")
|
||||
@patch("httpx.Client.request")
|
||||
def test_retry_logic_success(mock_request):
|
||||
side_effects = []
|
||||
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
|
||||
from core.file import File, FileExtraConfig, FileTransferMethod, FileType, ImageConfig
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, UserPromptMessage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
@@ -123,32 +128,30 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
model_config_mock, _, messages, inputs, context = get_chat_model_args
|
||||
|
||||
files = [
|
||||
FileVar(
|
||||
File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
url="https://example.com/image1.jpg",
|
||||
extra_config=FileExtraConfig(
|
||||
image_config={
|
||||
"detail": "high",
|
||||
}
|
||||
),
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
_extra_config=FileExtraConfig(image_config=ImageConfig(detail=ImagePromptMessageContent.DETAIL.HIGH)),
|
||||
)
|
||||
]
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
|
||||
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 4
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
@@ -157,7 +160,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
)
|
||||
assert isinstance(prompt_messages[3].content, list)
|
||||
assert len(prompt_messages[3].content) == 2
|
||||
assert prompt_messages[3].content[1].data == files[0].url
|
||||
assert prompt_messages[3].content[1].data == files[0].remote_url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
40
api/tests/unit_tests/core/test_file.py
Normal file
40
api/tests/unit_tests/core/test_file.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
|
||||
|
||||
def test_file_loads_and_dumps():
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
)
|
||||
|
||||
file_dict = file.model_dump()
|
||||
assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY
|
||||
assert file_dict["type"] == file.type.value
|
||||
assert isinstance(file_dict["type"], str)
|
||||
assert file_dict["transfer_method"] == file.transfer_method.value
|
||||
assert isinstance(file_dict["transfer_method"], str)
|
||||
assert "_extra_config" not in file_dict
|
||||
|
||||
file_obj = File.model_validate(file_dict)
|
||||
assert file_obj.id == file.id
|
||||
assert file_obj.tenant_id == file.tenant_id
|
||||
assert file_obj.type == file.type
|
||||
assert file_obj.transfer_method == file.transfer_method
|
||||
assert file_obj.remote_url == file.remote_url
|
||||
|
||||
|
||||
def test_file_to_dict():
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
)
|
||||
|
||||
file_dict = file.to_dict()
|
||||
assert "_extra_config" not in file_dict
|
||||
assert "url" in file_dict
|
||||
@@ -1,56 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
||||
|
||||
def test_get_parameter_type():
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string"
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string"
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean"
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number"
|
||||
with pytest.raises(ValueError):
|
||||
ToolParameterConverter.get_parameter_type("unsupported_type")
|
||||
|
||||
|
||||
def test_cast_parameter_by_type():
|
||||
# string
|
||||
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ""
|
||||
|
||||
# secret input
|
||||
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ""
|
||||
|
||||
# select
|
||||
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ""
|
||||
|
||||
# boolean
|
||||
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
|
||||
for value in true_values:
|
||||
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True
|
||||
|
||||
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
|
||||
for value in false_values:
|
||||
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False
|
||||
|
||||
# number
|
||||
assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1
|
||||
assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
|
||||
|
||||
# unknown
|
||||
assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
|
||||
49
api/tests/unit_tests/core/tools/test_tool_parameter_type.py
Normal file
49
api/tests/unit_tests/core/tools/test_tool_parameter_type.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
|
||||
def test_get_parameter_type():
|
||||
assert ToolParameter.ToolParameterType.STRING.as_normal_type() == "string"
|
||||
assert ToolParameter.ToolParameterType.SELECT.as_normal_type() == "string"
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.as_normal_type() == "string"
|
||||
assert ToolParameter.ToolParameterType.BOOLEAN.as_normal_type() == "boolean"
|
||||
assert ToolParameter.ToolParameterType.NUMBER.as_normal_type() == "number"
|
||||
assert ToolParameter.ToolParameterType.FILE.as_normal_type() == "file"
|
||||
assert ToolParameter.ToolParameterType.FILES.as_normal_type() == "files"
|
||||
|
||||
|
||||
def test_cast_parameter_by_type():
|
||||
# string
|
||||
assert ToolParameter.ToolParameterType.STRING.cast_value("test") == "test"
|
||||
assert ToolParameter.ToolParameterType.STRING.cast_value(1) == "1"
|
||||
assert ToolParameter.ToolParameterType.STRING.cast_value(1.0) == "1.0"
|
||||
assert ToolParameter.ToolParameterType.STRING.cast_value(None) == ""
|
||||
|
||||
# secret input
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value("test") == "test"
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1) == "1"
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1.0) == "1.0"
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(None) == ""
|
||||
|
||||
# select
|
||||
assert ToolParameter.ToolParameterType.SELECT.cast_value("test") == "test"
|
||||
assert ToolParameter.ToolParameterType.SELECT.cast_value(1) == "1"
|
||||
assert ToolParameter.ToolParameterType.SELECT.cast_value(1.0) == "1.0"
|
||||
assert ToolParameter.ToolParameterType.SELECT.cast_value(None) == ""
|
||||
|
||||
# boolean
|
||||
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
|
||||
for value in true_values:
|
||||
assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is True
|
||||
|
||||
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
|
||||
for value in false_values:
|
||||
assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is False
|
||||
|
||||
# number
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value("1") == 1
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value("1.0") == 1.0
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value("-1.0") == -1.0
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value(1) == 1
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value(1.0) == 1.0
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value(-1.0) == -1.0
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value(None) is None
|
||||
@@ -1,7 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
@@ -18,7 +18,8 @@ from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
@@ -86,7 +87,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
{"role": "system", "text": "say hi"},
|
||||
{"role": "user", "text": "{{#start.query#}}"},
|
||||
],
|
||||
"vision": {"configs": {"detail": "high"}, "enabled": False},
|
||||
"vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False},
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
@@ -105,7 +106,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
{"role": "system", "text": "say bye"},
|
||||
{"role": "user", "text": "{{#start.query#}}"},
|
||||
],
|
||||
"vision": {"configs": {"detail": "high"}, "enabled": False},
|
||||
"vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False},
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
@@ -124,7 +125,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
{"role": "system", "text": "say good morning"},
|
||||
{"role": "user", "text": "{{#start.query#}}"},
|
||||
],
|
||||
"vision": {"configs": {"detail": "high"}, "enabled": False},
|
||||
"vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False},
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
|
||||
@@ -3,7 +3,6 @@ import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
@@ -11,6 +10,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
@@ -14,6 +13,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
||||
|
||||
node_execution_id = str(uuid.uuid4())
|
||||
node_config = graph.node_id_config_mapping[next_node_id]
|
||||
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
|
||||
|
||||
yield NodeRunStartedEvent(
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, UserFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
@@ -12,6 +12,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
@@ -11,6 +10,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData
|
||||
from core.workflow.nodes.document_extractor.node import (
|
||||
_extract_text_from_doc,
|
||||
_extract_text_from_pdf,
|
||||
_extract_text_from_plain_text,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_extractor_node():
|
||||
node_data = DocumentExtractorNodeData(
|
||||
title="Test Document Extractor",
|
||||
variable_selector=["node_id", "variable_name"],
|
||||
)
|
||||
return DocumentExtractorNode(
|
||||
id="test_node_id",
|
||||
config={"id": "test_node_id", "data": node_data.model_dump()},
|
||||
graph_init_params=Mock(),
|
||||
graph=Mock(),
|
||||
graph_runtime_state=Mock(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_runtime_state():
|
||||
return Mock()
|
||||
|
||||
|
||||
def test_run_variable_not_found(document_extractor_node, mock_graph_runtime_state):
|
||||
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = None
|
||||
|
||||
result = document_extractor_node._run()
|
||||
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error is not None
|
||||
assert "File variable not found" in result.error
|
||||
|
||||
|
||||
def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_state):
|
||||
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = StringVariable(
|
||||
value="Not an ArrayFileSegment", name="test"
|
||||
)
|
||||
|
||||
result = document_extractor_node._run()
|
||||
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error is not None
|
||||
assert "is not an ArrayFileSegment" in result.error
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mime_type", "file_content", "expected_text", "transfer_method"),
|
||||
[
|
||||
("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE),
|
||||
("application/pdf", b"%PDF-1.5\n%Test PDF content", ["Mocked PDF content"], FileTransferMethod.LOCAL_FILE),
|
||||
(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
b"PK\x03\x04",
|
||||
["Mocked DOCX content"],
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
),
|
||||
("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL),
|
||||
],
|
||||
)
|
||||
def test_run_extract_text(
|
||||
document_extractor_node,
|
||||
mock_graph_runtime_state,
|
||||
mime_type,
|
||||
file_content,
|
||||
expected_text,
|
||||
transfer_method,
|
||||
monkeypatch,
|
||||
):
|
||||
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
|
||||
|
||||
mock_file = Mock(spec=File)
|
||||
mock_file.mime_type = mime_type
|
||||
mock_file.transfer_method = transfer_method
|
||||
mock_file.related_id = "test_file_id" if transfer_method == FileTransferMethod.LOCAL_FILE else None
|
||||
mock_file.remote_url = "https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None
|
||||
|
||||
mock_array_file_segment = Mock(spec=ArrayFileSegment)
|
||||
mock_array_file_segment.value = [mock_file]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_array_file_segment
|
||||
|
||||
mock_download = Mock(return_value=file_content)
|
||||
mock_ssrf_proxy_get = Mock()
|
||||
mock_ssrf_proxy_get.return_value.content = file_content
|
||||
mock_ssrf_proxy_get.return_value.raise_for_status = Mock()
|
||||
|
||||
monkeypatch.setattr("core.file.file_manager.download", mock_download)
|
||||
monkeypatch.setattr("core.helper.ssrf_proxy.get", mock_ssrf_proxy_get)
|
||||
|
||||
if mime_type == "application/pdf":
|
||||
mock_pdf_extract = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
|
||||
elif mime_type.startswith("application/vnd.openxmlformats"):
|
||||
mock_docx_extract = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_doc", mock_docx_extract)
|
||||
|
||||
result = document_extractor_node._run()
|
||||
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["text"] == expected_text
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt")
|
||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
mock_download.assert_called_once_with(mock_file)
|
||||
|
||||
|
||||
def test_extract_text_from_plain_text():
|
||||
text = _extract_text_from_plain_text(b"Hello, world!")
|
||||
assert text == "Hello, world!"
|
||||
|
||||
|
||||
@patch("pypdfium2.PdfDocument")
|
||||
def test_extract_text_from_pdf(mock_pdf_document):
|
||||
mock_page = Mock()
|
||||
mock_text_page = Mock()
|
||||
mock_text_page.get_text_range.return_value = "PDF content"
|
||||
mock_page.get_textpage.return_value = mock_text_page
|
||||
mock_pdf_document.return_value = [mock_page]
|
||||
text = _extract_text_from_pdf(b"%PDF-1.5\n%Test PDF content")
|
||||
assert text == "PDF content"
|
||||
|
||||
|
||||
@patch("docx.Document")
|
||||
def test_extract_text_from_doc(mock_document):
|
||||
mock_paragraph1 = Mock()
|
||||
mock_paragraph1.text = "Paragraph 1"
|
||||
mock_paragraph2 = Mock()
|
||||
mock_paragraph2.text = "Paragraph 2"
|
||||
mock_document.return_value.paragraphs = [mock_paragraph1, mock_paragraph2]
|
||||
|
||||
text = _extract_text_from_doc(b"PK\x03\x04")
|
||||
assert text == "Paragraph 1\nParagraph 2"
|
||||
|
||||
|
||||
def test_node_type(document_extractor_node):
|
||||
assert document_extractor_node._node_type == NodeType.DOCUMENT_EXTRACTOR
|
||||
@@ -0,0 +1,202 @@
|
||||
import httpx
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import FileVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end import EndStreamParam
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNode,
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeBody,
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.nodes.http_request.executor import _plain_text_to_dict
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
def test_plain_text_to_dict():
|
||||
assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""}
|
||||
assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"}
|
||||
assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"}
|
||||
assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {"aa": "bb", "cc": "dd"}
|
||||
|
||||
|
||||
def test_http_request_node_binary_file(monkeypatch):
|
||||
data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="post",
|
||||
url="http://example.org/post",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="binary",
|
||||
data=[
|
||||
BodyData(
|
||||
key="file",
|
||||
type="file",
|
||||
value="",
|
||||
file=["1111", "file"],
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(
|
||||
["1111", "file"],
|
||||
FileVariable(
|
||||
name="file",
|
||||
value=File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1111",
|
||||
),
|
||||
),
|
||||
)
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda *args, **kwargs: b"test",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.helper.ssrf_proxy.post",
|
||||
lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]),
|
||||
)
|
||||
result = node._run()
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["body"] == "test"
|
||||
|
||||
|
||||
def test_http_request_node_form_with_file(monkeypatch):
|
||||
data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="post",
|
||||
url="http://example.org/post",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="form-data",
|
||||
data=[
|
||||
BodyData(
|
||||
key="file",
|
||||
type="file",
|
||||
file=["1111", "file"],
|
||||
),
|
||||
BodyData(
|
||||
key="name",
|
||||
type="text",
|
||||
value="test",
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(
|
||||
["1111", "file"],
|
||||
FileVariable(
|
||||
name="file",
|
||||
value=File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1111",
|
||||
),
|
||||
),
|
||||
)
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda *args, **kwargs: b"test",
|
||||
)
|
||||
|
||||
def attr_checker(*args, **kwargs):
|
||||
assert kwargs["data"] == {"name": "test"}
|
||||
assert kwargs["files"] == {"file": b"test"}
|
||||
return httpx.Response(200, content=b"")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.helper.ssrf_proxy.post",
|
||||
attr_checker,
|
||||
)
|
||||
result = node._run()
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["body"] == ""
|
||||
@@ -1,16 +1,20 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
@@ -111,6 +115,7 @@ def test_execute_if_else_result_true():
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
|
||||
|
||||
@@ -191,4 +196,63 @@ def test_execute_if_else_result_false():
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is False
|
||||
|
||||
|
||||
def test_array_file_contains_file_name():
|
||||
node_data = IfElseNodeData(
|
||||
title="123",
|
||||
logical_operator="and",
|
||||
cases=[
|
||||
IfElseNodeData.Case(
|
||||
case_id="true",
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
Condition(
|
||||
comparison_operator="contains",
|
||||
variable_selector=["start", "array_contains"],
|
||||
sub_variable_condition=SubVariableCondition(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
SubCondition(
|
||||
key="name",
|
||||
comparison_operator="contains",
|
||||
value="ab",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=Mock(),
|
||||
graph=Mock(),
|
||||
graph_runtime_state=Mock(),
|
||||
config={
|
||||
"id": "if-else",
|
||||
"data": node_data.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
|
||||
value=[
|
||||
File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
filename="ab",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
|
||||
111
api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
Normal file
111
api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file import File
|
||||
from core.file.models import FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.workflow.nodes.list_operator.entities import FilterBy, FilterCondition, Limit, ListOperatorNodeData, OrderBy
|
||||
from core.workflow.nodes.list_operator.node import ListOperatorNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def list_operator_node():
|
||||
config = {
|
||||
"variable": ["test_variable"],
|
||||
"filter_by": FilterBy(
|
||||
enabled=True,
|
||||
conditions=[
|
||||
FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT])
|
||||
],
|
||||
),
|
||||
"order_by": OrderBy(enabled=False, value="asc"),
|
||||
"limit": Limit(enabled=False, size=0),
|
||||
"title": "Test Title",
|
||||
}
|
||||
node_data = ListOperatorNodeData(**config)
|
||||
node = ListOperatorNode(
|
||||
id="test_node_id",
|
||||
config={
|
||||
"id": "test_node_id",
|
||||
"data": node_data.model_dump(),
|
||||
},
|
||||
graph_init_params=MagicMock(),
|
||||
graph=MagicMock(),
|
||||
graph_runtime_state=MagicMock(),
|
||||
)
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.variable_pool = MagicMock()
|
||||
return node
|
||||
|
||||
|
||||
def test_filter_files_by_type(list_operator_node):
|
||||
# Setup test data
|
||||
files = [
|
||||
File(
|
||||
filename="image1.jpg",
|
||||
type=FileType.IMAGE,
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related1",
|
||||
),
|
||||
File(
|
||||
filename="document1.pdf",
|
||||
type=FileType.DOCUMENT,
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related2",
|
||||
),
|
||||
File(
|
||||
filename="image2.png",
|
||||
type=FileType.IMAGE,
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related3",
|
||||
),
|
||||
File(
|
||||
filename="audio1.mp3",
|
||||
type=FileType.AUDIO,
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related4",
|
||||
),
|
||||
]
|
||||
variable = ArrayFileSegment(value=files)
|
||||
list_operator_node.graph_runtime_state.variable_pool.get.return_value = variable
|
||||
|
||||
# Run the node
|
||||
result = list_operator_node._run()
|
||||
|
||||
# Verify the result
|
||||
expected_files = [
|
||||
{
|
||||
"filename": "image1.jpg",
|
||||
"type": FileType.IMAGE,
|
||||
"tenant_id": "tenant1",
|
||||
"transfer_method": FileTransferMethod.LOCAL_FILE,
|
||||
"related_id": "related1",
|
||||
},
|
||||
{
|
||||
"filename": "document1.pdf",
|
||||
"type": FileType.DOCUMENT,
|
||||
"tenant_id": "tenant1",
|
||||
"transfer_method": FileTransferMethod.LOCAL_FILE,
|
||||
"related_id": "related2",
|
||||
},
|
||||
{
|
||||
"filename": "image2.png",
|
||||
"type": FileType.IMAGE,
|
||||
"tenant_id": "tenant1",
|
||||
"transfer_method": FileTransferMethod.LOCAL_FILE,
|
||||
"related_id": "related3",
|
||||
},
|
||||
]
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
for expected_file, result_file in zip(expected_files, result.outputs["result"]):
|
||||
assert expected_file["filename"] == result_file.filename
|
||||
assert expected_file["type"] == result_file.type
|
||||
assert expected_file["tenant_id"] == result_file.tenant_id
|
||||
assert expected_file["transfer_method"] == result_file.transfer_method
|
||||
assert expected_file["related_id"] == result_file.related_id
|
||||
@@ -0,0 +1,67 @@
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNodeData
|
||||
|
||||
|
||||
def test_init_question_classifier_node_data():
|
||||
data = {
|
||||
"title": "test classifier node",
|
||||
"query_variable_selector": ["id", "name"],
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
|
||||
"classes": [{"id": "1", "name": "class 1"}],
|
||||
"instruction": "This is a test instruction",
|
||||
"memory": {
|
||||
"role_prefix": {"user": "Human:", "assistant": "AI:"},
|
||||
"window": {"enabled": True, "size": 5},
|
||||
"query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:",
|
||||
},
|
||||
"vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}},
|
||||
}
|
||||
|
||||
node_data = QuestionClassifierNodeData(**data)
|
||||
|
||||
assert node_data.query_variable_selector == ["id", "name"]
|
||||
assert node_data.model.provider == "openai"
|
||||
assert node_data.classes[0].id == "1"
|
||||
assert node_data.instruction == "This is a test instruction"
|
||||
assert node_data.memory is not None
|
||||
assert node_data.memory.role_prefix is not None
|
||||
assert node_data.memory.role_prefix.user == "Human:"
|
||||
assert node_data.memory.role_prefix.assistant == "AI:"
|
||||
assert node_data.memory.window.enabled == True
|
||||
assert node_data.memory.window.size == 5
|
||||
assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:"
|
||||
assert node_data.vision.enabled == True
|
||||
assert node_data.vision.configs.variable_selector == ["image"]
|
||||
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
|
||||
def test_init_question_classifier_node_data_without_vision_config():
|
||||
data = {
|
||||
"title": "test classifier node",
|
||||
"query_variable_selector": ["id", "name"],
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
|
||||
"classes": [{"id": "1", "name": "class 1"}],
|
||||
"instruction": "This is a test instruction",
|
||||
"memory": {
|
||||
"role_prefix": {"user": "Human:", "assistant": "AI:"},
|
||||
"window": {"enabled": True, "size": 5},
|
||||
"query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:",
|
||||
},
|
||||
}
|
||||
|
||||
node_data = QuestionClassifierNodeData(**data)
|
||||
|
||||
assert node_data.query_variable_selector == ["id", "name"]
|
||||
assert node_data.model.provider == "openai"
|
||||
assert node_data.classes[0].id == "1"
|
||||
assert node_data.instruction == "This is a test instruction"
|
||||
assert node_data.memory is not None
|
||||
assert node_data.memory.role_prefix is not None
|
||||
assert node_data.memory.role_prefix.user == "Human:"
|
||||
assert node_data.memory.role_prefix.assistant == "AI:"
|
||||
assert node_data.memory.window.enabled == True
|
||||
assert node_data.memory.window.size == 5
|
||||
assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:"
|
||||
assert node_data.vision.enabled == False
|
||||
assert node_data.vision.configs.variable_selector == ["sys", "files"]
|
||||
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH
|
||||
@@ -4,14 +4,14 @@ from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import ArrayStringVariable, StringVariable
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.variables import ArrayStringVariable, StringVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
DEFAULT_NODE_ID = "node_id"
|
||||
|
||||
45
api/tests/unit_tests/core/workflow/test_variable_pool.py
Normal file
45
api/tests/unit_tests/core/workflow/test_variable_pool.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import FileSegment, StringSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pool():
|
||||
return VariablePool(system_variables={}, user_inputs={})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file():
|
||||
return File(
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="test_related_id",
|
||||
remote_url="test_url",
|
||||
filename="test_file.txt",
|
||||
)
|
||||
|
||||
|
||||
def test_get_file_attribute(pool, file):
|
||||
# Add a FileSegment to the pool
|
||||
pool.add(("node_1", "file_var"), FileSegment(value=file))
|
||||
|
||||
# Test getting the 'name' attribute of the file
|
||||
result = pool.get(("node_1", "file_var", "name"))
|
||||
|
||||
assert result is not None
|
||||
assert result.value == file.filename
|
||||
|
||||
# Test getting a non-existent attribute
|
||||
with pytest.raises(ValueError):
|
||||
pool.get(("node_1", "file_var", "non_existent_attr"))
|
||||
|
||||
|
||||
def test_use_long_selector(pool):
|
||||
pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value"))
|
||||
|
||||
result = pool.get(("node_1", "part_1", "part_2"))
|
||||
assert result is not None
|
||||
assert result.value == "test_value"
|
||||
@@ -0,0 +1,28 @@
|
||||
from core.variables import SecretVariable
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.utils import variable_template_parser
|
||||
|
||||
|
||||
def test_extract_selectors_from_template():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey("user_id"): "fake-user-id",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[
|
||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
||||
],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
|
||||
template = (
|
||||
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
|
||||
)
|
||||
selectors = variable_template_parser.extract_selectors_from_template(template)
|
||||
assert selectors == [
|
||||
VariableSelector(variable="#sys.user_id#", value_selector=["sys", "user_id"]),
|
||||
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
|
||||
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
|
||||
]
|
||||
@@ -1,11 +1,12 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.segments import SegmentType, factory
|
||||
from core.variables import SegmentType
|
||||
from factories import variable_factory
|
||||
from models import ConversationVariable
|
||||
|
||||
|
||||
def test_from_variable_and_to_variable():
|
||||
variable = factory.build_variable_from_mapping(
|
||||
variable = variable_factory.build_variable_from_mapping(
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"name": "name",
|
||||
|
||||
@@ -3,7 +3,7 @@ from uuid import uuid4
|
||||
|
||||
import contexts
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.app.segments import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
||||
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user