mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-15 13:56:53 +08:00
Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
127
api/tests/integration_tests/model_runtime/__mock/google.py
Normal file
127
api/tests/integration_tests/model_runtime/__mock/google.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from google.generativeai import GenerativeModel
|
||||
from google.generativeai.types import GenerateContentResponse
|
||||
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
||||
import google.generativeai.types.generation_types as generation_config_types
|
||||
import google.generativeai.types.content_types as content_types
|
||||
import google.generativeai.types.safety_types as safety_types
|
||||
from google.generativeai.client import _ClientManager, configure
|
||||
|
||||
from google.ai import generativelanguage as glm
|
||||
|
||||
from typing import Generator, List
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import pytest
|
||||
|
||||
current_api_key = ''
|
||||
|
||||
class MockGoogleResponseClass(object):
|
||||
_done = False
|
||||
|
||||
def __iter__(self):
|
||||
full_response_text = 'it\'s google!'
|
||||
|
||||
for i in range(0, len(full_response_text) + 1, 1):
|
||||
if i == len(full_response_text):
|
||||
self._done = True
|
||||
yield GenerateContentResponse(
|
||||
done=True,
|
||||
iterator=None,
|
||||
result=glm.GenerateContentResponse({
|
||||
|
||||
}),
|
||||
chunks=[]
|
||||
)
|
||||
else:
|
||||
yield GenerateContentResponse(
|
||||
done=False,
|
||||
iterator=None,
|
||||
result=glm.GenerateContentResponse({
|
||||
|
||||
}),
|
||||
chunks=[]
|
||||
)
|
||||
|
||||
class MockGoogleResponseCandidateClass(object):
|
||||
finish_reason = 'stop'
|
||||
|
||||
class MockGoogleClass(object):
|
||||
@staticmethod
|
||||
def generate_content_sync() -> GenerateContentResponse:
|
||||
return GenerateContentResponse(
|
||||
done=True,
|
||||
iterator=None,
|
||||
result=glm.GenerateContentResponse({
|
||||
|
||||
}),
|
||||
chunks=[]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
|
||||
return MockGoogleResponseClass()
|
||||
|
||||
def generate_content(self: GenerativeModel,
|
||||
contents: content_types.ContentsType,
|
||||
*,
|
||||
generation_config: generation_config_types.GenerationConfigType | None = None,
|
||||
safety_settings: safety_types.SafetySettingOptions | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> GenerateContentResponse:
|
||||
global current_api_key
|
||||
|
||||
if len(current_api_key) < 16:
|
||||
raise Exception('Invalid API key')
|
||||
|
||||
if stream:
|
||||
return MockGoogleClass.generate_content_stream()
|
||||
|
||||
return MockGoogleClass.generate_content_sync()
|
||||
|
||||
@property
|
||||
def generative_response_text(self) -> str:
|
||||
return 'it\'s google!'
|
||||
|
||||
@property
|
||||
def generative_response_candidates(self) -> List[MockGoogleResponseCandidateClass]:
|
||||
return [MockGoogleResponseCandidateClass()]
|
||||
|
||||
def make_client(self: _ClientManager, name: str):
|
||||
global current_api_key
|
||||
|
||||
if name.endswith("_async"):
|
||||
name = name.split("_")[0]
|
||||
cls = getattr(glm, name.title() + "ServiceAsyncClient")
|
||||
else:
|
||||
cls = getattr(glm, name.title() + "ServiceClient")
|
||||
|
||||
# Attempt to configure using defaults.
|
||||
if not self.client_config:
|
||||
configure()
|
||||
|
||||
client_options = self.client_config.get("client_options", None)
|
||||
if client_options:
|
||||
current_api_key = client_options.api_key
|
||||
|
||||
def nop(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
original_init = cls.__init__
|
||||
cls.__init__ = nop
|
||||
client: glm.GenerativeServiceClient = cls(**self.client_config)
|
||||
cls.__init__ = original_init
|
||||
|
||||
if not self.default_metadata:
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
|
||||
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
|
||||
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
|
||||
|
||||
yield
|
||||
|
||||
monkeypatch.undo()
|
||||
Reference in New Issue
Block a user