mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
feat: add Anthropic claude-3 models support (#2684)
This commit is contained in:
@@ -1,52 +1,87 @@
|
||||
import os
|
||||
from time import sleep
|
||||
from typing import Any, Generator, List, Literal, Union
|
||||
from typing import Any, Literal, Union, Iterable
|
||||
|
||||
from anthropic.resources import Messages
|
||||
from anthropic.types.message_delta_event import Delta
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from anthropic import Anthropic
|
||||
from anthropic._types import NOT_GIVEN, Body, Headers, NotGiven, Query
|
||||
from anthropic.resources.completions import Completions
|
||||
from anthropic.types import Completion, completion_create_params
|
||||
from anthropic import Anthropic, Stream
|
||||
from anthropic.types import MessageParam, Message, MessageStreamEvent, \
|
||||
ContentBlock, MessageStartEvent, Usage, TextDelta, MessageDeltaEvent, MessageStopEvent, ContentBlockDeltaEvent, \
|
||||
MessageDeltaUsage
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
|
||||
|
||||
|
||||
class MockAnthropicClass(object):
|
||||
@staticmethod
|
||||
def mocked_anthropic_chat_create_sync(model: str) -> Completion:
|
||||
return Completion(
|
||||
completion='hello, I\'m a chatbot from anthropic',
|
||||
def mocked_anthropic_chat_create_sync(model: str) -> Message:
|
||||
return Message(
|
||||
id='msg-123',
|
||||
type='message',
|
||||
role='assistant',
|
||||
content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')],
|
||||
model=model,
|
||||
stop_reason='stop_sequence'
|
||||
stop_reason='stop_sequence',
|
||||
usage=Usage(
|
||||
input_tokens=1,
|
||||
output_tokens=1
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_anthropic_chat_create_stream(model: str) -> Generator[Completion, None, None]:
|
||||
def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent]:
|
||||
full_response_text = "hello, I'm a chatbot from anthropic"
|
||||
|
||||
for i in range(0, len(full_response_text) + 1):
|
||||
sleep(0.1)
|
||||
if i == len(full_response_text):
|
||||
yield Completion(
|
||||
completion='',
|
||||
model=model,
|
||||
stop_reason='stop_sequence'
|
||||
)
|
||||
else:
|
||||
yield Completion(
|
||||
completion=full_response_text[i],
|
||||
model=model,
|
||||
stop_reason=''
|
||||
yield MessageStartEvent(
|
||||
type='message_start',
|
||||
message=Message(
|
||||
id='msg-123',
|
||||
content=[],
|
||||
role='assistant',
|
||||
model=model,
|
||||
stop_reason=None,
|
||||
type='message',
|
||||
usage=Usage(
|
||||
input_tokens=1,
|
||||
output_tokens=1
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def mocked_anthropic(self: Completions, *,
|
||||
max_tokens_to_sample: int,
|
||||
model: Union[str, Literal["claude-2.1", "claude-instant-1"]],
|
||||
prompt: str,
|
||||
stream: Literal[True],
|
||||
**kwargs: Any
|
||||
) -> Union[Completion, Generator[Completion, None, None]]:
|
||||
index = 0
|
||||
for i in range(0, len(full_response_text)):
|
||||
sleep(0.1)
|
||||
yield ContentBlockDeltaEvent(
|
||||
type='content_block_delta',
|
||||
delta=TextDelta(text=full_response_text[i], type='text_delta'),
|
||||
index=index
|
||||
)
|
||||
|
||||
index += 1
|
||||
|
||||
yield MessageDeltaEvent(
|
||||
type='message_delta',
|
||||
delta=Delta(
|
||||
stop_reason='stop_sequence'
|
||||
),
|
||||
usage=MessageDeltaUsage(
|
||||
output_tokens=1
|
||||
)
|
||||
)
|
||||
|
||||
yield MessageStopEvent(type='message_stop')
|
||||
|
||||
def mocked_anthropic(self: Messages, *,
|
||||
max_tokens: int,
|
||||
messages: Iterable[MessageParam],
|
||||
model: str,
|
||||
stream: Literal[True],
|
||||
**kwargs: Any
|
||||
) -> Union[Message, Stream[MessageStreamEvent]]:
|
||||
if len(self._client.api_key) < 18:
|
||||
raise anthropic.AuthenticationError('Invalid API key')
|
||||
|
||||
@@ -55,12 +90,13 @@ class MockAnthropicClass(object):
|
||||
else:
|
||||
return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Completions, 'create', MockAnthropicClass.mocked_anthropic)
|
||||
monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
monkeypatch.undo()
|
||||
|
||||
@@ -15,14 +15,14 @@ def test_validate_credentials(setup_anthropic_mock):
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='claude-instant-1',
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='claude-instant-1',
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
}
|
||||
@@ -33,7 +33,7 @@ def test_invoke_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='claude-instant-1',
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'),
|
||||
'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL')
|
||||
@@ -49,7 +49,7 @@ def test_invoke_model(setup_anthropic_mock):
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'top_p': 1.0,
|
||||
'max_tokens_to_sample': 10
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
@@ -64,7 +64,7 @@ def test_invoke_stream_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='claude-instant-1',
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
},
|
||||
@@ -78,7 +78,7 @@ def test_invoke_stream_model(setup_anthropic_mock):
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens_to_sample': 100
|
||||
'max_tokens': 100
|
||||
},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
@@ -97,7 +97,7 @@ def test_get_num_tokens():
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='claude-instant-1',
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user