feat: server xinference support (#927)

This commit is contained in:
takatost
2023-08-20 17:46:41 +08:00
committed by GitHub
parent 8c991b5b26
commit da3f10a55e
18 changed files with 456 additions and 17 deletions

View File

@@ -57,6 +57,9 @@ class ModelProviderFactory:
elif provider_name == 'huggingface_hub':
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
return HuggingfaceHubProvider
elif provider_name == 'xinference':
from core.model_providers.providers.xinference_provider import XinferenceProvider
return XinferenceProvider
else:
raise NotImplementedError

View File

@@ -0,0 +1,69 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import Xinference
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class XinferenceModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
client = Xinference(
**self.credentials,
)
client.callbacks = self.callbacks
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate(
[prompts],
stop,
callbacks,
generate_config={
"stop": stop,
"stream": self.streaming,
**self.provider_model_kwargs,
}
)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Xinference: {str(ex)}")
@classmethod
def support_streaming(cls):
return True

View File

@@ -0,0 +1,141 @@
import json
from typing import Type
from langchain.llms import Xinference
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from models.provider import ProviderType
class XinferenceProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'xinference'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = XinferenceModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](alias='max_token', min=10, max=4000, default=256),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('Xinference Server URL must be provided.')
if 'model_uid' not in credentials:
raise CredentialsValidateFailedError('Xinference Model UID must be provided.')
try:
credential_kwargs = {
'server_url': credentials['server_url'],
'model_uid': credentials['model_uid'],
}
llm = Xinference(
**credential_kwargs
)
llm("ping", generate_config={'max_tokens': 10})
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'server_url': None,
'model_uid': None,
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['server_url']:
credentials['server_url'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['server_url']
)
if obfuscated:
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@@ -8,5 +8,6 @@
"wenxin",
"chatglm",
"replicate",
"huggingface_hub"
"huggingface_hub",
"xinference"
]

View File

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}