mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
feat: enhance gemini models (#11497)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import google.generativeai.types.generation_types as generation_config_types
|
||||
import pytest
|
||||
@@ -6,11 +7,10 @@ from _pytest.monkeypatch import MonkeyPatch
|
||||
from google.ai import generativelanguage as glm
|
||||
from google.ai.generativelanguage_v1beta.types import content as gag_content
|
||||
from google.generativeai import GenerativeModel
|
||||
from google.generativeai.client import _ClientManager, configure
|
||||
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
|
||||
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
||||
|
||||
current_api_key = ""
|
||||
from extensions import ext_redis
|
||||
|
||||
|
||||
class MockGoogleResponseClass:
|
||||
@@ -57,11 +57,6 @@ class MockGoogleClass:
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> GenerateContentResponse:
|
||||
global current_api_key
|
||||
|
||||
if len(current_api_key) < 16:
|
||||
raise Exception("Invalid API key")
|
||||
|
||||
if stream:
|
||||
return MockGoogleClass.generate_content_stream()
|
||||
|
||||
@@ -75,33 +70,29 @@ class MockGoogleClass:
|
||||
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
|
||||
return [MockGoogleResponseCandidateClass()]
|
||||
|
||||
def make_client(self: _ClientManager, name: str):
|
||||
global current_api_key
|
||||
|
||||
if name.endswith("_async"):
|
||||
name = name.split("_")[0]
|
||||
cls = getattr(glm, name.title() + "ServiceAsyncClient")
|
||||
else:
|
||||
cls = getattr(glm, name.title() + "ServiceClient")
|
||||
def mock_configure(api_key: str):
|
||||
if len(api_key) < 16:
|
||||
raise Exception("Invalid API key")
|
||||
|
||||
# Attempt to configure using defaults.
|
||||
if not self.client_config:
|
||||
configure()
|
||||
|
||||
client_options = self.client_config.get("client_options", None)
|
||||
if client_options:
|
||||
current_api_key = client_options.api_key
|
||||
class MockFileState:
|
||||
def __init__(self):
|
||||
self.name = "FINISHED"
|
||||
|
||||
def nop(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
original_init = cls.__init__
|
||||
cls.__init__ = nop
|
||||
client: glm.GenerativeServiceClient = cls(**self.client_config)
|
||||
cls.__init__ = original_init
|
||||
class MockGoogleFile:
|
||||
def __init__(self, name: str = "mock_file_name"):
|
||||
self.name = name
|
||||
self.state = MockFileState()
|
||||
|
||||
if not self.default_metadata:
|
||||
return client
|
||||
|
||||
def mock_get_file(name: str) -> MockGoogleFile:
|
||||
return MockGoogleFile(name)
|
||||
|
||||
|
||||
def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
|
||||
return MockGoogleFile()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -109,8 +100,17 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
|
||||
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
|
||||
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
|
||||
monkeypatch.setattr("google.generativeai.configure", mock_configure)
|
||||
monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
|
||||
monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)
|
||||
|
||||
yield
|
||||
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis() -> None:
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
ext_redis.redis_client.setex = MagicMock(return_value=None)
|
||||
ext_redis.redis_client.exists = MagicMock(return_value=True)
|
||||
|
||||
@@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel
|
||||
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
|
||||
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock, setup_mock_redis
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
||||
@@ -95,7 +95,7 @@ def test_invoke_stream_model(setup_google_mock):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
||||
def test_invoke_chat_model_with_vision(setup_google_mock):
|
||||
def test_invoke_chat_model_with_vision(setup_google_mock, setup_mock_redis):
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
@@ -124,7 +124,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
||||
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
|
||||
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock, setup_mock_redis):
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
|
||||
Reference in New Issue
Block a user