feat/enhance the multi-modal support (#8818)

This commit is contained in:
-LAN-
2024-10-21 10:43:49 +08:00
committed by GitHub
parent 7a1d6fe509
commit e61752bd3a
267 changed files with 6263 additions and 3523 deletions

View File

@@ -1,29 +1,55 @@
from enum import Enum
from .account import Account, AccountIntegrate, InvitationCode, Tenant
from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
from .model import (
ApiToken,
App,
AppMode,
Conversation,
EndUser,
FileUploadConfig,
InstalledApp,
Message,
MessageAnnotation,
MessageFile,
RecommendedApp,
Site,
UploadFile,
)
from .source import DataSourceOauthBinding
from .tools import ToolFile
from .workflow import (
ConversationVariable,
Workflow,
WorkflowAppLog,
WorkflowRun,
)
from .model import App, AppMode, Message
from .types import StringUUID
from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus
__all__ = ["ConversationVariable", "StringUUID", "AppMode", "WorkflowNodeExecutionStatus", "Workflow", "App", "Message"]
class CreatedByRole(Enum):
"""
Enum class for createdByRole
"""
ACCOUNT = "account"
END_USER = "end_user"
@classmethod
def value_of(cls, value: str) -> "CreatedByRole":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for role in cls:
if role.value == value:
return role
raise ValueError(f"invalid createdByRole value {value}")
__all__ = [
"ConversationVariable",
"Document",
"Dataset",
"DatasetProcessRule",
"DocumentSegment",
"DataSourceOauthBinding",
"AppMode",
"Workflow",
"App",
"Message",
"EndUser",
"MessageFile",
"UploadFile",
"Account",
"WorkflowAppLog",
"WorkflowRun",
"Site",
"InstalledApp",
"RecommendedApp",
"ApiToken",
"AccountIntegrate",
"InvitationCode",
"Tenant",
"Conversation",
"MessageAnnotation",
"FileUploadConfig",
"ToolFile",
]

View File

@@ -560,7 +560,7 @@ class DocumentSegment(db.Model):
)
def get_sign_content(self):
pattern = r"/files/([a-f0-9\-]+)/image-preview"
pattern = r"/files/([a-f0-9\-]+)/file-preview"
text = self.content
matches = re.finditer(pattern, text)
signed_urls = []
@@ -568,7 +568,7 @@ class DocumentSegment(db.Model):
upload_file_id = match.group(1)
nonce = os.urandom(16).hex()
timestamp = str(int(time.time()))
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()

16
api/models/enums.py Normal file
View File

@@ -0,0 +1,16 @@
from enum import Enum
class CreatedByRole(str, Enum):
ACCOUNT = "account"
END_USER = "end_user"
class UserFrom(str, Enum):
ACCOUNT = "account"
END_USER = "end-user"
class WorkflowRunTriggeredFrom(str, Enum):
DEBUGGING = "debugging"
APP_RUN = "app-run"

View File

@@ -1,24 +1,37 @@
import json
import re
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum
from typing import Optional
from typing import Any, Literal, Optional
from flask import request
from flask_login import UserMixin
from pydantic import BaseModel, Field
from sqlalchemy import Float, func, text
from sqlalchemy.orm import Mapped, mapped_column
from configs import dify_config
from core.file import FILE_MODEL_IDENTITY, File, FileExtraConfig, FileTransferMethod, FileType
from core.file import helpers as file_helpers
from core.file.tool_file_parser import ToolFileParser
from core.file.upload_file_parser import UploadFileParser
from extensions.ext_database import db
from libs.helper import generate_string
from models.enums import CreatedByRole
from .account import Account, Tenant
from .types import StringUUID
class FileUploadConfig(BaseModel):
enabled: bool = Field(default=False)
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_extensions: Sequence[str] = Field(default_factory=list)
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
number_limits: int = Field(default=0, gt=0, le=10)
class DifySetup(db.Model):
__tablename__ = "dify_setups"
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
@@ -27,7 +40,7 @@ class DifySetup(db.Model):
setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class AppMode(Enum):
class AppMode(str, Enum):
COMPLETION = "completion"
WORKFLOW = "workflow"
CHAT = "chat"
@@ -59,7 +72,7 @@ class App(db.Model):
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
mode = db.Column(db.String(255), nullable=False)
@@ -530,7 +543,7 @@ class Conversation(db.Model):
mode = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False)
summary = db.Column(db.Text)
inputs = db.Column(db.JSON)
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
introduction = db.Column(db.Text)
system_instruction = db.Column(db.Text)
system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
@@ -552,6 +565,28 @@ class Conversation(db.Model):
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
@property
def inputs(self):
inputs = self._inputs.copy()
for key, value in inputs.items():
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
inputs[key] = File.model_validate(value)
elif isinstance(value, list) and all(
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
):
inputs[key] = [File.model_validate(item) for item in value]
return inputs
@inputs.setter
def inputs(self, value: Mapping[str, Any]):
inputs = dict(value)
for k, v in inputs.items():
if isinstance(v, File):
inputs[k] = v.model_dump()
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
inputs[k] = [item.model_dump() for item in v]
self._inputs = inputs
@property
def model_config(self):
model_config = {}
@@ -700,13 +735,13 @@ class Message(db.Model):
model_id = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text)
conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
inputs = db.Column(db.JSON)
query = db.Column(db.Text, nullable=False)
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
query: Mapped[str] = db.Column(db.Text, nullable=False)
message = db.Column(db.JSON, nullable=False)
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
answer = db.Column(db.Text, nullable=False)
answer: Mapped[str] = db.Column(db.Text, nullable=False)
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
@@ -717,15 +752,37 @@ class Message(db.Model):
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
error = db.Column(db.Text)
message_metadata = db.Column(db.Text)
invoke_from = db.Column(db.String(255), nullable=True)
invoke_from: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True)
from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(StringUUID)
from_account_id = db.Column(StringUUID)
from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID)
from_account_id: Mapped[Optional[str]] = db.Column(StringUUID)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
workflow_run_id = db.Column(StringUUID)
@property
def inputs(self):
inputs = self._inputs.copy()
for key, value in inputs.items():
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
inputs[key] = File.model_validate(value)
elif isinstance(value, list) and all(
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
):
inputs[key] = [File.model_validate(item) for item in value]
return inputs
@inputs.setter
def inputs(self, value: Mapping[str, Any]):
inputs = dict(value)
for k, v in inputs.items():
if isinstance(v, File):
inputs[k] = v.model_dump()
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
inputs[k] = [item.model_dump() for item in v]
self._inputs = inputs
@property
def re_sign_file_url_answer(self) -> str:
if not self.answer:
@@ -772,19 +829,29 @@ class Message(db.Model):
sign_url = ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=tool_file_id, extension=extension
)
else:
elif "file-preview" in url:
# get upload file id
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp="
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp="
result = re.search(upload_file_id_pattern, url)
if not result:
continue
upload_file_id = result.group(1)
if not upload_file_id:
continue
sign_url = UploadFileParser.get_signed_temp_image_url(upload_file_id)
sign_url = file_helpers.get_signed_file_url(upload_file_id)
elif "image-preview" in url:
# image-preview is deprecated, use file-preview instead
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp="
result = re.search(upload_file_id_pattern, url)
if not result:
continue
upload_file_id = result.group(1)
if not upload_file_id:
continue
sign_url = file_helpers.get_signed_file_url(upload_file_id)
else:
continue
re_sign_file_url_answer = re_sign_file_url_answer.replace(url, sign_url)
@@ -870,50 +937,71 @@ class Message(db.Model):
@property
def message_files(self):
return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
from factories import file_factory
@property
def files(self):
message_files = self.message_files
message_files = db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
current_app = db.session.query(App).filter(App.id == self.app_id).first()
if not current_app:
raise ValueError(f"App {self.app_id} not found")
files = []
files: list[File] = []
for message_file in message_files:
url = message_file.url
if message_file.type == "image":
if message_file.transfer_method == "local_file":
upload_file = (
db.session.query(UploadFile).filter(UploadFile.id == message_file.upload_file_id).first()
)
url = UploadFileParser.get_image_data(upload_file=upload_file, force_url=True)
if message_file.transfer_method == "tool_file":
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
# get extension
if "." in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}'
if len(extension) > 10:
extension = ".bin"
else:
extension = ".bin"
# add sign url
url = ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=tool_file_id, extension=extension
)
files.append(
{
if message_file.transfer_method == "local_file":
if message_file.upload_file_id is None:
raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id")
file = file_factory.build_from_mapping(
mapping={
"id": message_file.id,
"upload_file_id": message_file.upload_file_id,
"transfer_method": message_file.transfer_method,
"type": message_file.type,
},
tenant_id=current_app.tenant_id,
user_id=self.from_account_id or self.from_end_user_id or "",
role=CreatedByRole(message_file.created_by_role),
config=FileExtraConfig(),
)
elif message_file.transfer_method == "remote_url":
if message_file.url is None:
raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url")
file = file_factory.build_from_mapping(
mapping={
"id": message_file.id,
"type": message_file.type,
"transfer_method": message_file.transfer_method,
"url": message_file.url,
},
tenant_id=current_app.tenant_id,
user_id=self.from_account_id or self.from_end_user_id or "",
role=CreatedByRole(message_file.created_by_role),
config=FileExtraConfig(),
)
elif message_file.transfer_method == "tool_file":
mapping = {
"id": message_file.id,
"type": message_file.type,
"url": url,
"belongs_to": message_file.belongs_to or "user",
"transfer_method": message_file.transfer_method,
"tool_file_id": message_file.upload_file_id,
}
)
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=current_app.tenant_id,
user_id=self.from_account_id or self.from_end_user_id or "",
role=CreatedByRole(message_file.created_by_role),
config=FileExtraConfig(),
)
else:
raise ValueError(
f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}"
)
files.append(file)
return files
result = [
{"belongs_to": message_file.belongs_to, **file.to_dict()}
for (file, message_file) in zip(files, message_files)
]
return result
@property
def workflow_run(self):
@@ -1003,16 +1091,39 @@ class MessageFile(db.Model):
db.Index("message_file_created_by_idx", "created_by"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
message_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False)
transfer_method = db.Column(db.String(255), nullable=False)
url = db.Column(db.Text, nullable=True)
belongs_to = db.Column(db.String(255), nullable=True)
upload_file_id = db.Column(StringUUID, nullable=True)
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
def __init__(
self,
*,
message_id: str,
type: FileType,
transfer_method: FileTransferMethod,
url: str | None = None,
belongs_to: Literal["user", "assistant"] | None = None,
upload_file_id: str | None = None,
created_by_role: CreatedByRole,
created_by: str,
):
self.message_id = message_id
self.type = type
self.transfer_method = transfer_method
self.url = url
self.belongs_to = belongs_to
self.upload_file_id = upload_file_id
self.created_by_role = created_by_role
self.created_by = created_by
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
message_id: Mapped[str] = db.Column(StringUUID, nullable=False)
type: Mapped[str] = db.Column(db.String(255), nullable=False)
transfer_method: Mapped[str] = db.Column(db.String(255), nullable=False)
url: Mapped[Optional[str]] = db.Column(db.Text, nullable=True)
belongs_to: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True)
upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True)
created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False)
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
created_at: Mapped[datetime] = db.Column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
class MessageAnnotation(db.Model):
@@ -1250,21 +1361,58 @@ class UploadFile(db.Model):
db.Index("upload_file_tenant_idx", "tenant_id"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
storage_type = db.Column(db.String(255), nullable=False)
key = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False)
size = db.Column(db.Integer, nullable=False)
extension = db.Column(db.String(255), nullable=False)
mime_type = db.Column(db.String(255), nullable=True)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
used = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
used_by = db.Column(StringUUID, nullable=True)
used_at = db.Column(db.DateTime, nullable=True)
hash = db.Column(db.String(255), nullable=True)
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
storage_type: Mapped[str] = db.Column(db.String(255), nullable=False)
key: Mapped[str] = db.Column(db.String(255), nullable=False)
name: Mapped[str] = db.Column(db.String(255), nullable=False)
size: Mapped[int] = db.Column(db.Integer, nullable=False)
extension: Mapped[str] = db.Column(db.String(255), nullable=False)
mime_type: Mapped[str] = db.Column(db.String(255), nullable=True)
created_by_role: Mapped[str] = db.Column(
db.String(255), nullable=False, server_default=db.text("'account'::character varying")
)
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
created_at: Mapped[datetime] = db.Column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True)
used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True)
hash: Mapped[str | None] = db.Column(db.String(255), nullable=True)
def __init__(
self,
*,
tenant_id: str,
storage_type: str,
key: str,
name: str,
size: int,
extension: str,
mime_type: str,
created_by_role: str,
created_by: str,
created_at: datetime,
used: bool,
used_by: str | None = None,
used_at: datetime | None = None,
hash: str | None = None,
) -> None:
self.tenant_id = tenant_id
self.storage_type = storage_type
self.key = key
self.name = name
self.size = size
self.extension = extension
self.mime_type = mime_type
self.created_by_role = created_by_role
self.created_by = created_by
self.created_at = created_at
self.used = used
self.used_by = used_by
self.used_at = used_at
self.hash = hash
class ApiRequest(db.Model):

View File

@@ -1,6 +1,8 @@
import json
from typing import Optional
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
@@ -101,7 +103,7 @@ class ApiToolProvider(db.Model):
icon = db.Column(db.String(255), nullable=False)
# original schema
schema = db.Column(db.Text, nullable=False)
schema_type_str = db.Column(db.String(40), nullable=False)
schema_type_str: Mapped[str] = db.Column(db.String(40), nullable=False)
# who created this tool
user_id = db.Column(StringUUID, nullable=False)
# tenant id
@@ -133,11 +135,11 @@ class ApiToolProvider(db.Model):
return json.loads(self.credentials_str)
@property
def user(self) -> Account:
def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant:
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
@@ -203,11 +205,11 @@ class WorkflowToolProvider(db.Model):
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property
def user(self) -> Account:
def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant:
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
@property
@@ -215,7 +217,7 @@ class WorkflowToolProvider(db.Model):
return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)]
@property
def app(self) -> App:
def app(self) -> App | None:
return db.session.query(App).filter(App.id == self.app_id).first()
@@ -288,27 +290,39 @@ class ToolConversationVariables(db.Model):
class ToolFile(db.Model):
"""
store the file created by agent
"""
__tablename__ = "tool_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_file_pkey"),
# add index for conversation_id
db.Index("tool_file_conversation_id_idx", "conversation_id"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# conversation user id
user_id = db.Column(StringUUID, nullable=False)
# tenant id
tenant_id = db.Column(StringUUID, nullable=False)
# conversation id
conversation_id = db.Column(StringUUID, nullable=True)
# file key
file_key = db.Column(db.String(255), nullable=False)
# mime type
mimetype = db.Column(db.String(255), nullable=False)
# original url
original_url = db.Column(db.String(2048), nullable=True)
user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
conversation_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True)
file_key: Mapped[str] = db.Column(db.String(255), nullable=False)
mimetype: Mapped[str] = db.Column(db.String(255), nullable=False)
original_url: Mapped[Optional[str]] = db.Column(db.String(2048), nullable=True)
name: Mapped[str] = mapped_column(default="")
size: Mapped[int] = mapped_column(default=-1)
def __init__(
self,
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str] = None,
file_key: str,
mimetype: str,
original_url: Optional[str] = None,
name: str,
size: int,
):
self.user_id = user_id
self.tenant_id = tenant_id
self.conversation_id = conversation_id
self.file_key = file_key
self.mimetype = mimetype
self.original_url = original_url
self.name = name
self.size = size

View File

@@ -5,41 +5,21 @@ from enum import Enum
from typing import Any, Optional, Union
from sqlalchemy import func
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import Mapped, mapped_column
import contexts
from constants import HIDDEN_VALUE
from core.app.segments import SecretVariable, Variable, factory
from core.helper import encrypter
from core.variables import SecretVariable, Variable
from extensions.ext_database import db
from factories import variable_factory
from libs import helper
from models.enums import CreatedByRole
from .account import Account
from .types import StringUUID
class CreatedByRole(Enum):
"""
Created By Role Enum
"""
ACCOUNT = "account"
END_USER = "end_user"
@classmethod
def value_of(cls, value: str) -> "CreatedByRole":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid created by role value {value}")
class WorkflowType(Enum):
"""
Workflow Type Enum
@@ -114,23 +94,23 @@ class Workflow(db.Model):
db.Index("workflow_version_idx", "tenant_id", "app_id", "version"),
)
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
app_id: Mapped[str] = db.Column(StringUUID, nullable=False)
type: Mapped[str] = db.Column(db.String(255), nullable=False)
version: Mapped[str] = db.Column(db.String(255), nullable=False)
graph: Mapped[str] = db.Column(db.Text)
features: Mapped[str] = db.Column(db.Text)
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
created_at: Mapped[datetime] = db.Column(
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(db.String(255), nullable=False)
version: Mapped[str] = mapped_column(db.String(255), nullable=False)
graph: Mapped[str] = mapped_column(db.Text)
_features: Mapped[str] = mapped_column("features")
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
updated_by: Mapped[str] = db.Column(StringUUID)
updated_at: Mapped[datetime] = db.Column(db.DateTime)
_environment_variables: Mapped[str] = db.Column(
updated_by: Mapped[str] = mapped_column(StringUUID)
updated_at: Mapped[datetime] = mapped_column(db.DateTime)
_environment_variables: Mapped[str] = mapped_column(
"environment_variables", db.Text, nullable=False, server_default="{}"
)
_conversation_variables: Mapped[str] = db.Column(
_conversation_variables: Mapped[str] = mapped_column(
"conversation_variables", db.Text, nullable=False, server_default="{}"
)
@@ -169,6 +149,34 @@ class Workflow(db.Model):
def graph_dict(self) -> Mapping[str, Any]:
return json.loads(self.graph) if self.graph else {}
@property
def features(self) -> str:
"""
Convert old features structure to new features structure.
"""
if not self._features:
return self._features
features = json.loads(self._features)
if features.get("file_upload", {}).get("image", {}).get("enabled", False):
image_enabled = True
image_number_limits = int(features["file_upload"]["image"].get("number_limits", 1))
image_transfer_methods = features["file_upload"]["image"].get(
"transfer_methods", ["remote_url", "local_file"]
)
features["file_upload"]["enabled"] = image_enabled
features["file_upload"]["number_limits"] = image_number_limits
features["file_upload"]["allowed_upload_methods"] = image_transfer_methods
features["file_upload"]["allowed_file_types"] = ["image"]
features["file_upload"]["allowed_extensions"] = []
del features["file_upload"]["image"]
self._features = json.dumps(features)
return self._features
@features.setter
def features(self, value: str) -> None:
self._features = value
@property
def features_dict(self) -> Mapping[str, Any]:
return json.loads(self.features) if self.features else {}
@@ -227,7 +235,7 @@ class Workflow(db.Model):
tenant_id = contexts.tenant_id.get()
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables)
results = [factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()]
results = [variable_factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()]
# decrypt secret variables value
decrypt_func = (
@@ -240,6 +248,10 @@ class Workflow(db.Model):
@environment_variables.setter
def environment_variables(self, value: Sequence[Variable]):
if not value:
self._environment_variables = "{}"
return
tenant_id = contexts.tenant_id.get()
value = list(value)
@@ -288,7 +300,7 @@ class Workflow(db.Model):
self._conversation_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()]
results = [variable_factory.build_variable_from_mapping(v) for v in variables_dict.values()]
return results
@conversation_variables.setter
@@ -299,28 +311,6 @@ class Workflow(db.Model):
)
class WorkflowRunTriggeredFrom(Enum):
"""
Workflow Run Triggered From Enum
"""
DEBUGGING = "debugging"
APP_RUN = "app-run"
@classmethod
def value_of(cls, value: str) -> "WorkflowRunTriggeredFrom":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid workflow run triggered from value {value}")
class WorkflowRunStatus(Enum):
"""
Workflow Run Status Enum
@@ -401,7 +391,7 @@ class WorkflowRun(db.Model):
graph = db.Column(db.Text)
inputs = db.Column(db.Text)
status = db.Column(db.String(255), nullable=False)
outputs = db.Column(db.Text)
outputs: Mapped[str] = db.Column(db.Text)
error = db.Column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
@@ -413,27 +403,27 @@ class WorkflowRun(db.Model):
@property
def created_by_account(self):
created_by_role = CreatedByRole.value_of(self.created_by_role)
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
@property
def created_by_end_user(self):
from models.model import EndUser
created_by_role = CreatedByRole.value_of(self.created_by_role)
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
@property
def graph_dict(self):
return json.loads(self.graph) if self.graph else None
return json.loads(self.graph) if self.graph else {}
@property
def inputs_dict(self):
return json.loads(self.inputs) if self.inputs else None
def inputs_dict(self) -> Mapping[str, Any]:
return json.loads(self.inputs) if self.inputs else {}
@property
def outputs_dict(self):
return json.loads(self.outputs) if self.outputs else None
def outputs_dict(self) -> Mapping[str, Any]:
return json.loads(self.outputs) if self.outputs else {}
@property
def message(self) -> Optional["Message"]:
@@ -640,14 +630,14 @@ class WorkflowNodeExecution(db.Model):
@property
def created_by_account(self):
created_by_role = CreatedByRole.value_of(self.created_by_role)
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
@property
def created_by_end_user(self):
from models.model import EndUser
created_by_role = CreatedByRole.value_of(self.created_by_role)
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
@property
@@ -672,7 +662,7 @@ class WorkflowNodeExecution(db.Model):
extras = {}
if self.execution_metadata_dict:
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes import NodeType
if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict:
tool_info = self.execution_metadata_dict["tool_info"]
@@ -759,14 +749,14 @@ class WorkflowAppLog(db.Model):
@property
def created_by_account(self):
created_by_role = CreatedByRole.value_of(self.created_by_role)
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
@property
def created_by_end_user(self):
from models.model import EndUser
created_by_role = CreatedByRole.value_of(self.created_by_role)
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
@@ -800,4 +790,4 @@ class ConversationVariable(db.Model):
def to_variable(self) -> Variable:
mapping = json.loads(self.data)
return factory.build_variable_from_mapping(mapping)
return variable_factory.build_variable_from_mapping(mapping)