feat: add LocalAI local embedding model support (#1021)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
takatost
2023-08-29 22:22:02 +08:00
committed by GitHub
parent b5953039de
commit 417c19577a
24 changed files with 1144 additions and 7 deletions

View File

@@ -63,6 +63,9 @@ class ModelProviderFactory:
elif provider_name == 'openllm':
from core.model_providers.providers.openllm_provider import OpenLLMProvider
return OpenLLMProvider
elif provider_name == 'localai':
from core.model_providers.providers.localai_provider import LocalAIProvider
return LocalAIProvider
else:
raise NotImplementedError

View File

@@ -0,0 +1,29 @@
from langchain.embeddings import LocalAIEmbeddings
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
class LocalAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = LocalAIEmbeddings(
model=name,
openai_api_key="1",
openai_api_base=credentials['server_url'],
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"LocalAI embedding: {str(ex)}")
else:
return ex

View File

@@ -0,0 +1,131 @@
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, get_buffer_string
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class LocalAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
if credentials['completion_type'] == 'chat_completion':
self.model_mode = ModelMode.CHAT
else:
self.model_mode = ModelMode.COMPLETION
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.model_mode == ModelMode.COMPLETION:
client = EnhanceOpenAI(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['server_url'] + '/v1',
**provider_model_kwargs
)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['server_url'] + '/v1'
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.model_mode == ModelMode.COMPLETION:
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to LocalAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to LocalAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("LocalAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@@ -0,0 +1,164 @@
import json
from typing import Type
from langchain.embeddings import LocalAIEmbeddings
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from models.provider import ProviderType
class LocalAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'localai'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = LocalAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = LocalAIEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=0.7),
top_p=KwargRule[float](min=0, max=1, default=1),
max_tokens=KwargRule[int](min=10, max=4097, default=16),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('LocalAI Server URL must be provided.')
try:
if model_type == ModelType.EMBEDDINGS:
model = LocalAIEmbeddings(
model=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url']
)
model.embed_query("ping")
else:
if ('completion_type' not in credentials
or credentials['completion_type'] not in ['completion', 'chat_completion']):
raise CredentialsValidateFailedError('LocalAI Completion Type must be provided.')
if credentials['completion_type'] == 'chat_completion':
model = EnhanceChatOpenAI(
model_name=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url'] + '/v1',
max_tokens=10,
request_timeout=60,
)
model([HumanMessage(content='ping')])
else:
model = EnhanceOpenAI(
model_name=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url'] + '/v1',
max_tokens=10,
request_timeout=60,
)
model('ping')
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'server_url': None,
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['server_url']:
credentials['server_url'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['server_url']
)
if obfuscated:
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@@ -10,5 +10,6 @@
"replicate",
"huggingface_hub",
"xinference",
"openllm"
"openllm",
"localai"
]

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@@ -42,7 +42,8 @@ class EnhanceChatOpenAI(ChatOpenAI):
return {
**super()._default_params,
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,

View File

@@ -1,7 +1,10 @@
import os
from typing import Dict, Any, Mapping, Optional, Union, Tuple
from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator
from langchain import OpenAI
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk
from langchain.schema.output import GenerationChunk
from pydantic import root_validator
@@ -33,7 +36,8 @@ class EnhanceOpenAI(OpenAI):
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
@@ -43,8 +47,33 @@ class EnhanceOpenAI(OpenAI):
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutates params
for stream_resp in completion_with_retry(
self, prompt=prompt, run_manager=run_manager, **params
):
if 'text' in stream_resp["choices"][0]:
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
else None,
)