Model Runtime (#1858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
takatost
2024-01-02 23:42:00 +08:00
committed by GitHub
parent e91dd28a76
commit d069c668f8
807 changed files with 171310 additions and 23806 deletions

View File

@@ -0,0 +1,304 @@
import os
from typing import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
LLMResultChunkDelta
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_hub.llm.llm import HuggingfaceHubLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='HuggingFaceH4/zephyr-7b-beta',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'invalid_key'
}
)
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='fake-model',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'invalid_key'
}
)
model.validate_credentials(
model='HuggingFaceH4/zephyr-7b-beta',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
}
)
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='HuggingFaceH4/zephyr-7b-beta',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='HuggingFaceH4/zephyr-7b-beta',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='openchat/openchat_3.5',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
}
)
model.validate_credentials(
model='openchat/openchat_3.5',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
}
)
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='openchat/openchat_3.5',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='openchat/openchat_3.5',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
}
)
model.validate_credentials(
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
}
)
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = HuggingfaceHubLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens == 7

View File

@@ -0,0 +1,120 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_hub.text_embedding.text_embedding import \
HuggingfaceHubTextEmbeddingModel
def test_hosted_inference_api_validate_credentials():
model = HuggingfaceHubTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='facebook/bart-base',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'invalid_key',
}
)
model.validate_credentials(
model='facebook/bart-base',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
}
)
def test_hosted_inference_api_invoke_model():
model = HuggingfaceHubTextEmbeddingModel()
result = model.invoke(
model='facebook/bart-base',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
},
texts=[
"hello",
"world"
]
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_inference_endpoints_validate_credentials():
model = HuggingfaceHubTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='all-MiniLM-L6-v2',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
}
)
model.validate_credentials(
model='all-MiniLM-L6-v2',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
}
)
def test_inference_endpoints_invoke_model():
model = HuggingfaceHubTextEmbeddingModel()
result = model.invoke(
model='all-MiniLM-L6-v2',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
},
texts=[
"hello",
"world"
]
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 0
def test_get_num_tokens():
model = HuggingfaceHubTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='all-MiniLM-L6-v2',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2