fix: inference embedding validate (#1187)

This commit is contained in:
takatost
2023-09-16 03:09:36 +08:00
committed by GitHub
parent ec5f585df4
commit c8bd76cd66
2 changed files with 17 additions and 8 deletions

View File

@@ -2,6 +2,7 @@ import json
from typing import Type
import requests
from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
'model_uid': credentials['model_uid'],
}
llm = XinferenceLLM(
**credential_kwargs
)
if model_type == ModelType.TEXT_GENERATION:
llm = XinferenceLLM(
**credential_kwargs
)
llm("ping")
llm("ping")
elif model_type == ModelType.EMBEDDINGS:
embedding = XinferenceEmbeddings(
**credential_kwargs
)
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider):
:param credentials:
:return:
"""
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
if model_type == ModelType.TEXT_GENERATION:
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])