feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

@@ -1,50 +0,0 @@
# -*- coding:utf-8 -*-
import pytest
import flask_migrate
from app import create_app
from extensions.ext_database import db
@pytest.fixture(scope='module')
def test_client():
# Create a Flask app configured for testing
from config import TestConfig
flask_app = create_app(TestConfig())
flask_app.config.from_object('config.TestingConfig')
# Create a test client using the Flask application configured for testing
with flask_app.test_client() as testing_client:
# Establish an application context
with flask_app.app_context():
yield testing_client # this is where the testing happens!
@pytest.fixture(scope='module')
def init_database(test_client):
# Initialize the database
with test_client.application.app_context():
flask_migrate.upgrade()
yield # this is where the testing happens!
# Clean up the database
with test_client.application.app_context():
flask_migrate.downgrade()
@pytest.fixture(scope='module')
def db_session(test_client):
with test_client.application.app_context():
yield db.session
@pytest.fixture(scope='function')
def login_default_user(test_client):
# todo
yield # this is where the testing happens!
test_client.get('/logout', follow_redirects=True)

View File

@@ -0,0 +1,35 @@
# OpenAI API Key
OPENAI_API_KEY=
# Azure OpenAI API Base Endpoint & API Key
AZURE_OPENAI_API_BASE=
AZURE_OPENAI_API_KEY=
# Anthropic API Key
ANTHROPIC_API_KEY=
# Replicate API Key
REPLICATE_API_TOKEN=
# Hugging Face API Key
HUGGINGFACE_API_KEY=
HUGGINGFACE_ENDPOINT_URL=
# Minimax Credentials
MINIMAX_API_KEY=
MINIMAX_GROUP_ID=
# Spark Credentials
SPARK_APP_ID=
SPARK_API_KEY=
SPARK_API_SECRET=
# Tongyi Credentials
TONGYI_DASHSCOPE_API_KEY=
# Wenxin Credentials
WENXIN_API_KEY=
WENXIN_SECRET_KEY=
# ChatGLM Credentials
CHATGLM_API_BASE=

View File

@@ -0,0 +1,19 @@
import os
# Getting the absolute path of the current file's directory
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
# Getting the absolute path of the project's root directory
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
# Loading the .env file if it exists
def _load_env() -> None:
dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env")
if os.path.exists(dotenv_path):
from dotenv import load_dotenv
load_dotenv(dotenv_path)
_load_env()

View File

@@ -0,0 +1,57 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='azure_openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_azure_openai_embedding_model(mocker):
model_name = 'text-embedding-ada-002'
valid_openai_api_base = os.environ['AZURE_OPENAI_API_BASE']
valid_openai_api_key = os.environ['AZURE_OPENAI_API_KEY']
openai_provider = AzureOpenAIProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='azure_openai',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'openai_api_base': valid_openai_api_base,
'openai_api_key': valid_openai_api_key,
'base_model_name': model_name
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return AzureOpenAIEmbedding(
model_provider=openai_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt, mocker):
embedding_model = get_mock_azure_openai_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 1536

View File

@@ -0,0 +1,44 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
from core.model_providers.providers.minimax_provider import MinimaxProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_group_id, valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='minimax',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'minimax_group_id': valid_group_id,
'minimax_api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_embedding_model():
model_name = 'embo-01'
valid_api_key = os.environ['MINIMAX_API_KEY']
valid_group_id = os.environ['MINIMAX_GROUP_ID']
provider = MinimaxProvider(provider=get_mock_provider(valid_group_id, valid_api_key))
return MinimaxEmbedding(
model_provider=provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt):
embedding_model = get_mock_embedding_model()
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 1536

View File

@@ -0,0 +1,40 @@
import json
import os
from unittest.mock import patch
from core.model_providers.providers.openai_provider import OpenAIProvider
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from models.provider import Provider, ProviderType
def get_mock_provider(valid_openai_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}),
is_valid=True,
)
def get_mock_openai_embedding_model():
model_name = 'text-embedding-ada-002'
valid_openai_api_key = os.environ['OPENAI_API_KEY']
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIEmbedding(
model_provider=openai_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt):
embedding_model = get_mock_openai_embedding_model()
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 1536

View File

@@ -0,0 +1,64 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.embedding.replicate_embedding import ReplicateEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.replicate_provider import ReplicateProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='replicate',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(mocker):
model_name = 'replicate/all-mpnet-base-v2'
valid_api_key = os.environ['REPLICATE_API_TOKEN']
model_provider = ReplicateProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='replicate',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'replicate_api_token': valid_api_key,
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return ReplicateEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 768
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 768

View File

@@ -0,0 +1,61 @@
import json
import os
from unittest.mock import patch
from langchain.schema import ChatGeneration, AIMessage
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.providers.anthropic_provider import AnthropicProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='anthropic',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'anthropic_api_key': valid_api_key}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
valid_api_key = os.environ['ANTHROPIC_API_KEY']
model_provider = AnthropicProvider(provider=get_mock_provider(valid_api_key))
return AnthropicModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('claude-2')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 6
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
model = get_mock_model('claude-2')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: ')]
rst = model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == '2'

View File

@@ -0,0 +1,86 @@
import json
import os
from unittest.mock import patch, MagicMock
import pytest
from langchain.schema import ChatGeneration, AIMessage
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='azure_openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_azure_openai_model(model_name, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
valid_openai_api_base = os.environ['AZURE_OPENAI_API_BASE']
valid_openai_api_key = os.environ['AZURE_OPENAI_API_KEY']
provider = AzureOpenAIProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='azure_openai',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({
'openai_api_base': valid_openai_api_base,
'openai_api_key': valid_openai_api_key,
'base_model_name': model_name
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return AzureOpenAIModel(
model_provider=provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_azure_openai_model('text-davinci-003', mocker)
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst == 6
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 22
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
rst = openai_model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'n'

View File

@@ -0,0 +1,124 @@
import json
import os
from unittest.mock import patch, MagicMock
from langchain.schema import Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='huggingface_hub',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_model(model_name, huggingfacehub_api_type, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_api_key = os.environ['HUGGINGFACE_API_KEY']
endpoint_url = os.environ['HUGGINGFACE_ENDPOINT_URL']
model_provider = HuggingfaceHubProvider(provider=get_mock_provider())
credentials = {
'huggingfacehub_api_type': huggingfacehub_api_type,
'huggingfacehub_api_token': valid_api_key
}
if huggingfacehub_api_type == 'inference_endpoints':
credentials['huggingfacehub_endpoint_url'] = endpoint_url
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='huggingface_hub',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps(credentials),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return HuggingfaceHubModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('huggingface_hub.hf_api.ModelInfo')
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mocker):
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
model = get_mock_model(
'tiiuae/falcon-40b',
'hosted_inference_api',
mocker
)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 5
@patch('huggingface_hub.hf_api.ModelInfo')
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocker):
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
model = get_mock_model(
'',
'inference_endpoints',
mocker
)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 5
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_hosted_inference_api_run(mock_decrypt, mocker):
model = get_mock_model(
'google/flan-t5-base',
'hosted_inference_api',
mocker
)
rst = model.run(
[PromptMessage(content='Human: Are you Really Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'n'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_inference_endpoints_run(mock_decrypt, mocker):
model = get_mock_model(
'',
'inference_endpoints',
mocker
)
rst = model.run(
[PromptMessage(content='Answer the following yes/no question. Can you write a whole Haiku in a single tweet?')],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'no'

View File

@@ -0,0 +1,64 @@
import json
import os
from unittest.mock import patch
from langchain.schema import ChatGeneration, AIMessage, Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.minimax_provider import MinimaxProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_group_id, valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='minimax',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'minimax_group_id': valid_group_id,
'minimax_api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_api_key = os.environ['MINIMAX_API_KEY']
valid_group_id = os.environ['MINIMAX_GROUP_ID']
model_provider = MinimaxProvider(provider=get_mock_provider(valid_group_id, valid_api_key))
return MinimaxModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('abab5.5-chat')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 5
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
model = get_mock_model('abab5.5-chat')
rst = model.run(
[PromptMessage(content='Human: Are you a real Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'n'

View File

@@ -0,0 +1,80 @@
import json
import os
from unittest.mock import patch
from langchain.schema import Generation, ChatGeneration, AIMessage
from core.model_providers.providers.openai_provider import OpenAIProvider
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.openai_model import OpenAIModel
from models.provider import Provider, ProviderType
def get_mock_provider(valid_openai_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}),
is_valid=True,
)
def get_mock_openai_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
model_name = model_name
valid_openai_api_key = os.environ['OPENAI_API_KEY']
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIModel(
model_provider=openai_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('text-davinci-003')
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst == 6
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('gpt-3.5-turbo')
rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 22
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
openai_model = get_mock_openai_model('text-davinci-003')
rst = openai_model.run(
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'n'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_run(mock_decrypt):
openai_model = get_mock_openai_model('gpt-3.5-turbo')
messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
rst = openai_model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == 'n'

View File

@@ -0,0 +1,73 @@
import json
import os
from unittest.mock import patch, MagicMock
from langchain.schema import Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.replicate_provider import ReplicateProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='replicate',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_model(model_name, model_version, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_api_key = os.environ['REPLICATE_API_TOKEN']
model_provider = ReplicateProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='replicate',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({
'replicate_api_token': valid_api_key,
'model_version': model_version
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return ReplicateModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 7
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')]
rst = model.run(
messages
)
assert len(rst.content) > 0

View File

@@ -0,0 +1,69 @@
import json
import os
from unittest.mock import patch
from langchain.schema import ChatGeneration, AIMessage, Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.minimax_provider import MinimaxProvider
from core.model_providers.providers.spark_provider import SparkProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_app_id, valid_api_key, valid_api_secret):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='spark',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'app_id': valid_app_id,
'api_key': valid_api_key,
'api_secret': valid_api_secret,
}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_app_id = os.environ['SPARK_APP_ID']
valid_api_key = os.environ['SPARK_API_KEY']
valid_api_secret = os.environ['SPARK_API_SECRET']
model_provider = SparkProvider(provider=get_mock_provider(valid_app_id, valid_api_key, valid_api_secret))
return SparkModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('spark')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 6
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
model = get_mock_model('spark')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
rst = model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == '2'

View File

@@ -0,0 +1,61 @@
import json
import os
from unittest.mock import patch
from langchain.schema import ChatGeneration, AIMessage, Generation
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.tongyi_model import TongyiModel
from core.model_providers.providers.tongyi_provider import TongyiProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='tongyi',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'dashscope_api_key': valid_api_key,
}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
valid_api_key = os.environ['TONGYI_DASHSCOPE_API_KEY']
model_provider = TongyiProvider(provider=get_mock_provider(valid_api_key))
return TongyiModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('qwen-v1')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 5
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
model = get_mock_model('qwen-v1')
rst = model.run(
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0

View File

@@ -0,0 +1,63 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.wenxin_model import WenxinModel
from core.model_providers.providers.wenxin_provider import WenxinProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key, valid_secret_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='wenxin',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'api_key': valid_api_key,
'secret_key': valid_secret_key,
}),
is_valid=True,
)
def get_mock_model(model_name):
model_kwargs = ModelKwargs(
temperature=0.01
)
valid_api_key = os.environ['WENXIN_API_KEY']
valid_secret_key = os.environ['WENXIN_SECRET_KEY']
model_provider = WenxinProvider(provider=get_mock_provider(valid_api_key, valid_secret_key))
return WenxinModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('ernie-bot')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 5
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
model = get_mock_model('ernie-bot')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
rst = model.run(
messages,
stop=['\nHuman:'],
)
assert len(rst.content) > 0
assert rst.content.strip() == '2'

View File

@@ -0,0 +1,40 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_AUDIO_MODEL
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_openai_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}),
is_valid=True,
)
def get_mock_openai_moderation_model():
valid_openai_api_key = os.environ['OPENAI_API_KEY']
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIModeration(
model_provider=openai_provider,
name=DEFAULT_AUDIO_MODEL
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
model = get_mock_openai_moderation_model()
rst = model.run('hello')
assert isinstance(rst, dict)
assert 'id' in rst

View File

@@ -0,0 +1,50 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_openai_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}),
is_valid=True,
)
def get_mock_openai_whisper_model():
model_name = 'whisper-1'
valid_openai_api_key = os.environ['OPENAI_API_KEY']
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIWhisper(
model_provider=openai_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt):
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the path to the audio file
audio_file_path = os.path.join(current_dir, 'audio.mp3')
model = get_mock_openai_whisper_model()
# Open the file and get the file object
with open(audio_file_path, 'rb') as audio_file:
rst = model.run(audio_file)
assert isinstance(rst, dict)
assert rst['text'] == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'

View File

@@ -1,75 +0,0 @@
import json
import pytest
from flask import url_for
from models.model import Account
# Sample user data for testing
sample_user_data = {
'name': 'Test User',
'email': 'test@example.com',
'interface_language': 'en-US',
'interface_theme': 'light',
'timezone': 'America/New_York',
'password': 'testpassword',
'new_password': 'newtestpassword',
'repeat_new_password': 'newtestpassword'
}
# Create a test user and log them in
@pytest.fixture(scope='function')
def logged_in_user(client, session):
# Create test user and add them to the database
# Replace this with your actual User model and any required fields
# todo refer to api.controllers.setup.SetupApi.post() to create a user
db_user_data = sample_user_data.copy()
db_user_data['password_salt'] = 'testpasswordsalt'
del db_user_data['new_password']
del db_user_data['repeat_new_password']
test_user = Account(**db_user_data)
session.add(test_user)
session.commit()
# Log in the test user
client.post(url_for('console.loginapi'), data={'email': sample_user_data['email'], 'password': sample_user_data['password']})
return test_user
def test_account_profile(logged_in_user, client):
response = client.get(url_for('console.accountprofileapi'))
assert response.status_code == 200
assert json.loads(response.data)['name'] == sample_user_data['name']
def test_account_name(logged_in_user, client):
new_name = 'New Test User'
response = client.post(url_for('console.accountnameapi'), json={'name': new_name})
assert response.status_code == 200
assert json.loads(response.data)['name'] == new_name
def test_account_interface_language(logged_in_user, client):
new_language = 'zh-CN'
response = client.post(url_for('console.accountinterfacelanguageapi'), json={'interface_language': new_language})
assert response.status_code == 200
assert json.loads(response.data)['interface_language'] == new_language
def test_account_interface_theme(logged_in_user, client):
new_theme = 'dark'
response = client.post(url_for('console.accountinterfacethemeapi'), json={'interface_theme': new_theme})
assert response.status_code == 200
assert json.loads(response.data)['interface_theme'] == new_theme
def test_account_timezone(logged_in_user, client):
new_timezone = 'Asia/Shanghai'
response = client.post(url_for('console.accounttimezoneapi'), json={'timezone': new_timezone})
assert response.status_code == 200
assert json.loads(response.data)['timezone'] == new_timezone
def test_account_password(logged_in_user, client):
response = client.post(url_for('console.accountpasswordapi'), json={
'password': sample_user_data['password'],
'new_password': sample_user_data['new_password'],
'repeat_new_password': sample_user_data['repeat_new_password']
})
assert response.status_code == 200
assert json.loads(response.data)['result'] == 'success'

View File

@@ -1,108 +0,0 @@
import pytest
from app import create_app, db
from flask_login import current_user
from models.model import Account, TenantAccountJoin, Tenant
@pytest.fixture
def client(test_client, db_session):
app = create_app()
app.config["TESTING"] = True
with app.app_context():
db.create_all()
yield test_client
db.drop_all()
def test_login_api_post(client, db_session):
# create a tenant, account, and tenant account join
tenant = Tenant(name="Test Tenant", status="normal")
account = Account(email="test@test.com", name="Test User")
account.password_salt = "uQ7K0/0wUJ7VPhf3qBzwNQ=="
account.password = "A9YpfzjK7c/tOwzamrvpJg=="
db.session.add_all([tenant, account])
db.session.flush()
tenant_account_join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, is_tenant_owner=True)
db.session.add(tenant_account_join)
db.session.commit()
# login with correct credentials
response = client.post("/login", json={
"email": "test@test.com",
"password": "Abc123456",
"remember_me": True
})
assert response.status_code == 200
assert response.json == {"result": "success"}
assert current_user == account
assert 'tenant_id' in client.session
assert client.session['tenant_id'] == tenant.id
# login with incorrect password
response = client.post("/login", json={
"email": "test@test.com",
"password": "wrong_password",
"remember_me": True
})
assert response.status_code == 401
# login with non-existent account
response = client.post("/login", json={
"email": "non_existent_account@test.com",
"password": "Abc123456",
"remember_me": True
})
assert response.status_code == 401
def test_logout_api_get(client, db_session):
# create a tenant, account, and tenant account join
tenant = Tenant(name="Test Tenant", status="normal")
account = Account(email="test@test.com", name="Test User")
db.session.add_all([tenant, account])
db.session.flush()
tenant_account_join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, is_tenant_owner=True)
db.session.add(tenant_account_join)
db.session.commit()
# login and check if session variable and current_user are set
with client.session_transaction() as session:
session['tenant_id'] = tenant.id
client.post("/login", json={
"email": "test@test.com",
"password": "Abc123456",
"remember_me": True
})
assert current_user == account
assert 'tenant_id' in client.session
assert client.session['tenant_id'] == tenant.id
# logout and check if session variable and current_user are unset
response = client.get("/logout")
assert response.status_code == 200
assert current_user.is_authenticated is False
assert 'tenant_id' not in client.session
def test_reset_password_api_get(client, db_session):
# create a tenant, account, and tenant account join
tenant = Tenant(name="Test Tenant", status="normal")
account = Account(email="test@test.com", name="Test User")
db.session.add_all([tenant, account])
db.session.flush()
tenant_account_join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, is_tenant_owner=True)
db.session.add(tenant_account_join)
db.session.commit()
# reset password in cloud edition
app = client.application
app.config["CLOUD_EDITION"] = True
response = client.get("/reset_password")
assert response.status_code == 200
assert response.json == {"result": "success"}
# reset password in non-cloud edition
app.config["CLOUD_EDITION"] = False
response = client.get("/reset_password")
assert response.status_code == 200
assert response.json == {"result": "success"}

View File

@@ -1,80 +0,0 @@
import os
import pytest
from models.model import Account, Tenant, TenantAccountJoin
def test_setup_api_get(test_client,db_session):
response = test_client.get("/setup")
assert response.status_code == 200
assert response.json == {"step": "not_start"}
# create a tenant and check again
tenant = Tenant(name="Test Tenant", status="normal")
db_session.add(tenant)
db_session.commit()
response = test_client.get("/setup")
assert response.status_code == 200
assert response.json == {"step": "step2"}
# create setup file and check again
response = test_client.get("/setup")
assert response.status_code == 200
assert response.json == {"step": "finished"}
def test_setup_api_post(test_client):
response = test_client.post("/setup", json={
"email": "test@test.com",
"name": "Test User",
"password": "Abc123456"
})
assert response.status_code == 200
assert response.json == {"result": "success", "next_step": "step2"}
# check if the tenant, account, and tenant account join records were created
tenant = Tenant.query.first()
assert tenant.name == "Test User's LLM Factory"
assert tenant.status == "normal"
assert tenant.encrypt_public_key
account = Account.query.first()
assert account.email == "test@test.com"
assert account.name == "Test User"
assert account.password_salt
assert account.password
assert TenantAccountJoin.query.filter_by(account_id=account.id, is_tenant_owner=True).count() == 1
# check if password is encrypted correctly
salt = account.password_salt.encode()
password_hashed = account.password.encode()
assert account.password == base64.b64encode(hash_password("Abc123456", salt)).decode()
def test_setup_step2_api_post(test_client,db_session):
# create a tenant, account, and setup file
tenant = Tenant(name="Test Tenant", status="normal")
account = Account(email="test@test.com", name="Test User")
db_session.add_all([tenant, account])
db_session.commit()
# try to set up with incorrect language
response = test_client.post("/setup/step2", json={
"interface_language": "invalid_language",
"timezone": "Asia/Shanghai"
})
assert response.status_code == 400
# set up successfully
response = test_client.post("/setup/step2", json={
"interface_language": "en",
"timezone": "Asia/Shanghai"
})
assert response.status_code == 200
assert response.json == {"result": "success", "next_step": "finished"}
# check if account was updated correctly
account = Account.query.first()
assert account.interface_language == "en"
assert account.timezone == "Asia/Shanghai"
assert account.interface_theme == "light"
assert account.last_login_ip == "127.0.0.1"

View File

@@ -1,22 +0,0 @@
# -*- coding:utf-8 -*-
import pytest
from app import create_app
def test_create_app():
# Test Default(CE) Config
app = create_app()
assert app.config['SECRET_KEY'] is not None
assert app.config['SQLALCHEMY_DATABASE_URI'] is not None
assert app.config['EDITION'] == "SELF_HOSTED"
# Test TestConfig
from config import TestConfig
test_app = create_app(TestConfig())
assert test_app.config['SECRET_KEY'] is not None
assert test_app.config['SQLALCHEMY_DATABASE_URI'] is not None
assert test_app.config['TESTING'] is True

View File

View File

@@ -0,0 +1,44 @@
from typing import Type
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.providers.base import BaseModelProvider
class FakeModelProvider(BaseModelProvider):
@property
def provider_name(self):
return 'fake'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [{'id': 'test_model', 'name': 'Test Model'}]
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
return OpenAIModel
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
pass
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
pass
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
return credentials
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
return ModelKwargsRules()
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
return {}

View File

@@ -0,0 +1,123 @@
from typing import List, Optional, Any
import anthropic
import httpx
import pytest
from unittest.mock import patch
import json
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.schema import BaseMessage, ChatResult, ChatGeneration, AIMessage
from core.model_providers.providers.anthropic_provider import AnthropicProvider
from core.model_providers.providers.base import CredentialsValidateFailedError
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'anthropic'
MODEL_PROVIDER_CLASS = AnthropicProvider
VALIDATE_CREDENTIAL_KEY = 'anthropic_api_key'
def mock_chat_generate(messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content='answer'))])
def mock_chat_generate_invalid(messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any):
raise anthropic.APIStatusError('Invalid credentials',
request=httpx._models.Request(
method='POST',
url='https://api.anthropic.com/v1/completions',
),
response=httpx._models.Response(
status_code=401,
),
body=None
)
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
@patch('langchain.chat_models.ChatAnthropic._generate', side_effect=mock_chat_generate)
def test_is_provider_credentials_valid_or_raise_valid(mock_create):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'valid_key'})
@patch('langchain.chat_models.ChatAnthropic._generate', side_effect=mock_chat_generate_invalid)
def test_is_provider_credentials_valid_or_raise_invalid(mock_create):
# raise CredentialsValidateFailedError if anthropic_api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
# raise CredentialsValidateFailedError if anthropic_api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'invalid_key'})
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', {VALIDATE_CREDENTIAL_KEY: api_key})
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result[VALIDATE_CREDENTIAL_KEY] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: 'encrypted_valid_key'}),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
api_key = 'valid_key'
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: f'encrypted_{api_key}'}),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result[VALIDATE_CREDENTIAL_KEY][6:-2]
assert len(middle_token) == max(len(api_key) - 8, 0)
assert all(char == '*' for char in middle_token)
@patch('core.model_providers.providers.hosted.hosted_model_providers.anthropic')
def test_get_credentials_hosted(mock_hosted):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.SYSTEM.value,
encrypted_config='',
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
mock_hosted.api_key = 'hosted_key'
result = model_provider.get_provider_credentials()
assert result[VALIDATE_CREDENTIAL_KEY] == 'hosted_key'

View File

@@ -0,0 +1,117 @@
import pytest
from unittest.mock import patch, MagicMock
import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
from core.model_providers.providers.base import CredentialsValidateFailedError
from models.provider import ProviderType, Provider, ProviderModel
PROVIDER_NAME = 'azure_openai'
MODEL_PROVIDER_CLASS = AzureOpenAIProvider
VALIDATE_CREDENTIAL = {
'openai_api_base': 'https://xxxx.openai.azure.com/',
'openai_api_key': 'valid_key',
'base_model_name': 'gpt-35-turbo'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_model_credentials_valid_or_raise(mocker):
mocker.patch('langchain.chat_models.base.BaseChatModel.generate', return_value=None)
# assert True if credentials is valid
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=VALIDATE_CREDENTIAL
)
def test_is_model_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if credentials is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={}
)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt):
openai_api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={'openai_api_key': openai_api_key}
)
mock_encrypt.assert_called_with('tenant_id', openai_api_key)
assert result['openai_api_key'] == f'encrypted_{openai_api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_custom(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['openai_api_key'] = 'encrypted_' + encrypted_credential['openai_api_key']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION
)
assert result['openai_api_key'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['openai_api_key'] = 'encrypted_' + encrypted_credential['openai_api_key']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
obfuscated=True
)
middle_token = result['openai_api_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['openai_api_key']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -0,0 +1,72 @@
from unittest.mock import MagicMock
import pytest
from core.model_providers.error import QuotaExceededError
from core.model_providers.models.entity.model_params import ModelType
from models.provider import Provider, ProviderType
from tests.unit_tests.model_providers.fake_model_provider import FakeModelProvider
def test_get_supported_model_list(mocker):
mocker.patch.object(
FakeModelProvider,
'get_rules',
return_value={'support_provider_types': ['custom'], 'model_flexibility': 'configurable'}
)
mock_provider_model = MagicMock()
mock_provider_model.model_name = 'test_model'
mock_query = MagicMock()
mock_query.filter.return_value.order_by.return_value.all.return_value = [mock_provider_model]
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
provider = FakeModelProvider(provider=Provider())
result = provider.get_supported_model_list(ModelType.TEXT_GENERATION)
assert result == [{'id': 'test_model', 'name': 'test_model'}]
def test_check_quota_over_limit(mocker):
mocker.patch.object(
FakeModelProvider,
'get_rules',
return_value={'support_provider_types': ['system']}
)
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = None
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.SYSTEM.value))
with pytest.raises(QuotaExceededError):
provider.check_quota_over_limit()
def test_check_quota_not_over_limit(mocker):
mocker.patch.object(
FakeModelProvider,
'get_rules',
return_value={'support_provider_types': ['system']}
)
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = Provider()
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.SYSTEM.value))
assert provider.check_quota_over_limit() is None
def test_check_custom_quota_over_limit(mocker):
mocker.patch.object(
FakeModelProvider,
'get_rules',
return_value={'support_provider_types': ['custom']}
)
provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.CUSTOM.value))
assert provider.check_quota_over_limit() is None

View File

@@ -0,0 +1,89 @@
import pytest
from unittest.mock import patch
import json
from langchain.schema import LLMResult, Generation, AIMessage, ChatResult, ChatGeneration
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
from core.model_providers.providers.spark_provider import SparkProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'chatglm'
MODEL_PROVIDER_CLASS = ChatGLMProvider
VALIDATE_CREDENTIAL = {
'api_base': 'valid_api_base',
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.chatglm.ChatGLM._call',
return_value="abc")
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_base'] = 'invalid_api_base'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
assert result['api_base'] == f'encrypted_{VALIDATE_CREDENTIAL["api_base"]}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_base'] = 'encrypted_' + encrypted_credential['api_base']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['api_base'] == 'valid_api_base'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_base'] = 'encrypted_' + encrypted_credential['api_base']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['api_base'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_base']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -0,0 +1,161 @@
import pytest
from unittest.mock import patch, MagicMock
import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
from models.provider import ProviderType, Provider, ProviderModel
PROVIDER_NAME = 'huggingface_hub'
MODEL_PROVIDER_CLASS = HuggingfaceHubProvider
HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = {
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'valid_key'
}
INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = {
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'valid_key',
'huggingfacehub_endpoint_url': 'valid_url'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
@patch('huggingface_hub.hf_api.ModelInfo')
def test_hosted_inference_api_is_credentials_valid_or_raise_valid(mock_model_info, mocker):
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL
)
@patch('huggingface_hub.hf_api.ModelInfo')
def test_hosted_inference_api_is_credentials_valid_or_raise_invalid(mock_model_info):
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={}
)
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
})
def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
mocker.patch('langchain.llms.huggingface_endpoint.HuggingFaceEndpoint._call', return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL
)
def test_inference_endpoints_is_credentials_valid_or_raise_invalid(mocker):
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={}
)
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_endpoint_url': 'valid_url'
})
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL.copy()
)
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result['huggingfacehub_api_token'] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_custom(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL.copy()
encrypted_credential['huggingfacehub_api_token'] = 'encrypted_' + encrypted_credential['huggingfacehub_api_token']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION
)
assert result['huggingfacehub_api_token'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL.copy()
encrypted_credential['huggingfacehub_api_token'] = 'encrypted_' + encrypted_credential['huggingfacehub_api_token']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
obfuscated=True
)
middle_token = result['huggingfacehub_api_token'][6:-2]
assert len(middle_token) == max(len(INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL['huggingfacehub_api_token']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -0,0 +1,88 @@
import pytest
from unittest.mock import patch
import json
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.minimax_provider import MinimaxProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'minimax'
MODEL_PROVIDER_CLASS = MinimaxProvider
VALIDATE_CREDENTIAL = {
'minimax_group_id': 'fake-group-id',
'minimax_api_key': 'valid_key'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.minimax.Minimax._call', return_value='abc')
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['minimax_api_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result['minimax_api_key'] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['minimax_api_key'] = 'encrypted_' + encrypted_credential['minimax_api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['minimax_api_key'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['minimax_api_key'] = 'encrypted_' + encrypted_credential['minimax_api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['minimax_api_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['minimax_api_key']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -0,0 +1,126 @@
import pytest
from unittest.mock import patch, MagicMock
import json
from openai.error import AuthenticationError
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'openai'
MODEL_PROVIDER_CLASS = OpenAIProvider
VALIDATE_CREDENTIAL_KEY = 'openai_api_key'
def moderation_side_effect(*args, **kwargs):
if kwargs['api_key'] == 'valid_key':
mock_instance = MagicMock()
mock_instance.request = MagicMock()
return mock_instance, {}
else:
raise AuthenticationError('Invalid credentials')
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
@patch('openai.ChatCompletion.create', side_effect=moderation_side_effect)
def test_is_provider_credentials_valid_or_raise_valid(mock_create):
# assert True if api_key is valid
credentials = {VALIDATE_CREDENTIAL_KEY: 'valid_key'}
assert MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credentials) is None
@patch('openai.ChatCompletion.create', side_effect=moderation_side_effect)
def test_is_provider_credentials_valid_or_raise_invalid(mock_create):
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'invalid_key'})
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', {VALIDATE_CREDENTIAL_KEY: api_key})
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result[VALIDATE_CREDENTIAL_KEY] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: 'encrypted_valid_key'}),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom_str(mock_decrypt):
"""
Only the OpenAI provider needs to be compatible with the previous case where the encrypted_config was stored as a plain string.
:param mock_decrypt:
:return:
"""
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config='encrypted_valid_key',
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
openai_api_key = 'valid_key'
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: f'encrypted_{openai_api_key}'}),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result[VALIDATE_CREDENTIAL_KEY][6:-2]
assert len(middle_token) == max(len(openai_api_key) - 8, 0)
assert all(char == '*' for char in middle_token)
@patch('core.model_providers.providers.hosted.hosted_model_providers.openai')
def test_get_credentials_hosted(mock_hosted):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.SYSTEM.value,
encrypted_config='',
is_valid=True
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
mock_hosted.api_key = 'hosted_key'
result = model_provider.get_provider_credentials()
assert result[VALIDATE_CREDENTIAL_KEY] == 'hosted_key'

View File

@@ -0,0 +1,125 @@
import pytest
from unittest.mock import patch, MagicMock
import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.replicate_provider import ReplicateProvider
from models.provider import ProviderType, Provider, ProviderModel
PROVIDER_NAME = 'replicate'
MODEL_PROVIDER_CLASS = ReplicateProvider
VALIDATE_CREDENTIAL = {
'model_version': 'fake-version',
'replicate_api_token': 'valid_key'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_credentials_valid_or_raise_valid(mocker):
mock_query = MagicMock()
mock_query.return_value = None
mocker.patch('replicate.model.ModelCollection.get', return_value=mock_query)
mocker.patch('replicate.model.Model.versions', return_value=mock_query)
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=VALIDATE_CREDENTIAL.copy()
)
def test_is_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if replicate_api_token is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={}
)
# raise CredentialsValidateFailedError if replicate_api_token is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={'replicate_api_token': 'invalid_key'})
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=VALIDATE_CREDENTIAL.copy()
)
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result['replicate_api_token'] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_custom(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['replicate_api_token'] = 'encrypted_' + encrypted_credential['replicate_api_token']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION
)
assert result['replicate_api_token'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['replicate_api_token'] = 'encrypted_' + encrypted_credential['replicate_api_token']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
obfuscated=True
)
middle_token = result['replicate_api_token'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['replicate_api_token']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -0,0 +1,97 @@
import pytest
from unittest.mock import patch
import json
from langchain.schema import LLMResult, Generation, AIMessage, ChatResult, ChatGeneration
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.spark_provider import SparkProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'spark'
MODEL_PROVIDER_CLASS = SparkProvider
VALIDATE_CREDENTIAL = {
'app_id': 'valid_app_id',
'api_key': 'valid_key',
'api_secret': 'valid_secret'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.spark.ChatSpark._generate',
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content="abc"))]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
assert result['api_secret'] == f'encrypted_{VALIDATE_CREDENTIAL["api_secret"]}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
encrypted_credential['api_secret'] = 'encrypted_' + encrypted_credential['api_secret']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['api_key'] == 'valid_key'
assert result['api_secret'] == 'valid_secret'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
encrypted_credential['api_secret'] = 'encrypted_' + encrypted_credential['api_secret']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['api_key'][6:-2]
middle_secret = result['api_secret'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
assert len(middle_secret) == max(len(VALIDATE_CREDENTIAL['api_secret']) - 8, 0)
assert all(char == '*' for char in middle_token)
assert all(char == '*' for char in middle_secret)

View File

@@ -0,0 +1,90 @@
import pytest
from unittest.mock import patch
import json
from langchain.schema import LLMResult, Generation
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.minimax_provider import MinimaxProvider
from core.model_providers.providers.tongyi_provider import TongyiProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'tongyi'
MODEL_PROVIDER_CLASS = TongyiProvider
VALIDATE_CREDENTIAL = {
'dashscope_api_key': 'valid_key'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.tongyi.Tongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['dashscope_api_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
api_key = 'valid_key'
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result['dashscope_api_key'] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['dashscope_api_key'] = 'encrypted_' + encrypted_credential['dashscope_api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['dashscope_api_key'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['dashscope_api_key'] = 'encrypted_' + encrypted_credential['dashscope_api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['dashscope_api_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['dashscope_api_key']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@@ -0,0 +1,93 @@
import pytest
from unittest.mock import patch
import json
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.wenxin_provider import WenxinProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'wenxin'
MODEL_PROVIDER_CLASS = WenxinProvider
VALIDATE_CREDENTIAL = {
'api_key': 'valid_key',
'secret_key': 'valid_secret'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._call', return_value="abc")
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
assert result['secret_key'] == f'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['api_key'] == 'valid_key'
assert result['secret_key'] == 'valid_secret'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['api_key'][6:-2]
middle_secret = result['secret_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
assert len(middle_secret) == max(len(VALIDATE_CREDENTIAL['secret_key']) - 8, 0)
assert all(char == '*' for char in middle_token)
assert all(char == '*' for char in middle_secret)