feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -6,7 +6,7 @@ import secrets
import uuid
from datetime import UTC, datetime, timedelta
from hashlib import sha256
from typing import Any, Optional
from typing import Any, Optional, cast
from pydantic import BaseModel
from sqlalchemy import func
@@ -119,7 +119,7 @@ class AccountService:
account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return account
return cast(Account, account)
@staticmethod
def get_account_jwt_token(account: Account) -> str:
@@ -132,7 +132,7 @@ class AccountService:
"sub": "Console API Passport",
}
token = PassportService().issue(payload)
token: str = PassportService().issue(payload)
return token
@staticmethod
@@ -164,7 +164,7 @@ class AccountService:
db.session.commit()
return account
return cast(Account, account)
@staticmethod
def update_account_password(account, password, new_password):
@@ -347,6 +347,8 @@ class AccountService:
language: Optional[str] = "en-US",
):
account_email = account.email if account else email
if account_email is None:
raise ValueError("Email must be provided.")
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import PasswordResetRateLimitExceededError
@@ -377,6 +379,8 @@ class AccountService:
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
):
if email is None:
raise ValueError("Email must be provided.")
if cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
@@ -669,7 +673,7 @@ class TenantService:
@staticmethod
def get_tenant_count() -> int:
"""Get tenant count"""
return db.session.query(func.count(Tenant.id)).scalar()
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
@staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
@@ -733,10 +737,10 @@ class TenantService:
db.session.commit()
@staticmethod
def get_custom_config(tenant_id: str) -> None:
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).one_or_404()
def get_custom_config(tenant_id: str) -> dict:
tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404()
return tenant.custom_config_dict
return cast(dict, tenant.custom_config_dict)
class RegisterService:
@@ -807,7 +811,7 @@ class RegisterService:
account.status = AccountStatus.ACTIVE.value if not status else status.value
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
if open_id is not None or provider is not None:
if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
if FeatureService.get_system_features().is_allow_create_workspace:
@@ -828,10 +832,11 @@ class RegisterService:
@classmethod
def invite_new_member(
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None
) -> str:
"""Invite new member"""
account = Account.query.filter_by(email=email).first()
assert inviter is not None, "Inviter must be provided."
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")
@@ -894,7 +899,9 @@ class RegisterService:
redis_client.delete(cls._get_invitation_token_key(token))
@classmethod
def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[dict[str, Any]]:
def get_invitation_if_token_valid(
cls, workspace_id: Optional[str], email: str, token: str
) -> Optional[dict[str, Any]]:
invitation_data = cls._get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None
@@ -953,7 +960,7 @@ class RegisterService:
if not data:
return None
invitation = json.loads(data)
invitation: dict = json.loads(data)
return invitation

View File

@@ -48,6 +48,8 @@ class AdvancedPromptTemplateService:
return cls.get_chat_prompt(
copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
)
# default return empty dict
return {}
@classmethod
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
@@ -91,3 +93,5 @@ class AdvancedPromptTemplateService:
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
# default return empty dict
return {}

View File

@@ -1,5 +1,7 @@
from typing import Optional
import pytz
from flask_login import current_user
from flask_login import current_user # type: ignore
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
from core.tools.tool_manager import ToolManager
@@ -14,7 +16,7 @@ class AgentService:
"""
Service to get agent logs
"""
conversation: Conversation = (
conversation: Optional[Conversation] = (
db.session.query(Conversation)
.filter(
Conversation.id == conversation_id,
@@ -26,7 +28,7 @@ class AgentService:
if not conversation:
raise ValueError(f"Conversation not found: {conversation_id}")
message: Message = (
message: Optional[Message] = (
db.session.query(Message)
.filter(
Message.id == message_id,
@@ -72,7 +74,10 @@ class AgentService:
}
agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict())
agent_tools = agent_config.tools
if not agent_config:
return result
agent_tools = agent_config.tools or []
def find_agent_tool(tool_name: str):
for agent_tool in agent_tools:

View File

@@ -1,8 +1,9 @@
import datetime
import uuid
from typing import cast
import pandas as pd
from flask_login import current_user
from flask_login import current_user # type: ignore
from sqlalchemy import or_
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@@ -71,7 +72,7 @@ class AppAnnotationService:
app_id,
annotation_setting.collection_binding_id,
)
return annotation
return cast(MessageAnnotation, annotation)
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
@@ -124,8 +125,7 @@ class AppAnnotationService:
raise NotFound("App not found")
if keyword:
annotations = (
db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
.filter(
or_(
MessageAnnotation.question.ilike("%{}%".format(keyword)),
@@ -137,8 +137,7 @@ class AppAnnotationService:
)
else:
annotations = (
db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
@@ -327,8 +326,7 @@ class AppAnnotationService:
raise NotFound("Annotation not found")
annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory)
.filter(
AppAnnotationHitHistory.query.filter(
AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)

View File

@@ -1,7 +1,7 @@
import logging
import uuid
from enum import StrEnum
from typing import Optional
from typing import Optional, cast
from uuid import uuid4
import yaml
@@ -103,7 +103,7 @@ class AppDslService:
raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content
content = ""
content: bytes | str = b""
if mode == ImportMode.YAML_URL:
if not yaml_url:
return Import(
@@ -136,7 +136,7 @@ class AppDslService:
)
try:
content = content.decode("utf-8")
content = cast(bytes, content).decode("utf-8")
except UnicodeDecodeError as e:
return Import(
id=import_id,
@@ -362,6 +362,9 @@ class AppDslService:
app.icon_background = icon_background or app_data.get("icon_background", app.icon_background)
app.updated_by = account.id
else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")
# Create new app
app = App()
app.id = str(uuid4())

View File

@@ -118,7 +118,7 @@ class AppGenerateService:
@staticmethod
def _get_max_active_requests(app_model: App) -> int:
max_active_requests = app_model.max_active_requests
if app_model.max_active_requests is None:
if max_active_requests is None:
max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
return max_active_requests
@@ -150,7 +150,7 @@ class AppGenerateService:
message_id: str,
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[dict, Generator]:
) -> Union[Mapping, Generator]:
"""
Generate more like this
:param app_model: app model

View File

@@ -1,9 +1,9 @@
import json
import logging
from datetime import UTC, datetime
from typing import cast
from typing import Optional, cast
from flask_login import current_user
from flask_login import current_user # type: ignore
from flask_sqlalchemy.pagination import Pagination
from configs import dify_config
@@ -83,7 +83,7 @@ class AppService:
# get default model instance
try:
model_instance = model_manager.get_default_model_instance(
tenant_id=account.current_tenant_id, model_type=ModelType.LLM
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
)
except (ProviderTokenNotInitError, LLMBadRequestError):
model_instance = None
@@ -100,6 +100,8 @@ class AppService:
else:
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema is None:
raise ValueError(f"model schema not found for model {model_instance.model}")
default_model_dict = {
"provider": model_instance.provider,
@@ -109,7 +111,7 @@ class AppService:
}
else:
provider, model = model_manager.get_default_provider_model_name(
tenant_id=account.current_tenant_id, model_type=ModelType.LLM
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
)
default_model_config["model"]["provider"] = provider
default_model_config["model"]["name"] = model
@@ -314,7 +316,7 @@ class AppService:
"""
app_mode = AppMode.value_of(app_model.mode)
meta = {"tool_icons": {}}
meta: dict = {"tool_icons": {}}
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
@@ -336,7 +338,7 @@ class AppService:
}
)
else:
app_model_config: AppModelConfig = app_model.app_model_config
app_model_config: Optional[AppModelConfig] = app_model.app_model_config
if not app_model_config:
return meta
@@ -352,16 +354,18 @@ class AppService:
keys = list(tool.keys())
if len(keys) >= 4:
# current tool standard
provider_type = tool.get("provider_type")
provider_id = tool.get("provider_id")
tool_name = tool.get("tool_name")
provider_type = tool.get("provider_type", "")
provider_id = tool.get("provider_id", "")
tool_name = tool.get("tool_name", "")
if provider_type == "builtin":
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
elif provider_type == "api":
try:
provider: ApiToolProvider = (
provider: Optional[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first()
)
if provider is None:
raise ValueError(f"provider not found for tool {tool_name}")
meta["tool_icons"][tool_name] = json.loads(provider.icon)
except:
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}

View File

@@ -110,6 +110,8 @@ class AudioService:
voices = model_instance.get_tts_voices()
if voices:
voice = voices[0].get("value")
if not voice:
raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.")
@@ -121,6 +123,8 @@ class AudioService:
if message_id:
message = db.session.query(Message).filter(Message.id == message_id).first()
if message is None:
return None
if message.answer == "" and message.status == "normal":
return None
@@ -130,6 +134,8 @@ class AudioService:
return Response(stream_with_context(response), content_type="audio/mpeg")
return response
else:
if not text:
raise ValueError("Text is required")
response = invoke_tts(text, app_model, voice)
if isinstance(response, Generator):
return Response(stream_with_context(response), content_type="audio/mpeg")

View File

@@ -11,8 +11,8 @@ class FirecrawlAuth(ApiKeyAuthBase):
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
self.api_key = credentials.get("config").get("api_key", None)
self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev")
self.api_key = credentials.get("config", {}).get("api_key", None)
self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev")
if not self.api_key:
raise ValueError("No API key provided")

View File

@@ -11,7 +11,7 @@ class JinaAuth(ApiKeyAuthBase):
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
self.api_key = credentials.get("config").get("api_key", None)
self.api_key = credentials.get("config", {}).get("api_key", None)
if not self.api_key:
raise ValueError("No API key provided")

View File

@@ -11,7 +11,7 @@ class JinaAuth(ApiKeyAuthBase):
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
self.api_key = credentials.get("config").get("api_key", None)
self.api_key = credentials.get("config", {}).get("api_key", None)
if not self.api_key:
raise ValueError("No API key provided")

View File

@@ -1,4 +1,5 @@
import os
from typing import Optional
import httpx
from tenacity import retry, retry_if_not_exception_type, stop_before_delay, wait_fixed
@@ -58,11 +59,14 @@ class BillingService:
def is_tenant_owner_or_admin(current_user):
tenant_id = current_user.current_tenant_id
join = (
join: Optional[TenantAccountJoin] = (
db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first()
)
if not join:
raise ValueError("Tenant account join not found")
if not TenantAccountRole.is_privileged_role(join.role):
raise ValueError("Only team owner or team admin can perform this action")

View File

@@ -72,8 +72,7 @@ class ConversationService:
sort_direction=sort_direction,
reference_conversation=current_page_last_conversation,
)
count_stmt = stmt.where(rest_filter_condition)
count_stmt = select(func.count()).select_from(count_stmt.subquery())
count_stmt = select(func.count()).select_from(stmt.where(rest_filter_condition).subquery())
rest_count = session.scalar(count_stmt) or 0
if rest_count > 0:
has_more = True

View File

@@ -6,7 +6,7 @@ import time
import uuid
from typing import Any, Optional
from flask_login import current_user
from flask_login import current_user # type: ignore
from sqlalchemy import func
from werkzeug.exceptions import NotFound
@@ -186,8 +186,9 @@ class DatasetService:
return dataset
@staticmethod
def get_dataset(dataset_id) -> Dataset:
return Dataset.query.filter_by(id=dataset_id).first()
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first()
return dataset
@staticmethod
def check_dataset_model_setting(dataset):
@@ -228,6 +229,8 @@ class DatasetService:
@staticmethod
def update_dataset(dataset_id, data, user):
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise ValueError("Dataset not found")
DatasetService.check_dataset_permission(dataset, user)
if dataset.provider == "external":
@@ -371,7 +374,13 @@ class DatasetService:
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None):
def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None):
if not dataset:
raise ValueError("Dataset not found")
if not user:
raise ValueError("User not found")
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
if dataset.created_by != user.id:
raise NoPermissionError("You do not have permission to access this dataset.")
@@ -765,6 +774,11 @@ class DocumentService:
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
else:
logging.warn(
f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule"
)
return
db.session.add(dataset_process_rule)
db.session.commit()
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
@@ -1009,9 +1023,10 @@ class DocumentService:
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
db.session.add(dataset_process_rule)
db.session.commit()
document.dataset_process_rule_id = dataset_process_rule.id
if dataset_process_rule is not None:
db.session.add(dataset_process_rule)
db.session.commit()
document.dataset_process_rule_id = dataset_process_rule.id
# update document data source
if document_data.get("data_source"):
file_name = ""
@@ -1554,7 +1569,7 @@ class SegmentService:
segment.word_count = len(content)
if document.doc_form == "qa_model":
segment.answer = segment_update_entity.answer
segment.word_count += len(segment_update_entity.answer)
segment.word_count += len(segment_update_entity.answer or "")
word_count_change = segment.word_count - word_count_change
if segment_update_entity.keywords:
segment.keywords = segment_update_entity.keywords
@@ -1569,7 +1584,8 @@ class SegmentService:
db.session.add(document)
# update segment index task
if segment_update_entity.enabled:
VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset)
keywords = segment_update_entity.keywords or []
VectorService.create_segments_vector([keywords], [segment], dataset)
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
@@ -1601,7 +1617,7 @@ class SegmentService:
segment.disabled_by = None
if document.doc_form == "qa_model":
segment.answer = segment_update_entity.answer
segment.word_count += len(segment_update_entity.answer)
segment.word_count += len(segment_update_entity.answer or "")
word_count_change = segment.word_count - word_count_change
# update document word count
if word_count_change != 0:
@@ -1619,8 +1635,8 @@ class SegmentService:
segment.status = "error"
segment.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
return segment
new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
return new_segment
@classmethod
def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
@@ -1680,6 +1696,8 @@ class DatasetCollectionBindingService:
.order_by(DatasetCollectionBinding.created_at)
.first()
)
if not dataset_collection_binding:
raise ValueError("Dataset collection binding not found")
return dataset_collection_binding

View File

@@ -8,8 +8,8 @@ class EnterpriseRequest:
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
proxies = {
"http": None,
"https": None,
"http": "",
"https": "",
}
@classmethod

View File

@@ -4,7 +4,11 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict
from configs import dify_config
from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.entities.model_entities import (
ModelWithProviderEntity,
ProviderModelWithStatusEntity,
SimpleModelProviderEntity,
)
from core.entities.provider_entities import QuotaConfiguration
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
@@ -148,7 +152,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
Model with provider entity.
"""
provider: SimpleProviderEntityResponse
provider: SimpleModelProviderEntity
def __init__(self, model: ModelWithProviderEntity) -> None:
super().__init__(**model.model_dump())

View File

@@ -1,7 +1,7 @@
import json
from copy import deepcopy
from datetime import UTC, datetime
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast
import httpx
import validators
@@ -45,7 +45,10 @@ class ExternalDatasetService:
@staticmethod
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis:
ExternalDatasetService.check_endpoint_and_api_key(args.get("settings"))
settings = args.get("settings")
if settings is None:
raise ValueError("settings is required")
ExternalDatasetService.check_endpoint_and_api_key(settings)
external_knowledge_api = ExternalKnowledgeApis(
tenant_id=tenant_id,
created_by=user_id,
@@ -86,11 +89,16 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
return ExternalKnowledgeApis.query.filter_by(id=external_knowledge_api_id).first()
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id
).first()
if external_knowledge_api is None:
raise ValueError("api template not found")
return external_knowledge_api
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
if external_knowledge_api is None:
@@ -127,7 +135,7 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
if not external_knowledge_binding:
@@ -163,8 +171,9 @@ class ExternalDatasetService:
"follow_redirects": True,
}
response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs)
response: httpx.Response = getattr(ssrf_proxy, settings.request_method)(
data=json.dumps(settings.params), files=files, **kwargs
)
return response
@staticmethod
@@ -265,15 +274,15 @@ class ExternalDatasetService:
"knowledge_id": external_knowledge_binding.external_knowledge_id,
}
external_knowledge_api_setting = {
"url": f"{settings.get('endpoint')}/retrieval",
"request_method": "post",
"headers": headers,
"params": request_params,
}
response = ExternalDatasetService.process_external_api(
ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None
ExternalKnowledgeApiSetting(
url=f"{settings.get('endpoint')}/retrieval",
request_method="post",
headers=headers,
params=request_params,
),
None,
)
if response.status_code == 200:
return response.json().get("records", [])
return cast(list[Any], response.json().get("records", []))
return []

View File

@@ -3,7 +3,7 @@ import hashlib
import uuid
from typing import Any, Literal, Union
from flask_login import current_user
from flask_login import current_user # type: ignore
from werkzeug.exceptions import NotFound
from configs import dify_config
@@ -61,14 +61,14 @@ class FileService:
# end_user
current_tenant_id = user.tenant_id
file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension
file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
# save file to storage
storage.save(file_key, content)
# save file to db
upload_file = UploadFile(
tenant_id=current_tenant_id,
tenant_id=current_tenant_id or "",
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=filename,

View File

@@ -1,5 +1,6 @@
import logging
import time
from typing import Any
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
@@ -24,7 +25,7 @@ class HitTestingService:
dataset: Dataset,
query: str,
account: Account,
retrieval_model: dict,
retrieval_model: Any, # FIXME drop this any
external_retrieval_model: dict,
limit: int = 10,
) -> dict:
@@ -68,7 +69,7 @@ class HitTestingService:
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(dataset, query, all_documents)
return dict(cls.compact_retrieve_response(dataset, query, all_documents))
@classmethod
def external_retrieve(
@@ -102,13 +103,16 @@ class HitTestingService:
db.session.add(dataset_query)
db.session.commit()
return cls.compact_external_retrieve_response(dataset, query, all_documents)
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
records = []
for document in documents:
if document.metadata is None:
continue
index_node_id = document.metadata["doc_id"]
segment = (
@@ -140,7 +144,7 @@ class HitTestingService:
}
@classmethod
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list):
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]:
records = []
if dataset.provider == "external":
for document in documents:
@@ -152,11 +156,10 @@ class HitTestingService:
}
records.append(record)
return {
"query": {
"content": query,
},
"query": {"content": query},
"records": records,
}
return {"query": {"content": query}, "records": []}
@classmethod
def hit_testing_args_check(cls, args):

View File

@@ -1,4 +1,4 @@
import boto3
import boto3 # type: ignore
from configs import dify_config

View File

@@ -157,7 +157,7 @@ class MessageService:
user: Optional[Union[Account, EndUser]],
rating: Optional[str],
content: Optional[str],
) -> MessageFeedback:
):
if not user:
raise ValueError("user cannot be None")
@@ -264,6 +264,8 @@ class MessageService:
)
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
if not app_model_config:
raise ValueError("did not find app model config")
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
if suggested_questions_after_answer.get("enabled", False) is False:
@@ -285,7 +287,7 @@ class MessageService:
)
with measure_time() as timer:
questions = LLMGenerator.generate_suggested_questions_after_answer(
questions: list[Message] = LLMGenerator.generate_suggested_questions_after_answer(
tenant_id=app_model.tenant_id, histories=histories
)

View File

@@ -2,7 +2,7 @@ import datetime
import json
import logging
from json import JSONDecodeError
from typing import Optional
from typing import Optional, Union
from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
@@ -88,11 +88,11 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type = ModelType.value_of(model_type)
model_type_enum = ModelType.value_of(model_type)
# Get provider model setting
provider_model_setting = provider_configuration.get_provider_model_setting(
model_type=model_type,
model_type=model_type_enum,
model=model,
)
@@ -106,7 +106,7 @@ class ModelLoadBalancingService:
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.order_by(LoadBalancingModelConfig.created_at)
@@ -124,7 +124,7 @@ class ModelLoadBalancingService:
if not inherit_config_exists:
# Initialize the inherit configuration
inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type)
inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum)
# prepend the inherit configuration
load_balancing_configs.insert(0, inherit_config)
@@ -148,7 +148,7 @@ class ModelLoadBalancingService:
tenant_id=tenant_id,
provider=provider,
model=model,
model_type=model_type,
model_type=model_type_enum,
config_id=load_balancing_config.id,
)
@@ -214,7 +214,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type = ModelType.value_of(model_type)
model_type_enum = ModelType.value_of(model_type)
# Get load balancing configurations
load_balancing_model_config = (
@@ -222,7 +222,7 @@ class ModelLoadBalancingService:
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
@@ -300,7 +300,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type = ModelType.value_of(model_type)
model_type_enum = ModelType.value_of(model_type)
if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs")
@@ -310,7 +310,7 @@ class ModelLoadBalancingService:
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.all()
@@ -359,7 +359,7 @@ class ModelLoadBalancingService:
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type,
model_type=model_type_enum,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_config,
@@ -395,7 +395,7 @@ class ModelLoadBalancingService:
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type,
model_type=model_type_enum,
model=model,
credentials=credentials,
validate=False,
@@ -405,7 +405,7 @@ class ModelLoadBalancingService:
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type.to_origin_model_type(),
model_type=model_type_enum.to_origin_model_type(),
model_name=model,
name=name,
encrypted_config=json.dumps(credentials),
@@ -450,7 +450,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type = ModelType.value_of(model_type)
model_type_enum = ModelType.value_of(model_type)
load_balancing_model_config = None
if config_id:
@@ -460,7 +460,7 @@ class ModelLoadBalancingService:
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
@@ -474,7 +474,7 @@ class ModelLoadBalancingService:
self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type,
model_type=model_type_enum,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_model_config,
@@ -547,19 +547,14 @@ class ModelLoadBalancingService:
def _get_credential_schema(
self, provider_configuration: ProviderConfiguration
) -> ModelCredentialSchema | ProviderCredentialSchema:
"""
Get form schemas.
:param provider_configuration: provider configuration
:return:
"""
# Get credential form schemas from model credential schema or provider credential schema
) -> Union[ModelCredentialSchema, ProviderCredentialSchema]:
"""Get form schemas."""
if provider_configuration.provider.model_credential_schema:
credential_schema = provider_configuration.provider.model_credential_schema
return provider_configuration.provider.model_credential_schema
elif provider_configuration.provider.provider_credential_schema:
return provider_configuration.provider.provider_credential_schema
else:
credential_schema = provider_configuration.provider.provider_credential_schema
return credential_schema
raise ValueError("No credential schema found")
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
"""

View File

@@ -7,7 +7,7 @@ from typing import Optional, cast
import requests
from flask import current_app
from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@@ -100,23 +100,15 @@ class ModelProviderService:
ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
]
def get_provider_credentials(self, tenant_id: str, provider: str) -> dict:
def get_provider_credentials(self, tenant_id: str, provider: str):
"""
get provider credentials.
:param tenant_id:
:param provider:
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Get provider custom credentials from workspace
return provider_configuration.get_custom_credentials(obfuscated=True)
def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
@@ -176,7 +168,7 @@ class ModelProviderService:
# Remove custom provider credentials.
provider_configuration.delete_custom_credentials()
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict:
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str):
"""
get model credentials.
@@ -287,7 +279,7 @@ class ModelProviderService:
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
# Group models by provider
provider_models = {}
provider_models: dict[str, list[ModelWithProviderEntity]] = {}
for model in models:
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []
@@ -362,7 +354,7 @@ class ModelProviderService:
return []
# Call get_parameter_rules method of model instance to get model parameter rules
return model_type_instance.get_parameter_rules(model=model, credentials=credentials)
return list(model_type_instance.get_parameter_rules(model=model, credentials=credentials))
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
"""
@@ -422,6 +414,7 @@ class ModelProviderService:
"""
provider_instance = model_provider_factory.get_provider_instance(provider)
provider_schema = provider_instance.get_provider_schema()
file_name: str | None = None
if icon_type.lower() == "icon_small":
if not provider_schema.icon_small:
@@ -439,6 +432,8 @@ class ModelProviderService:
file_name = provider_schema.icon_large.zh_Hans
else:
file_name = provider_schema.icon_large.en_US
if not file_name:
return None, None
root_path = current_app.root_path
provider_instance_path = os.path.dirname(
@@ -524,7 +519,7 @@ class ModelProviderService:
def free_quota_submit(self, tenant_id: str, provider: str):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
api_url = api_base_url + "/api/v1/providers/apply"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
@@ -545,7 +540,7 @@ class ModelProviderService:
def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
api_url = api_base_url + "/api/v1/providers/qualification-verify"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}

View File

@@ -1,3 +1,5 @@
from typing import Optional
from core.moderation.factory import ModerationFactory, ModerationOutputsResult
from extensions.ext_database import db
from models.model import App, AppModelConfig
@@ -5,7 +7,7 @@ from models.model import App, AppModelConfig
class ModerationService:
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
app_model_config: AppModelConfig = None
app_model_config: Optional[AppModelConfig] = None
app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()

View File

@@ -1,3 +1,5 @@
from typing import Optional
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
from extensions.ext_database import db
from models.model import App, TraceAppConfig
@@ -12,7 +14,7 @@ class OpsService:
:param tracing_provider: tracing provider
:return:
"""
trace_config_data: TraceAppConfig = (
trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
@@ -22,7 +24,10 @@ class OpsService:
return None
# decrypt_token and obfuscated_token
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
tenant = db.session.query(App).filter(App.id == app_id).first()
if not tenant:
return None
tenant_id = tenant.tenant_id
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config
)
@@ -73,8 +78,9 @@ class OpsService:
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["other_keys"],
)
default_config_instance = config_class(**tracing_config)
for key in other_keys:
# FIXME: ignore type error
default_config_instance = config_class(**tracing_config) # type: ignore
for key in other_keys: # type: ignore
if key in tracing_config and tracing_config[key] == "":
tracing_config[key] = getattr(default_config_instance, key, None)
@@ -92,7 +98,7 @@ class OpsService:
project_url = None
# check if trace config already exists
trace_config_data: TraceAppConfig = (
trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
@@ -102,7 +108,10 @@ class OpsService:
return None
# get tenant id
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
tenant = db.session.query(App).filter(App.id == app_id).first()
if not tenant:
return None
tenant_id = tenant.tenant_id
tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
if project_url:
tracing_config["project_url"] = project_url
@@ -139,7 +148,10 @@ class OpsService:
return None
# get tenant id
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
tenant = db.session.query(App).filter(App.id == app_id).first()
if not tenant:
return None
tenant_id = tenant.tenant_id
tracing_config = OpsTraceManager.encrypt_tracing_config(
tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config
)

View File

@@ -41,7 +41,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8")
)
return cls.builtin_data
return cls.builtin_data or {}
@classmethod
def fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
@@ -50,8 +50,8 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
:param language: language
:return:
"""
builtin_data = cls._get_builtin_data()
return builtin_data.get("recommended_apps", {}).get(language)
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
return builtin_data.get("recommended_apps", {}).get(language, {})
@classmethod
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]:
@@ -60,5 +60,5 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
:param app_id: App ID
:return:
"""
builtin_data = cls._get_builtin_data()
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
return builtin_data.get("app_details", {}).get(app_id)

View File

@@ -47,8 +47,8 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
response = requests.get(url, timeout=(3, 10))
if response.status_code != 200:
return None
return response.json()
data: dict = response.json()
return data
@classmethod
def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
@@ -63,7 +63,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
if response.status_code != 200:
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
result = response.json()
result: dict = response.json()
if "categories" in result:
result["categories"] = sorted(result["categories"])

View File

@@ -33,5 +33,5 @@ class RecommendedAppService:
"""
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result = retrieval_instance.get_recommend_app_detail(app_id)
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
return result

View File

@@ -13,6 +13,8 @@ class SavedMessageService:
def pagination_by_last_id(
cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
) -> InfiniteScrollPagination:
if not user:
raise ValueError("User is required")
saved_messages = (
db.session.query(SavedMessage)
.filter(
@@ -31,6 +33,8 @@ class SavedMessageService:
@classmethod
def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
if not user:
return
saved_message = (
db.session.query(SavedMessage)
.filter(
@@ -59,6 +63,8 @@ class SavedMessageService:
@classmethod
def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
if not user:
return
saved_message = (
db.session.query(SavedMessage)
.filter(

View File

@@ -1,7 +1,7 @@
import uuid
from typing import Optional
from flask_login import current_user
from flask_login import current_user # type: ignore
from sqlalchemy import func
from werkzeug.exceptions import NotFound
@@ -21,7 +21,7 @@ class TagService:
if keyword:
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id)
results = query.order_by(Tag.created_at.desc()).all()
results: list = query.order_by(Tag.created_at.desc()).all()
return results
@staticmethod

View File

@@ -1,6 +1,7 @@
import json
import logging
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional, cast
from httpx import get
@@ -28,12 +29,12 @@ logger = logging.getLogger(__name__)
class ApiToolManageService:
@staticmethod
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
def parser_api_schema(schema: str) -> Mapping[str, Any]:
"""
parse api schema to tool bundle
"""
try:
warnings = {}
warnings: dict[str, str] = {}
try:
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
except Exception as e:
@@ -68,13 +69,16 @@ class ApiToolManageService:
),
]
return jsonable_encoder(
{
"schema_type": schema_type,
"parameters_schema": tool_bundles,
"credentials_schema": credentials_schema,
"warning": warnings,
}
return cast(
Mapping,
jsonable_encoder(
{
"schema_type": schema_type,
"parameters_schema": tool_bundles,
"credentials_schema": credentials_schema,
"warning": warnings,
}
),
)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
@@ -129,7 +133,7 @@ class ApiToolManageService:
raise ValueError(f"provider {provider_name} already exists")
# parse openapi to tool bundle
extra_info = {}
extra_info: dict[str, str] = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
@@ -262,9 +266,8 @@ class ApiToolManageService:
if provider is None:
raise ValueError(f"api provider {provider_name} does not exists")
# parse openapi to tool bundle
extra_info = {}
extra_info: dict[str, str] = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
@@ -416,7 +419,7 @@ class ApiToolManageService:
provider_controller.validate_credentials_format(credentials)
# get tool
tool = provider_controller.get_tool(tool_name)
tool = tool.fork_tool_runtime(
runtime_tool = tool.fork_tool_runtime(
runtime={
"credentials": credentials,
"tenant_id": tenant_id,
@@ -454,7 +457,7 @@ class ApiToolManageService:
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
for tool in tools:
for tool in tools or []:
user_provider.tools.append(
ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels

View File

@@ -50,8 +50,8 @@ class BuiltinToolManageService:
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
result = []
for tool in tools:
result: list[UserTool] = []
for tool in tools or []:
result.append(
ToolTransformService.tool_to_user_tool(
tool=tool,
@@ -217,6 +217,8 @@ class BuiltinToolManageService:
name_func=lambda x: x.identity.name,
):
continue
if provider_controller.identity is None:
continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
@@ -229,7 +231,7 @@ class BuiltinToolManageService:
ToolTransformService.repack_provider(user_builtin_provider)
tools = provider_controller.get_tools()
for tool in tools:
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id,

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Optional, Union
from typing import Optional, Union, cast
from configs import dify_config
from core.tools.entities.api_entities import UserTool, UserToolProvider
@@ -35,7 +35,7 @@ class ToolTransformService:
return url_prefix + "builtin/" + provider_name + "/icon"
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
try:
return json.loads(icon)
return cast(dict, json.loads(icon))
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
@@ -53,8 +53,11 @@ class ToolTransformService:
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
)
elif isinstance(provider, UserToolProvider):
provider.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
provider.icon = cast(
str,
ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
),
)
@staticmethod
@@ -66,6 +69,9 @@ class ToolTransformService:
"""
convert provider controller to user provider
"""
if provider_controller.identity is None:
raise ValueError("provider identity is None")
result = UserToolProvider(
id=provider_controller.identity.name,
author=provider_controller.identity.author,
@@ -93,7 +99,8 @@ class ToolTransformService:
# get credentials schema
schema = provider_controller.get_credentials_schema()
for name, value in schema.items():
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
assert result.masked_credentials is not None, "masked credentials is None"
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(str(value.type))
# check if the provider need credentials
if not provider_controller.need_credentials:
@@ -149,6 +156,9 @@ class ToolTransformService:
"""
convert provider controller to user provider
"""
if provider_controller.identity is None:
raise ValueError("provider identity is None")
return UserToolProvider(
id=provider_controller.provider_id,
author=provider_controller.identity.author,
@@ -180,6 +190,8 @@ class ToolTransformService:
convert provider controller to user provider
"""
username = "Anonymous"
if db_provider.user is None:
raise ValueError(f"user is None for api provider {db_provider.id}")
try:
username = db_provider.user.name
except Exception as e:
@@ -256,19 +268,25 @@ class ToolTransformService:
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
if tool.identity is None:
raise ValueError("tool identity is None")
return UserTool(
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
description=tool.description.human,
description=I18nObject(
en_US=tool.description.human if tool.description else "",
zh_Hans=tool.description.human if tool.description else "",
),
parameters=current_parameters,
labels=labels,
)
if isinstance(tool, ApiToolBundle):
return UserTool(
author=tool.author,
name=tool.operation_id,
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
name=tool.operation_id or "",
label=I18nObject(en_US=tool.operation_id or "", zh_Hans=tool.operation_id or ""),
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
parameters=tool.parameters,
labels=labels,

View File

@@ -6,8 +6,10 @@ from typing import Any, Optional
from sqlalchemy import or_
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from extensions.ext_database import db
@@ -32,7 +34,7 @@ class WorkflowToolManageService:
label: str,
icon: dict,
description: str,
parameters: Mapping[str, Any],
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: Optional[list[str]] = None,
) -> dict:
@@ -97,7 +99,7 @@ class WorkflowToolManageService:
label: str,
icon: dict,
description: str,
parameters: list[dict],
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: Optional[list[str]] = None,
) -> dict:
@@ -131,7 +133,7 @@ class WorkflowToolManageService:
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider = (
workflow_tool_provider: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
@@ -140,14 +142,14 @@ class WorkflowToolManageService:
if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App = (
app: Optional[App] = (
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
)
if app is None:
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
workflow: Workflow = app.workflow
workflow: Optional[Workflow] = app.workflow
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
@@ -193,7 +195,7 @@ class WorkflowToolManageService:
# skip deleted tools
pass
labels = ToolLabelManager.get_tools_labels(tools)
labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)])
result = []
@@ -202,10 +204,11 @@ class WorkflowToolManageService:
provider_controller=tool, labels=labels.get(tool.provider_id, [])
)
ToolTransformService.repack_provider(user_tool_provider)
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
if to_user_tool is None or len(to_user_tool) == 0:
continue
user_tool_provider.tools = [
ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, [])
)
ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=labels.get(tool.provider_id, []))
]
result.append(user_tool_provider)
@@ -236,7 +239,7 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider = (
db_tool: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
@@ -245,13 +248,19 @@ class WorkflowToolManageService:
if db_tool is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
workflow_app: Optional[App] = (
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
)
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
if to_user_tool is None or len(to_user_tool) == 0:
raise ValueError(f"Tool {workflow_tool_id} not found")
return {
"name": db_tool.name,
"label": db_tool.label,
@@ -261,9 +270,9 @@ class WorkflowToolManageService:
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
),
"synced": workflow_app.workflow.version == db_tool.version,
"synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
"privacy_policy": db_tool.privacy_policy,
}
@@ -276,7 +285,7 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider = (
db_tool: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
@@ -285,12 +294,17 @@ class WorkflowToolManageService:
if db_tool is None:
raise ValueError(f"Tool {workflow_app_id} not found")
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
workflow_app: Optional[App] = (
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
)
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
if to_user_tool is None or len(to_user_tool) == 0:
raise ValueError(f"Tool {workflow_app_id} not found")
return {
"name": db_tool.name,
@@ -301,14 +315,14 @@ class WorkflowToolManageService:
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
),
"synced": workflow_app.workflow.version == db_tool.version,
"synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
"privacy_policy": db_tool.privacy_policy,
}
@classmethod
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
"""
List workflow tool provider tools.
:param user_id: the user id
@@ -316,7 +330,7 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the list of tools
"""
db_tool: WorkflowToolProvider = (
db_tool: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
@@ -326,9 +340,8 @@ class WorkflowToolManageService:
raise ValueError(f"Tool {workflow_tool_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
if to_user_tool is None or len(to_user_tool) == 0:
raise ValueError(f"Tool {workflow_tool_id} not found")
return [
ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
)
]
return [ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool))]

View File

@@ -26,6 +26,8 @@ class WebConversationService:
pinned: Optional[bool] = None,
sort_by="-updated_at",
) -> InfiniteScrollPagination:
if not user:
raise ValueError("User is required")
include_ids = None
exclude_ids = None
if pinned is not None and user:
@@ -59,6 +61,8 @@ class WebConversationService:
@classmethod
def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
if not user:
return
pinned_conversation = (
db.session.query(PinnedConversation)
.filter(
@@ -89,6 +93,8 @@ class WebConversationService:
@classmethod
def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
if not user:
return
pinned_conversation = (
db.session.query(PinnedConversation)
.filter(

View File

@@ -1,8 +1,9 @@
import datetime
import json
from typing import Any
import requests
from flask_login import current_user
from flask_login import current_user # type: ignore
from core.helper import encrypter
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
@@ -23,9 +24,9 @@ class WebsiteService:
@classmethod
def crawl_url(cls, args: dict) -> dict:
provider = args.get("provider")
provider = args.get("provider", "")
url = args.get("url")
options = args.get("options")
options = args.get("options", "")
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key
@@ -164,16 +165,18 @@ class WebsiteService:
return crawl_status_data
@classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
# decrypt api_key
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
# FIXME data is redefine too many times here, use Any to ease the type checking, fix it later
data: Any
if provider == "firecrawl":
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
data = storage.load_once(file_key)
if data:
data = json.loads(data.decode("utf-8"))
d = storage.load_once(file_key)
if d:
data = json.loads(d.decode("utf-8"))
else:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
result = firecrawl_app.check_crawl_status(job_id)
@@ -183,22 +186,17 @@ class WebsiteService:
if data:
for item in data:
if item.get("source_url") == url:
return item
return dict(item)
return None
elif provider == "jinareader":
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
data = storage.load_once(file_key)
if data:
data = json.loads(data.decode("utf-8"))
elif not job_id:
if not job_id:
response = requests.get(
f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return response.json().get("data")
return dict(response.json().get("data", {}))
else:
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
response = requests.post(
@@ -218,12 +216,13 @@ class WebsiteService:
data = response.json().get("data", {})
for item in data.get("processed", {}).values():
if item.get("data", {}).get("url") == url:
return item.get("data", {})
return dict(item.get("data", {}))
return None
else:
raise ValueError("Invalid provider")
@classmethod
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key

View File

@@ -1,5 +1,5 @@
import json
from typing import Optional
from typing import Any, Optional
from core.app.app_config.entities import (
DatasetEntity,
@@ -101,7 +101,7 @@ class WorkflowConverter:
app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
# init workflow graph
graph = {"nodes": [], "edges": []}
graph: dict[str, Any] = {"nodes": [], "edges": []}
# Convert list:
# - variables -> start
@@ -118,7 +118,7 @@ class WorkflowConverter:
graph["nodes"].append(start_node)
# convert to http request node
external_data_variable_node_mapping = {}
external_data_variable_node_mapping: dict[str, str] = {}
if app_config.external_data_variables:
http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node(
app_model=app_model,
@@ -199,15 +199,16 @@ class WorkflowConverter:
return workflow
def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
app_mode = AppMode.value_of(app_model.mode)
if app_mode == AppMode.AGENT_CHAT or app_model.is_agent:
app_mode_enum = AppMode.value_of(app_model.mode)
app_config: EasyUIBasedAppConfig
if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent:
app_model.mode = AppMode.AGENT_CHAT.value
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
elif app_mode == AppMode.CHAT:
elif app_mode_enum == AppMode.CHAT:
app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
elif app_mode == AppMode.COMPLETION:
elif app_mode_enum == AppMode.COMPLETION:
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
@@ -302,7 +303,7 @@ class WorkflowConverter:
nodes.append(http_request_node)
# append code node for response body parsing
code_node = {
code_node: dict[str, Any] = {
"id": f"code_{index}",
"position": None,
"data": {
@@ -401,6 +402,7 @@ class WorkflowConverter:
)
role_prefix = None
prompts: Any = None
# Chat Model
if model_config.mode == LLMMode.CHAT.value:

View File

@@ -1,3 +1,5 @@
from typing import Optional
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
@@ -92,7 +94,7 @@ class WorkflowRunService:
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun:
def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
"""
Get workflow run detail

View File

@@ -2,7 +2,7 @@ import json
import time
from collections.abc import Sequence
from datetime import UTC, datetime
from typing import Optional, cast
from typing import Any, Optional, cast
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@@ -242,7 +242,7 @@ class WorkflowService:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
node_error_args = {
node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
@@ -338,7 +338,7 @@ class WorkflowService:
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
# convert to workflow
new_app = workflow_converter.convert_to_workflow(
new_app: App = workflow_converter.convert_to_workflow(
app_model=app_model,
account=account,
name=args.get("name", "Default Name"),

View File

@@ -1,4 +1,4 @@
from flask_login import current_user
from flask_login import current_user # type: ignore
from configs import dify_config
from extensions.ext_database import db
@@ -29,6 +29,7 @@ class WorkspaceService:
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
.first()
)
assert tenant_account_join is not None, "TenantAccountJoin not found"
tenant_info["role"] = tenant_account_join.role
can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo