mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-23 17:55:46 +08:00
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -19,40 +19,43 @@ from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
|
||||
|
||||
|
||||
class MockXinferenceClass:
|
||||
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
|
||||
if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
if 'generate' == model_uid:
|
||||
def get_chat_model(
|
||||
self: Client, model_uid: str
|
||||
) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
|
||||
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if "generate" == model_uid:
|
||||
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if 'chat' == model_uid:
|
||||
if "chat" == model_uid:
|
||||
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if 'embedding' == model_uid:
|
||||
if "embedding" == model_uid:
|
||||
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if 'rerank' == model_uid:
|
||||
if "rerank" == model_uid:
|
||||
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
def get(self: Session, url: str, **kwargs):
|
||||
response = Response()
|
||||
if 'v1/models/' in url:
|
||||
if "v1/models/" in url:
|
||||
# get model uid
|
||||
model_uid = url.split('/')[-1] or ''
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
|
||||
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
|
||||
model_uid = url.split("/")[-1] or ""
|
||||
if not re.match(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
|
||||
) and model_uid not in ["generate", "chat", "embedding", "rerank"]:
|
||||
response.status_code = 404
|
||||
response._content = b'{}'
|
||||
response._content = b"{}"
|
||||
return response
|
||||
|
||||
# check if url is valid
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
|
||||
response.status_code = 404
|
||||
response._content = b'{}'
|
||||
response._content = b"{}"
|
||||
return response
|
||||
|
||||
if model_uid in ['generate', 'chat']:
|
||||
|
||||
if model_uid in ["generate", "chat"]:
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
response._content = b"""{
|
||||
"model_type": "LLM",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
@@ -75,12 +78,12 @@ class MockXinferenceClass:
|
||||
"revision": null,
|
||||
"context_length": 2048,
|
||||
"replica": 1
|
||||
}'''
|
||||
}"""
|
||||
return response
|
||||
|
||||
elif model_uid == 'embedding':
|
||||
|
||||
elif model_uid == "embedding":
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
response._content = b"""{
|
||||
"model_type": "embedding",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
@@ -93,51 +96,48 @@ class MockXinferenceClass:
|
||||
],
|
||||
"revision": null,
|
||||
"max_tokens": 512
|
||||
}'''
|
||||
}"""
|
||||
return response
|
||||
|
||||
elif 'v1/cluster/auth' in url:
|
||||
|
||||
elif "v1/cluster/auth" in url:
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
response._content = b"""{
|
||||
"auth": true
|
||||
}'''
|
||||
}"""
|
||||
return response
|
||||
|
||||
|
||||
def _check_cluster_authenticated(self):
|
||||
self._cluster_authed = True
|
||||
|
||||
def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
|
||||
|
||||
def rerank(
|
||||
self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
|
||||
) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
|
||||
self._model_uid != 'rerank':
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
|
||||
raise RuntimeError('404 Not Found')
|
||||
if (
|
||||
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
|
||||
and self._model_uid != "rerank"
|
||||
):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if top_n is None:
|
||||
top_n = 1
|
||||
|
||||
return {
|
||||
'results': [
|
||||
{
|
||||
'index': i,
|
||||
'document': doc,
|
||||
'relevance_score': 0.9
|
||||
}
|
||||
for i, doc in enumerate(documents[:top_n])
|
||||
"results": [
|
||||
{"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
|
||||
]
|
||||
}
|
||||
|
||||
def create_embedding(
|
||||
self: RESTfulGenerateModelHandle,
|
||||
input: Union[str, list[str]],
|
||||
**kwargs
|
||||
) -> dict:
|
||||
|
||||
def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
|
||||
self._model_uid != 'embedding':
|
||||
raise RuntimeError('404 Not Found')
|
||||
if (
|
||||
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
|
||||
and self._model_uid != "embedding"
|
||||
):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
@@ -147,32 +147,27 @@ class MockXinferenceClass:
|
||||
object="list",
|
||||
model=self._model_uid,
|
||||
data=[
|
||||
EmbeddingData(
|
||||
index=i,
|
||||
object="embedding",
|
||||
embedding=[1919.810 for _ in range(768)]
|
||||
)
|
||||
EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
|
||||
for i in range(ipt_len)
|
||||
],
|
||||
usage=EmbeddingUsage(
|
||||
prompt_tokens=ipt_len,
|
||||
total_tokens=ipt_len
|
||||
)
|
||||
usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
|
||||
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
|
||||
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
|
||||
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
|
||||
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
|
||||
monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
|
||||
monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
|
||||
monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
|
||||
monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
|
||||
monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
monkeypatch.undo()
|
||||
|
||||
Reference in New Issue
Block a user