chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -15,9 +15,11 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
logger = logging.getLogger(__name__)
class CodeExecutionException(Exception):
pass
class CodeExecutionResponse(BaseModel):
class Data(BaseModel):
stdout: Optional[str] = None
@@ -29,9 +31,9 @@ class CodeExecutionResponse(BaseModel):
class CodeLanguage(str, Enum):
PYTHON3 = 'python3'
JINJA2 = 'jinja2'
JAVASCRIPT = 'javascript'
PYTHON3 = "python3"
JINJA2 = "jinja2"
JAVASCRIPT = "javascript"
class CodeExecutor:
@@ -45,63 +47,65 @@ class CodeExecutor:
}
code_language_to_running_language = {
CodeLanguage.JAVASCRIPT: 'nodejs',
CodeLanguage.JAVASCRIPT: "nodejs",
CodeLanguage.JINJA2: CodeLanguage.PYTHON3,
CodeLanguage.PYTHON3: CodeLanguage.PYTHON3,
}
supported_dependencies_languages: set[CodeLanguage] = {
CodeLanguage.PYTHON3
}
supported_dependencies_languages: set[CodeLanguage] = {CodeLanguage.PYTHON3}
@classmethod
def execute_code(cls,
language: CodeLanguage,
preload: str,
code: str) -> str:
def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str:
"""
Execute code
:param language: code language
:param code: code
:return:
"""
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / 'v1' / 'sandbox' / 'run'
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run"
headers = {
'X-Api-Key': dify_config.CODE_EXECUTION_API_KEY
}
headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY}
data = {
'language': cls.code_language_to_running_language.get(language),
'code': code,
'preload': preload,
'enable_network': True
"language": cls.code_language_to_running_language.get(language),
"code": code,
"preload": preload,
"enable_network": True,
}
try:
response = post(str(url), json=data, headers=headers,
timeout=Timeout(
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
pool=None))
response = post(
str(url),
json=data,
headers=headers,
timeout=Timeout(
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
pool=None,
),
)
if response.status_code == 503:
raise CodeExecutionException('Code execution service is unavailable')
raise CodeExecutionException("Code execution service is unavailable")
elif response.status_code != 200:
raise Exception(f'Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running')
raise Exception(
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
)
except CodeExecutionException as e:
raise e
except Exception as e:
raise CodeExecutionException('Failed to execute code, which is likely a network issue,'
' please check if the sandbox service is running.'
f' ( Error: {str(e)} )')
raise CodeExecutionException(
"Failed to execute code, which is likely a network issue,"
" please check if the sandbox service is running."
f" ( Error: {str(e)} )"
)
try:
response = response.json()
except:
raise CodeExecutionException('Failed to parse response')
raise CodeExecutionException("Failed to parse response")
if (code := response.get('code')) != 0:
if (code := response.get("code")) != 0:
raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}")
response = CodeExecutionResponse(**response)
@@ -109,7 +113,7 @@ class CodeExecutor:
if response.data.error:
raise CodeExecutionException(response.data.error)
return response.data.stdout or ''
return response.data.stdout or ""
@classmethod
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict:
@@ -122,7 +126,7 @@ class CodeExecutor:
"""
template_transformer = cls.code_template_transformers.get(language)
if not template_transformer:
raise CodeExecutionException(f'Unsupported language {language}')
raise CodeExecutionException(f"Unsupported language {language}")
runner, preload = template_transformer.transform_caller(code, inputs)

View File

@@ -26,23 +26,9 @@ class CodeNodeProvider(BaseModel):
return {
"type": "code",
"config": {
"variables": [
{
"variable": "arg1",
"value_selector": []
},
{
"variable": "arg2",
"value_selector": []
}
],
"variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}],
"code_language": cls.get_language(),
"code": cls.get_default_code(),
"outputs": {
"result": {
"type": "string",
"children": None
}
}
}
"outputs": {"result": {"type": "string", "children": None}},
},
}

View File

@@ -18,4 +18,5 @@ class JavascriptCodeProvider(CodeNodeProvider):
result: arg1 + arg2
}
}
""")
"""
)

View File

@@ -21,5 +21,6 @@ class NodeJsTemplateTransformer(TemplateTransformer):
var output_json = JSON.stringify(output_obj)
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
console.log(result)
""")
"""
)
return runner_script

View File

@@ -10,8 +10,6 @@ class Jinja2Formatter:
:param inputs: inputs
:return:
"""
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=template, inputs=inputs
)
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
return result['result']
return result["result"]

View File

@@ -11,9 +11,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
:param response: response
:return:
"""
return {
'result': cls.extract_result_str_from_response(response)
}
return {"result": cls.extract_result_str_from_response(response)}
@classmethod
def get_runner_script(cls) -> str:

View File

@@ -17,4 +17,5 @@ class Python3CodeProvider(CodeNodeProvider):
return {
"result": arg1 + arg2,
}
""")
"""
)

View File

@@ -5,9 +5,9 @@ from base64 import b64encode
class TemplateTransformer(ABC):
_code_placeholder: str = '{{code}}'
_inputs_placeholder: str = '{{inputs}}'
_result_tag: str = '<<RESULT>>'
_code_placeholder: str = "{{code}}"
_inputs_placeholder: str = "{{inputs}}"
_result_tag: str = "<<RESULT>>"
@classmethod
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
@@ -24,9 +24,9 @@ class TemplateTransformer(ABC):
@classmethod
def extract_result_str_from_response(cls, response: str) -> str:
result = re.search(rf'{cls._result_tag}(.*){cls._result_tag}', response, re.DOTALL)
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
if not result:
raise ValueError('Failed to parse result')
raise ValueError("Failed to parse result")
result = result.group(1)
return result
@@ -50,7 +50,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: dict) -> str:
inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode()
input_base64_encoded = b64encode(inputs_json_str).decode('utf-8')
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded
@classmethod
@@ -67,4 +67,4 @@ class TemplateTransformer(ABC):
"""
Get preload script
"""
return ''
return ""

View File

@@ -8,14 +8,15 @@ def obfuscated_token(token: str):
if not token:
return token
if len(token) <= 8:
return '*' * 20
return token[:6] + '*' * 12 + token[-2:]
return "*" * 20
return token[:6] + "*" * 12 + token[-2:]
def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
raise ValueError(f'Tenant with id {tenant_id} not found')
raise ValueError(f"Tenant with id {tenant_id} not found")
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()

View File

@@ -25,7 +25,7 @@ class ProviderCredentialsCache:
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None

View File

@@ -12,19 +12,20 @@ logger = logging.getLogger(__name__)
def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
moderation_config = hosting_configuration.moderation_config
if (moderation_config and moderation_config.enabled is True
and 'openai' in hosting_configuration.provider_map
and hosting_configuration.provider_map['openai'].enabled is True
if (
moderation_config
and moderation_config.enabled is True
and "openai" in hosting_configuration.provider_map
and hosting_configuration.provider_map["openai"].enabled is True
):
using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
provider_name = model_config.provider
if using_provider_type == ProviderType.SYSTEM \
and provider_name in moderation_config.providers:
hosting_openai_config = hosting_configuration.provider_map['openai']
if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
hosting_openai_config = hosting_configuration.provider_map["openai"]
# 2000 text per chunk
length = 2000
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
text_chunks = [text[i : i + length] for i in range(0, len(text), length)]
if len(text_chunks) == 0:
return True
@@ -34,15 +35,13 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
try:
model_type_instance = OpenAIModerationModel()
moderation_result = model_type_instance.invoke(
model='text-moderation-stable',
credentials=hosting_openai_config.credentials,
text=text_chunk
model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk
)
if moderation_result is True:
return True
except Exception as ex:
logger.exception(ex)
raise InvokeBadRequestError('Rate limit exceeded, please try again later.')
raise InvokeBadRequestError("Rate limit exceeded, please try again later.")
return False

View File

@@ -37,8 +37,9 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
"""
Get all the subclasses of the parent type from the module
"""
classes = [x for _, x in vars(mod).items()
if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)]
classes = [
x for _, x in vars(mod).items() if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)
]
return classes
@@ -56,6 +57,6 @@ def load_single_subclass_from_source(
case 1:
return subclasses[0]
case 0:
raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}')
raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}")
case _:
raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}')
raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}")

View File

@@ -73,10 +73,10 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
def is_filtered(
include_set: set[str],
exclude_set: set[str],
data: Any,
name_func: Callable[[Any], str],
include_set: set[str],
exclude_set: set[str],
data: Any,
name_func: Callable[[Any], str],
) -> bool:
"""
Check if the object should be filtered out.
@@ -102,9 +102,9 @@ def is_filtered(
def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> list[Any]:
"""
Sort the objects by the position map.
@@ -117,13 +117,13 @@ def sort_by_position_map(
if not position_map or not data:
return data
return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf')))
return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf")))
def sort_to_dict_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> OrderedDict[str, Any]:
"""
Sort the objects into a ordered dict by the position map.

View File

@@ -1,31 +1,34 @@
"""
Proxy requests to avoid SSRF
"""
import logging
import os
import time
import httpx
SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '')
SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '')
SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '')
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv('SSRF_DEFAULT_MAX_RETRIES', '3'))
SSRF_PROXY_ALL_URL = os.getenv("SSRF_PROXY_ALL_URL", "")
SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "")
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
proxies = {
'http://': SSRF_PROXY_HTTP_URL,
'https://': SSRF_PROXY_HTTPS_URL
} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None
proxies = (
{"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL}
if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL
else None
)
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
if "follow_redirects" not in kwargs:
kwargs["follow_redirects"] = allow_redirects
retries = 0
while retries <= max_retries:
try:
@@ -52,24 +55,24 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request('GET', url, max_retries=max_retries, **kwargs)
return make_request("GET", url, max_retries=max_retries, **kwargs)
def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request('POST', url, max_retries=max_retries, **kwargs)
return make_request("POST", url, max_retries=max_retries, **kwargs)
def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request('PUT', url, max_retries=max_retries, **kwargs)
return make_request("PUT", url, max_retries=max_retries, **kwargs)
def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request('PATCH', url, max_retries=max_retries, **kwargs)
return make_request("PATCH", url, max_retries=max_retries, **kwargs)
def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request('DELETE', url, max_retries=max_retries, **kwargs)
return make_request("DELETE", url, max_retries=max_retries, **kwargs)
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request('HEAD', url, max_retries=max_retries, **kwargs)
return make_request("HEAD", url, max_retries=max_retries, **kwargs)

View File

@@ -9,14 +9,11 @@ from extensions.ext_redis import redis_client
class ToolParameterCacheType(Enum):
PARAMETER = "tool_parameter"
class ToolParameterCache:
def __init__(self,
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType,
identity_id: str
):
def __init__(
self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
):
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}"
def get(self) -> Optional[dict]:
@@ -28,7 +25,7 @@ class ToolParameterCache:
cached_tool_parameter = redis_client.get(self.cache_key)
if cached_tool_parameter:
try:
cached_tool_parameter = cached_tool_parameter.decode('utf-8')
cached_tool_parameter = cached_tool_parameter.decode("utf-8")
cached_tool_parameter = json.loads(cached_tool_parameter)
except JSONDecodeError:
return None
@@ -52,4 +49,4 @@ class ToolParameterCache:
:return:
"""
redis_client.delete(self.cache_key)
redis_client.delete(self.cache_key)

View File

@@ -9,6 +9,7 @@ from extensions.ext_redis import redis_client
class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider"
class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
@@ -22,7 +23,7 @@ class ToolProviderCredentialsCache:
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
@@ -46,4 +47,4 @@ class ToolProviderCredentialsCache:
:return:
"""
redis_client.delete(self.cache_key)
redis_client.delete(self.cache_key)