mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-23 17:55:46 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
""" Provide the input parameters type for the cogview provider class """
|
||||
"""Provide the input parameters type for the cogview provider class"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
@@ -7,7 +8,8 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
||||
|
||||
|
||||
class COGVIEWProvider(BuiltinToolProviderController):
|
||||
""" cogview provider """
|
||||
"""cogview provider"""
|
||||
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
CogView3Tool().fork_tool_runtime(
|
||||
@@ -15,13 +17,12 @@ class COGVIEWProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。",
|
||||
"size": "square",
|
||||
"n": 1
|
||||
"n": 1,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e)) from e
|
||||
|
||||
@@ -7,43 +7,42 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CogView3Tool(BuiltinTool):
|
||||
""" CogView3 Tool """
|
||||
"""CogView3 Tool"""
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke CogView3 tool
|
||||
"""
|
||||
client = ZhipuAI(
|
||||
base_url=self.runtime.credentials['zhipuai_base_url'],
|
||||
api_key=self.runtime.credentials['zhipuai_api_key'],
|
||||
base_url=self.runtime.credentials["zhipuai_base_url"],
|
||||
api_key=self.runtime.credentials["zhipuai_api_key"],
|
||||
)
|
||||
size_mapping = {
|
||||
'square': '1024x1024',
|
||||
'vertical': '1024x1792',
|
||||
'horizontal': '1792x1024',
|
||||
"square": "1024x1024",
|
||||
"vertical": "1024x1792",
|
||||
"horizontal": "1792x1024",
|
||||
}
|
||||
# prompt
|
||||
prompt = tool_parameters.get('prompt', '')
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
return self.create_text_message("Please input prompt")
|
||||
# get size
|
||||
size = size_mapping[tool_parameters.get('size', 'square')]
|
||||
size = size_mapping[tool_parameters.get("size", "square")]
|
||||
# get n
|
||||
n = tool_parameters.get('n', 1)
|
||||
n = tool_parameters.get("n", 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get('quality', 'standard')
|
||||
if quality not in ['standard', 'hd']:
|
||||
return self.create_text_message('Invalid quality')
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
if quality not in ["standard", "hd"]:
|
||||
return self.create_text_message("Invalid quality")
|
||||
# get style
|
||||
style = tool_parameters.get('style', 'vivid')
|
||||
if style not in ['natural', 'vivid']:
|
||||
return self.create_text_message('Invalid style')
|
||||
style = tool_parameters.get("style", "vivid")
|
||||
if style not in ["natural", "vivid"]:
|
||||
return self.create_text_message("Invalid style")
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
|
||||
extra_body = {'seed': seed_id}
|
||||
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
||||
extra_body = {"seed": seed_id}
|
||||
response = client.images.generations(
|
||||
prompt=prompt,
|
||||
model="cogview-3",
|
||||
@@ -52,18 +51,22 @@ class CogView3Tool(BuiltinTool):
|
||||
extra_body=extra_body,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format='b64_json'
|
||||
response_format="b64_json",
|
||||
)
|
||||
result = []
|
||||
for image in response.data:
|
||||
result.append(self.create_image_message(image=image.url))
|
||||
result.append(self.create_json_message({
|
||||
"url": image.url,
|
||||
}))
|
||||
result.append(
|
||||
self.create_json_message(
|
||||
{
|
||||
"url": image.url,
|
||||
}
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||
random_id = ''.join(random.choices(characters, k=length))
|
||||
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
random_id = "".join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
|
||||
Reference in New Issue
Block a user