chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -8,6 +8,7 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz
"""
This class is responsible for providing the stability tool.
"""
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
This method is responsible for validating the credentials.

View File

@@ -9,26 +9,23 @@ class BaseStabilityAuthorization:
"""
This method is responsible for validating the credentials.
"""
api_key = credentials.get('api_key', '')
api_key = credentials.get("api_key", "")
if not api_key:
raise ToolProviderCredentialValidationError('API key is required.')
raise ToolProviderCredentialValidationError("API key is required.")
response = requests.get(
URL('https://api.stability.ai') / 'v1' / 'user' / 'account',
URL("https://api.stability.ai") / "v1" / "user" / "account",
headers=self.generate_authorization_headers(credentials),
timeout=(5, 30)
timeout=(5, 30),
)
if not response.ok:
raise ToolProviderCredentialValidationError('Invalid API key.')
raise ToolProviderCredentialValidationError("Invalid API key.")
return True
def generate_authorization_headers(self, credentials: dict) -> dict[str, str]:
"""
This method is responsible for generating the authorization headers.
"""
return {
'Authorization': f'Bearer {credentials.get("api_key", "")}'
}
return {"Authorization": f'Bearer {credentials.get("api_key", "")}'}

View File

@@ -11,10 +11,11 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
"""
This class is responsible for providing the stable diffusion tool.
"""
model_endpoint_map: dict[str, str] = {
'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
'core': 'https://api.stability.ai/v2beta/stable-image/generate/core',
"sd3": "https://api.stability.ai/v2beta/stable-image/generate/sd3",
"sd3-turbo": "https://api.stability.ai/v2beta/stable-image/generate/sd3",
"core": "https://api.stability.ai/v2beta/stable-image/generate/core",
}
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
@@ -22,39 +23,34 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
Invoke the tool.
"""
payload = {
'prompt': tool_parameters.get('prompt', ''),
'aspect_ratio': tool_parameters.get('aspect_ratio', '16:9') or tool_parameters.get('aspect_radio', '16:9'),
'mode': 'text-to-image',
'seed': tool_parameters.get('seed', 0),
'output_format': 'png',
"prompt": tool_parameters.get("prompt", ""),
"aspect_ratio": tool_parameters.get("aspect_ratio", "16:9") or tool_parameters.get("aspect_radio", "16:9"),
"mode": "text-to-image",
"seed": tool_parameters.get("seed", 0),
"output_format": "png",
}
model = tool_parameters.get('model', 'core')
model = tool_parameters.get("model", "core")
if model in ['sd3', 'sd3-turbo']:
payload['model'] = tool_parameters.get('model')
if model in ["sd3", "sd3-turbo"]:
payload["model"] = tool_parameters.get("model")
if not model == 'sd3-turbo':
payload['negative_prompt'] = tool_parameters.get('negative_prompt', '')
if not model == "sd3-turbo":
payload["negative_prompt"] = tool_parameters.get("negative_prompt", "")
response = post(
self.model_endpoint_map[tool_parameters.get('model', 'core')],
self.model_endpoint_map[tool_parameters.get("model", "core")],
headers={
'accept': 'image/*',
"accept": "image/*",
**self.generate_authorization_headers(self.runtime.credentials),
},
files={
key: (None, str(value)) for key, value in payload.items()
},
timeout=(5, 30)
files={key: (None, str(value)) for key, value in payload.items()},
timeout=(5, 30),
)
if not response.status_code == 200:
raise Exception(response.text)
return self.create_blob_message(
blob=response.content, meta={
'mime_type': 'image/png'
},
save_as=self.VARIABLE_KEY.IMAGE.value
blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
)