mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 18:23:07 +08:00
Initial commit
This commit is contained in:
23
api/core/llm/provider/anthropic_provider.py
Normal file
23
api/core/llm/provider/anthropic_provider.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
# todo
|
||||
return []
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
|
||||
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
|
||||
"""
|
||||
return {
|
||||
'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.ANTHROPIC
|
||||
105
api/core/llm/provider/azure_provider.py
Normal file
105
api/core/llm/provider/azure_provider.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class AzureProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
url = "{}/openai/deployments?api-version={}".format(
|
||||
credentials.get('openai_api_base'),
|
||||
credentials.get('openai_api_version')
|
||||
)
|
||||
|
||||
headers = {
|
||||
"api-key": credentials.get('openai_api_key'),
|
||||
"content-type": "application/json; charset=utf-8"
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return [{
|
||||
'id': deployment['id'],
|
||||
'name': '{} ({})'.format(deployment['id'], deployment['model'])
|
||||
} for deployment in result['data'] if deployment['status'] == 'succeeded']
|
||||
else:
|
||||
# TODO: optimize in future
|
||||
raise Exception('Failed to get deployments from Azure OpenAI. Status code: {}'.format(response.status_code))
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary.
|
||||
"""
|
||||
encrypted_config = self.get_provider_api_key(model_id=model_id)
|
||||
config = json.loads(encrypted_config)
|
||||
config['openai_api_type'] = 'azure'
|
||||
config['deployment_name'] = model_id
|
||||
return config
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.AZURE_OPENAI
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
config = json.loads(config)
|
||||
except:
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': 'https://foo.microsoft.com/bar',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
if obfuscated:
|
||||
if not config.get('openai_api_key'):
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': 'https://foo.microsoft.com/bar',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
|
||||
return config
|
||||
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
# TODO: change to dict when implemented
|
||||
return lambda value: value
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
# TODO: implement
|
||||
pass
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
"""
|
||||
Returns the encrypted token.
|
||||
"""
|
||||
return json.dumps({
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': config['openai_api_base'],
|
||||
'openai_api_key': self.encrypt_token(config['openai_api_key'])
|
||||
})
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
"""
|
||||
Returns the decrypted token.
|
||||
"""
|
||||
config = json.loads(token)
|
||||
config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
|
||||
return config
|
||||
124
api/core/llm/provider/base.py
Normal file
124
api/core/llm/provider/base.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
from models.account import Tenant
|
||||
from models.provider import Provider, ProviderType, ProviderName
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str:
|
||||
"""
|
||||
Returns the decrypted API key for the given tenant_id and provider_name.
|
||||
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
|
||||
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
|
||||
"""
|
||||
provider = self.get_provider(prefer_custom)
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
|
||||
|
||||
if model_id and model_id == 'gpt-4':
|
||||
raise ModelCurrentlyNotSupportError()
|
||||
|
||||
if quota_used >= quota_limit:
|
||||
raise QuotaExceededError()
|
||||
|
||||
return self.get_hosted_credentials()
|
||||
else:
|
||||
return self.get_decrypted_token(provider.encrypted_config)
|
||||
|
||||
def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
"""
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.get_provider_name().value
|
||||
).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
|
||||
|
||||
custom_provider = None
|
||||
system_provider = None
|
||||
|
||||
for provider in providers:
|
||||
if provider.provider_type == ProviderType.CUSTOM.value:
|
||||
custom_provider = provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value:
|
||||
system_provider = provider
|
||||
|
||||
if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config:
|
||||
return custom_provider
|
||||
elif system_provider and system_provider.is_valid:
|
||||
return system_provider
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_hosted_credentials(self) -> str:
|
||||
if self.get_provider_name() != ProviderName.OPENAI:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
return hosted_llm_credentials.openai.api_key
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
except:
|
||||
config = 'THIS-IS-A-MOCK-TOKEN'
|
||||
|
||||
if obfuscated:
|
||||
return self.obfuscated_token(config)
|
||||
|
||||
return config
|
||||
|
||||
def obfuscated_token(self, token: str):
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
|
||||
def get_token_type(self):
|
||||
return str
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
return self.encrypt_token(config)
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
return self.decrypt_token(token)
|
||||
|
||||
def encrypt_token(self, token):
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
def decrypt_token(self, token):
|
||||
return rsa.decrypt(base64.b64decode(token), self.tenant_id)
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def config_validate(self, config: str):
|
||||
raise NotImplementedError
|
||||
2
api/core/llm/provider/errors.py
Normal file
2
api/core/llm/provider/errors.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class ValidateFailedError(Exception):
|
||||
description = "Provider Validate failed"
|
||||
22
api/core/llm/provider/huggingface_provider.py
Normal file
22
api/core/llm/provider/huggingface_provider.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class HuggingfaceProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
# todo
|
||||
return []
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Huggingface as a dictionary, for the given tenant_id.
|
||||
"""
|
||||
return {
|
||||
'huggingface_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.HUGGINGFACEHUB
|
||||
53
api/core/llm/provider/llm_provider_service.py
Normal file
53
api/core/llm/provider/llm_provider_service.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.llm.provider.anthropic_provider import AnthropicProvider
|
||||
from core.llm.provider.azure_provider import AzureProvider
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.huggingface_provider import HuggingfaceProvider
|
||||
from core.llm.provider.openai_provider import OpenAIProvider
|
||||
from models.provider import Provider
|
||||
|
||||
|
||||
class LLMProviderService:
|
||||
|
||||
def __init__(self, tenant_id: str, provider_name: str):
|
||||
self.provider = self.init_provider(tenant_id, provider_name)
|
||||
|
||||
def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
|
||||
if provider_name == 'openai':
|
||||
return OpenAIProvider(tenant_id)
|
||||
elif provider_name == 'azure_openai':
|
||||
return AzureProvider(tenant_id)
|
||||
elif provider_name == 'anthropic':
|
||||
return AnthropicProvider(tenant_id)
|
||||
elif provider_name == 'huggingface':
|
||||
return HuggingfaceProvider(tenant_id)
|
||||
else:
|
||||
raise Exception('provider {} not found'.format(provider_name))
|
||||
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
return self.provider.get_models(model_id)
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
return self.provider.get_credentials(model_id)
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
return self.provider.get_provider_configs(obfuscated)
|
||||
|
||||
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
|
||||
return self.provider.get_provider(prefer_custom)
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
|
||||
:param config:
|
||||
:raises: ValidateFailedError
|
||||
"""
|
||||
return self.provider.config_validate(config)
|
||||
|
||||
def get_token_type(self):
|
||||
return self.provider.get_token_type()
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
return self.provider.get_encrypted_token(config)
|
||||
44
api/core/llm/provider/openai_provider.py
Normal file
44
api/core/llm/provider/openai_provider.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import openai
|
||||
from openai.error import AuthenticationError, OpenAIError
|
||||
|
||||
from core.llm.moderation import Moderation
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
response = openai.Model.list(**credentials)
|
||||
|
||||
return [{
|
||||
'id': model['id'],
|
||||
'name': model['id'],
|
||||
} for model in response['data']]
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the credentials for the given tenant_id and provider_name.
|
||||
"""
|
||||
return {
|
||||
'openai_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.OPENAI
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
try:
|
||||
Moderation(self.get_provider_name().value, config).moderate('test')
|
||||
except (AuthenticationError, OpenAIError) as ex:
|
||||
raise ValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('OpenAI config validation failed')
|
||||
raise ex
|
||||
Reference in New Issue
Block a user