chore(api/tests): apply ruff reformat #7590 (#7591)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Bowen Liang
2024-08-23 23:52:25 +08:00
committed by GitHub
parent 2da63654e5
commit b035c02f78
155 changed files with 4279 additions and 5925 deletions

View File

@@ -20,92 +20,84 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_moc
from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='ChatGLM3',
model="ChatGLM3",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': 'www ' + os.environ.get('XINFERENCE_CHAT_MODEL_UID')
}
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": "www " + os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
},
)
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='aaaaa',
credentials={
'server_url': '',
'model_uid': ''
}
)
model.validate_credentials(model="aaaaa", credentials={"server_url": "", "model_uid": ""})
model.validate_credentials(
model='ChatGLM3',
model="ChatGLM3",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
}
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
},
)
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True)
def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
response = model.invoke(
model='ChatGLM3',
model="ChatGLM3",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True)
def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
response = model.invoke(
model='ChatGLM3',
model="ChatGLM3",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@@ -114,6 +106,8 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
"""
Funtion calling of xinference does not support stream mode currently
"""
@@ -168,7 +162,7 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
# )
# assert isinstance(response, Generator)
# call: LLMResultChunk = None
# chunks = []
@@ -241,86 +235,75 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
# assert response.usage.total_tokens > 0
# assert response.message.tool_calls[0].function.name == 'get_current_weather'
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True)
def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='alapaca',
model="alapaca",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': 'www ' + os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
}
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": "www " + os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
},
)
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='alapaca',
credentials={
'server_url': '',
'model_uid': ''
}
)
model.validate_credentials(model="alapaca", credentials={"server_url": "", "model_uid": ""})
model.validate_credentials(
model='alapaca',
model="alapaca",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
}
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
},
)
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True)
def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
response = model.invoke(
model='alapaca',
model="alapaca",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
},
prompt_messages=[
UserPromptMessage(
content='the United States is'
)
],
prompt_messages=[UserPromptMessage(content="the United States is")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True)
def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
response = model.invoke(
model='alapaca',
model="alapaca",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
},
prompt_messages=[
UserPromptMessage(
content='the United States is'
)
],
prompt_messages=[UserPromptMessage(content="the United States is")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@@ -330,68 +313,54 @@ def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = XinferenceAILargeLanguageModel()
num_tokens = model.get_num_tokens(
model='ChatGLM3',
model="ChatGLM3",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
)
]
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 77
num_tokens = model.get_num_tokens(
model='ChatGLM3',
model="ChatGLM3",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 21
assert num_tokens == 21