chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -25,25 +25,25 @@ from .variables import (
)
__all__ = [
'IntegerVariable',
'FloatVariable',
'ObjectVariable',
'SecretVariable',
'StringVariable',
'ArrayAnyVariable',
'Variable',
'SegmentType',
'SegmentGroup',
'Segment',
'NoneSegment',
'NoneVariable',
'IntegerSegment',
'FloatSegment',
'ObjectSegment',
'ArrayAnySegment',
'StringSegment',
'ArrayStringVariable',
'ArrayNumberVariable',
'ArrayObjectVariable',
'ArraySegment',
"IntegerVariable",
"FloatVariable",
"ObjectVariable",
"SecretVariable",
"StringVariable",
"ArrayAnyVariable",
"Variable",
"SegmentType",
"SegmentGroup",
"Segment",
"NoneSegment",
"NoneVariable",
"IntegerSegment",
"FloatSegment",
"ObjectSegment",
"ArrayAnySegment",
"StringSegment",
"ArrayStringVariable",
"ArrayNumberVariable",
"ArrayObjectVariable",
"ArraySegment",
]

View File

@@ -28,12 +28,12 @@ from .variables import (
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if (value_type := mapping.get('value_type')) is None:
raise VariableError('missing value type')
if not mapping.get('name'):
raise VariableError('missing name')
if (value := mapping.get('value')) is None:
raise VariableError('missing value')
if (value_type := mapping.get("value_type")) is None:
raise VariableError("missing value type")
if not mapping.get("name"):
raise VariableError("missing name")
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
@@ -44,7 +44,7 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
case SegmentType.NUMBER if isinstance(value, float):
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f'invalid number value {value}')
raise VariableError(f"invalid number value {value}")
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
@@ -54,9 +54,9 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case _:
raise VariableError(f'not supported value type {value_type}')
raise VariableError(f"not supported value type {value_type}")
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}')
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
return result
@@ -73,4 +73,4 @@ def build_segment(value: Any, /) -> Segment:
return ObjectSegment(value=value)
if isinstance(value, list):
return ArrayAnySegment(value=value)
raise ValueError(f'not supported value {value}')
raise ValueError(f"not supported value {value}")

View File

@@ -4,14 +4,14 @@ from core.workflow.entities.variable_pool import VariablePool
from . import SegmentGroup, factory
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
def convert_template(*, template: str, variable_pool: VariablePool):
parts = re.split(VARIABLE_PATTERN, template)
segments = []
for part in filter(lambda x: x, parts):
if '.' in part and (value := variable_pool.get(part.split('.'))):
if "." in part and (value := variable_pool.get(part.split("."))):
segments.append(value)
else:
segments.append(factory.build_segment(part))

View File

@@ -8,15 +8,15 @@ class SegmentGroup(Segment):
@property
def text(self):
return ''.join([segment.text for segment in self.value])
return "".join([segment.text for segment in self.value])
@property
def log(self):
return ''.join([segment.log for segment in self.value])
return "".join([segment.log for segment in self.value])
@property
def markdown(self):
return ''.join([segment.markdown for segment in self.value])
return "".join([segment.markdown for segment in self.value])
def to_object(self):
return [segment.to_object() for segment in self.value]

View File

@@ -14,13 +14,13 @@ class Segment(BaseModel):
value_type: SegmentType
value: Any
@field_validator('value_type')
@field_validator("value_type")
def validate_value_type(cls, value):
"""
This validator checks if the provided value is equal to the default value of the 'value_type' field.
If the value is different, a ValueError is raised.
"""
if value != cls.model_fields['value_type'].default:
if value != cls.model_fields["value_type"].default:
raise ValueError("Cannot modify 'value_type'")
return value
@@ -50,15 +50,15 @@ class NoneSegment(Segment):
@property
def text(self) -> str:
return 'null'
return "null"
@property
def log(self) -> str:
return 'null'
return "null"
@property
def markdown(self) -> str:
return 'null'
return "null"
class StringSegment(Segment):
@@ -76,24 +76,21 @@ class IntegerSegment(Segment):
value: int
class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any]
@property
def text(self) -> str:
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
return json.dumps(self.model_dump()["value"], ensure_ascii=False)
@property
def log(self) -> str:
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
@property
def markdown(self) -> str:
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
class ArraySegment(Segment):
@@ -101,11 +98,11 @@ class ArraySegment(Segment):
def markdown(self) -> str:
items = []
for item in self.value:
if hasattr(item, 'to_markdown'):
if hasattr(item, "to_markdown"):
items.append(item.to_markdown())
else:
items.append(str(item))
return '\n'.join(items)
return "\n".join(items)
class ArrayAnySegment(ArraySegment):
@@ -126,4 +123,3 @@ class ArrayNumberSegment(ArraySegment):
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]]

View File

@@ -2,14 +2,14 @@ from enum import Enum
class SegmentType(str, Enum):
NONE = 'none'
NUMBER = 'number'
STRING = 'string'
SECRET = 'secret'
ARRAY_ANY = 'array[any]'
ARRAY_STRING = 'array[string]'
ARRAY_NUMBER = 'array[number]'
ARRAY_OBJECT = 'array[object]'
OBJECT = 'object'
NONE = "none"
NUMBER = "number"
STRING = "string"
SECRET = "secret"
ARRAY_ANY = "array[any]"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
OBJECT = "object"
GROUP = 'group'
GROUP = "group"

View File

@@ -23,11 +23,11 @@ class Variable(Segment):
"""
id: str = Field(
default='',
default="",
description="Unique identity for variable. It's only used by environment variables now.",
)
name: str
description: str = Field(default='', description='Description of the variable.')
description: str = Field(default="", description="Description of the variable.")
class StringVariable(StringSegment, Variable):
@@ -62,7 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
pass
class SecretVariable(StringVariable):
value_type: SegmentType = SegmentType.SECRET