mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 10:13:01 +08:00
Feat/blocking function call (#2247)
This commit is contained in:
@@ -19,58 +19,86 @@ class MockXinferenceClass(object):
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
if 'generate' == model_uid:
|
||||
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url)
|
||||
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if 'chat' == model_uid:
|
||||
return RESTfulChatModelHandle(model_uid, base_url=self.base_url)
|
||||
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if 'embedding' == model_uid:
|
||||
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url)
|
||||
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if 'rerank' == model_uid:
|
||||
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url)
|
||||
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
def get(self: Session, url: str, **kwargs):
|
||||
if '/v1/models/' in url:
|
||||
response = Response()
|
||||
|
||||
response = Response()
|
||||
if 'v1/models/' in url:
|
||||
# get model uid
|
||||
model_uid = url.split('/')[-1]
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
|
||||
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
|
||||
response.status_code = 404
|
||||
raise ConnectionError('404 Not Found')
|
||||
return response
|
||||
|
||||
# check if url is valid
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
|
||||
response.status_code = 404
|
||||
raise ConnectionError('404 Not Found')
|
||||
|
||||
return response
|
||||
|
||||
if model_uid in ['generate', 'chat']:
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
"model_type": "LLM",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
"0",
|
||||
"1"
|
||||
],
|
||||
"model_name": "chatglm3-6b",
|
||||
"model_lang": [
|
||||
"en"
|
||||
],
|
||||
"model_ability": [
|
||||
"generate",
|
||||
"chat"
|
||||
],
|
||||
"model_description": "latest chatglm3",
|
||||
"model_format": "pytorch",
|
||||
"model_size_in_billions": 7,
|
||||
"quantization": "none",
|
||||
"model_hub": "huggingface",
|
||||
"revision": null,
|
||||
"context_length": 2048,
|
||||
"replica": 1
|
||||
}'''
|
||||
return response
|
||||
|
||||
elif model_uid == 'embedding':
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
"model_type": "embedding",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
"0",
|
||||
"1"
|
||||
],
|
||||
"model_name": "bge",
|
||||
"model_lang": [
|
||||
"en"
|
||||
],
|
||||
"revision": null,
|
||||
"max_tokens": 512
|
||||
}'''
|
||||
return response
|
||||
|
||||
elif 'v1/cluster/auth' in url:
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
"model_type": "LLM",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
"0",
|
||||
"1"
|
||||
],
|
||||
"model_name": "chatglm3-6b",
|
||||
"model_lang": [
|
||||
"en"
|
||||
],
|
||||
"model_ability": [
|
||||
"generate",
|
||||
"chat"
|
||||
],
|
||||
"model_description": "latest chatglm3",
|
||||
"model_format": "pytorch",
|
||||
"model_size_in_billions": 7,
|
||||
"quantization": "none",
|
||||
"model_hub": "huggingface",
|
||||
"revision": null,
|
||||
"context_length": 2048,
|
||||
"replica": 1
|
||||
"auth": true
|
||||
}'''
|
||||
return response
|
||||
|
||||
def _check_cluster_authenticated(self):
|
||||
self._cluster_authed = True
|
||||
|
||||
def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
|
||||
@@ -133,6 +161,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
|
||||
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
|
||||
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
|
||||
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
|
||||
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
|
||||
|
||||
Reference in New Issue
Block a user