fix: xinference chat support (#939)

This commit is contained in:
takatost
2023-08-21 20:44:29 +08:00
committed by GitHub
parent f53242c081
commit e0a48c4972
4 changed files with 204 additions and 16 deletions

View File

@@ -4,7 +4,6 @@ import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.replicate_provider import ReplicateProvider
from core.model_providers.providers.xinference_provider import XinferenceProvider
from models.provider import ProviderType, Provider, ProviderModel
@@ -25,7 +24,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.xinference.Xinference._call',
mocker.patch('core.third_party.langchain.llms.xinference_llm.XinferenceLLM._call',
return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
@@ -53,8 +52,15 @@ def test_is_credentials_valid_or_raise_invalid():
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt):
def test_encrypt_model_credentials(mock_encrypt, mocker):
api_key = 'http://127.0.0.1:9997/'
mocker.patch('core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials',
return_value={
'model_handle_type': 'generate',
'model_format': 'ggmlv3'
})
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',