mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 18:23:07 +08:00
Feat: support azure openai for temporary (#101)
This commit is contained in:
@@ -4,9 +4,14 @@ from langchain.callbacks import CallbackManager
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
|
||||
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class LLMBuilder:
|
||||
@@ -31,16 +36,23 @@ class LLMBuilder:
|
||||
if model_name == 'fake':
|
||||
return FakeListLLM(responses=[])
|
||||
|
||||
provider = cls.get_default_provider(tenant_id)
|
||||
|
||||
mode = cls.get_mode_by_model(model_name)
|
||||
if mode == 'chat':
|
||||
# llm_cls = StreamableAzureChatOpenAI
|
||||
llm_cls = StreamableChatOpenAI
|
||||
if provider == 'openai':
|
||||
llm_cls = StreamableChatOpenAI
|
||||
else:
|
||||
llm_cls = StreamableAzureChatOpenAI
|
||||
elif mode == 'completion':
|
||||
llm_cls = StreamableOpenAI
|
||||
if provider == 'openai':
|
||||
llm_cls = StreamableOpenAI
|
||||
else:
|
||||
llm_cls = StreamableAzureOpenAI
|
||||
else:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
model_credentials = cls.get_model_credentials(tenant_id, model_name)
|
||||
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
||||
|
||||
return llm_cls(
|
||||
model_name=model_name,
|
||||
@@ -86,18 +98,31 @@ class LLMBuilder:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
@classmethod
|
||||
def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict:
|
||||
def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
|
||||
"""
|
||||
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
|
||||
Raises an exception if the model_name is not found or if the provider is not found.
|
||||
"""
|
||||
if not model_name:
|
||||
raise Exception('model name not found')
|
||||
#
|
||||
# if model_name not in llm_constant.models:
|
||||
# raise Exception('model {} not found'.format(model_name))
|
||||
|
||||
if model_name not in llm_constant.models:
|
||||
raise Exception('model {} not found'.format(model_name))
|
||||
|
||||
model_provider = llm_constant.models[model_name]
|
||||
# model_provider = llm_constant.models[model_name]
|
||||
|
||||
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
|
||||
return provider_service.get_credentials(model_name)
|
||||
|
||||
@classmethod
|
||||
def get_default_provider(cls, tenant_id: str) -> str:
|
||||
provider = BaseProvider.get_valid_provider(tenant_id)
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider_name = 'openai'
|
||||
else:
|
||||
provider_name = provider.provider_name
|
||||
|
||||
return provider_name
|
||||
|
||||
@@ -36,10 +36,9 @@ class AzureProvider(BaseProvider):
|
||||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary.
|
||||
"""
|
||||
encrypted_config = self.get_provider_api_key(model_id=model_id)
|
||||
config = json.loads(encrypted_config)
|
||||
config = self.get_provider_api_key(model_id=model_id)
|
||||
config['openai_api_type'] = 'azure'
|
||||
config['deployment_name'] = model_id
|
||||
config['deployment_name'] = model_id.replace('.', '')
|
||||
return config
|
||||
|
||||
def get_provider_name(self):
|
||||
@@ -51,12 +50,11 @@ class AzureProvider(BaseProvider):
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
config = json.loads(config)
|
||||
except:
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': 'https://foo.microsoft.com/bar',
|
||||
'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
@@ -65,7 +63,7 @@ class AzureProvider(BaseProvider):
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': 'https://foo.microsoft.com/bar',
|
||||
'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class BaseProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str:
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the decrypted API key for the given tenant_id and provider_name.
|
||||
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
|
||||
@@ -43,23 +43,35 @@ class BaseProvider(ABC):
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
"""
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.get_provider_name().value
|
||||
).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
|
||||
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
|
||||
|
||||
@classmethod
|
||||
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
"""
|
||||
query = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant_id
|
||||
)
|
||||
|
||||
if provider_name:
|
||||
query = query.filter(Provider.provider_name == provider_name)
|
||||
|
||||
providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
|
||||
|
||||
custom_provider = None
|
||||
system_provider = None
|
||||
|
||||
for provider in providers:
|
||||
if provider.provider_type == ProviderType.CUSTOM.value:
|
||||
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
|
||||
custom_provider = provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value:
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
|
||||
system_provider = provider
|
||||
|
||||
if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config:
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
elif system_provider and system_provider.is_valid:
|
||||
elif system_provider:
|
||||
return system_provider
|
||||
else:
|
||||
return None
|
||||
@@ -80,7 +92,7 @@ class BaseProvider(ABC):
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
except:
|
||||
config = 'THIS-IS-A-MOCK-TOKEN'
|
||||
config = ''
|
||||
|
||||
if obfuscated:
|
||||
return self.obfuscated_token(config)
|
||||
|
||||
@@ -1,12 +1,50 @@
|
||||
import requests
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
**super()._default_params,
|
||||
"engine": self.deployment_name,
|
||||
"api_type": self.openai_api_type,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
|
||||
64
api/core/llm/streamable_azure_open_ai.py
Normal file
64
api/core/llm/streamable_azure_open_ai.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
|
||||
from langchain.llms import AzureOpenAI
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List, Dict, Mapping, Any
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableAzureOpenAI(AzureOpenAI):
|
||||
openai_api_type: str = "azure"
|
||||
openai_api_version: str = ""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
if values["streaming"] and values["n"] > 1:
|
||||
raise ValueError("Cannot stream results when n > 1.")
|
||||
if values["streaming"] and values["best_of"] > 1:
|
||||
raise ValueError("Cannot stream results when best_of > 1.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**super()._invocation_params, **{
|
||||
"api_type": self.openai_api_type,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {**super()._identifying_params, **{
|
||||
"api_type": self.openai_api_type,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return super().generate(prompts, stop)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(prompts, stop)
|
||||
@@ -1,12 +1,52 @@
|
||||
import os
|
||||
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableChatOpenAI(ChatOpenAI):
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
**super()._default_params,
|
||||
"api_type": 'openai',
|
||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
|
||||
@@ -1,12 +1,54 @@
|
||||
import os
|
||||
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Any, Mapping
|
||||
from langchain import OpenAI
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableOpenAI(OpenAI):
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
if values["streaming"] and values["n"] > 1:
|
||||
raise ValueError("Cannot stream results when n > 1.")
|
||||
if values["streaming"] and values["best_of"] > 1:
|
||||
raise ValueError("Cannot stream results when best_of > 1.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**super()._invocation_params, **{
|
||||
"api_type": 'openai',
|
||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {**super()._identifying_params, **{
|
||||
"api_type": 'openai',
|
||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
|
||||
Reference in New Issue
Block a user