chore: the consistency of MultiModalPromptMessageContent (#11721)

This commit is contained in:
非法操作
2024-12-17 15:01:38 +08:00
committed by GitHub
parent 78c3051585
commit c9b4029ce7
14 changed files with 108 additions and 99 deletions

View File

@@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
from configs import dify_config
from core.app.app_config.entities import ModelConfigEntity
from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -126,6 +127,7 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
model_config_mock, _, messages, inputs, context = get_chat_model_args
dify_config.MULTIMODAL_SEND_FORMAT = "url"
files = [
File(
@@ -140,7 +142,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
mock_get_encoded_string.return_value = ImagePromptMessageContent(
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,

View File

@@ -18,8 +18,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
@@ -249,8 +248,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Setup dify config
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
dify_config.MULTIMODAL_SEND_FORMAT = "url"
# Generate fake values for prompt template
fake_assistant_prompt = faker.sentence()
@@ -326,9 +324,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
extension=".jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
)
],
vision_enabled=True,
@@ -362,7 +361,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
ImagePromptMessageContent(
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
),
]
),
],
@@ -385,7 +386,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
expected_messages=[
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
ImagePromptMessageContent(
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
),
]
),
]
@@ -396,9 +399,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
extension=".jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
)
},
),