chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -13,11 +13,12 @@ class FileExtraConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: Optional[dict[str, Any]] = None
class FileType(enum.Enum):
IMAGE = 'image'
IMAGE = "image"
@staticmethod
def value_of(value):
@@ -28,9 +29,9 @@ class FileType(enum.Enum):
class FileTransferMethod(enum.Enum):
REMOTE_URL = 'remote_url'
LOCAL_FILE = 'local_file'
TOOL_FILE = 'tool_file'
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"
@staticmethod
def value_of(value):
@@ -39,9 +40,10 @@ class FileTransferMethod(enum.Enum):
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(enum.Enum):
USER = 'user'
ASSISTANT = 'assistant'
USER = "user"
ASSISTANT = "assistant"
@staticmethod
def value_of(value):
@@ -65,16 +67,16 @@ class FileVar(BaseModel):
def to_dict(self) -> dict:
return {
'__variant': self.__class__.__name__,
'tenant_id': self.tenant_id,
'type': self.type.value,
'transfer_method': self.transfer_method.value,
'url': self.preview_url,
'remote_url': self.url,
'related_id': self.related_id,
'filename': self.filename,
'extension': self.extension,
'mime_type': self.mime_type,
"__variant": self.__class__.__name__,
"tenant_id": self.tenant_id,
"type": self.type.value,
"transfer_method": self.transfer_method.value,
"url": self.preview_url,
"remote_url": self.url,
"related_id": self.related_id,
"filename": self.filename,
"extension": self.extension,
"mime_type": self.mime_type,
}
def to_markdown(self) -> str:
@@ -86,7 +88,7 @@ class FileVar(BaseModel):
if self.type == FileType.IMAGE:
text = f'![{self.filename or ""}]({preview_url})'
else:
text = f'[{self.filename or preview_url}]({preview_url})'
text = f"[{self.filename or preview_url}]({preview_url})"
return text
@@ -115,28 +117,29 @@ class FileVar(BaseModel):
return ImagePromptMessageContent(
data=self.data,
detail=ImagePromptMessageContent.DETAIL.HIGH
if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
if image_config.get("detail") == "high"
else ImagePromptMessageContent.DETAIL.LOW,
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
from models.model import UploadFile
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == self.related_id,
UploadFile.tenant_id == self.tenant_id
).first())
return UploadFileParser.get_image_data(
upload_file=upload_file,
force_url=force_url
upload_file = (
db.session.query(UploadFile)
.filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
.first()
)
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
extension = self.extension
# add sign url
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=self.related_id, extension=extension)
return ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=self.related_id, extension=extension
)
return None

View File

@@ -13,13 +13,13 @@ from services.file_service import IMAGE_EXTENSIONS
class MessageFileParser:
def __init__(self, tenant_id: str, app_id: str) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig,
user: Union[Account, EndUser]) -> list[FileVar]:
def validate_and_transform_files_arg(
self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
) -> list[FileVar]:
"""
validate and transform files arg
@@ -30,22 +30,22 @@ class MessageFileParser:
"""
for file in files:
if not isinstance(file, dict):
raise ValueError('Invalid file format, must be dict')
if not file.get('type'):
raise ValueError('Missing file type')
FileType.value_of(file.get('type'))
if not file.get('transfer_method'):
raise ValueError('Missing file transfer method')
FileTransferMethod.value_of(file.get('transfer_method'))
if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
if not file.get('url'):
raise ValueError('Missing file url')
if not file.get('url').startswith('http'):
raise ValueError('Invalid file url')
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
raise ValueError('Missing file upload_file_id')
if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'):
raise ValueError('Missing file tool_file_id')
raise ValueError("Invalid file format, must be dict")
if not file.get("type"):
raise ValueError("Missing file type")
FileType.value_of(file.get("type"))
if not file.get("transfer_method"):
raise ValueError("Missing file transfer method")
FileTransferMethod.value_of(file.get("transfer_method"))
if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
if not file.get("url"):
raise ValueError("Missing file url")
if not file.get("url").startswith("http"):
raise ValueError("Invalid file url")
if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
raise ValueError("Missing file upload_file_id")
if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
raise ValueError("Missing file tool_file_id")
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_extra_config)
@@ -62,17 +62,17 @@ class MessageFileParser:
continue
# Validate number of files
if len(files) > image_config['number_limits']:
if len(files) > image_config["number_limits"]:
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
for file_obj in file_objs:
# Validate transfer method
if file_obj.transfer_method.value not in image_config['transfer_methods']:
raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
if file_obj.transfer_method.value not in image_config["transfer_methods"]:
raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
# Validate file type
if file_obj.type != FileType.IMAGE:
raise ValueError(f'Invalid file type: {file_obj.type}')
raise ValueError(f"Invalid file type: {file_obj.type}")
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
# check remote url valid and is image
@@ -81,18 +81,21 @@ class MessageFileParser:
raise ValueError(error)
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
# get upload file from upload_file_id
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == file_obj.related_id,
UploadFile.tenant_id == self.tenant_id,
UploadFile.created_by == user.id,
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
UploadFile.extension.in_(IMAGE_EXTENSIONS)
).first())
upload_file = (
db.session.query(UploadFile)
.filter(
UploadFile.id == file_obj.related_id,
UploadFile.tenant_id == self.tenant_id,
UploadFile.created_by == user.id,
UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
UploadFile.extension.in_(IMAGE_EXTENSIONS),
)
.first()
)
# check upload file is belong to tenant and user
if not upload_file:
raise ValueError('Invalid upload file')
raise ValueError("Invalid upload file")
new_files.append(file_obj)
@@ -113,8 +116,9 @@ class MessageFileParser:
# return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
file_extra_config: FileExtraConfig) -> dict[FileType, list[FileVar]]:
def _to_file_objs(
self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
) -> dict[FileType, list[FileVar]]:
"""
transform files to file objs
@@ -152,23 +156,23 @@ class MessageFileParser:
:return:
"""
if isinstance(file, dict):
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
if transfer_method != FileTransferMethod.TOOL_FILE:
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
type=FileType.value_of(file.get("type")),
transfer_method=transfer_method,
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=file_extra_config
url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=file_extra_config,
)
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
type=FileType.value_of(file.get("type")),
transfer_method=transfer_method,
url=None,
related_id=file.get('tool_file_id'),
extra_config=file_extra_config
related_id=file.get("tool_file_id"),
extra_config=file_extra_config,
)
else:
return FileVar(
@@ -178,7 +182,7 @@ class MessageFileParser:
transfer_method=FileTransferMethod.value_of(file.transfer_method),
url=file.url,
related_id=file.upload_file_id or None,
extra_config=file_extra_config
extra_config=file_extra_config,
)
def _check_image_remote_url(self, url):
@@ -190,17 +194,17 @@ class MessageFileParser:
def is_s3_presigned_url(url):
try:
parsed_url = urlparse(url)
if 'amazonaws.com' not in parsed_url.netloc:
if "amazonaws.com" not in parsed_url.netloc:
return False
query_params = parse_qs(parsed_url.query)
required_params = ['Signature', 'Expires']
required_params = ["Signature", "Expires"]
for param in required_params:
if param not in query_params:
return False
if not query_params['Expires'][0].isdigit():
if not query_params["Expires"][0].isdigit():
return False
signature = query_params['Signature'][0]
if not re.match(r'^[A-Za-z0-9+/]+={0,2}$', signature):
signature = query_params["Signature"][0]
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
return False
return True
except Exception:

View File

@@ -1,8 +1,7 @@
tool_file_manager = {
'manager': None
}
tool_file_manager = {"manager": None}
class ToolFileParser:
@staticmethod
def get_tool_file_manager() -> 'ToolFileManager':
return tool_file_manager['manager']
def get_tool_file_manager() -> "ToolFileManager":
return tool_file_manager["manager"]

View File

@@ -9,7 +9,7 @@ from typing import Optional
from configs import dify_config
from extensions.ext_storage import storage
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
@@ -22,18 +22,18 @@ class UploadFileParser:
if upload_file.extension not in IMAGE_EXTENSIONS:
return None
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == 'url' or force_url:
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.error(f'File not found: {upload_file.key}')
logging.error(f"File not found: {upload_file.key}")
return None
encoded_string = base64.b64encode(data).decode('utf-8')
return f'data:{upload_file.mime_type};base64,{encoded_string}'
encoded_string = base64.b64encode(data).decode("utf-8")
return f"data:{upload_file.mime_type};base64,{encoded_string}"
@classmethod
def get_signed_temp_image_url(cls, upload_file_id) -> str:
@@ -44,7 +44,7 @@ class UploadFileParser:
:return:
"""
base_url = dify_config.FILES_URL
image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview'
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()