feat: support LLM understand video (#9828)

This commit is contained in:
非法操作
2024-11-08 13:22:52 +08:00
committed by GitHub
parent c9f785e00f
commit 033ab5490b
9 changed files with 69 additions and 20 deletions

View File

@@ -12,11 +12,13 @@ from .message_entities import (
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from .model_entities import ModelPropertyKey
__all__ = [
"ImagePromptMessageContent",
"VideoPromptMessageContent",
"PromptMessage",
"PromptMessageRole",
"LLMUsage",

View File

@@ -56,6 +56,7 @@ class PromptMessageContentType(Enum):
TEXT = "text"
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
class PromptMessageContent(BaseModel):
@@ -75,6 +76,12 @@ class TextPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.TEXT
class VideoPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO
data: str = Field(..., description="Base64 encoded video data")
format: str = Field(..., description="Video format")
class AudioPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data")

View File

@@ -29,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
@@ -431,6 +432,14 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
sub_message_dict = {"image": image_url}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.VIDEO:
message_content = cast(VideoPromptMessageContent, message_content)
video_url = message_content.data
if message_content.data.startswith("data:"):
raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")
sub_message_dict = {"video": video_url}
sub_messages.append(sub_message_dict)
# resort sub_messages to ensure text is always at last
sub_messages = sorted(sub_messages, key=lambda x: "text" in x)

View File

@@ -313,21 +313,35 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
return params
def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]:
if isinstance(prompt_message, str):
if isinstance(prompt_message, list):
sub_messages = []
for item in prompt_message:
if item.type == PromptMessageContentType.IMAGE:
sub_messages.append(
{
"type": "image_url",
"image_url": {"url": self._remove_base64_header(item.data)},
}
)
elif item.type == PromptMessageContentType.VIDEO:
sub_messages.append(
{
"type": "video_url",
"video_url": {"url": self._remove_base64_header(item.data)},
}
)
else:
sub_messages.append({"type": "text", "text": item.data})
return sub_messages
else:
return [{"type": "text", "text": prompt_message}]
return [
{"type": "image_url", "image_url": {"url": self._remove_image_header(item.data)}}
if item.type == PromptMessageContentType.IMAGE
else {"type": "text", "text": item.data}
for item in prompt_message
]
def _remove_base64_header(self, file_content: str) -> str:
if file_content.startswith("data:"):
data_split = file_content.split(";base64,")
return data_split[1]
def _remove_image_header(self, image: str) -> str:
if image.startswith("data:image"):
return image.split(",")[1]
return image
return file_content
def _handle_generate_response(
self,