feat: enhance gemini models (#11497)

This commit is contained in:
非法操作
2024-12-17 12:05:13 +08:00
committed by GitHub
parent 56cfdce453
commit 74fdc16bd1
23 changed files with 138 additions and 113 deletions

View File

@@ -1,4 +1,5 @@
from collections.abc import Generator
from unittest.mock import MagicMock
import google.generativeai.types.generation_types as generation_config_types
import pytest
@@ -6,11 +7,10 @@ from _pytest.monkeypatch import MonkeyPatch
from google.ai import generativelanguage as glm
from google.ai.generativelanguage_v1beta.types import content as gag_content
from google.generativeai import GenerativeModel
from google.generativeai.client import _ClientManager, configure
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
from google.generativeai.types.generation_types import BaseGenerateContentResponse
current_api_key = ""
from extensions import ext_redis
class MockGoogleResponseClass:
@@ -57,11 +57,6 @@ class MockGoogleClass:
stream: bool = False,
**kwargs,
) -> GenerateContentResponse:
global current_api_key
if len(current_api_key) < 16:
raise Exception("Invalid API key")
if stream:
return MockGoogleClass.generate_content_stream()
@@ -75,33 +70,29 @@ class MockGoogleClass:
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
return [MockGoogleResponseCandidateClass()]
def make_client(self: _ClientManager, name: str):
global current_api_key
if name.endswith("_async"):
name = name.split("_")[0]
cls = getattr(glm, name.title() + "ServiceAsyncClient")
else:
cls = getattr(glm, name.title() + "ServiceClient")
def mock_configure(api_key: str):
if len(api_key) < 16:
raise Exception("Invalid API key")
# Attempt to configure using defaults.
if not self.client_config:
configure()
client_options = self.client_config.get("client_options", None)
if client_options:
current_api_key = client_options.api_key
class MockFileState:
def __init__(self):
self.name = "FINISHED"
def nop(self, *args, **kwargs):
pass
original_init = cls.__init__
cls.__init__ = nop
client: glm.GenerativeServiceClient = cls(**self.client_config)
cls.__init__ = original_init
class MockGoogleFile:
def __init__(self, name: str = "mock_file_name"):
self.name = name
self.state = MockFileState()
if not self.default_metadata:
return client
def mock_get_file(name: str) -> MockGoogleFile:
return MockGoogleFile(name)
def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
return MockGoogleFile()
@pytest.fixture
@@ -109,8 +100,17 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
monkeypatch.setattr("google.generativeai.configure", mock_configure)
monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)
yield
monkeypatch.undo()
@pytest.fixture
def setup_mock_redis() -> None:
ext_redis.redis_client.get = MagicMock(return_value=None)
ext_redis.redis_client.setex = MagicMock(return_value=None)
ext_redis.redis_client.exists = MagicMock(return_value=True)

View File

@@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock, setup_mock_redis
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
@@ -95,7 +95,7 @@ def test_invoke_stream_model(setup_google_mock):
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
def test_invoke_chat_model_with_vision(setup_google_mock):
def test_invoke_chat_model_with_vision(setup_google_mock, setup_mock_redis):
model = GoogleLargeLanguageModel()
result = model.invoke(
@@ -124,7 +124,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock, setup_mock_redis):
model = GoogleLargeLanguageModel()
result = model.invoke(