feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

@@ -0,0 +1,35 @@
# OpenAI API Key
OPENAI_API_KEY=
# Azure OpenAI API Base Endpoint & API Key
AZURE_OPENAI_API_BASE=
AZURE_OPENAI_API_KEY=
# Anthropic API Key
ANTHROPIC_API_KEY=
# Replicate API Key
REPLICATE_API_TOKEN=
# Hugging Face API Key
HUGGINGFACE_API_KEY=
HUGGINGFACE_ENDPOINT_URL=
# Minimax Credentials
MINIMAX_API_KEY=
MINIMAX_GROUP_ID=
# Spark Credentials
SPARK_APP_ID=
SPARK_API_KEY=
SPARK_API_SECRET=
# Tongyi Credentials
TONGYI_DASHSCOPE_API_KEY=
# Wenxin Credentials
WENXIN_API_KEY=
WENXIN_SECRET_KEY=
# ChatGLM Credentials
CHATGLM_API_BASE=

View File

View File

@@ -0,0 +1,19 @@
import os
# Getting the absolute path of the current file's directory
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
# Getting the absolute path of the project's root directory
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
# Loading the .env file if it exists
def _load_env() -> None:
dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env")
if os.path.exists(dotenv_path):
from dotenv import load_dotenv
load_dotenv(dotenv_path)
_load_env()

View File

@@ -0,0 +1,57 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='azure_openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_azure_openai_embedding_model(mocker):
model_name = 'text-embedding-ada-002'
valid_openai_api_base = os.environ['AZURE_OPENAI_API_BASE']
valid_openai_api_key = os.environ['AZURE_OPENAI_API_KEY']
openai_provider = AzureOpenAIProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='azure_openai',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'openai_api_base': valid_openai_api_base,
'openai_api_key': valid_openai_api_key,
'base_model_name': model_name
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return AzureOpenAIEmbedding(
model_provider=openai_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt, mocker):
embedding_model = get_mock_azure_openai_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 1536

View File

@@ -0,0 +1,44 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
from core.model_providers.providers.minimax_provider import MinimaxProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_group_id, valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='minimax',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'minimax_group_id': valid_group_id,
'minimax_api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_embedding_model():
model_name = 'embo-01'
valid_api_key = os.environ['MINIMAX_API_KEY']
valid_group_id = os.environ['MINIMAX_GROUP_ID']
provider = MinimaxProvider(provider=get_mock_provider(valid_group_id, valid_api_key))
return MinimaxEmbedding(
model_provider=provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt):
embedding_model = get_mock_embedding_model()
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 1536

View File

@@ -0,0 +1,40 @@
import json
import os
from unittest.mock import patch
from core.model_providers.providers.openai_provider import OpenAIProvider
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from models.provider import Provider, ProviderType
def get_mock_provider(valid_openai_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}),
is_valid=True,
)
def get_mock_openai_embedding_model():
model_name = 'text-embedding-ada-002'
valid_openai_api_key = os.environ['OPENAI_API_KEY']
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIEmbedding(
model_provider=openai_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt):
embedding_model = get_mock_openai_embedding_model()
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 1536

View File

@@ -0,0 +1,64 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.embedding.replicate_embedding import ReplicateEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.replicate_provider import ReplicateProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='replicate',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(mocker):
model_name = 'replicate/all-mpnet-base-v2'
valid_api_key = os.environ['REPLICATE_API_TOKEN']
model_provider = ReplicateProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='replicate',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'replicate_api_token': valid_api_key,
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return ReplicateEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 768
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 768

View File

@@ -0,0 +1,61 @@
import json
import os
from unittest.mock import patch
from langchain.schema import ChatGeneration, AIMessage
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.providers.anthropic_provider import AnthropicProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='anthropic',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'anthropic_api_key': valid_api_key}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
valid_api_key = os.environ['ANTHROPIC_API_KEY']
model_provider = AnthropicProvider(provider=get_mock_provider(valid_api_key))
return AnthropicModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('claude-2')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 6
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
model = get_mock_model('claude-2')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: ')]
rst = model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == '2'

View File

@@ -0,0 +1,86 @@
import json
import os
from unittest.mock import patch, MagicMock
import pytest
from langchain.schema import ChatGeneration, AIMessage
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='azure_openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_azure_openai_model(model_name, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
valid_openai_api_base = os.environ['AZURE_OPENAI_API_BASE']
valid_openai_api_key = os.environ['AZURE_OPENAI_API_KEY']
provider = AzureOpenAIProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='azure_openai',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({
'openai_api_base': valid_openai_api_base,
'openai_api_key': valid_openai_api_key,
'base_model_name': model_name
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return AzureOpenAIModel(
model_provider=provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_azure_openai_model('text-davinci-003', mocker)
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst == 6
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 22
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
rst = openai_model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'n'

View File

@@ -0,0 +1,124 @@
import json
import os
from unittest.mock import patch, MagicMock
from langchain.schema import Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='huggingface_hub',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_model(model_name, huggingfacehub_api_type, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_api_key = os.environ['HUGGINGFACE_API_KEY']
endpoint_url = os.environ['HUGGINGFACE_ENDPOINT_URL']
model_provider = HuggingfaceHubProvider(provider=get_mock_provider())
credentials = {
'huggingfacehub_api_type': huggingfacehub_api_type,
'huggingfacehub_api_token': valid_api_key
}
if huggingfacehub_api_type == 'inference_endpoints':
credentials['huggingfacehub_endpoint_url'] = endpoint_url
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='huggingface_hub',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps(credentials),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return HuggingfaceHubModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('huggingface_hub.hf_api.ModelInfo')
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mocker):
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
model = get_mock_model(
'tiiuae/falcon-40b',
'hosted_inference_api',
mocker
)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 5
@patch('huggingface_hub.hf_api.ModelInfo')
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocker):
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
model = get_mock_model(
'',
'inference_endpoints',
mocker
)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 5
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_hosted_inference_api_run(mock_decrypt, mocker):
model = get_mock_model(
'google/flan-t5-base',
'hosted_inference_api',
mocker
)
rst = model.run(
[PromptMessage(content='Human: Are you Really Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'n'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_inference_endpoints_run(mock_decrypt, mocker):
model = get_mock_model(
'',
'inference_endpoints',
mocker
)
rst = model.run(
[PromptMessage(content='Answer the following yes/no question. Can you write a whole Haiku in a single tweet?')],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'no'

View File

@@ -0,0 +1,64 @@
import json
import os
from unittest.mock import patch
from langchain.schema import ChatGeneration, AIMessage, Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.minimax_provider import MinimaxProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_group_id, valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='minimax',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'minimax_group_id': valid_group_id,
'minimax_api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_api_key = os.environ['MINIMAX_API_KEY']
valid_group_id = os.environ['MINIMAX_GROUP_ID']
model_provider = MinimaxProvider(provider=get_mock_provider(valid_group_id, valid_api_key))
return MinimaxModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('abab5.5-chat')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 5
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
model = get_mock_model('abab5.5-chat')
rst = model.run(
[PromptMessage(content='Human: Are you a real Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'n'

View File

@@ -0,0 +1,80 @@
import json
import os
from unittest.mock import patch
from langchain.schema import Generation, ChatGeneration, AIMessage
from core.model_providers.providers.openai_provider import OpenAIProvider
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.openai_model import OpenAIModel
from models.provider import Provider, ProviderType
def get_mock_provider(valid_openai_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}),
is_valid=True,
)
def get_mock_openai_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
model_name = model_name
valid_openai_api_key = os.environ['OPENAI_API_KEY']
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIModel(
model_provider=openai_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('text-davinci-003')
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst == 6
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('gpt-3.5-turbo')
rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 22
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
openai_model = get_mock_openai_model('text-davinci-003')
rst = openai_model.run(
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'n'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_run(mock_decrypt):
openai_model = get_mock_openai_model('gpt-3.5-turbo')
messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
rst = openai_model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'n'

View File

@@ -0,0 +1,73 @@
import json
import os
from unittest.mock import patch, MagicMock
from langchain.schema import Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.replicate_provider import ReplicateProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='replicate',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_model(model_name, model_version, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_api_key = os.environ['REPLICATE_API_TOKEN']
model_provider = ReplicateProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='replicate',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({
'replicate_api_token': valid_api_key,
'model_version': model_version
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return ReplicateModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 7
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')]
rst = model.run(
messages
)
assert len(rst.content) > 0

View File

@@ -0,0 +1,69 @@
import json
import os
from unittest.mock import patch
from langchain.schema import ChatGeneration, AIMessage, Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.minimax_provider import MinimaxProvider
from core.model_providers.providers.spark_provider import SparkProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_app_id, valid_api_key, valid_api_secret):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='spark',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'app_id': valid_app_id,
'api_key': valid_api_key,
'api_secret': valid_api_secret,
}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_app_id = os.environ['SPARK_APP_ID']
valid_api_key = os.environ['SPARK_API_KEY']
valid_api_secret = os.environ['SPARK_API_SECRET']
model_provider = SparkProvider(provider=get_mock_provider(valid_app_id, valid_api_key, valid_api_secret))
return SparkModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('spark')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 6
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
model = get_mock_model('spark')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
rst = model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == '2'

View File

@@ -0,0 +1,61 @@
import json
import os
from unittest.mock import patch
from langchain.schema import ChatGeneration, AIMessage, Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.tongyi_model import TongyiModel
from core.model_providers.providers.tongyi_provider import TongyiProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='tongyi',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'dashscope_api_key': valid_api_key,
}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_api_key = os.environ['TONGYI_DASHSCOPE_API_KEY']
model_provider = TongyiProvider(provider=get_mock_provider(valid_api_key))
return TongyiModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('qwen-v1')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 5
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
model = get_mock_model('qwen-v1')
rst = model.run(
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0

View File

@@ -0,0 +1,63 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.wenxin_model import WenxinModel
from core.model_providers.providers.wenxin_provider import WenxinProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key, valid_secret_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='wenxin',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'api_key': valid_api_key,
'secret_key': valid_secret_key,
}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
temperature=0.01
)
valid_api_key = os.environ['WENXIN_API_KEY']
valid_secret_key = os.environ['WENXIN_SECRET_KEY']
model_provider = WenxinProvider(provider=get_mock_provider(valid_api_key, valid_secret_key))
return WenxinModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('ernie-bot')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 5
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
model = get_mock_model('ernie-bot')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
rst = model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == '2'

View File

@@ -0,0 +1,40 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_AUDIO_MODEL
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_openai_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}),
is_valid=True,
)
def get_mock_openai_moderation_model():
valid_openai_api_key = os.environ['OPENAI_API_KEY']
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIModeration(
model_provider=openai_provider,
name=DEFAULT_AUDIO_MODEL
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
model = get_mock_openai_moderation_model()
rst = model.run('hello')
assert isinstance(rst, dict)
assert 'id' in rst

View File

@@ -0,0 +1,50 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_openai_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}),
is_valid=True,
)
def get_mock_openai_whisper_model():
model_name = 'whisper-1'
valid_openai_api_key = os.environ['OPENAI_API_KEY']
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIWhisper(
model_provider=openai_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the path to the audio file
audio_file_path = os.path.join(current_dir, 'audio.mp3')
model = get_mock_openai_whisper_model()
# Open the file and get the file object
with open(audio_file_path, 'rb') as audio_file:
rst = model.run(audio_file)
assert isinstance(rst, dict)
assert rst['text'] == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'