mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-20 00:06:56 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -8,6 +8,7 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz
|
||||
"""
|
||||
This class is responsible for providing the stability tool.
|
||||
"""
|
||||
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
This method is responsible for validating the credentials.
|
||||
|
||||
@@ -9,26 +9,23 @@ class BaseStabilityAuthorization:
|
||||
"""
|
||||
This method is responsible for validating the credentials.
|
||||
"""
|
||||
api_key = credentials.get('api_key', '')
|
||||
api_key = credentials.get("api_key", "")
|
||||
if not api_key:
|
||||
raise ToolProviderCredentialValidationError('API key is required.')
|
||||
|
||||
raise ToolProviderCredentialValidationError("API key is required.")
|
||||
|
||||
response = requests.get(
|
||||
URL('https://api.stability.ai') / 'v1' / 'user' / 'account',
|
||||
URL("https://api.stability.ai") / "v1" / "user" / "account",
|
||||
headers=self.generate_authorization_headers(credentials),
|
||||
timeout=(5, 30)
|
||||
timeout=(5, 30),
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise ToolProviderCredentialValidationError('Invalid API key.')
|
||||
raise ToolProviderCredentialValidationError("Invalid API key.")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def generate_authorization_headers(self, credentials: dict) -> dict[str, str]:
|
||||
"""
|
||||
This method is responsible for generating the authorization headers.
|
||||
"""
|
||||
return {
|
||||
'Authorization': f'Bearer {credentials.get("api_key", "")}'
|
||||
}
|
||||
|
||||
return {"Authorization": f'Bearer {credentials.get("api_key", "")}'}
|
||||
|
||||
@@ -11,10 +11,11 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
|
||||
"""
|
||||
This class is responsible for providing the stable diffusion tool.
|
||||
"""
|
||||
|
||||
model_endpoint_map: dict[str, str] = {
|
||||
'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
|
||||
'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
|
||||
'core': 'https://api.stability.ai/v2beta/stable-image/generate/core',
|
||||
"sd3": "https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
||||
"sd3-turbo": "https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
||||
"core": "https://api.stability.ai/v2beta/stable-image/generate/core",
|
||||
}
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
@@ -22,39 +23,34 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
|
||||
Invoke the tool.
|
||||
"""
|
||||
payload = {
|
||||
'prompt': tool_parameters.get('prompt', ''),
|
||||
'aspect_ratio': tool_parameters.get('aspect_ratio', '16:9') or tool_parameters.get('aspect_radio', '16:9'),
|
||||
'mode': 'text-to-image',
|
||||
'seed': tool_parameters.get('seed', 0),
|
||||
'output_format': 'png',
|
||||
"prompt": tool_parameters.get("prompt", ""),
|
||||
"aspect_ratio": tool_parameters.get("aspect_ratio", "16:9") or tool_parameters.get("aspect_radio", "16:9"),
|
||||
"mode": "text-to-image",
|
||||
"seed": tool_parameters.get("seed", 0),
|
||||
"output_format": "png",
|
||||
}
|
||||
|
||||
model = tool_parameters.get('model', 'core')
|
||||
model = tool_parameters.get("model", "core")
|
||||
|
||||
if model in ['sd3', 'sd3-turbo']:
|
||||
payload['model'] = tool_parameters.get('model')
|
||||
if model in ["sd3", "sd3-turbo"]:
|
||||
payload["model"] = tool_parameters.get("model")
|
||||
|
||||
if not model == 'sd3-turbo':
|
||||
payload['negative_prompt'] = tool_parameters.get('negative_prompt', '')
|
||||
if not model == "sd3-turbo":
|
||||
payload["negative_prompt"] = tool_parameters.get("negative_prompt", "")
|
||||
|
||||
response = post(
|
||||
self.model_endpoint_map[tool_parameters.get('model', 'core')],
|
||||
self.model_endpoint_map[tool_parameters.get("model", "core")],
|
||||
headers={
|
||||
'accept': 'image/*',
|
||||
"accept": "image/*",
|
||||
**self.generate_authorization_headers(self.runtime.credentials),
|
||||
},
|
||||
files={
|
||||
key: (None, str(value)) for key, value in payload.items()
|
||||
},
|
||||
timeout=(5, 30)
|
||||
files={key: (None, str(value)) for key, value in payload.items()},
|
||||
timeout=(5, 30),
|
||||
)
|
||||
|
||||
if not response.status_code == 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
|
||||
return self.create_blob_message(
|
||||
blob=response.content, meta={
|
||||
'mime_type': 'image/png'
|
||||
},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user