mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-22 09:16:53 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -44,32 +44,29 @@ class ApiModeration(Moderation):
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['inputs_config']['enabled']:
|
||||
params = ModerationInputParams(
|
||||
app_id=self.app_id,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
if self.config["inputs_config"]["enabled"]:
|
||||
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
|
||||
return ModerationInputsResult(**result)
|
||||
|
||||
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
return ModerationInputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['outputs_config']['enabled']:
|
||||
params = ModerationOutputParams(
|
||||
app_id=self.app_id,
|
||||
text=text
|
||||
)
|
||||
if self.config["outputs_config"]["enabled"]:
|
||||
params = ModerationOutputParams(app_id=self.app_id, text=text)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
|
||||
return ModerationOutputsResult(**result)
|
||||
|
||||
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
return ModerationOutputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
|
||||
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id"))
|
||||
@@ -80,9 +77,10 @@ class ApiModeration(Moderation):
|
||||
|
||||
@staticmethod
|
||||
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
extension = db.session.query(APIBasedExtension).filter(
|
||||
APIBasedExtension.tenant_id == tenant_id,
|
||||
APIBasedExtension.id == api_based_extension_id
|
||||
).first()
|
||||
extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return extension
|
||||
|
||||
@@ -8,8 +8,8 @@ from core.extension.extensible import Extensible, ExtensionModule
|
||||
|
||||
|
||||
class ModerationAction(Enum):
|
||||
DIRECT_OUTPUT = 'direct_output'
|
||||
OVERRIDDEN = 'overridden'
|
||||
DIRECT_OUTPUT = "direct_output"
|
||||
OVERRIDDEN = "overridden"
|
||||
|
||||
|
||||
class ModerationInputsResult(BaseModel):
|
||||
@@ -31,6 +31,7 @@ class Moderation(Extensible, ABC):
|
||||
"""
|
||||
The base class of moderation.
|
||||
"""
|
||||
|
||||
module: ExtensionModule = ExtensionModule.MODERATION
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None:
|
||||
|
||||
@@ -13,13 +13,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class InputModeration:
|
||||
def check(
|
||||
self, app_id: str,
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
app_config: AppConfig,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
message_id: str,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
@@ -39,10 +40,7 @@ class InputModeration:
|
||||
moderation_type = sensitive_word_avoidance_config.type
|
||||
|
||||
moderation_factory = ModerationFactory(
|
||||
name=moderation_type,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
config=sensitive_word_avoidance_config.config
|
||||
name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config
|
||||
)
|
||||
|
||||
with measure_time() as timer:
|
||||
@@ -55,7 +53,7 @@ class InputModeration:
|
||||
message_id=message_id,
|
||||
moderation_result=moderation_result,
|
||||
inputs=inputs,
|
||||
timer=timer
|
||||
timer=timer,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -25,31 +25,35 @@ class KeywordsModeration(Moderation):
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['inputs_config']['enabled']:
|
||||
preset_response = self.config['inputs_config']['preset_response']
|
||||
if self.config["inputs_config"]["enabled"]:
|
||||
preset_response = self.config["inputs_config"]["preset_response"]
|
||||
|
||||
if query:
|
||||
inputs['query__'] = query
|
||||
inputs["query__"] = query
|
||||
|
||||
# Filter out empty values
|
||||
keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword]
|
||||
keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword]
|
||||
|
||||
flagged = self._is_violated(inputs, keywords_list)
|
||||
|
||||
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
return ModerationInputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['outputs_config']['enabled']:
|
||||
if self.config["outputs_config"]["enabled"]:
|
||||
# Filter out empty values
|
||||
keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword]
|
||||
keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword]
|
||||
|
||||
flagged = self._is_violated({'text': text}, keywords_list)
|
||||
preset_response = self.config['outputs_config']['preset_response']
|
||||
flagged = self._is_violated({"text": text}, keywords_list)
|
||||
preset_response = self.config["outputs_config"]["preset_response"]
|
||||
|
||||
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
return ModerationOutputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||
for value in inputs.values():
|
||||
|
||||
@@ -21,37 +21,36 @@ class OpenAIModeration(Moderation):
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['inputs_config']['enabled']:
|
||||
preset_response = self.config['inputs_config']['preset_response']
|
||||
if self.config["inputs_config"]["enabled"]:
|
||||
preset_response = self.config["inputs_config"]["preset_response"]
|
||||
|
||||
if query:
|
||||
inputs['query__'] = query
|
||||
inputs["query__"] = query
|
||||
flagged = self._is_violated(inputs)
|
||||
|
||||
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
return ModerationInputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['outputs_config']['enabled']:
|
||||
flagged = self._is_violated({'text': text})
|
||||
preset_response = self.config['outputs_config']['preset_response']
|
||||
if self.config["outputs_config"]["enabled"]:
|
||||
flagged = self._is_violated({"text": text})
|
||||
preset_response = self.config["outputs_config"]["preset_response"]
|
||||
|
||||
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
return ModerationOutputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict):
|
||||
text = '\n'.join(str(inputs.values()))
|
||||
text = "\n".join(str(inputs.values()))
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider="openai",
|
||||
model_type=ModelType.MODERATION,
|
||||
model="text-moderation-stable"
|
||||
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable"
|
||||
)
|
||||
|
||||
openai_moderation = model_instance.invoke_moderation(
|
||||
text=text
|
||||
)
|
||||
openai_moderation = model_instance.invoke_moderation(text=text)
|
||||
|
||||
return openai_moderation
|
||||
|
||||
@@ -29,7 +29,7 @@ class OutputModeration(BaseModel):
|
||||
|
||||
thread: Optional[threading.Thread] = None
|
||||
thread_running: bool = True
|
||||
buffer: str = ''
|
||||
buffer: str = ""
|
||||
is_final_chunk: bool = False
|
||||
final_output: Optional[str] = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
@@ -50,11 +50,7 @@ class OutputModeration(BaseModel):
|
||||
self.buffer = completion
|
||||
self.is_final_chunk = True
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=completion
|
||||
)
|
||||
result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion)
|
||||
|
||||
if not result or not result.flagged:
|
||||
return completion
|
||||
@@ -65,21 +61,19 @@ class OutputModeration(BaseModel):
|
||||
final_output = result.text
|
||||
|
||||
if public_event:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageReplaceEvent(
|
||||
text=final_output
|
||||
),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
return final_output
|
||||
|
||||
def start_thread(self) -> threading.Thread:
|
||||
buffer_size = dify_config.MODERATION_BUFFER_SIZE
|
||||
thread = threading.Thread(target=self.worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'buffer_size': buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE
|
||||
})
|
||||
thread = threading.Thread(
|
||||
target=self.worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE,
|
||||
},
|
||||
)
|
||||
|
||||
thread.start()
|
||||
|
||||
@@ -104,9 +98,7 @@ class OutputModeration(BaseModel):
|
||||
current_length = buffer_length
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=moderation_buffer
|
||||
tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
@@ -116,16 +108,11 @@ class OutputModeration(BaseModel):
|
||||
final_output = result.preset_response
|
||||
self.final_output = final_output
|
||||
else:
|
||||
final_output = result.text + self.buffer[len(moderation_buffer):]
|
||||
final_output = result.text + self.buffer[len(moderation_buffer) :]
|
||||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageReplaceEvent(
|
||||
text=final_output
|
||||
),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
@@ -133,10 +120,7 @@ class OutputModeration(BaseModel):
|
||||
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
|
||||
try:
|
||||
moderation_factory = ModerationFactory(
|
||||
name=self.rule.type,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
config=self.rule.config
|
||||
name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config
|
||||
)
|
||||
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
|
||||
Reference in New Issue
Block a user