feat: use xinference client instead of xinference (#1339)

This commit is contained in:
takatost
2023-10-13 15:46:09 +08:00
committed by GitHub
parent 9822f687f7
commit 3efaa713da
5 changed files with 83 additions and 14 deletions

View File

@@ -1,21 +1,54 @@
from typing import List
from typing import List, Optional, Any
import numpy as np
from langchain.embeddings import XinferenceEmbeddings
from langchain.embeddings.base import Embeddings
from xinference_client.client.restful.restful_client import Client
class XinferenceEmbedding(XinferenceEmbeddings):
class XinferenceEmbeddings(Embeddings):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""
def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
):
super().__init__()
if server_url is None:
raise ValueError("Please provide server URL")
if model_uid is None:
raise ValueError("Please provide the model UID")
self.server_url = server_url
self.model_uid = model_uid
self.client = Client(server_url)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
vectors = super().embed_documents(texts)
model = self.client.get_model(self.model_uid)
embeddings = [
model.create_embedding(text)["data"][0]["embedding"] for text in texts
]
vectors = [list(map(float, e)) for e in embeddings]
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
return normalized_vectors
def embed_query(self, text: str) -> List[float]:
vector = super().embed_query(text)
model = self.client.get_model(self.model_uid)
embedding_res = model.create_embedding(text)
embedding = embedding_res["data"][0]["embedding"]
vector = list(map(float, embedding))
normalized_vector = (vector / np.linalg.norm(vector)).tolist()
return normalized_vector

View File

@@ -1,16 +1,53 @@
from typing import Optional, List, Any, Union, Generator
from typing import Optional, List, Any, Union, Generator, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Xinference
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from xinference.client import (
from xinference_client.client.restful.restful_client import (
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
RESTfulGenerateModelHandle, Client,
)
class XinferenceLLM(Xinference):
class XinferenceLLM(LLM):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""
def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
):
super().__init__(
**{
"server_url": server_url,
"model_uid": model_uid,
}
)
if self.server_url is None:
raise ValueError("Please provide server URL")
if self.model_uid is None:
raise ValueError("Please provide the model UID")
self.client = Client(server_url)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "xinference"
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"server_url": self.server_url},
**{"model_uid": self.model_uid},
}
def _call(
self,
prompt: str,