chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -1,4 +1,5 @@
""" Provide the input parameters type for the cogview provider class """
"""Provide the input parameters type for the cogview provider class"""
from typing import Any
from core.tools.errors import ToolProviderCredentialValidationError
@@ -7,7 +8,8 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class COGVIEWProvider(BuiltinToolProviderController):
""" cogview provider """
"""cogview provider"""
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
CogView3Tool().fork_tool_runtime(
@@ -15,13 +17,12 @@ class COGVIEWProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
user_id="",
tool_parameters={
"prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。",
"size": "square",
"n": 1
"n": 1,
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e)) from e

View File

@@ -7,43 +7,42 @@ from core.tools.tool.builtin_tool import BuiltinTool
class CogView3Tool(BuiltinTool):
""" CogView3 Tool """
"""CogView3 Tool"""
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke CogView3 tool
"""
client = ZhipuAI(
base_url=self.runtime.credentials['zhipuai_base_url'],
api_key=self.runtime.credentials['zhipuai_api_key'],
base_url=self.runtime.credentials["zhipuai_base_url"],
api_key=self.runtime.credentials["zhipuai_api_key"],
)
size_mapping = {
'square': '1024x1024',
'vertical': '1024x1792',
'horizontal': '1792x1024',
"square": "1024x1024",
"vertical": "1024x1792",
"horizontal": "1792x1024",
}
# prompt
prompt = tool_parameters.get('prompt', '')
prompt = tool_parameters.get("prompt", "")
if not prompt:
return self.create_text_message('Please input prompt')
return self.create_text_message("Please input prompt")
# get size
size = size_mapping[tool_parameters.get('size', 'square')]
size = size_mapping[tool_parameters.get("size", "square")]
# get n
n = tool_parameters.get('n', 1)
n = tool_parameters.get("n", 1)
# get quality
quality = tool_parameters.get('quality', 'standard')
if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality')
quality = tool_parameters.get("quality", "standard")
if quality not in ["standard", "hd"]:
return self.create_text_message("Invalid quality")
# get style
style = tool_parameters.get('style', 'vivid')
if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style')
style = tool_parameters.get("style", "vivid")
if style not in ["natural", "vivid"]:
return self.create_text_message("Invalid style")
# set extra body
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
extra_body = {'seed': seed_id}
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
extra_body = {"seed": seed_id}
response = client.images.generations(
prompt=prompt,
model="cogview-3",
@@ -52,18 +51,22 @@ class CogView3Tool(BuiltinTool):
extra_body=extra_body,
style=style,
quality=quality,
response_format='b64_json'
response_format="b64_json",
)
result = []
for image in response.data:
result.append(self.create_image_message(image=image.url))
result.append(self.create_json_message({
"url": image.url,
}))
result.append(
self.create_json_message(
{
"url": image.url,
}
)
)
return result
@staticmethod
def _generate_random_id(length=8):
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
random_id = ''.join(random.choices(characters, k=length))
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
random_id = "".join(random.choices(characters, k=length))
return random_id