mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 02:46:52 +08:00
feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -1,29 +1,55 @@
|
||||
from enum import Enum
|
||||
from .account import Account, AccountIntegrate, InvitationCode, Tenant
|
||||
from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
|
||||
from .model import (
|
||||
ApiToken,
|
||||
App,
|
||||
AppMode,
|
||||
Conversation,
|
||||
EndUser,
|
||||
FileUploadConfig,
|
||||
InstalledApp,
|
||||
Message,
|
||||
MessageAnnotation,
|
||||
MessageFile,
|
||||
RecommendedApp,
|
||||
Site,
|
||||
UploadFile,
|
||||
)
|
||||
from .source import DataSourceOauthBinding
|
||||
from .tools import ToolFile
|
||||
from .workflow import (
|
||||
ConversationVariable,
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
from .model import App, AppMode, Message
|
||||
from .types import StringUUID
|
||||
from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus
|
||||
|
||||
__all__ = ["ConversationVariable", "StringUUID", "AppMode", "WorkflowNodeExecutionStatus", "Workflow", "App", "Message"]
|
||||
|
||||
|
||||
class CreatedByRole(Enum):
|
||||
"""
|
||||
Enum class for createdByRole
|
||||
"""
|
||||
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "CreatedByRole":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for role in cls:
|
||||
if role.value == value:
|
||||
return role
|
||||
raise ValueError(f"invalid createdByRole value {value}")
|
||||
__all__ = [
|
||||
"ConversationVariable",
|
||||
"Document",
|
||||
"Dataset",
|
||||
"DatasetProcessRule",
|
||||
"DocumentSegment",
|
||||
"DataSourceOauthBinding",
|
||||
"AppMode",
|
||||
"Workflow",
|
||||
"App",
|
||||
"Message",
|
||||
"EndUser",
|
||||
"MessageFile",
|
||||
"UploadFile",
|
||||
"Account",
|
||||
"WorkflowAppLog",
|
||||
"WorkflowRun",
|
||||
"Site",
|
||||
"InstalledApp",
|
||||
"RecommendedApp",
|
||||
"ApiToken",
|
||||
"AccountIntegrate",
|
||||
"InvitationCode",
|
||||
"Tenant",
|
||||
"Conversation",
|
||||
"MessageAnnotation",
|
||||
"FileUploadConfig",
|
||||
"ToolFile",
|
||||
]
|
||||
|
||||
@@ -560,7 +560,7 @@ class DocumentSegment(db.Model):
|
||||
)
|
||||
|
||||
def get_sign_content(self):
|
||||
pattern = r"/files/([a-f0-9\-]+)/image-preview"
|
||||
pattern = r"/files/([a-f0-9\-]+)/file-preview"
|
||||
text = self.content
|
||||
matches = re.finditer(pattern, text)
|
||||
signed_urls = []
|
||||
@@ -568,7 +568,7 @@ class DocumentSegment(db.Model):
|
||||
upload_file_id = match.group(1)
|
||||
nonce = os.urandom(16).hex()
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
16
api/models/enums.py
Normal file
16
api/models/enums.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CreatedByRole(str, Enum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
|
||||
class UserFrom(str, Enum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(str, Enum):
|
||||
DEBUGGING = "debugging"
|
||||
APP_RUN = "app-run"
|
||||
@@ -1,24 +1,37 @@
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Float, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileExtraConfig, FileTransferMethod, FileType
|
||||
from core.file import helpers as file_helpers
|
||||
from core.file.tool_file_parser import ToolFileParser
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import generate_string
|
||||
from models.enums import CreatedByRole
|
||||
|
||||
from .account import Account, Tenant
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class FileUploadConfig(BaseModel):
|
||||
enabled: bool = Field(default=False)
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
number_limits: int = Field(default=0, gt=0, le=10)
|
||||
|
||||
|
||||
class DifySetup(db.Model):
|
||||
__tablename__ = "dify_setups"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
@@ -27,7 +40,7 @@ class DifySetup(db.Model):
|
||||
setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
|
||||
class AppMode(Enum):
|
||||
class AppMode(str, Enum):
|
||||
COMPLETION = "completion"
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
@@ -59,7 +72,7 @@ class App(db.Model):
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
|
||||
mode = db.Column(db.String(255), nullable=False)
|
||||
@@ -530,7 +543,7 @@ class Conversation(db.Model):
|
||||
mode = db.Column(db.String(255), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
summary = db.Column(db.Text)
|
||||
inputs = db.Column(db.JSON)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
||||
introduction = db.Column(db.Text)
|
||||
system_instruction = db.Column(db.Text)
|
||||
system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
@@ -552,6 +565,28 @@ class Conversation(db.Model):
|
||||
|
||||
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
inputs = self._inputs.copy()
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
inputs[key] = File.model_validate(value)
|
||||
elif isinstance(value, list) and all(
|
||||
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
|
||||
):
|
||||
inputs[key] = [File.model_validate(item) for item in value]
|
||||
return inputs
|
||||
|
||||
@inputs.setter
|
||||
def inputs(self, value: Mapping[str, Any]):
|
||||
inputs = dict(value)
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, File):
|
||||
inputs[k] = v.model_dump()
|
||||
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
|
||||
inputs[k] = [item.model_dump() for item in v]
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
def model_config(self):
|
||||
model_config = {}
|
||||
@@ -700,13 +735,13 @@ class Message(db.Model):
|
||||
model_id = db.Column(db.String(255), nullable=True)
|
||||
override_model_configs = db.Column(db.Text)
|
||||
conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
|
||||
inputs = db.Column(db.JSON)
|
||||
query = db.Column(db.Text, nullable=False)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
||||
query: Mapped[str] = db.Column(db.Text, nullable=False)
|
||||
message = db.Column(db.JSON, nullable=False)
|
||||
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
answer = db.Column(db.Text, nullable=False)
|
||||
answer: Mapped[str] = db.Column(db.Text, nullable=False)
|
||||
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
@@ -717,15 +752,37 @@ class Message(db.Model):
|
||||
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
error = db.Column(db.Text)
|
||||
message_metadata = db.Column(db.Text)
|
||||
invoke_from = db.Column(db.String(255), nullable=True)
|
||||
invoke_from: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True)
|
||||
from_source = db.Column(db.String(255), nullable=False)
|
||||
from_end_user_id = db.Column(StringUUID)
|
||||
from_account_id = db.Column(StringUUID)
|
||||
from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID)
|
||||
from_account_id: Mapped[Optional[str]] = db.Column(StringUUID)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
workflow_run_id = db.Column(StringUUID)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
inputs = self._inputs.copy()
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
inputs[key] = File.model_validate(value)
|
||||
elif isinstance(value, list) and all(
|
||||
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
|
||||
):
|
||||
inputs[key] = [File.model_validate(item) for item in value]
|
||||
return inputs
|
||||
|
||||
@inputs.setter
|
||||
def inputs(self, value: Mapping[str, Any]):
|
||||
inputs = dict(value)
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, File):
|
||||
inputs[k] = v.model_dump()
|
||||
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
|
||||
inputs[k] = [item.model_dump() for item in v]
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
def re_sign_file_url_answer(self) -> str:
|
||||
if not self.answer:
|
||||
@@ -772,19 +829,29 @@ class Message(db.Model):
|
||||
sign_url = ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=tool_file_id, extension=extension
|
||||
)
|
||||
else:
|
||||
elif "file-preview" in url:
|
||||
# get upload file id
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp="
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp="
|
||||
result = re.search(upload_file_id_pattern, url)
|
||||
if not result:
|
||||
continue
|
||||
|
||||
upload_file_id = result.group(1)
|
||||
|
||||
if not upload_file_id:
|
||||
continue
|
||||
|
||||
sign_url = UploadFileParser.get_signed_temp_image_url(upload_file_id)
|
||||
sign_url = file_helpers.get_signed_file_url(upload_file_id)
|
||||
elif "image-preview" in url:
|
||||
# image-preview is deprecated, use file-preview instead
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp="
|
||||
result = re.search(upload_file_id_pattern, url)
|
||||
if not result:
|
||||
continue
|
||||
upload_file_id = result.group(1)
|
||||
if not upload_file_id:
|
||||
continue
|
||||
sign_url = file_helpers.get_signed_file_url(upload_file_id)
|
||||
else:
|
||||
continue
|
||||
|
||||
re_sign_file_url_answer = re_sign_file_url_answer.replace(url, sign_url)
|
||||
|
||||
@@ -870,50 +937,71 @@ class Message(db.Model):
|
||||
|
||||
@property
|
||||
def message_files(self):
|
||||
return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
|
||||
from factories import file_factory
|
||||
|
||||
@property
|
||||
def files(self):
|
||||
message_files = self.message_files
|
||||
message_files = db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
|
||||
current_app = db.session.query(App).filter(App.id == self.app_id).first()
|
||||
if not current_app:
|
||||
raise ValueError(f"App {self.app_id} not found")
|
||||
|
||||
files = []
|
||||
files: list[File] = []
|
||||
for message_file in message_files:
|
||||
url = message_file.url
|
||||
if message_file.type == "image":
|
||||
if message_file.transfer_method == "local_file":
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter(UploadFile.id == message_file.upload_file_id).first()
|
||||
)
|
||||
|
||||
url = UploadFileParser.get_image_data(upload_file=upload_file, force_url=True)
|
||||
if message_file.transfer_method == "tool_file":
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
|
||||
# get extension
|
||||
if "." in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
if len(extension) > 10:
|
||||
extension = ".bin"
|
||||
else:
|
||||
extension = ".bin"
|
||||
# add sign url
|
||||
url = ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=tool_file_id, extension=extension
|
||||
)
|
||||
|
||||
files.append(
|
||||
{
|
||||
if message_file.transfer_method == "local_file":
|
||||
if message_file.upload_file_id is None:
|
||||
raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id")
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping={
|
||||
"id": message_file.id,
|
||||
"upload_file_id": message_file.upload_file_id,
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"type": message_file.type,
|
||||
},
|
||||
tenant_id=current_app.tenant_id,
|
||||
user_id=self.from_account_id or self.from_end_user_id or "",
|
||||
role=CreatedByRole(message_file.created_by_role),
|
||||
config=FileExtraConfig(),
|
||||
)
|
||||
elif message_file.transfer_method == "remote_url":
|
||||
if message_file.url is None:
|
||||
raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url")
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping={
|
||||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"url": message_file.url,
|
||||
},
|
||||
tenant_id=current_app.tenant_id,
|
||||
user_id=self.from_account_id or self.from_end_user_id or "",
|
||||
role=CreatedByRole(message_file.created_by_role),
|
||||
config=FileExtraConfig(),
|
||||
)
|
||||
elif message_file.transfer_method == "tool_file":
|
||||
mapping = {
|
||||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"url": url,
|
||||
"belongs_to": message_file.belongs_to or "user",
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"tool_file_id": message_file.upload_file_id,
|
||||
}
|
||||
)
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=current_app.tenant_id,
|
||||
user_id=self.from_account_id or self.from_end_user_id or "",
|
||||
role=CreatedByRole(message_file.created_by_role),
|
||||
config=FileExtraConfig(),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}"
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
return files
|
||||
result = [
|
||||
{"belongs_to": message_file.belongs_to, **file.to_dict()}
|
||||
for (file, message_file) in zip(files, message_files)
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def workflow_run(self):
|
||||
@@ -1003,16 +1091,39 @@ class MessageFile(db.Model):
|
||||
db.Index("message_file_created_by_idx", "created_by"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
message_id = db.Column(StringUUID, nullable=False)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
transfer_method = db.Column(db.String(255), nullable=False)
|
||||
url = db.Column(db.Text, nullable=True)
|
||||
belongs_to = db.Column(db.String(255), nullable=True)
|
||||
upload_file_id = db.Column(StringUUID, nullable=True)
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
message_id: str,
|
||||
type: FileType,
|
||||
transfer_method: FileTransferMethod,
|
||||
url: str | None = None,
|
||||
belongs_to: Literal["user", "assistant"] | None = None,
|
||||
upload_file_id: str | None = None,
|
||||
created_by_role: CreatedByRole,
|
||||
created_by: str,
|
||||
):
|
||||
self.message_id = message_id
|
||||
self.type = type
|
||||
self.transfer_method = transfer_method
|
||||
self.url = url
|
||||
self.belongs_to = belongs_to
|
||||
self.upload_file_id = upload_file_id
|
||||
self.created_by_role = created_by_role
|
||||
self.created_by = created_by
|
||||
|
||||
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
message_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
transfer_method: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
url: Mapped[Optional[str]] = db.Column(db.Text, nullable=True)
|
||||
belongs_to: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True)
|
||||
upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = db.Column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
|
||||
|
||||
class MessageAnnotation(db.Model):
|
||||
@@ -1250,21 +1361,58 @@ class UploadFile(db.Model):
|
||||
db.Index("upload_file_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
storage_type = db.Column(db.String(255), nullable=False)
|
||||
key = db.Column(db.String(255), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
size = db.Column(db.Integer, nullable=False)
|
||||
extension = db.Column(db.String(255), nullable=False)
|
||||
mime_type = db.Column(db.String(255), nullable=True)
|
||||
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
used = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
used_by = db.Column(StringUUID, nullable=True)
|
||||
used_at = db.Column(db.DateTime, nullable=True)
|
||||
hash = db.Column(db.String(255), nullable=True)
|
||||
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
storage_type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
key: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
name: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
size: Mapped[int] = db.Column(db.Integer, nullable=False)
|
||||
extension: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
mime_type: Mapped[str] = db.Column(db.String(255), nullable=True)
|
||||
created_by_role: Mapped[str] = db.Column(
|
||||
db.String(255), nullable=False, server_default=db.text("'account'::character varying")
|
||||
)
|
||||
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = db.Column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True)
|
||||
used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True)
|
||||
hash: Mapped[str | None] = db.Column(db.String(255), nullable=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
storage_type: str,
|
||||
key: str,
|
||||
name: str,
|
||||
size: int,
|
||||
extension: str,
|
||||
mime_type: str,
|
||||
created_by_role: str,
|
||||
created_by: str,
|
||||
created_at: datetime,
|
||||
used: bool,
|
||||
used_by: str | None = None,
|
||||
used_at: datetime | None = None,
|
||||
hash: str | None = None,
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.storage_type = storage_type
|
||||
self.key = key
|
||||
self.name = name
|
||||
self.size = size
|
||||
self.extension = extension
|
||||
self.mime_type = mime_type
|
||||
self.created_by_role = created_by_role
|
||||
self.created_by = created_by
|
||||
self.created_at = created_at
|
||||
self.used = used
|
||||
self.used_by = used_by
|
||||
self.used_at = used_at
|
||||
self.hash = hash
|
||||
|
||||
|
||||
class ApiRequest(db.Model):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
@@ -101,7 +103,7 @@ class ApiToolProvider(db.Model):
|
||||
icon = db.Column(db.String(255), nullable=False)
|
||||
# original schema
|
||||
schema = db.Column(db.Text, nullable=False)
|
||||
schema_type_str = db.Column(db.String(40), nullable=False)
|
||||
schema_type_str: Mapped[str] = db.Column(db.String(40), nullable=False)
|
||||
# who created this tool
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
@@ -133,11 +135,11 @@ class ApiToolProvider(db.Model):
|
||||
return json.loads(self.credentials_str)
|
||||
|
||||
@property
|
||||
def user(self) -> Account:
|
||||
def user(self) -> Account | None:
|
||||
return db.session.query(Account).filter(Account.id == self.user_id).first()
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant:
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
|
||||
|
||||
@@ -203,11 +205,11 @@ class WorkflowToolProvider(db.Model):
|
||||
return ApiProviderSchemaType.value_of(self.schema_type_str)
|
||||
|
||||
@property
|
||||
def user(self) -> Account:
|
||||
def user(self) -> Account | None:
|
||||
return db.session.query(Account).filter(Account.id == self.user_id).first()
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant:
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
|
||||
@property
|
||||
@@ -215,7 +217,7 @@ class WorkflowToolProvider(db.Model):
|
||||
return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)]
|
||||
|
||||
@property
|
||||
def app(self) -> App:
|
||||
def app(self) -> App | None:
|
||||
return db.session.query(App).filter(App.id == self.app_id).first()
|
||||
|
||||
|
||||
@@ -288,27 +290,39 @@ class ToolConversationVariables(db.Model):
|
||||
|
||||
|
||||
class ToolFile(db.Model):
|
||||
"""
|
||||
store the file created by agent
|
||||
"""
|
||||
|
||||
__tablename__ = "tool_files"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_file_pkey"),
|
||||
# add index for conversation_id
|
||||
db.Index("tool_file_conversation_id_idx", "conversation_id"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
# conversation user id
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
# conversation id
|
||||
conversation_id = db.Column(StringUUID, nullable=True)
|
||||
# file key
|
||||
file_key = db.Column(db.String(255), nullable=False)
|
||||
# mime type
|
||||
mimetype = db.Column(db.String(255), nullable=False)
|
||||
# original url
|
||||
original_url = db.Column(db.String(2048), nullable=True)
|
||||
user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True)
|
||||
file_key: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
mimetype: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
original_url: Mapped[Optional[str]] = db.Column(db.String(2048), nullable=True)
|
||||
name: Mapped[str] = mapped_column(default="")
|
||||
size: Mapped[int] = mapped_column(default=-1)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
file_key: str,
|
||||
mimetype: str,
|
||||
original_url: Optional[str] = None,
|
||||
name: str,
|
||||
size: int,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.tenant_id = tenant_id
|
||||
self.conversation_id = conversation_id
|
||||
self.file_key = file_key
|
||||
self.mimetype = mimetype
|
||||
self.original_url = original_url
|
||||
self.name = name
|
||||
self.size = size
|
||||
|
||||
@@ -5,41 +5,21 @@ from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
import contexts
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.app.segments import SecretVariable, Variable, factory
|
||||
from core.helper import encrypter
|
||||
from core.variables import SecretVariable, Variable
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from libs import helper
|
||||
from models.enums import CreatedByRole
|
||||
|
||||
from .account import Account
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class CreatedByRole(Enum):
|
||||
"""
|
||||
Created By Role Enum
|
||||
"""
|
||||
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "CreatedByRole":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid created by role value {value}")
|
||||
|
||||
|
||||
class WorkflowType(Enum):
|
||||
"""
|
||||
Workflow Type Enum
|
||||
@@ -114,23 +94,23 @@ class Workflow(db.Model):
|
||||
db.Index("workflow_version_idx", "tenant_id", "app_id", "version"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
version: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
graph: Mapped[str] = db.Column(db.Text)
|
||||
features: Mapped[str] = db.Column(db.Text)
|
||||
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = db.Column(
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
version: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
graph: Mapped[str] = mapped_column(db.Text)
|
||||
_features: Mapped[str] = mapped_column("features")
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_by: Mapped[str] = db.Column(StringUUID)
|
||||
updated_at: Mapped[datetime] = db.Column(db.DateTime)
|
||||
_environment_variables: Mapped[str] = db.Column(
|
||||
updated_by: Mapped[str] = mapped_column(StringUUID)
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime)
|
||||
_environment_variables: Mapped[str] = mapped_column(
|
||||
"environment_variables", db.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
_conversation_variables: Mapped[str] = db.Column(
|
||||
_conversation_variables: Mapped[str] = mapped_column(
|
||||
"conversation_variables", db.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
|
||||
@@ -169,6 +149,34 @@ class Workflow(db.Model):
|
||||
def graph_dict(self) -> Mapping[str, Any]:
|
||||
return json.loads(self.graph) if self.graph else {}
|
||||
|
||||
@property
|
||||
def features(self) -> str:
|
||||
"""
|
||||
Convert old features structure to new features structure.
|
||||
"""
|
||||
if not self._features:
|
||||
return self._features
|
||||
|
||||
features = json.loads(self._features)
|
||||
if features.get("file_upload", {}).get("image", {}).get("enabled", False):
|
||||
image_enabled = True
|
||||
image_number_limits = int(features["file_upload"]["image"].get("number_limits", 1))
|
||||
image_transfer_methods = features["file_upload"]["image"].get(
|
||||
"transfer_methods", ["remote_url", "local_file"]
|
||||
)
|
||||
features["file_upload"]["enabled"] = image_enabled
|
||||
features["file_upload"]["number_limits"] = image_number_limits
|
||||
features["file_upload"]["allowed_upload_methods"] = image_transfer_methods
|
||||
features["file_upload"]["allowed_file_types"] = ["image"]
|
||||
features["file_upload"]["allowed_extensions"] = []
|
||||
del features["file_upload"]["image"]
|
||||
self._features = json.dumps(features)
|
||||
return self._features
|
||||
|
||||
@features.setter
|
||||
def features(self, value: str) -> None:
|
||||
self._features = value
|
||||
|
||||
@property
|
||||
def features_dict(self) -> Mapping[str, Any]:
|
||||
return json.loads(self.features) if self.features else {}
|
||||
@@ -227,7 +235,7 @@ class Workflow(db.Model):
|
||||
tenant_id = contexts.tenant_id.get()
|
||||
|
||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables)
|
||||
results = [factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()]
|
||||
results = [variable_factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()]
|
||||
|
||||
# decrypt secret variables value
|
||||
decrypt_func = (
|
||||
@@ -240,6 +248,10 @@ class Workflow(db.Model):
|
||||
|
||||
@environment_variables.setter
|
||||
def environment_variables(self, value: Sequence[Variable]):
|
||||
if not value:
|
||||
self._environment_variables = "{}"
|
||||
return
|
||||
|
||||
tenant_id = contexts.tenant_id.get()
|
||||
|
||||
value = list(value)
|
||||
@@ -288,7 +300,7 @@ class Workflow(db.Model):
|
||||
self._conversation_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
results = [variable_factory.build_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@conversation_variables.setter
|
||||
@@ -299,28 +311,6 @@ class Workflow(db.Model):
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(Enum):
|
||||
"""
|
||||
Workflow Run Triggered From Enum
|
||||
"""
|
||||
|
||||
DEBUGGING = "debugging"
|
||||
APP_RUN = "app-run"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowRunTriggeredFrom":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid workflow run triggered from value {value}")
|
||||
|
||||
|
||||
class WorkflowRunStatus(Enum):
|
||||
"""
|
||||
Workflow Run Status Enum
|
||||
@@ -401,7 +391,7 @@ class WorkflowRun(db.Model):
|
||||
graph = db.Column(db.Text)
|
||||
inputs = db.Column(db.Text)
|
||||
status = db.Column(db.String(255), nullable=False)
|
||||
outputs = db.Column(db.Text)
|
||||
outputs: Mapped[str] = db.Column(db.Text)
|
||||
error = db.Column(db.Text)
|
||||
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
|
||||
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
@@ -413,27 +403,27 @@ class WorkflowRun(db.Model):
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatedByRole.value_of(self.created_by_role)
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
||||
|
||||
@property
|
||||
def created_by_end_user(self):
|
||||
from models.model import EndUser
|
||||
|
||||
created_by_role = CreatedByRole.value_of(self.created_by_role)
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
||||
|
||||
@property
|
||||
def graph_dict(self):
|
||||
return json.loads(self.graph) if self.graph else None
|
||||
return json.loads(self.graph) if self.graph else {}
|
||||
|
||||
@property
|
||||
def inputs_dict(self):
|
||||
return json.loads(self.inputs) if self.inputs else None
|
||||
def inputs_dict(self) -> Mapping[str, Any]:
|
||||
return json.loads(self.inputs) if self.inputs else {}
|
||||
|
||||
@property
|
||||
def outputs_dict(self):
|
||||
return json.loads(self.outputs) if self.outputs else None
|
||||
def outputs_dict(self) -> Mapping[str, Any]:
|
||||
return json.loads(self.outputs) if self.outputs else {}
|
||||
|
||||
@property
|
||||
def message(self) -> Optional["Message"]:
|
||||
@@ -640,14 +630,14 @@ class WorkflowNodeExecution(db.Model):
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatedByRole.value_of(self.created_by_role)
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
||||
|
||||
@property
|
||||
def created_by_end_user(self):
|
||||
from models.model import EndUser
|
||||
|
||||
created_by_role = CreatedByRole.value_of(self.created_by_role)
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
||||
|
||||
@property
|
||||
@@ -672,7 +662,7 @@ class WorkflowNodeExecution(db.Model):
|
||||
|
||||
extras = {}
|
||||
if self.execution_metadata_dict:
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict:
|
||||
tool_info = self.execution_metadata_dict["tool_info"]
|
||||
@@ -759,14 +749,14 @@ class WorkflowAppLog(db.Model):
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatedByRole.value_of(self.created_by_role)
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
||||
|
||||
@property
|
||||
def created_by_end_user(self):
|
||||
from models.model import EndUser
|
||||
|
||||
created_by_role = CreatedByRole.value_of(self.created_by_role)
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
||||
|
||||
|
||||
@@ -800,4 +790,4 @@ class ConversationVariable(db.Model):
|
||||
|
||||
def to_variable(self) -> Variable:
|
||||
mapping = json.loads(self.data)
|
||||
return factory.build_variable_from_mapping(mapping)
|
||||
return variable_factory.build_variable_from_mapping(mapping)
|
||||
|
||||
Reference in New Issue
Block a user