chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -44,32 +44,29 @@ class ApiModeration(Moderation):
flagged = False
preset_response = ""
if self.config['inputs_config']['enabled']:
params = ModerationInputParams(
app_id=self.app_id,
inputs=inputs,
query=query
)
if self.config["inputs_config"]["enabled"]:
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
return ModerationInputsResult(**result)
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
return ModerationInputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
if self.config['outputs_config']['enabled']:
params = ModerationOutputParams(
app_id=self.app_id,
text=text
)
if self.config["outputs_config"]["enabled"]:
params = ModerationOutputParams(app_id=self.app_id, text=text)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
return ModerationOutputsResult(**result)
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
return ModerationOutputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id"))
@@ -80,9 +77,10 @@ class ApiModeration(Moderation):
@staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
extension = db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)
return extension

View File

@@ -8,8 +8,8 @@ from core.extension.extensible import Extensible, ExtensionModule
class ModerationAction(Enum):
DIRECT_OUTPUT = 'direct_output'
OVERRIDDEN = 'overridden'
DIRECT_OUTPUT = "direct_output"
OVERRIDDEN = "overridden"
class ModerationInputsResult(BaseModel):
@@ -31,6 +31,7 @@ class Moderation(Extensible, ABC):
"""
The base class of moderation.
"""
module: ExtensionModule = ExtensionModule.MODERATION
def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None:

View File

@@ -13,13 +13,14 @@ logger = logging.getLogger(__name__)
class InputModeration:
def check(
self, app_id: str,
self,
app_id: str,
tenant_id: str,
app_config: AppConfig,
inputs: dict,
query: str,
message_id: str,
trace_manager: Optional[TraceQueueManager] = None
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
@@ -39,10 +40,7 @@ class InputModeration:
moderation_type = sensitive_word_avoidance_config.type
moderation_factory = ModerationFactory(
name=moderation_type,
app_id=app_id,
tenant_id=tenant_id,
config=sensitive_word_avoidance_config.config
name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config
)
with measure_time() as timer:
@@ -55,7 +53,7 @@ class InputModeration:
message_id=message_id,
moderation_result=moderation_result,
inputs=inputs,
timer=timer
timer=timer,
)
)

View File

@@ -25,31 +25,35 @@ class KeywordsModeration(Moderation):
flagged = False
preset_response = ""
if self.config['inputs_config']['enabled']:
preset_response = self.config['inputs_config']['preset_response']
if self.config["inputs_config"]["enabled"]:
preset_response = self.config["inputs_config"]["preset_response"]
if query:
inputs['query__'] = query
inputs["query__"] = query
# Filter out empty values
keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword]
keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword]
flagged = self._is_violated(inputs, keywords_list)
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
return ModerationInputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
if self.config['outputs_config']['enabled']:
if self.config["outputs_config"]["enabled"]:
# Filter out empty values
keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword]
keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword]
flagged = self._is_violated({'text': text}, keywords_list)
preset_response = self.config['outputs_config']['preset_response']
flagged = self._is_violated({"text": text}, keywords_list)
preset_response = self.config["outputs_config"]["preset_response"]
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
return ModerationOutputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
for value in inputs.values():

View File

@@ -21,37 +21,36 @@ class OpenAIModeration(Moderation):
flagged = False
preset_response = ""
if self.config['inputs_config']['enabled']:
preset_response = self.config['inputs_config']['preset_response']
if self.config["inputs_config"]["enabled"]:
preset_response = self.config["inputs_config"]["preset_response"]
if query:
inputs['query__'] = query
inputs["query__"] = query
flagged = self._is_violated(inputs)
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
return ModerationInputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
if self.config['outputs_config']['enabled']:
flagged = self._is_violated({'text': text})
preset_response = self.config['outputs_config']['preset_response']
if self.config["outputs_config"]["enabled"]:
flagged = self._is_violated({"text": text})
preset_response = self.config["outputs_config"]["preset_response"]
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
return ModerationOutputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _is_violated(self, inputs: dict):
text = '\n'.join(str(inputs.values()))
text = "\n".join(str(inputs.values()))
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider="openai",
model_type=ModelType.MODERATION,
model="text-moderation-stable"
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable"
)
openai_moderation = model_instance.invoke_moderation(
text=text
)
openai_moderation = model_instance.invoke_moderation(text=text)
return openai_moderation

View File

@@ -29,7 +29,7 @@ class OutputModeration(BaseModel):
thread: Optional[threading.Thread] = None
thread_running: bool = True
buffer: str = ''
buffer: str = ""
is_final_chunk: bool = False
final_output: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -50,11 +50,7 @@ class OutputModeration(BaseModel):
self.buffer = completion
self.is_final_chunk = True
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=completion
)
result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion)
if not result or not result.flagged:
return completion
@@ -65,21 +61,19 @@ class OutputModeration(BaseModel):
final_output = result.text
if public_event:
self.queue_manager.publish(
QueueMessageReplaceEvent(
text=final_output
),
PublishFrom.TASK_PIPELINE
)
self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
return final_output
def start_thread(self) -> threading.Thread:
buffer_size = dify_config.MODERATION_BUFFER_SIZE
thread = threading.Thread(target=self.worker, kwargs={
'flask_app': current_app._get_current_object(),
'buffer_size': buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE
})
thread = threading.Thread(
target=self.worker,
kwargs={
"flask_app": current_app._get_current_object(),
"buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE,
},
)
thread.start()
@@ -104,9 +98,7 @@ class OutputModeration(BaseModel):
current_length = buffer_length
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=moderation_buffer
tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer
)
if not result or not result.flagged:
@@ -116,16 +108,11 @@ class OutputModeration(BaseModel):
final_output = result.preset_response
self.final_output = final_output
else:
final_output = result.text + self.buffer[len(moderation_buffer):]
final_output = result.text + self.buffer[len(moderation_buffer) :]
# trigger replace event
if self.thread_running:
self.queue_manager.publish(
QueueMessageReplaceEvent(
text=final_output
),
PublishFrom.TASK_PIPELINE
)
self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
if result.action == ModerationAction.DIRECT_OUTPUT:
break
@@ -133,10 +120,7 @@ class OutputModeration(BaseModel):
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
try:
moderation_factory = ModerationFactory(
name=self.rule.type,
app_id=app_id,
tenant_id=tenant_id,
config=self.rule.config
name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config
)
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)