mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 18:23:07 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -6,7 +6,7 @@ import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from hashlib import sha256
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
@@ -119,7 +119,7 @@ class AccountService:
|
||||
account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return account
|
||||
return cast(Account, account)
|
||||
|
||||
@staticmethod
|
||||
def get_account_jwt_token(account: Account) -> str:
|
||||
@@ -132,7 +132,7 @@ class AccountService:
|
||||
"sub": "Console API Passport",
|
||||
}
|
||||
|
||||
token = PassportService().issue(payload)
|
||||
token: str = PassportService().issue(payload)
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
@@ -164,7 +164,7 @@ class AccountService:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return account
|
||||
return cast(Account, account)
|
||||
|
||||
@staticmethod
|
||||
def update_account_password(account, password, new_password):
|
||||
@@ -347,6 +347,8 @@ class AccountService:
|
||||
language: Optional[str] = "en-US",
|
||||
):
|
||||
account_email = account.email if account else email
|
||||
if account_email is None:
|
||||
raise ValueError("Email must be provided.")
|
||||
|
||||
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import PasswordResetRateLimitExceededError
|
||||
@@ -377,6 +379,8 @@ class AccountService:
|
||||
def send_email_code_login_email(
|
||||
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
|
||||
):
|
||||
if email is None:
|
||||
raise ValueError("Email must be provided.")
|
||||
if cls.email_code_login_rate_limiter.is_rate_limited(email):
|
||||
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
|
||||
|
||||
@@ -669,7 +673,7 @@ class TenantService:
|
||||
@staticmethod
|
||||
def get_tenant_count() -> int:
|
||||
"""Get tenant count"""
|
||||
return db.session.query(func.count(Tenant.id)).scalar()
|
||||
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
|
||||
|
||||
@staticmethod
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
|
||||
@@ -733,10 +737,10 @@ class TenantService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_custom_config(tenant_id: str) -> None:
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).one_or_404()
|
||||
def get_custom_config(tenant_id: str) -> dict:
|
||||
tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404()
|
||||
|
||||
return tenant.custom_config_dict
|
||||
return cast(dict, tenant.custom_config_dict)
|
||||
|
||||
|
||||
class RegisterService:
|
||||
@@ -807,7 +811,7 @@ class RegisterService:
|
||||
account.status = AccountStatus.ACTIVE.value if not status else status.value
|
||||
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
if open_id is not None or provider is not None:
|
||||
if open_id is not None and provider is not None:
|
||||
AccountService.link_account_integrate(provider, open_id, account)
|
||||
|
||||
if FeatureService.get_system_features().is_allow_create_workspace:
|
||||
@@ -828,10 +832,11 @@ class RegisterService:
|
||||
|
||||
@classmethod
|
||||
def invite_new_member(
|
||||
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None
|
||||
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None
|
||||
) -> str:
|
||||
"""Invite new member"""
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
assert inviter is not None, "Inviter must be provided."
|
||||
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
@@ -894,7 +899,9 @@ class RegisterService:
|
||||
redis_client.delete(cls._get_invitation_token_key(token))
|
||||
|
||||
@classmethod
|
||||
def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[dict[str, Any]]:
|
||||
def get_invitation_if_token_valid(
|
||||
cls, workspace_id: Optional[str], email: str, token: str
|
||||
) -> Optional[dict[str, Any]]:
|
||||
invitation_data = cls._get_invitation_by_token(token, workspace_id, email)
|
||||
if not invitation_data:
|
||||
return None
|
||||
@@ -953,7 +960,7 @@ class RegisterService:
|
||||
if not data:
|
||||
return None
|
||||
|
||||
invitation = json.loads(data)
|
||||
invitation: dict = json.loads(data)
|
||||
return invitation
|
||||
|
||||
|
||||
|
||||
@@ -48,6 +48,8 @@ class AdvancedPromptTemplateService:
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
# default return empty dict
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
@@ -91,3 +93,5 @@ class AdvancedPromptTemplateService:
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
# default return empty dict
|
||||
return {}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytz
|
||||
from flask_login import current_user
|
||||
from flask_login import current_user # type: ignore
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
@@ -14,7 +16,7 @@ class AgentService:
|
||||
"""
|
||||
Service to get agent logs
|
||||
"""
|
||||
conversation: Conversation = (
|
||||
conversation: Optional[Conversation] = (
|
||||
db.session.query(Conversation)
|
||||
.filter(
|
||||
Conversation.id == conversation_id,
|
||||
@@ -26,7 +28,7 @@ class AgentService:
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation not found: {conversation_id}")
|
||||
|
||||
message: Message = (
|
||||
message: Optional[Message] = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.id == message_id,
|
||||
@@ -72,7 +74,10 @@ class AgentService:
|
||||
}
|
||||
|
||||
agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict())
|
||||
agent_tools = agent_config.tools
|
||||
if not agent_config:
|
||||
return result
|
||||
|
||||
agent_tools = agent_config.tools or []
|
||||
|
||||
def find_agent_tool(tool_name: str):
|
||||
for agent_tool in agent_tools:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from flask_login import current_user
|
||||
from flask_login import current_user # type: ignore
|
||||
from sqlalchemy import or_
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -71,7 +72,7 @@ class AppAnnotationService:
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
)
|
||||
return annotation
|
||||
return cast(MessageAnnotation, annotation)
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
|
||||
@@ -124,8 +125,7 @@ class AppAnnotationService:
|
||||
raise NotFound("App not found")
|
||||
if keyword:
|
||||
annotations = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
|
||||
.filter(
|
||||
or_(
|
||||
MessageAnnotation.question.ilike("%{}%".format(keyword)),
|
||||
@@ -137,8 +137,7 @@ class AppAnnotationService:
|
||||
)
|
||||
else:
|
||||
annotations = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
@@ -327,8 +326,7 @@ class AppAnnotationService:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
annotation_hit_histories = (
|
||||
db.session.query(AppAnnotationHitHistory)
|
||||
.filter(
|
||||
AppAnnotationHitHistory.query.filter(
|
||||
AppAnnotationHitHistory.app_id == app_id,
|
||||
AppAnnotationHitHistory.annotation_id == annotation_id,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
@@ -103,7 +103,7 @@ class AppDslService:
|
||||
raise ValueError(f"Invalid import_mode: {import_mode}")
|
||||
|
||||
# Get YAML content
|
||||
content = ""
|
||||
content: bytes | str = b""
|
||||
if mode == ImportMode.YAML_URL:
|
||||
if not yaml_url:
|
||||
return Import(
|
||||
@@ -136,7 +136,7 @@ class AppDslService:
|
||||
)
|
||||
|
||||
try:
|
||||
content = content.decode("utf-8")
|
||||
content = cast(bytes, content).decode("utf-8")
|
||||
except UnicodeDecodeError as e:
|
||||
return Import(
|
||||
id=import_id,
|
||||
@@ -362,6 +362,9 @@ class AppDslService:
|
||||
app.icon_background = icon_background or app_data.get("icon_background", app.icon_background)
|
||||
app.updated_by = account.id
|
||||
else:
|
||||
if account.current_tenant_id is None:
|
||||
raise ValueError("Current tenant is not set")
|
||||
|
||||
# Create new app
|
||||
app = App()
|
||||
app.id = str(uuid4())
|
||||
|
||||
@@ -118,7 +118,7 @@ class AppGenerateService:
|
||||
@staticmethod
|
||||
def _get_max_active_requests(app_model: App) -> int:
|
||||
max_active_requests = app_model.max_active_requests
|
||||
if app_model.max_active_requests is None:
|
||||
if max_active_requests is None:
|
||||
max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
|
||||
return max_active_requests
|
||||
|
||||
@@ -150,7 +150,7 @@ class AppGenerateService:
|
||||
message_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[dict, Generator]:
|
||||
) -> Union[Mapping, Generator]:
|
||||
"""
|
||||
Generate more like this
|
||||
:param app_model: app model
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
||||
from configs import dify_config
|
||||
@@ -83,7 +83,7 @@ class AppService:
|
||||
# get default model instance
|
||||
try:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=account.current_tenant_id, model_type=ModelType.LLM
|
||||
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
|
||||
)
|
||||
except (ProviderTokenNotInitError, LLMBadRequestError):
|
||||
model_instance = None
|
||||
@@ -100,6 +100,8 @@ class AppService:
|
||||
else:
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if model_schema is None:
|
||||
raise ValueError(f"model schema not found for model {model_instance.model}")
|
||||
|
||||
default_model_dict = {
|
||||
"provider": model_instance.provider,
|
||||
@@ -109,7 +111,7 @@ class AppService:
|
||||
}
|
||||
else:
|
||||
provider, model = model_manager.get_default_provider_model_name(
|
||||
tenant_id=account.current_tenant_id, model_type=ModelType.LLM
|
||||
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
|
||||
)
|
||||
default_model_config["model"]["provider"] = provider
|
||||
default_model_config["model"]["name"] = model
|
||||
@@ -314,7 +316,7 @@ class AppService:
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
meta = {"tool_icons": {}}
|
||||
meta: dict = {"tool_icons": {}}
|
||||
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
@@ -336,7 +338,7 @@ class AppService:
|
||||
}
|
||||
)
|
||||
else:
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
app_model_config: Optional[AppModelConfig] = app_model.app_model_config
|
||||
|
||||
if not app_model_config:
|
||||
return meta
|
||||
@@ -352,16 +354,18 @@ class AppService:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
# current tool standard
|
||||
provider_type = tool.get("provider_type")
|
||||
provider_id = tool.get("provider_id")
|
||||
tool_name = tool.get("tool_name")
|
||||
provider_type = tool.get("provider_type", "")
|
||||
provider_id = tool.get("provider_id", "")
|
||||
tool_name = tool.get("tool_name", "")
|
||||
if provider_type == "builtin":
|
||||
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
|
||||
elif provider_type == "api":
|
||||
try:
|
||||
provider: ApiToolProvider = (
|
||||
provider: Optional[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first()
|
||||
)
|
||||
if provider is None:
|
||||
raise ValueError(f"provider not found for tool {tool_name}")
|
||||
meta["tool_icons"][tool_name] = json.loads(provider.icon)
|
||||
except:
|
||||
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@@ -110,6 +110,8 @@ class AudioService:
|
||||
voices = model_instance.get_tts_voices()
|
||||
if voices:
|
||||
voice = voices[0].get("value")
|
||||
if not voice:
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
else:
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
|
||||
@@ -121,6 +123,8 @@ class AudioService:
|
||||
|
||||
if message_id:
|
||||
message = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
if message is None:
|
||||
return None
|
||||
if message.answer == "" and message.status == "normal":
|
||||
return None
|
||||
|
||||
@@ -130,6 +134,8 @@ class AudioService:
|
||||
return Response(stream_with_context(response), content_type="audio/mpeg")
|
||||
return response
|
||||
else:
|
||||
if not text:
|
||||
raise ValueError("Text is required")
|
||||
response = invoke_tts(text, app_model, voice)
|
||||
if isinstance(response, Generator):
|
||||
return Response(stream_with_context(response), content_type="audio/mpeg")
|
||||
|
||||
@@ -11,8 +11,8 @@ class FirecrawlAuth(ApiKeyAuthBase):
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
|
||||
self.api_key = credentials.get("config").get("api_key", None)
|
||||
self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev")
|
||||
self.api_key = credentials.get("config", {}).get("api_key", None)
|
||||
self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
@@ -11,7 +11,7 @@ class JinaAuth(ApiKeyAuthBase):
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
|
||||
self.api_key = credentials.get("config").get("api_key", None)
|
||||
self.api_key = credentials.get("config", {}).get("api_key", None)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
@@ -11,7 +11,7 @@ class JinaAuth(ApiKeyAuthBase):
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
|
||||
self.api_key = credentials.get("config").get("api_key", None)
|
||||
self.api_key = credentials.get("config", {}).get("api_key", None)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from tenacity import retry, retry_if_not_exception_type, stop_before_delay, wait_fixed
|
||||
@@ -58,11 +59,14 @@ class BillingService:
|
||||
def is_tenant_owner_or_admin(current_user):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
join = (
|
||||
join: Optional[TenantAccountJoin] = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not join:
|
||||
raise ValueError("Tenant account join not found")
|
||||
|
||||
if not TenantAccountRole.is_privileged_role(join.role):
|
||||
raise ValueError("Only team owner or team admin can perform this action")
|
||||
|
||||
@@ -72,8 +72,7 @@ class ConversationService:
|
||||
sort_direction=sort_direction,
|
||||
reference_conversation=current_page_last_conversation,
|
||||
)
|
||||
count_stmt = stmt.where(rest_filter_condition)
|
||||
count_stmt = select(func.count()).select_from(count_stmt.subquery())
|
||||
count_stmt = select(func.count()).select_from(stmt.where(rest_filter_condition).subquery())
|
||||
rest_count = session.scalar(count_stmt) or 0
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
@@ -6,7 +6,7 @@ import time
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_login import current_user # type: ignore
|
||||
from sqlalchemy import func
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -186,8 +186,9 @@ class DatasetService:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Dataset:
|
||||
return Dataset.query.filter_by(id=dataset_id).first()
|
||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_model_setting(dataset):
|
||||
@@ -228,6 +229,8 @@ class DatasetService:
|
||||
@staticmethod
|
||||
def update_dataset(dataset_id, data, user):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
if dataset.provider == "external":
|
||||
@@ -371,7 +374,13 @@ class DatasetService:
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None):
|
||||
def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None):
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
|
||||
if dataset.created_by != user.id:
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
@@ -765,6 +774,11 @@ class DocumentService:
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
else:
|
||||
logging.warn(
|
||||
f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule"
|
||||
)
|
||||
return
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
|
||||
@@ -1009,9 +1023,10 @@ class DocumentService:
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
if dataset_process_rule is not None:
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
# update document data source
|
||||
if document_data.get("data_source"):
|
||||
file_name = ""
|
||||
@@ -1554,7 +1569,7 @@ class SegmentService:
|
||||
segment.word_count = len(content)
|
||||
if document.doc_form == "qa_model":
|
||||
segment.answer = segment_update_entity.answer
|
||||
segment.word_count += len(segment_update_entity.answer)
|
||||
segment.word_count += len(segment_update_entity.answer or "")
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
if segment_update_entity.keywords:
|
||||
segment.keywords = segment_update_entity.keywords
|
||||
@@ -1569,7 +1584,8 @@ class SegmentService:
|
||||
db.session.add(document)
|
||||
# update segment index task
|
||||
if segment_update_entity.enabled:
|
||||
VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset)
|
||||
keywords = segment_update_entity.keywords or []
|
||||
VectorService.create_segments_vector([keywords], [segment], dataset)
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
@@ -1601,7 +1617,7 @@ class SegmentService:
|
||||
segment.disabled_by = None
|
||||
if document.doc_form == "qa_model":
|
||||
segment.answer = segment_update_entity.answer
|
||||
segment.word_count += len(segment_update_entity.answer)
|
||||
segment.word_count += len(segment_update_entity.answer or "")
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
# update document word count
|
||||
if word_count_change != 0:
|
||||
@@ -1619,8 +1635,8 @@ class SegmentService:
|
||||
segment.status = "error"
|
||||
segment.error = str(e)
|
||||
db.session.commit()
|
||||
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
|
||||
return segment
|
||||
new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
|
||||
return new_segment
|
||||
|
||||
@classmethod
|
||||
def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
|
||||
@@ -1680,6 +1696,8 @@ class DatasetCollectionBindingService:
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
raise ValueError("Dataset collection binding not found")
|
||||
|
||||
return dataset_collection_binding
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ class EnterpriseRequest:
|
||||
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
|
||||
|
||||
proxies = {
|
||||
"http": None,
|
||||
"https": None,
|
||||
"http": "",
|
||||
"https": "",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -4,7 +4,11 @@ from typing import Optional
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
|
||||
from core.entities.model_entities import (
|
||||
ModelWithProviderEntity,
|
||||
ProviderModelWithStatusEntity,
|
||||
SimpleModelProviderEntity,
|
||||
)
|
||||
from core.entities.provider_entities import QuotaConfiguration
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@@ -148,7 +152,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
|
||||
Model with provider entity.
|
||||
"""
|
||||
|
||||
provider: SimpleProviderEntityResponse
|
||||
provider: SimpleModelProviderEntity
|
||||
|
||||
def __init__(self, model: ModelWithProviderEntity) -> None:
|
||||
super().__init__(**model.model_dump())
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
import validators
|
||||
@@ -45,7 +45,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis:
|
||||
ExternalDatasetService.check_endpoint_and_api_key(args.get("settings"))
|
||||
settings = args.get("settings")
|
||||
if settings is None:
|
||||
raise ValueError("settings is required")
|
||||
ExternalDatasetService.check_endpoint_and_api_key(settings)
|
||||
external_knowledge_api = ExternalKnowledgeApis(
|
||||
tenant_id=tenant_id,
|
||||
created_by=user_id,
|
||||
@@ -86,11 +89,16 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
|
||||
return ExternalKnowledgeApis.query.filter_by(id=external_knowledge_api_id).first()
|
||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
|
||||
id=external_knowledge_api_id
|
||||
).first()
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
return external_knowledge_api
|
||||
|
||||
@staticmethod
|
||||
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
|
||||
id=external_knowledge_api_id, tenant_id=tenant_id
|
||||
).first()
|
||||
if external_knowledge_api is None:
|
||||
@@ -127,7 +135,7 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
|
||||
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by(
|
||||
dataset_id=dataset_id, tenant_id=tenant_id
|
||||
).first()
|
||||
if not external_knowledge_binding:
|
||||
@@ -163,8 +171,9 @@ class ExternalDatasetService:
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs)
|
||||
|
||||
response: httpx.Response = getattr(ssrf_proxy, settings.request_method)(
|
||||
data=json.dumps(settings.params), files=files, **kwargs
|
||||
)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
@@ -265,15 +274,15 @@ class ExternalDatasetService:
|
||||
"knowledge_id": external_knowledge_binding.external_knowledge_id,
|
||||
}
|
||||
|
||||
external_knowledge_api_setting = {
|
||||
"url": f"{settings.get('endpoint')}/retrieval",
|
||||
"request_method": "post",
|
||||
"headers": headers,
|
||||
"params": request_params,
|
||||
}
|
||||
response = ExternalDatasetService.process_external_api(
|
||||
ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None
|
||||
ExternalKnowledgeApiSetting(
|
||||
url=f"{settings.get('endpoint')}/retrieval",
|
||||
request_method="post",
|
||||
headers=headers,
|
||||
params=request_params,
|
||||
),
|
||||
None,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json().get("records", [])
|
||||
return cast(list[Any], response.json().get("records", []))
|
||||
return []
|
||||
|
||||
@@ -3,7 +3,7 @@ import hashlib
|
||||
import uuid
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_login import current_user # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
@@ -61,14 +61,14 @@ class FileService:
|
||||
# end_user
|
||||
current_tenant_id = user.tenant_id
|
||||
|
||||
file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension
|
||||
file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, content)
|
||||
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_tenant_id or "",
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=filename,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document
|
||||
@@ -24,7 +25,7 @@ class HitTestingService:
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
retrieval_model: dict,
|
||||
retrieval_model: Any, # FIXME drop this any
|
||||
external_retrieval_model: dict,
|
||||
limit: int = 10,
|
||||
) -> dict:
|
||||
@@ -68,7 +69,7 @@ class HitTestingService:
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
return cls.compact_retrieve_response(dataset, query, all_documents)
|
||||
return dict(cls.compact_retrieve_response(dataset, query, all_documents))
|
||||
|
||||
@classmethod
|
||||
def external_retrieve(
|
||||
@@ -102,13 +103,16 @@ class HitTestingService:
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
return cls.compact_external_retrieve_response(dataset, query, all_documents)
|
||||
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
|
||||
|
||||
@classmethod
|
||||
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
|
||||
records = []
|
||||
|
||||
for document in documents:
|
||||
if document.metadata is None:
|
||||
continue
|
||||
|
||||
index_node_id = document.metadata["doc_id"]
|
||||
|
||||
segment = (
|
||||
@@ -140,7 +144,7 @@ class HitTestingService:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list):
|
||||
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]:
|
||||
records = []
|
||||
if dataset.provider == "external":
|
||||
for document in documents:
|
||||
@@ -152,11 +156,10 @@ class HitTestingService:
|
||||
}
|
||||
records.append(record)
|
||||
return {
|
||||
"query": {
|
||||
"content": query,
|
||||
},
|
||||
"query": {"content": query},
|
||||
"records": records,
|
||||
}
|
||||
return {"query": {"content": query}, "records": []}
|
||||
|
||||
@classmethod
|
||||
def hit_testing_args_check(cls, args):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
@@ -157,7 +157,7 @@ class MessageService:
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
rating: Optional[str],
|
||||
content: Optional[str],
|
||||
) -> MessageFeedback:
|
||||
):
|
||||
if not user:
|
||||
raise ValueError("user cannot be None")
|
||||
|
||||
@@ -264,6 +264,8 @@ class MessageService:
|
||||
)
|
||||
|
||||
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
|
||||
if not app_model_config:
|
||||
raise ValueError("did not find app model config")
|
||||
|
||||
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
|
||||
if suggested_questions_after_answer.get("enabled", False) is False:
|
||||
@@ -285,7 +287,7 @@ class MessageService:
|
||||
)
|
||||
|
||||
with measure_time() as timer:
|
||||
questions = LLMGenerator.generate_suggested_questions_after_answer(
|
||||
questions: list[Message] = LLMGenerator.generate_suggested_questions_after_answer(
|
||||
tenant_id=app_model.tenant_id, histories=histories
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities.provider_configuration import ProviderConfiguration
|
||||
@@ -88,11 +88,11 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Convert model type to ModelType
|
||||
model_type = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
|
||||
# Get provider model setting
|
||||
provider_model_setting = provider_configuration.get_provider_model_setting(
|
||||
model_type=model_type,
|
||||
model_type=model_type_enum,
|
||||
model=model,
|
||||
)
|
||||
|
||||
@@ -106,7 +106,7 @@ class ModelLoadBalancingService:
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.order_by(LoadBalancingModelConfig.created_at)
|
||||
@@ -124,7 +124,7 @@ class ModelLoadBalancingService:
|
||||
|
||||
if not inherit_config_exists:
|
||||
# Initialize the inherit configuration
|
||||
inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type)
|
||||
inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum)
|
||||
|
||||
# prepend the inherit configuration
|
||||
load_balancing_configs.insert(0, inherit_config)
|
||||
@@ -148,7 +148,7 @@ class ModelLoadBalancingService:
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
model_type=model_type_enum,
|
||||
config_id=load_balancing_config.id,
|
||||
)
|
||||
|
||||
@@ -214,7 +214,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Convert model type to ModelType
|
||||
model_type = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_model_config = (
|
||||
@@ -222,7 +222,7 @@ class ModelLoadBalancingService:
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
@@ -300,7 +300,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Convert model type to ModelType
|
||||
model_type = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
|
||||
if not isinstance(configs, list):
|
||||
raise ValueError("Invalid load balancing configs")
|
||||
@@ -310,7 +310,7 @@ class ModelLoadBalancingService:
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.all()
|
||||
@@ -359,7 +359,7 @@ class ModelLoadBalancingService:
|
||||
credentials = self._custom_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=model_type,
|
||||
model_type=model_type_enum,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
load_balancing_model_config=load_balancing_config,
|
||||
@@ -395,7 +395,7 @@ class ModelLoadBalancingService:
|
||||
credentials = self._custom_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=model_type,
|
||||
model_type=model_type_enum,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
validate=False,
|
||||
@@ -405,7 +405,7 @@ class ModelLoadBalancingService:
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_name=model,
|
||||
name=name,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
@@ -450,7 +450,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Convert model type to ModelType
|
||||
model_type = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
|
||||
load_balancing_model_config = None
|
||||
if config_id:
|
||||
@@ -460,7 +460,7 @@ class ModelLoadBalancingService:
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
@@ -474,7 +474,7 @@ class ModelLoadBalancingService:
|
||||
self._custom_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=model_type,
|
||||
model_type=model_type_enum,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
load_balancing_model_config=load_balancing_model_config,
|
||||
@@ -547,19 +547,14 @@ class ModelLoadBalancingService:
|
||||
|
||||
def _get_credential_schema(
|
||||
self, provider_configuration: ProviderConfiguration
|
||||
) -> ModelCredentialSchema | ProviderCredentialSchema:
|
||||
"""
|
||||
Get form schemas.
|
||||
:param provider_configuration: provider configuration
|
||||
:return:
|
||||
"""
|
||||
# Get credential form schemas from model credential schema or provider credential schema
|
||||
) -> Union[ModelCredentialSchema, ProviderCredentialSchema]:
|
||||
"""Get form schemas."""
|
||||
if provider_configuration.provider.model_credential_schema:
|
||||
credential_schema = provider_configuration.provider.model_credential_schema
|
||||
return provider_configuration.provider.model_credential_schema
|
||||
elif provider_configuration.provider.provider_credential_schema:
|
||||
return provider_configuration.provider.provider_credential_schema
|
||||
else:
|
||||
credential_schema = provider_configuration.provider.provider_credential_schema
|
||||
|
||||
return credential_schema
|
||||
raise ValueError("No credential schema found")
|
||||
|
||||
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Optional, cast
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
|
||||
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@@ -100,23 +100,15 @@ class ModelProviderService:
|
||||
ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
|
||||
]
|
||||
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str) -> dict:
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str):
|
||||
"""
|
||||
get provider credentials.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider:
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Get provider custom credentials from workspace
|
||||
return provider_configuration.get_custom_credentials(obfuscated=True)
|
||||
|
||||
def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
|
||||
@@ -176,7 +168,7 @@ class ModelProviderService:
|
||||
# Remove custom provider credentials.
|
||||
provider_configuration.delete_custom_credentials()
|
||||
|
||||
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict:
|
||||
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str):
|
||||
"""
|
||||
get model credentials.
|
||||
|
||||
@@ -287,7 +279,7 @@ class ModelProviderService:
|
||||
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
|
||||
|
||||
# Group models by provider
|
||||
provider_models = {}
|
||||
provider_models: dict[str, list[ModelWithProviderEntity]] = {}
|
||||
for model in models:
|
||||
if model.provider.provider not in provider_models:
|
||||
provider_models[model.provider.provider] = []
|
||||
@@ -362,7 +354,7 @@ class ModelProviderService:
|
||||
return []
|
||||
|
||||
# Call get_parameter_rules method of model instance to get model parameter rules
|
||||
return model_type_instance.get_parameter_rules(model=model, credentials=credentials)
|
||||
return list(model_type_instance.get_parameter_rules(model=model, credentials=credentials))
|
||||
|
||||
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
|
||||
"""
|
||||
@@ -422,6 +414,7 @@ class ModelProviderService:
|
||||
"""
|
||||
provider_instance = model_provider_factory.get_provider_instance(provider)
|
||||
provider_schema = provider_instance.get_provider_schema()
|
||||
file_name: str | None = None
|
||||
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
@@ -439,6 +432,8 @@ class ModelProviderService:
|
||||
file_name = provider_schema.icon_large.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_large.en_US
|
||||
if not file_name:
|
||||
return None, None
|
||||
|
||||
root_path = current_app.root_path
|
||||
provider_instance_path = os.path.dirname(
|
||||
@@ -524,7 +519,7 @@ class ModelProviderService:
|
||||
|
||||
def free_quota_submit(self, tenant_id: str, provider: str):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
|
||||
api_url = api_base_url + "/api/v1/providers/apply"
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
@@ -545,7 +540,7 @@ class ModelProviderService:
|
||||
|
||||
def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
|
||||
api_url = api_base_url + "/api/v1/providers/qualification-verify"
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.moderation.factory import ModerationFactory, ModerationOutputsResult
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
@@ -5,7 +7,7 @@ from models.model import App, AppModelConfig
|
||||
|
||||
class ModerationService:
|
||||
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
|
||||
app_model_config: AppModelConfig = None
|
||||
app_model_config: Optional[AppModelConfig] = None
|
||||
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, TraceAppConfig
|
||||
@@ -12,7 +14,7 @@ class OpsService:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config_data: TraceAppConfig = (
|
||||
trace_config_data: Optional[TraceAppConfig] = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
@@ -22,7 +24,10 @@ class OpsService:
|
||||
return None
|
||||
|
||||
# decrypt_token and obfuscated_token
|
||||
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
|
||||
tenant = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not tenant:
|
||||
return None
|
||||
tenant_id = tenant.tenant_id
|
||||
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(
|
||||
tenant_id, tracing_provider, trace_config_data.tracing_config
|
||||
)
|
||||
@@ -73,8 +78,9 @@ class OpsService:
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
default_config_instance = config_class(**tracing_config)
|
||||
for key in other_keys:
|
||||
# FIXME: ignore type error
|
||||
default_config_instance = config_class(**tracing_config) # type: ignore
|
||||
for key in other_keys: # type: ignore
|
||||
if key in tracing_config and tracing_config[key] == "":
|
||||
tracing_config[key] = getattr(default_config_instance, key, None)
|
||||
|
||||
@@ -92,7 +98,7 @@ class OpsService:
|
||||
project_url = None
|
||||
|
||||
# check if trace config already exists
|
||||
trace_config_data: TraceAppConfig = (
|
||||
trace_config_data: Optional[TraceAppConfig] = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
@@ -102,7 +108,10 @@ class OpsService:
|
||||
return None
|
||||
|
||||
# get tenant id
|
||||
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
|
||||
tenant = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not tenant:
|
||||
return None
|
||||
tenant_id = tenant.tenant_id
|
||||
tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
|
||||
if project_url:
|
||||
tracing_config["project_url"] = project_url
|
||||
@@ -139,7 +148,10 @@ class OpsService:
|
||||
return None
|
||||
|
||||
# get tenant id
|
||||
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
|
||||
tenant = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not tenant:
|
||||
return None
|
||||
tenant_id = tenant.tenant_id
|
||||
tracing_config = OpsTraceManager.encrypt_tracing_config(
|
||||
tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8")
|
||||
)
|
||||
|
||||
return cls.builtin_data
|
||||
return cls.builtin_data or {}
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
|
||||
@@ -50,8 +50,8 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
builtin_data = cls._get_builtin_data()
|
||||
return builtin_data.get("recommended_apps", {}).get(language)
|
||||
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
|
||||
return builtin_data.get("recommended_apps", {}).get(language, {})
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]:
|
||||
@@ -60,5 +60,5 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
:param app_id: App ID
|
||||
:return:
|
||||
"""
|
||||
builtin_data = cls._get_builtin_data()
|
||||
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
|
||||
return builtin_data.get("app_details", {}).get(app_id)
|
||||
|
||||
@@ -47,8 +47,8 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
|
||||
return response.json()
|
||||
data: dict = response.json()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
|
||||
@@ -63,7 +63,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
|
||||
|
||||
result = response.json()
|
||||
result: dict = response.json()
|
||||
|
||||
if "categories" in result:
|
||||
result["categories"] = sorted(result["categories"])
|
||||
|
||||
@@ -33,5 +33,5 @@ class RecommendedAppService:
|
||||
"""
|
||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
|
||||
result = retrieval_instance.get_recommend_app_detail(app_id)
|
||||
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
|
||||
return result
|
||||
|
||||
@@ -13,6 +13,8 @@ class SavedMessageService:
|
||||
def pagination_by_last_id(
|
||||
cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
raise ValueError("User is required")
|
||||
saved_messages = (
|
||||
db.session.query(SavedMessage)
|
||||
.filter(
|
||||
@@ -31,6 +33,8 @@ class SavedMessageService:
|
||||
|
||||
@classmethod
|
||||
def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
if not user:
|
||||
return
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.filter(
|
||||
@@ -59,6 +63,8 @@ class SavedMessageService:
|
||||
|
||||
@classmethod
|
||||
def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
if not user:
|
||||
return
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.filter(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_login import current_user # type: ignore
|
||||
from sqlalchemy import func
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -21,7 +21,7 @@ class TagService:
|
||||
if keyword:
|
||||
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
|
||||
query = query.group_by(Tag.id)
|
||||
results = query.order_by(Tag.created_at.desc()).all()
|
||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from httpx import get
|
||||
|
||||
@@ -28,12 +29,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ApiToolManageService:
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
|
||||
def parser_api_schema(schema: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
parse api schema to tool bundle
|
||||
"""
|
||||
try:
|
||||
warnings = {}
|
||||
warnings: dict[str, str] = {}
|
||||
try:
|
||||
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
|
||||
except Exception as e:
|
||||
@@ -68,13 +69,16 @@ class ApiToolManageService:
|
||||
),
|
||||
]
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"schema_type": schema_type,
|
||||
"parameters_schema": tool_bundles,
|
||||
"credentials_schema": credentials_schema,
|
||||
"warning": warnings,
|
||||
}
|
||||
return cast(
|
||||
Mapping,
|
||||
jsonable_encoder(
|
||||
{
|
||||
"schema_type": schema_type,
|
||||
"parameters_schema": tool_bundles,
|
||||
"credentials_schema": credentials_schema,
|
||||
"warning": warnings,
|
||||
}
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
@@ -129,7 +133,7 @@ class ApiToolManageService:
|
||||
raise ValueError(f"provider {provider_name} already exists")
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info = {}
|
||||
extra_info: dict[str, str] = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
@@ -262,9 +266,8 @@ class ApiToolManageService:
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"api provider {provider_name} does not exists")
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info = {}
|
||||
extra_info: dict[str, str] = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
@@ -416,7 +419,7 @@ class ApiToolManageService:
|
||||
provider_controller.validate_credentials_format(credentials)
|
||||
# get tool
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime_tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
@@ -454,7 +457,7 @@ class ApiToolManageService:
|
||||
|
||||
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
|
||||
|
||||
for tool in tools:
|
||||
for tool in tools or []:
|
||||
user_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
|
||||
|
||||
@@ -50,8 +50,8 @@ class BuiltinToolManageService:
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
|
||||
|
||||
result = []
|
||||
for tool in tools:
|
||||
result: list[UserTool] = []
|
||||
for tool in tools or []:
|
||||
result.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool=tool,
|
||||
@@ -217,6 +217,8 @@ class BuiltinToolManageService:
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
if provider_controller.identity is None:
|
||||
continue
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
@@ -229,7 +231,7 @@ class BuiltinToolManageService:
|
||||
ToolTransformService.repack_provider(user_builtin_provider)
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools:
|
||||
for tool in tools or []:
|
||||
user_builtin_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
@@ -35,7 +35,7 @@ class ToolTransformService:
|
||||
return url_prefix + "builtin/" + provider_name + "/icon"
|
||||
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
|
||||
try:
|
||||
return json.loads(icon)
|
||||
return cast(dict, json.loads(icon))
|
||||
except:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@@ -53,8 +53,11 @@ class ToolTransformService:
|
||||
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
|
||||
)
|
||||
elif isinstance(provider, UserToolProvider):
|
||||
provider.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
|
||||
provider.icon = cast(
|
||||
str,
|
||||
ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -66,6 +69,9 @@ class ToolTransformService:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
if provider_controller.identity is None:
|
||||
raise ValueError("provider identity is None")
|
||||
|
||||
result = UserToolProvider(
|
||||
id=provider_controller.identity.name,
|
||||
author=provider_controller.identity.author,
|
||||
@@ -93,7 +99,8 @@ class ToolTransformService:
|
||||
# get credentials schema
|
||||
schema = provider_controller.get_credentials_schema()
|
||||
for name, value in schema.items():
|
||||
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
|
||||
assert result.masked_credentials is not None, "masked credentials is None"
|
||||
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(str(value.type))
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
@@ -149,6 +156,9 @@ class ToolTransformService:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
if provider_controller.identity is None:
|
||||
raise ValueError("provider identity is None")
|
||||
|
||||
return UserToolProvider(
|
||||
id=provider_controller.provider_id,
|
||||
author=provider_controller.identity.author,
|
||||
@@ -180,6 +190,8 @@ class ToolTransformService:
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
username = "Anonymous"
|
||||
if db_provider.user is None:
|
||||
raise ValueError(f"user is None for api provider {db_provider.id}")
|
||||
try:
|
||||
username = db_provider.user.name
|
||||
except Exception as e:
|
||||
@@ -256,19 +268,25 @@ class ToolTransformService:
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
if tool.identity is None:
|
||||
raise ValueError("tool identity is None")
|
||||
|
||||
return UserTool(
|
||||
author=tool.identity.author,
|
||||
name=tool.identity.name,
|
||||
label=tool.identity.label,
|
||||
description=tool.description.human,
|
||||
description=I18nObject(
|
||||
en_US=tool.description.human if tool.description else "",
|
||||
zh_Hans=tool.description.human if tool.description else "",
|
||||
),
|
||||
parameters=current_parameters,
|
||||
labels=labels,
|
||||
)
|
||||
if isinstance(tool, ApiToolBundle):
|
||||
return UserTool(
|
||||
author=tool.author,
|
||||
name=tool.operation_id,
|
||||
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
|
||||
name=tool.operation_id or "",
|
||||
label=I18nObject(en_US=tool.operation_id or "", zh_Hans=tool.operation_id or ""),
|
||||
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
|
||||
parameters=tool.parameters,
|
||||
labels=labels,
|
||||
|
||||
@@ -6,8 +6,10 @@ from typing import Any, Optional
|
||||
from sqlalchemy import or_
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from extensions.ext_database import db
|
||||
@@ -32,7 +34,7 @@ class WorkflowToolManageService:
|
||||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: Mapping[str, Any],
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: Optional[list[str]] = None,
|
||||
) -> dict:
|
||||
@@ -97,7 +99,7 @@ class WorkflowToolManageService:
|
||||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[dict],
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: Optional[list[str]] = None,
|
||||
) -> dict:
|
||||
@@ -131,7 +133,7 @@ class WorkflowToolManageService:
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} already exists")
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider = (
|
||||
workflow_tool_provider: Optional[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
@@ -140,14 +142,14 @@ class WorkflowToolManageService:
|
||||
if workflow_tool_provider is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
app: App = (
|
||||
app: Optional[App] = (
|
||||
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
|
||||
|
||||
workflow: Workflow = app.workflow
|
||||
workflow: Optional[Workflow] = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
|
||||
|
||||
@@ -193,7 +195,7 @@ class WorkflowToolManageService:
|
||||
# skip deleted tools
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(tools)
|
||||
labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)])
|
||||
|
||||
result = []
|
||||
|
||||
@@ -202,10 +204,11 @@ class WorkflowToolManageService:
|
||||
provider_controller=tool, labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
ToolTransformService.repack_provider(user_tool_provider)
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
continue
|
||||
user_tool_provider.tools = [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=labels.get(tool.provider_id, []))
|
||||
]
|
||||
result.append(user_tool_provider)
|
||||
|
||||
@@ -236,7 +239,7 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db_tool: Optional[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
@@ -245,13 +248,19 @@ class WorkflowToolManageService:
|
||||
if db_tool is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
workflow_app: Optional[App] = (
|
||||
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
return {
|
||||
"name": db_tool.name,
|
||||
"label": db_tool.label,
|
||||
@@ -261,9 +270,9 @@ class WorkflowToolManageService:
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
"synced": workflow_app.workflow.version == db_tool.version,
|
||||
"synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
@@ -276,7 +285,7 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db_tool: Optional[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.first()
|
||||
@@ -285,12 +294,17 @@ class WorkflowToolManageService:
|
||||
if db_tool is None:
|
||||
raise ValueError(f"Tool {workflow_app_id} not found")
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
workflow_app: Optional[App] = (
|
||||
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
raise ValueError(f"Tool {workflow_app_id} not found")
|
||||
|
||||
return {
|
||||
"name": db_tool.name,
|
||||
@@ -301,14 +315,14 @@ class WorkflowToolManageService:
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
"synced": workflow_app.workflow.version == db_tool.version,
|
||||
"synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
|
||||
"""
|
||||
List workflow tool provider tools.
|
||||
:param user_id: the user id
|
||||
@@ -316,7 +330,7 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db_tool: Optional[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
@@ -326,9 +340,8 @@ class WorkflowToolManageService:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
)
|
||||
]
|
||||
return [ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool))]
|
||||
|
||||
@@ -26,6 +26,8 @@ class WebConversationService:
|
||||
pinned: Optional[bool] = None,
|
||||
sort_by="-updated_at",
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
raise ValueError("User is required")
|
||||
include_ids = None
|
||||
exclude_ids = None
|
||||
if pinned is not None and user:
|
||||
@@ -59,6 +61,8 @@ class WebConversationService:
|
||||
|
||||
@classmethod
|
||||
def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
|
||||
if not user:
|
||||
return
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.filter(
|
||||
@@ -89,6 +93,8 @@ class WebConversationService:
|
||||
|
||||
@classmethod
|
||||
def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
|
||||
if not user:
|
||||
return
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.filter(
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from flask_login import current_user
|
||||
from flask_login import current_user # type: ignore
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
@@ -23,9 +24,9 @@ class WebsiteService:
|
||||
|
||||
@classmethod
|
||||
def crawl_url(cls, args: dict) -> dict:
|
||||
provider = args.get("provider")
|
||||
provider = args.get("provider", "")
|
||||
url = args.get("url")
|
||||
options = args.get("options")
|
||||
options = args.get("options", "")
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
@@ -164,16 +165,18 @@ class WebsiteService:
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
# FIXME data is redefine too many times here, use Any to ease the type checking, fix it later
|
||||
data: Any
|
||||
if provider == "firecrawl":
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
data = storage.load_once(file_key)
|
||||
if data:
|
||||
data = json.loads(data.decode("utf-8"))
|
||||
d = storage.load_once(file_key)
|
||||
if d:
|
||||
data = json.loads(d.decode("utf-8"))
|
||||
else:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
@@ -183,22 +186,17 @@ class WebsiteService:
|
||||
if data:
|
||||
for item in data:
|
||||
if item.get("source_url") == url:
|
||||
return item
|
||||
return dict(item)
|
||||
return None
|
||||
elif provider == "jinareader":
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
data = storage.load_once(file_key)
|
||||
if data:
|
||||
data = json.loads(data.decode("utf-8"))
|
||||
elif not job_id:
|
||||
if not job_id:
|
||||
response = requests.get(
|
||||
f"https://r.jina.ai/{url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return response.json().get("data")
|
||||
return dict(response.json().get("data", {}))
|
||||
else:
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
response = requests.post(
|
||||
@@ -218,12 +216,13 @@ class WebsiteService:
|
||||
data = response.json().get("data", {})
|
||||
for item in data.get("processed", {}).values():
|
||||
if item.get("data", {}).get("url") == url:
|
||||
return item.get("data", {})
|
||||
return dict(item.get("data", {}))
|
||||
return None
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@@ -101,7 +101,7 @@ class WorkflowConverter:
|
||||
app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
|
||||
# init workflow graph
|
||||
graph = {"nodes": [], "edges": []}
|
||||
graph: dict[str, Any] = {"nodes": [], "edges": []}
|
||||
|
||||
# Convert list:
|
||||
# - variables -> start
|
||||
@@ -118,7 +118,7 @@ class WorkflowConverter:
|
||||
graph["nodes"].append(start_node)
|
||||
|
||||
# convert to http request node
|
||||
external_data_variable_node_mapping = {}
|
||||
external_data_variable_node_mapping: dict[str, str] = {}
|
||||
if app_config.external_data_variables:
|
||||
http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
@@ -199,15 +199,16 @@ class WorkflowConverter:
|
||||
return workflow
|
||||
|
||||
def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
app_mode_enum = AppMode.value_of(app_model.mode)
|
||||
app_config: EasyUIBasedAppConfig
|
||||
if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
app_model.mode = AppMode.AGENT_CHAT.value
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
elif app_mode == AppMode.CHAT:
|
||||
elif app_mode_enum == AppMode.CHAT:
|
||||
app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
elif app_mode_enum == AppMode.COMPLETION:
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
@@ -302,7 +303,7 @@ class WorkflowConverter:
|
||||
nodes.append(http_request_node)
|
||||
|
||||
# append code node for response body parsing
|
||||
code_node = {
|
||||
code_node: dict[str, Any] = {
|
||||
"id": f"code_{index}",
|
||||
"position": None,
|
||||
"data": {
|
||||
@@ -401,6 +402,7 @@ class WorkflowConverter:
|
||||
)
|
||||
|
||||
role_prefix = None
|
||||
prompts: Any = None
|
||||
|
||||
# Chat Model
|
||||
if model_config.mode == LLMMode.CHAT.value:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
@@ -92,7 +94,7 @@ class WorkflowRunService:
|
||||
|
||||
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
|
||||
|
||||
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun:
|
||||
def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
|
||||
"""
|
||||
Get workflow run detail
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
@@ -242,7 +242,7 @@ class WorkflowService:
|
||||
raise ValueError("Node run failed with no run result")
|
||||
# single step debug mode error handling return
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
|
||||
node_error_args = {
|
||||
node_error_args: dict[str, Any] = {
|
||||
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
"error": node_run_result.error,
|
||||
"inputs": node_run_result.inputs,
|
||||
@@ -338,7 +338,7 @@ class WorkflowService:
|
||||
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
|
||||
|
||||
# convert to workflow
|
||||
new_app = workflow_converter.convert_to_workflow(
|
||||
new_app: App = workflow_converter.convert_to_workflow(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
name=args.get("name", "Default Name"),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flask_login import current_user
|
||||
from flask_login import current_user # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
@@ -29,6 +29,7 @@ class WorkspaceService:
|
||||
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
assert tenant_account_join is not None, "TenantAccountJoin not found"
|
||||
tenant_info["role"] = tenant_account_join.role
|
||||
|
||||
can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
|
||||
|
||||
Reference in New Issue
Block a user