feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,3 +1,5 @@
from typing import Optional
from pydantic import BaseModel
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
@@ -43,6 +45,8 @@ class ApiModeration(Moderation):
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
raise ValueError("The config is not set.")
if self.config["inputs_config"]["enabled"]:
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
@@ -57,6 +61,8 @@ class ApiModeration(Moderation):
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
if self.config is None:
raise ValueError("The config is not set.")
if self.config["outputs_config"]["enabled"]:
params = ModerationOutputParams(app_id=self.app_id, text=text)
@@ -69,14 +75,18 @@ class ApiModeration(Moderation):
)
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id"))
if self.config is None:
raise ValueError("The config is not set.")
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
if not extension:
raise ValueError("API-based Extension not found. Please check it again.")
requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
result = requestor.request(extension_point, params)
return result
@staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)