mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
|
||||
@@ -43,6 +45,8 @@ class ApiModeration(Moderation):
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["inputs_config"]["enabled"]:
|
||||
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
|
||||
@@ -57,6 +61,8 @@ class ApiModeration(Moderation):
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["outputs_config"]["enabled"]:
|
||||
params = ModerationOutputParams(app_id=self.app_id, text=text)
|
||||
@@ -69,14 +75,18 @@ class ApiModeration(Moderation):
|
||||
)
|
||||
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
|
||||
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id"))
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
|
||||
if not extension:
|
||||
raise ValueError("API-based Extension not found. Please check it again.")
|
||||
requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
|
||||
|
||||
result = requestor.request(extension_point, params)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
|
||||
extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
|
||||
@@ -100,14 +100,14 @@ class Moderation(Extensible, ABC):
|
||||
if not inputs_config.get("preset_response"):
|
||||
raise ValueError("inputs_config.preset_response is required")
|
||||
|
||||
if len(inputs_config.get("preset_response")) > 100:
|
||||
if len(inputs_config.get("preset_response", 0)) > 100:
|
||||
raise ValueError("inputs_config.preset_response must be less than 100 characters")
|
||||
|
||||
if outputs_config_enabled:
|
||||
if not outputs_config.get("preset_response"):
|
||||
raise ValueError("outputs_config.preset_response is required")
|
||||
|
||||
if len(outputs_config.get("preset_response")) > 100:
|
||||
if len(outputs_config.get("preset_response", 0)) > 100:
|
||||
raise ValueError("outputs_config.preset_response must be less than 100 characters")
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,8 @@ class ModerationFactory:
|
||||
"""
|
||||
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||
extension_class.validate_config(tenant_id, config)
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import AppConfig
|
||||
from core.moderation.base import ModerationAction, ModerationError
|
||||
@@ -17,11 +18,11 @@ class InputModeration:
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
app_config: AppConfig,
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[bool, dict, str]:
|
||||
) -> tuple[bool, Mapping[str, Any], str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
@@ -33,6 +34,7 @@ class InputModeration:
|
||||
:param trace_manager: trace manager
|
||||
:return:
|
||||
"""
|
||||
inputs = dict(inputs)
|
||||
if not app_config.sensitive_word_avoidance:
|
||||
return False, inputs, query
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class KeywordsModeration(Moderation):
|
||||
if not config.get("keywords"):
|
||||
raise ValueError("keywords is required")
|
||||
|
||||
if len(config.get("keywords")) > 10000:
|
||||
if len(config.get("keywords", [])) > 10000:
|
||||
raise ValueError("keywords length must be less than 10000")
|
||||
|
||||
keywords_row_len = config["keywords"].split("\n")
|
||||
@@ -31,6 +31,8 @@ class KeywordsModeration(Moderation):
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["inputs_config"]["enabled"]:
|
||||
preset_response = self.config["inputs_config"]["preset_response"]
|
||||
@@ -50,6 +52,8 @@ class KeywordsModeration(Moderation):
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["outputs_config"]["enabled"]:
|
||||
# Filter out empty values
|
||||
|
||||
@@ -20,6 +20,8 @@ class OpenAIModeration(Moderation):
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["inputs_config"]["enabled"]:
|
||||
preset_response = self.config["inputs_config"]["preset_response"]
|
||||
@@ -35,6 +37,8 @@ class OpenAIModeration(Moderation):
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
|
||||
if self.config["outputs_config"]["enabled"]:
|
||||
flagged = self._is_violated({"text": text})
|
||||
|
||||
@@ -70,7 +70,7 @@ class OutputModeration(BaseModel):
|
||||
thread = threading.Thread(
|
||||
target=self.worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE,
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user