mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-11 11:56:53 +08:00
Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -1,18 +1,27 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import openai
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_moderation(model_config: ModelConfigEntity, text: str) -> bool:
|
||||
moderation_config = hosting_configuration.moderation_config
|
||||
if (moderation_config and moderation_config.enabled is True
|
||||
and 'openai' in hosting_configuration.provider_map
|
||||
and hosting_configuration.provider_map['openai'].enabled is True
|
||||
):
|
||||
using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
|
||||
provider_name = model_config.provider
|
||||
if using_provider_type == ProviderType.SYSTEM \
|
||||
and provider_name in moderation_config.providers:
|
||||
hosting_openai_config = hosting_configuration.provider_map['openai']
|
||||
|
||||
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
||||
if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
|
||||
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and model_provider.provider_name in hosted_config.moderation.providers:
|
||||
# 2000 text per chunk
|
||||
length = 2000
|
||||
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||
@@ -23,14 +32,17 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
||||
text_chunk = random.choice(text_chunks)
|
||||
|
||||
try:
|
||||
moderation_result = openai.Moderation.create(input=text_chunk,
|
||||
api_key=hosted_model_providers.openai.api_key)
|
||||
model_type_instance = OpenAIModerationModel()
|
||||
moderation_result = model_type_instance.invoke(
|
||||
model='text-moderation-stable',
|
||||
credentials=hosting_openai_config.credentials,
|
||||
text=text_chunk
|
||||
)
|
||||
|
||||
if moderation_result is True:
|
||||
return True
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
logger.exception(ex)
|
||||
raise InvokeBadRequestError('Rate limit exceeded, please try again later.')
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result['flagged'] is True:
|
||||
return False
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user