mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-15 22:06:52 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -7,15 +7,12 @@ class DIDProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
# Example validation using the D-ID talks tool
|
||||
TalksTool().fork_tool_runtime(
|
||||
runtime={"credentials": credentials}
|
||||
).invoke(
|
||||
user_id='',
|
||||
TalksTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"source_url": "https://www.d-id.com/wp-content/uploads/2023/11/Hero-image-1.png",
|
||||
"text_input": "Hello, welcome to use D-ID tool in Dify",
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@@ -12,14 +12,14 @@ logger = logging.getLogger(__name__)
|
||||
class DIDApp:
|
||||
def __init__(self, api_key: str | None = None, base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or 'https://api.d-id.com'
|
||||
self.base_url = base_url or "https://api.d-id.com"
|
||||
if not self.api_key:
|
||||
raise ValueError('API key is required')
|
||||
raise ValueError("API key is required")
|
||||
|
||||
def _prepare_headers(self, idempotency_key: str | None = None):
|
||||
headers = {'Content-Type': 'application/json', 'Authorization': f'Basic {self.api_key}'}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Basic {self.api_key}"}
|
||||
if idempotency_key:
|
||||
headers['Idempotency-Key'] = idempotency_key
|
||||
headers["Idempotency-Key"] = idempotency_key
|
||||
return headers
|
||||
|
||||
def _request(
|
||||
@@ -44,44 +44,44 @@ class DIDApp:
|
||||
return None
|
||||
|
||||
def talks(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs):
|
||||
endpoint = f'{self.base_url}/talks'
|
||||
endpoint = f"{self.base_url}/talks"
|
||||
headers = self._prepare_headers(idempotency_key)
|
||||
data = kwargs['params']
|
||||
logger.debug(f'Send request to {endpoint=} body={data}')
|
||||
response = self._request('POST', endpoint, data, headers)
|
||||
data = kwargs["params"]
|
||||
logger.debug(f"Send request to {endpoint=} body={data}")
|
||||
response = self._request("POST", endpoint, data, headers)
|
||||
if response is None:
|
||||
raise HTTPError('Failed to initiate D-ID talks after multiple retries')
|
||||
id: str = response['id']
|
||||
raise HTTPError("Failed to initiate D-ID talks after multiple retries")
|
||||
id: str = response["id"]
|
||||
if wait:
|
||||
return self._monitor_job_status(id=id, target='talks', poll_interval=poll_interval)
|
||||
return self._monitor_job_status(id=id, target="talks", poll_interval=poll_interval)
|
||||
return id
|
||||
|
||||
def animations(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs):
|
||||
endpoint = f'{self.base_url}/animations'
|
||||
endpoint = f"{self.base_url}/animations"
|
||||
headers = self._prepare_headers(idempotency_key)
|
||||
data = kwargs['params']
|
||||
logger.debug(f'Send request to {endpoint=} body={data}')
|
||||
response = self._request('POST', endpoint, data, headers)
|
||||
data = kwargs["params"]
|
||||
logger.debug(f"Send request to {endpoint=} body={data}")
|
||||
response = self._request("POST", endpoint, data, headers)
|
||||
if response is None:
|
||||
raise HTTPError('Failed to initiate D-ID talks after multiple retries')
|
||||
id: str = response['id']
|
||||
raise HTTPError("Failed to initiate D-ID talks after multiple retries")
|
||||
id: str = response["id"]
|
||||
if wait:
|
||||
return self._monitor_job_status(target='animations', id=id, poll_interval=poll_interval)
|
||||
return self._monitor_job_status(target="animations", id=id, poll_interval=poll_interval)
|
||||
return id
|
||||
|
||||
def check_did_status(self, target: str, id: str):
|
||||
endpoint = f'{self.base_url}/{target}/{id}'
|
||||
endpoint = f"{self.base_url}/{target}/{id}"
|
||||
headers = self._prepare_headers()
|
||||
response = self._request('GET', endpoint, headers=headers)
|
||||
response = self._request("GET", endpoint, headers=headers)
|
||||
if response is None:
|
||||
raise HTTPError(f'Failed to check status for talks {id} after multiple retries')
|
||||
raise HTTPError(f"Failed to check status for talks {id} after multiple retries")
|
||||
return response
|
||||
|
||||
def _monitor_job_status(self, target: str, id: str, poll_interval: int):
|
||||
while True:
|
||||
status = self.check_did_status(target=target, id=id)
|
||||
if status['status'] == 'done':
|
||||
if status["status"] == "done":
|
||||
return status
|
||||
elif status['status'] == 'error' or status['status'] == 'rejected':
|
||||
elif status["status"] == "error" or status["status"] == "rejected":
|
||||
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}')
|
||||
time.sleep(poll_interval)
|
||||
|
||||
@@ -10,33 +10,33 @@ class AnimationsTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url'])
|
||||
app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"])
|
||||
|
||||
driver_expressions_str = tool_parameters.get('driver_expressions')
|
||||
driver_expressions_str = tool_parameters.get("driver_expressions")
|
||||
driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None
|
||||
|
||||
config = {
|
||||
'stitch': tool_parameters.get('stitch', True),
|
||||
'mute': tool_parameters.get('mute'),
|
||||
'result_format': tool_parameters.get('result_format') or 'mp4',
|
||||
"stitch": tool_parameters.get("stitch", True),
|
||||
"mute": tool_parameters.get("mute"),
|
||||
"result_format": tool_parameters.get("result_format") or "mp4",
|
||||
}
|
||||
config = {k: v for k, v in config.items() if v is not None and v != ''}
|
||||
config = {k: v for k, v in config.items() if v is not None and v != ""}
|
||||
|
||||
options = {
|
||||
'source_url': tool_parameters['source_url'],
|
||||
'driver_url': tool_parameters.get('driver_url'),
|
||||
'config': config,
|
||||
"source_url": tool_parameters["source_url"],
|
||||
"driver_url": tool_parameters.get("driver_url"),
|
||||
"config": config,
|
||||
}
|
||||
options = {k: v for k, v in options.items() if v is not None and v != ''}
|
||||
options = {k: v for k, v in options.items() if v is not None and v != ""}
|
||||
|
||||
if not options.get('source_url'):
|
||||
raise ValueError('Source URL is required')
|
||||
if not options.get("source_url"):
|
||||
raise ValueError("Source URL is required")
|
||||
|
||||
if config.get('logo_url'):
|
||||
if not config.get('logo_x'):
|
||||
raise ValueError('Logo X position is required when logo URL is provided')
|
||||
if not config.get('logo_y'):
|
||||
raise ValueError('Logo Y position is required when logo URL is provided')
|
||||
if config.get("logo_url"):
|
||||
if not config.get("logo_x"):
|
||||
raise ValueError("Logo X position is required when logo URL is provided")
|
||||
if not config.get("logo_y"):
|
||||
raise ValueError("Logo Y position is required when logo URL is provided")
|
||||
|
||||
animations_result = app.animations(params=options, wait=True)
|
||||
|
||||
@@ -44,6 +44,6 @@ class AnimationsTool(BuiltinTool):
|
||||
animations_result = json.dumps(animations_result, ensure_ascii=False, indent=4)
|
||||
|
||||
if not animations_result:
|
||||
return self.create_text_message('D-ID animations request failed.')
|
||||
return self.create_text_message("D-ID animations request failed.")
|
||||
|
||||
return self.create_text_message(animations_result)
|
||||
|
||||
@@ -10,49 +10,49 @@ class TalksTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url'])
|
||||
app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"])
|
||||
|
||||
driver_expressions_str = tool_parameters.get('driver_expressions')
|
||||
driver_expressions_str = tool_parameters.get("driver_expressions")
|
||||
driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None
|
||||
|
||||
script = {
|
||||
'type': tool_parameters.get('script_type') or 'text',
|
||||
'input': tool_parameters.get('text_input'),
|
||||
'audio_url': tool_parameters.get('audio_url'),
|
||||
'reduce_noise': tool_parameters.get('audio_reduce_noise', False),
|
||||
"type": tool_parameters.get("script_type") or "text",
|
||||
"input": tool_parameters.get("text_input"),
|
||||
"audio_url": tool_parameters.get("audio_url"),
|
||||
"reduce_noise": tool_parameters.get("audio_reduce_noise", False),
|
||||
}
|
||||
script = {k: v for k, v in script.items() if v is not None and v != ''}
|
||||
script = {k: v for k, v in script.items() if v is not None and v != ""}
|
||||
config = {
|
||||
'stitch': tool_parameters.get('stitch', True),
|
||||
'sharpen': tool_parameters.get('sharpen'),
|
||||
'fluent': tool_parameters.get('fluent'),
|
||||
'result_format': tool_parameters.get('result_format') or 'mp4',
|
||||
'pad_audio': tool_parameters.get('pad_audio'),
|
||||
'driver_expressions': driver_expressions,
|
||||
"stitch": tool_parameters.get("stitch", True),
|
||||
"sharpen": tool_parameters.get("sharpen"),
|
||||
"fluent": tool_parameters.get("fluent"),
|
||||
"result_format": tool_parameters.get("result_format") or "mp4",
|
||||
"pad_audio": tool_parameters.get("pad_audio"),
|
||||
"driver_expressions": driver_expressions,
|
||||
}
|
||||
config = {k: v for k, v in config.items() if v is not None and v != ''}
|
||||
config = {k: v for k, v in config.items() if v is not None and v != ""}
|
||||
|
||||
options = {
|
||||
'source_url': tool_parameters['source_url'],
|
||||
'driver_url': tool_parameters.get('driver_url'),
|
||||
'script': script,
|
||||
'config': config,
|
||||
"source_url": tool_parameters["source_url"],
|
||||
"driver_url": tool_parameters.get("driver_url"),
|
||||
"script": script,
|
||||
"config": config,
|
||||
}
|
||||
options = {k: v for k, v in options.items() if v is not None and v != ''}
|
||||
options = {k: v for k, v in options.items() if v is not None and v != ""}
|
||||
|
||||
if not options.get('source_url'):
|
||||
raise ValueError('Source URL is required')
|
||||
if not options.get("source_url"):
|
||||
raise ValueError("Source URL is required")
|
||||
|
||||
if script.get('type') == 'audio':
|
||||
script.pop('input', None)
|
||||
if not script.get('audio_url'):
|
||||
raise ValueError('Audio URL is required for audio script type')
|
||||
if script.get("type") == "audio":
|
||||
script.pop("input", None)
|
||||
if not script.get("audio_url"):
|
||||
raise ValueError("Audio URL is required for audio script type")
|
||||
|
||||
if script.get('type') == 'text':
|
||||
script.pop('audio_url', None)
|
||||
script.pop('reduce_noise', None)
|
||||
if not script.get('input'):
|
||||
raise ValueError('Text input is required for text script type')
|
||||
if script.get("type") == "text":
|
||||
script.pop("audio_url", None)
|
||||
script.pop("reduce_noise", None)
|
||||
if not script.get("input"):
|
||||
raise ValueError("Text input is required for text script type")
|
||||
|
||||
talks_result = app.talks(params=options, wait=True)
|
||||
|
||||
@@ -60,6 +60,6 @@ class TalksTool(BuiltinTool):
|
||||
talks_result = json.dumps(talks_result, ensure_ascii=False, indent=4)
|
||||
|
||||
if not talks_result:
|
||||
return self.create_text_message('D-ID talks request failed.')
|
||||
return self.create_text_message("D-ID talks request failed.")
|
||||
|
||||
return self.create_text_message(talks_result)
|
||||
|
||||
Reference in New Issue
Block a user