feat: enhance gemini models (#11497)

This commit is contained in:
非法操作
2024-12-17 12:05:13 +08:00
committed by GitHub
parent 56cfdce453
commit 74fdc16bd1
23 changed files with 138 additions and 113 deletions

View File

@@ -101,13 +101,14 @@ class ImagePromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
format: str = Field("jpg", description="Image format")
class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str
format: str = Field(..., description="Document format")
class PromptMessage(ABC, BaseModel):

View File

@@ -526,17 +526,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent):
if message_content.mime_type != "application/pdf":
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type != "application/pdf":
raise ValueError(
f"Unsupported document type {message_content.mime_type}, "
"only support application/pdf"
f"Unsupported document type {mime_type}, " "only support application/pdf"
)
sub_message_dict = {
"type": "document",
"source": {
"type": message_content.encode_format,
"media_type": message_content.mime_type,
"data": message_content.data,
"media_type": mime_type,
"data": base64_data,
},
}
sub_messages.append(sub_message_dict)

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767

View File

@@ -7,6 +7,9 @@ features:
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767

View File

@@ -1,29 +1,30 @@
import base64
import io
import json
import os
import tempfile
import time
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Optional, Union
import google.ai.generativelanguage as glm
import google.generativeai as genai
import requests
from google.api_core import exceptions
from google.generativeai.client import _ClientManager
from google.generativeai.types import ContentType, GenerateContentResponse
from google.generativeai.types import ContentType, File, GenerateContentResponse
from google.generativeai.types.content_types import to_part
from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
@@ -35,21 +36,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
GOOGLE_AVAILABLE_MIMETYPE = [
"application/pdf",
"application/x-javascript",
"text/javascript",
"application/x-python",
"text/x-python",
"text/plain",
"text/html",
"text/css",
"text/md",
"text/csv",
"text/xml",
"text/rtf",
]
from extensions.ext_redis import redis_client
class GoogleLargeLanguageModel(LargeLanguageModel):
@@ -201,29 +188,17 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
if stop:
config_kwargs["stop_sequences"] = stop
genai.configure(api_key=credentials["google_api_key"])
google_model = genai.GenerativeModel(model_name=model)
history = []
# hack for gemini-pro-vision, which currently does not support multi-turn chat
if model == "gemini-pro-vision":
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)
# Create a new ClientManager with tenant's API key
new_client_manager = _ClientManager()
new_client_manager.configure(api_key=credentials["google_api_key"])
new_custom_client = new_client_manager.make_client("generative")
google_model._client = new_custom_client
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)
response = google_model.generate_content(
contents=history,
@@ -346,7 +321,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
content = message.content
if isinstance(content, list):
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
@@ -359,6 +334,44 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return message_text
def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
key = f"{message_content.type.value}:{hash(message_content.data)}"
if redis_client.exists(key):
try:
return genai.get_file(redis_client.get(key).decode())
except:
pass
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
if message_content.data.startswith("data:"):
metadata, base64_data = message_content.data.split(",", 1)
file_content = base64.b64decode(base64_data)
mime_type = metadata.split(";", 1)[0].split(":")[1]
temp_file.write(file_content)
else:
# only ImagePromptMessageContent and VideoPromptMessageContent has url
try:
response = requests.get(message_content.data)
response.raise_for_status()
if message_content.type is ImagePromptMessageContent:
prefix = "image/"
elif message_content.type is VideoPromptMessageContent:
prefix = "video/"
mime_type = prefix + message_content.format
temp_file.write(response.content)
except Exception as ex:
raise ValueError(f"Failed to fetch data from url {message_content.data}, {ex}")
temp_file.flush()
try:
file = genai.upload_file(path=temp_file.name, mime_type=mime_type)
while file.state.name == "PROCESSING":
time.sleep(5)
file = genai.get_file(file.name)
# google will delete your upload files in 2 days.
redis_client.setex(key, 47 * 60 * 60, file.name)
return file
finally:
os.unlink(temp_file.name)
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
"""
Format a single message into glm.Content for Google API
@@ -374,28 +387,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
glm_content["parts"].append(to_part(c.data))
elif c.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, c)
if message_content.data.startswith("data:"):
metadata, base64_data = c.data.split(",", 1)
mime_type = metadata.split(";", 1)[0].split(":")[1]
else:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
glm_content["parts"].append(blob)
elif c.type == PromptMessageContentType.DOCUMENT:
message_content = cast(DocumentPromptMessageContent, c)
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
glm_content["parts"].append(blob)
else:
glm_content["parts"].append(self._upload_file_content_to_google(c))
return glm_content
elif isinstance(message, AssistantPromptMessage):

View File

@@ -920,10 +920,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, AudioPromptMessageContent):
data_split = message_content.data.split(";base64,")
base64_data = data_split[1]
sub_message_dict = {
"type": "input_audio",
"input_audio": {
"data": message_content.data,
"data": base64_data,
"format": message_content.format,
},
}