mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 19:06:51 +08:00
feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
from .constants import FILE_MODEL_IDENTITY
|
||||
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
|
||||
from .models import (
|
||||
File,
|
||||
FileExtraConfig,
|
||||
ImageConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FileType",
|
||||
"FileExtraConfig",
|
||||
"FileTransferMethod",
|
||||
"FileBelongsTo",
|
||||
"File",
|
||||
"ImageConfig",
|
||||
"FileAttribute",
|
||||
"ArrayFileAttribute",
|
||||
"FILE_MODEL_IDENTITY",
|
||||
]
|
||||
|
||||
1
api/core/file/constants.py
Normal file
1
api/core/file/constants.py
Normal file
@@ -0,0 +1 @@
|
||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||
55
api/core/file/enums.py
Normal file
55
api/core/file/enums.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class FileType(str, Enum):
|
||||
IMAGE = "image"
|
||||
DOCUMENT = "document"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
CUSTOM = "custom"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(str, Enum):
|
||||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileTransferMethod:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileBelongsTo(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileBelongsTo:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileAttribute(str, Enum):
|
||||
TYPE = "type"
|
||||
SIZE = "size"
|
||||
NAME = "name"
|
||||
MIME_TYPE = "mime_type"
|
||||
TRANSFER_METHOD = "transfer_method"
|
||||
URL = "url"
|
||||
EXTENSION = "extension"
|
||||
|
||||
|
||||
class ArrayFileAttribute(str, Enum):
|
||||
LENGTH = "length"
|
||||
156
api/core/file/file_manager.py
Normal file
156
api/core/file/file_manager.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import base64
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from . import helpers
|
||||
from .enums import FileAttribute
|
||||
from .models import File, FileTransferMethod, FileType
|
||||
from .tool_file_parser import ToolFileParser
|
||||
|
||||
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
match attr:
|
||||
case FileAttribute.TYPE:
|
||||
return file.type.value
|
||||
case FileAttribute.SIZE:
|
||||
return file.size
|
||||
case FileAttribute.NAME:
|
||||
return file.filename
|
||||
case FileAttribute.MIME_TYPE:
|
||||
return file.mime_type
|
||||
case FileAttribute.TRANSFER_METHOD:
|
||||
return file.transfer_method.value
|
||||
case FileAttribute.URL:
|
||||
return file.remote_url
|
||||
case FileAttribute.EXTENSION:
|
||||
return file.extension
|
||||
case _:
|
||||
raise ValueError(f"Invalid file attribute: {attr}")
|
||||
|
||||
|
||||
def to_prompt_message_content(f: File, /):
|
||||
"""
|
||||
Convert a File object to an ImagePromptMessageContent object.
|
||||
|
||||
This function takes a File object and converts it to an ImagePromptMessageContent
|
||||
object, which can be used as a prompt for image-based AI models.
|
||||
|
||||
Args:
|
||||
file (File): The File object to convert. Must be of type FileType.IMAGE.
|
||||
|
||||
Returns:
|
||||
ImagePromptMessageContent: An object containing the image data and detail level.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file is not an image or if the file data is missing.
|
||||
|
||||
Note:
|
||||
The detail level of the image prompt is determined by the file's extra_config.
|
||||
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||
"""
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
|
||||
if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail:
|
||||
detail = f._extra_config.image_config.detail
|
||||
else:
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
return ImagePromptMessageContent(data=data, detail=detail)
|
||||
case FileType.AUDIO:
|
||||
encoded_string = _file_to_encoded_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
return _download_file_content(upload_file.key)
|
||||
|
||||
|
||||
def _download_file_content(path: str, /):
|
||||
"""
|
||||
Download and return the contents of a file as bytes.
|
||||
|
||||
This function loads the file from storage and ensures it's in bytes format.
|
||||
|
||||
Args:
|
||||
path (str): The path to the file in storage.
|
||||
|
||||
Returns:
|
||||
bytes: The contents of the file as a bytes object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the loaded file is not a bytes object.
|
||||
"""
|
||||
data = storage.load(path, stream=False)
|
||||
if not isinstance(data, bytes):
|
||||
raise ValueError(f"file {path} is not a bytes object")
|
||||
return data
|
||||
|
||||
|
||||
def _get_encoded_string(f: File, /):
|
||||
match f.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
encoded_string = base64.b64encode(content).decode("utf-8")
|
||||
return encoded_string
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
data = _download_file_content(upload_file.key)
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
data = _download_file_content(tool_file.file_key)
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
case _:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _to_base64_data_string(f: File, /):
|
||||
encoded_string = _get_encoded_string(f)
|
||||
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||
|
||||
|
||||
def _file_to_encoded_string(f: File, /):
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
return _to_base64_data_string(f)
|
||||
case FileType.AUDIO:
|
||||
return _get_encoded_string(f)
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
return f.remote_url
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if f.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=f.related_id)
|
||||
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
# add sign url
|
||||
if f.related_id is None or f.extension is None:
|
||||
raise ValueError("Missing file related_id or extension")
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
@@ -1,145 +0,0 @@
|
||||
import enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.tool_file_parser import ToolFileParser
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
class FileExtraConfig(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
|
||||
image_config: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class FileType(enum.Enum):
|
||||
IMAGE = "image"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(enum.Enum):
|
||||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileTransferMethod:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileBelongsTo(enum.Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileBelongsTo:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileVar(BaseModel):
|
||||
id: Optional[str] = None # message file id
|
||||
tenant_id: str
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
url: Optional[str] = None # remote url
|
||||
related_id: Optional[str] = None
|
||||
extra_config: Optional[FileExtraConfig] = None
|
||||
filename: Optional[str] = None
|
||||
extension: Optional[str] = None
|
||||
mime_type: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"__variant": self.__class__.__name__,
|
||||
"tenant_id": self.tenant_id,
|
||||
"type": self.type.value,
|
||||
"transfer_method": self.transfer_method.value,
|
||||
"url": self.preview_url,
|
||||
"remote_url": self.url,
|
||||
"related_id": self.related_id,
|
||||
"filename": self.filename,
|
||||
"extension": self.extension,
|
||||
"mime_type": self.mime_type,
|
||||
}
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
"""
|
||||
Convert file to markdown
|
||||
:return:
|
||||
"""
|
||||
preview_url = self.preview_url
|
||||
if self.type == FileType.IMAGE:
|
||||
text = f''
|
||||
else:
|
||||
text = f"[{self.filename or preview_url}]({preview_url})"
|
||||
|
||||
return text
|
||||
|
||||
@property
|
||||
def data(self) -> Optional[str]:
|
||||
"""
|
||||
Get image data, file signed url or base64 data
|
||||
depending on config MULTIMODAL_SEND_IMAGE_FORMAT
|
||||
:return:
|
||||
"""
|
||||
return self._get_data()
|
||||
|
||||
@property
|
||||
def preview_url(self) -> Optional[str]:
|
||||
"""
|
||||
Get signed preview url
|
||||
:return:
|
||||
"""
|
||||
return self._get_data(force_url=True)
|
||||
|
||||
@property
|
||||
def prompt_message_content(self) -> ImagePromptMessageContent:
|
||||
if self.type == FileType.IMAGE:
|
||||
image_config = self.extra_config.image_config
|
||||
|
||||
return ImagePromptMessageContent(
|
||||
data=self.data,
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH
|
||||
if image_config.get("detail") == "high"
|
||||
else ImagePromptMessageContent.DETAIL.LOW,
|
||||
)
|
||||
|
||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
||||
from models.model import UploadFile
|
||||
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
extension = self.extension
|
||||
# add sign url
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=extension
|
||||
)
|
||||
|
||||
return None
|
||||
32
api/core/file/file_repository.py
Normal file
32
api/core/file/file_repository.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
from .models import File
|
||||
|
||||
|
||||
def get_upload_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(UploadFile).filter(
|
||||
UploadFile.id == file.related_id,
|
||||
UploadFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"upload file {file.related_id} not found")
|
||||
return record
|
||||
|
||||
|
||||
def get_tool_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(ToolFile).filter(
|
||||
ToolFile.id == file.related_id,
|
||||
ToolFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"tool file {file.related_id} not found")
|
||||
return record
|
||||
48
api/core/file/helpers.py
Normal file
48
api/core/file/helpers.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def get_signed_file_url(upload_file_id: str) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
key = dify_config.SECRET_KEY.encode()
|
||||
msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
|
||||
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@@ -1,243 +0,0 @@
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser, MessageFile, UploadFile
|
||||
from services.file_service import IMAGE_EXTENSIONS
|
||||
|
||||
|
||||
class MessageFileParser:
|
||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(
|
||||
self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
|
||||
) -> list[FileVar]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
:param files:
|
||||
:param file_extra_config:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
for file in files:
|
||||
if not isinstance(file, dict):
|
||||
raise ValueError("Invalid file format, must be dict")
|
||||
if not file.get("type"):
|
||||
raise ValueError("Missing file type")
|
||||
FileType.value_of(file.get("type"))
|
||||
if not file.get("transfer_method"):
|
||||
raise ValueError("Missing file transfer method")
|
||||
FileTransferMethod.value_of(file.get("transfer_method"))
|
||||
if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
|
||||
if not file.get("url"):
|
||||
raise ValueError("Missing file url")
|
||||
if not file.get("url").startswith("http"):
|
||||
raise ValueError("Invalid file url")
|
||||
if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
|
||||
raise ValueError("Missing file upload_file_id")
|
||||
if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
|
||||
raise ValueError("Missing file tool_file_id")
|
||||
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
||||
|
||||
# validate files
|
||||
new_files = []
|
||||
for file_type, file_objs in type_file_objs.items():
|
||||
if file_type == FileType.IMAGE:
|
||||
# parse and validate files
|
||||
image_config = file_extra_config.image_config
|
||||
|
||||
# check if image file feature is enabled
|
||||
if not image_config:
|
||||
continue
|
||||
|
||||
# Validate number of files
|
||||
if len(files) > image_config["number_limits"]:
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Validate transfer method
|
||||
if file_obj.transfer_method.value not in image_config["transfer_methods"]:
|
||||
raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
|
||||
|
||||
# Validate file type
|
||||
if file_obj.type != FileType.IMAGE:
|
||||
raise ValueError(f"Invalid file type: {file_obj.type}")
|
||||
|
||||
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
# check remote url valid and is image
|
||||
result, error = self._check_image_remote_url(file_obj.url)
|
||||
if result is False:
|
||||
raise ValueError(error)
|
||||
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
# get upload file from upload_file_id
|
||||
upload_file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == file_obj.related_id,
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.created_by == user.id,
|
||||
UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
UploadFile.extension.in_(IMAGE_EXTENSIONS),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# check upload file is belong to tenant and user
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
new_files.append(file_obj)
|
||||
|
||||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
|
||||
"""
|
||||
transform message files
|
||||
|
||||
:param files:
|
||||
:param file_extra_config:
|
||||
:return:
|
||||
"""
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
||||
|
||||
# return all file objs
|
||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||
|
||||
def _to_file_objs(
|
||||
self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
|
||||
) -> dict[FileType, list[FileVar]]:
|
||||
"""
|
||||
transform files to file objs
|
||||
|
||||
:param files:
|
||||
:param file_extra_config:
|
||||
:return:
|
||||
"""
|
||||
type_file_objs: dict[FileType, list[FileVar]] = {
|
||||
# Currently only support image
|
||||
FileType.IMAGE: []
|
||||
}
|
||||
|
||||
if not files:
|
||||
return type_file_objs
|
||||
|
||||
# group by file type and convert file args or message files to FileObj
|
||||
for file in files:
|
||||
if isinstance(file, MessageFile):
|
||||
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
|
||||
continue
|
||||
|
||||
file_obj = self._to_file_obj(file, file_extra_config)
|
||||
if file_obj.type not in type_file_objs:
|
||||
continue
|
||||
|
||||
type_file_objs[file_obj.type].append(file_obj)
|
||||
|
||||
return type_file_objs
|
||||
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
|
||||
"""
|
||||
transform file to file obj
|
||||
|
||||
:param file:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(file, dict):
|
||||
transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
|
||||
if transfer_method != FileTransferMethod.TOOL_FILE:
|
||||
return FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get("type")),
|
||||
transfer_method=transfer_method,
|
||||
url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
extra_config=file_extra_config,
|
||||
)
|
||||
return FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get("type")),
|
||||
transfer_method=transfer_method,
|
||||
url=None,
|
||||
related_id=file.get("tool_file_id"),
|
||||
extra_config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
return FileVar(
|
||||
id=file.id,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.type),
|
||||
transfer_method=FileTransferMethod.value_of(file.transfer_method),
|
||||
url=file.url,
|
||||
related_id=file.upload_file_id or None,
|
||||
extra_config=file_extra_config,
|
||||
)
|
||||
|
||||
def _check_image_remote_url(self, url):
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||
" Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
|
||||
def is_s3_presigned_url(url):
|
||||
try:
|
||||
parsed_url = urlparse(url)
|
||||
if "amazonaws.com" not in parsed_url.netloc:
|
||||
return False
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
def check_presign_v2(query_params):
|
||||
required_params = ["Signature", "Expires"]
|
||||
for param in required_params:
|
||||
if param not in query_params:
|
||||
return False
|
||||
if not query_params["Expires"][0].isdigit():
|
||||
return False
|
||||
signature = query_params["Signature"][0]
|
||||
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def check_presign_v4(query_params):
|
||||
required_params = ["X-Amz-Signature", "X-Amz-Expires"]
|
||||
for param in required_params:
|
||||
if param not in query_params:
|
||||
return False
|
||||
if not query_params["X-Amz-Expires"][0].isdigit():
|
||||
return False
|
||||
signature = query_params["X-Amz-Signature"][0]
|
||||
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return check_presign_v4(query_params) or check_presign_v2(query_params)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if is_s3_presigned_url(url):
|
||||
response = requests.get(url, headers=headers, allow_redirects=True)
|
||||
if response.status_code in {200, 304}:
|
||||
return True, ""
|
||||
|
||||
response = requests.head(url, headers=headers, allow_redirects=True)
|
||||
if response.status_code in {200, 304}:
|
||||
return True, ""
|
||||
else:
|
||||
return False, "URL does not exist."
|
||||
except requests.RequestException as e:
|
||||
return False, f"Error checking URL: {e}"
|
||||
140
api/core/file/models.py
Normal file
140
api/core/file/models.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
from . import helpers
|
||||
from .constants import FILE_MODEL_IDENTITY
|
||||
from .enums import FileTransferMethod, FileType
|
||||
from .tool_file_parser import ToolFileParser
|
||||
|
||||
|
||||
class ImageConfig(BaseModel):
|
||||
"""
|
||||
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||
"""
|
||||
|
||||
number_limits: int = 0
|
||||
transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
detail: ImagePromptMessageContent.DETAIL | None = None
|
||||
|
||||
|
||||
class FileExtraConfig(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
|
||||
image_config: Optional[ImageConfig] = None
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
number_limits: int = 0
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
dify_model_identity: str = FILE_MODEL_IDENTITY
|
||||
|
||||
id: Optional[str] = None # message file id
|
||||
tenant_id: str
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
remote_url: Optional[str] = None # remote url
|
||||
related_id: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
||||
mime_type: Optional[str] = None
|
||||
size: int = -1
|
||||
_extra_config: FileExtraConfig | None = None
|
||||
|
||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||
data = self.model_dump(mode="json")
|
||||
return {
|
||||
**data,
|
||||
"url": self.generate_url(),
|
||||
}
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
url = self.generate_url()
|
||||
if self.type == FileType.IMAGE:
|
||||
text = f''
|
||||
else:
|
||||
text = f"[{self.filename or url}]({url})"
|
||||
|
||||
return text
|
||||
|
||||
def generate_url(self) -> Optional[str]:
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
else:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_after(self):
|
||||
match self.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
if not self.remote_url:
|
||||
raise ValueError("Missing file url")
|
||||
if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"):
|
||||
raise ValueError("Invalid file url")
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
if not self.related_id:
|
||||
raise ValueError("Missing file related_id")
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
if not self.related_id:
|
||||
raise ValueError("Missing file related_id")
|
||||
|
||||
# Validate the extra config.
|
||||
if not self._extra_config:
|
||||
return self
|
||||
|
||||
if self._extra_config.allowed_file_types:
|
||||
if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
|
||||
raise ValueError(f"Invalid file type: {self.type}")
|
||||
|
||||
if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
|
||||
raise ValueError(f"Invalid file extension: {self.extension}")
|
||||
|
||||
if (
|
||||
self._extra_config.allowed_upload_methods
|
||||
and self.transfer_method not in self._extra_config.allowed_upload_methods
|
||||
):
|
||||
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||
|
||||
match self.type:
|
||||
case FileType.IMAGE:
|
||||
# NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||
if not self._extra_config.image_config:
|
||||
return self
|
||||
# TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
|
||||
if (
|
||||
self._extra_config.image_config.transfer_methods
|
||||
and self.transfer_method not in self._extra_config.image_config.transfer_methods
|
||||
):
|
||||
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||
|
||||
return self
|
||||
@@ -1,4 +1,9 @@
|
||||
tool_file_manager = {"manager": None}
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||
|
||||
|
||||
class ToolFileParser:
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.error(f"File not found: {upload_file.key}")
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file: UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
Reference in New Issue
Block a user