mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 10:56:52 +08:00
feat: universal chat in explore (#649)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
63
api/core/tool/provider/base.py
Normal file
63
api/core/tool/provider/base.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
from models.account import Tenant
|
||||
from models.tool import ToolProvider, ToolProviderName
|
||||
|
||||
|
||||
class BaseToolProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_name(self) -> ToolProviderName:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def credentials_validate(self, credentials: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and tool_name.
|
||||
"""
|
||||
query = db.session.query(ToolProvider).filter(
|
||||
ToolProvider.tenant_id == self.tenant_id,
|
||||
ToolProvider.tool_name == self.get_provider_name().value
|
||||
)
|
||||
|
||||
if must_enabled:
|
||||
query = query.filter(ToolProvider.is_enabled == True)
|
||||
|
||||
return query.first()
|
||||
|
||||
def encrypt_token(self, token) -> str:
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
|
||||
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
|
||||
|
||||
if obfuscated:
|
||||
return self._obfuscated_token(token)
|
||||
|
||||
return token
|
||||
|
||||
def _obfuscated_token(self, token: str) -> str:
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
2
api/core/tool/provider/errors.py
Normal file
2
api/core/tool/provider/errors.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class ToolValidateFailedError(Exception):
|
||||
description = "Tool Provider Validate failed"
|
||||
77
api/core/tool/provider/serpapi_provider.py
Normal file
77
api/core/tool/provider/serpapi_provider.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.tool.provider.base import BaseToolProvider
|
||||
from core.tool.provider.errors import ToolValidateFailedError
|
||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
|
||||
from models.tool import ToolProviderName
|
||||
|
||||
|
||||
class SerpAPIToolProvider(BaseToolProvider):
|
||||
def get_provider_name(self) -> ToolProviderName:
|
||||
"""
|
||||
Returns the name of the provider.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return ToolProviderName.SERPAPI
|
||||
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
Returns the credentials for SerpAPI as a dictionary.
|
||||
|
||||
:param obfuscated: obfuscate credentials if True
|
||||
:return:
|
||||
"""
|
||||
tool_provider = self.get_provider(must_enabled=True)
|
||||
if not tool_provider:
|
||||
return None
|
||||
|
||||
credentials = tool_provider.credentials
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
if credentials.get('api_key'):
|
||||
credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
|
||||
|
||||
return credentials
|
||||
|
||||
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
||||
"""
|
||||
Returns the credentials function kwargs as a dictionary.
|
||||
|
||||
:return:
|
||||
"""
|
||||
credentials = self.get_credentials()
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
return {
|
||||
'serpapi_api_key': credentials.get('api_key')
|
||||
}
|
||||
|
||||
def credentials_validate(self, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
if 'api_key' not in credentials or not credentials.get('api_key'):
|
||||
raise ToolValidateFailedError("SerpAPI api_key is required.")
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
|
||||
try:
|
||||
OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
|
||||
except Exception as e:
|
||||
raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
|
||||
|
||||
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
||||
"""
|
||||
Encrypts the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
|
||||
return credentials
|
||||
43
api/core/tool/provider/tool_provider_service.py
Normal file
43
api/core/tool/provider/tool_provider_service.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.tool.provider.base import BaseToolProvider
|
||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||
|
||||
|
||||
class ToolProviderService:
|
||||
|
||||
def __init__(self, tenant_id: str, provider_name: str):
|
||||
self.provider = self._init_provider(tenant_id, provider_name)
|
||||
|
||||
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
|
||||
if provider_name == 'serpapi':
|
||||
return SerpAPIToolProvider(tenant_id)
|
||||
else:
|
||||
raise Exception('tool provider {} not found'.format(provider_name))
|
||||
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
Returns the credentials for Tool as a dictionary.
|
||||
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.provider.get_credentials(obfuscated)
|
||||
|
||||
def credentials_validate(self, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:raises: ValidateFailedError
|
||||
"""
|
||||
return self.provider.credentials_validate(credentials)
|
||||
|
||||
def encrypt_credentials(self, credentials: dict):
|
||||
"""
|
||||
Encrypts the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return self.provider.encrypt_credentials(credentials)
|
||||
Reference in New Issue
Block a user