mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file import file_manager
|
||||
from core.file.models import File
|
||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
@@ -14,8 +17,8 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
class AdvancedPromptTransform(PromptTransform):
|
||||
@@ -28,22 +31,19 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
def get_prompt(
|
||||
self,
|
||||
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
|
||||
inputs: dict,
|
||||
*,
|
||||
prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate,
|
||||
inputs: dict[str, str],
|
||||
query: str,
|
||||
files: list[FileVar],
|
||||
files: Sequence[File],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
query_prompt_template: Optional[str] = None,
|
||||
) -> list[PromptMessage]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
if isinstance(prompt_template, CompletionModelPromptTemplate):
|
||||
prompt_messages = self._get_completion_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
@@ -54,12 +54,11 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
)
|
||||
elif model_mode == ModelMode.CHAT:
|
||||
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
|
||||
prompt_messages = self._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
query_prompt_template=query_prompt_template,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
@@ -74,7 +73,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_template: CompletionModelPromptTemplate,
|
||||
inputs: dict,
|
||||
query: Optional[str],
|
||||
files: list[FileVar],
|
||||
files: Sequence[File],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
@@ -88,10 +87,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_messages = []
|
||||
|
||||
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
|
||||
|
||||
if memory and memory_config:
|
||||
role_prefix = memory_config.role_prefix
|
||||
@@ -100,15 +99,15 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config=memory_config,
|
||||
raw_prompt=raw_prompt,
|
||||
role_prefix=role_prefix,
|
||||
prompt_template=prompt_template,
|
||||
parser=parser,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if query:
|
||||
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
|
||||
prompt_inputs = self._set_query_variable(query, parser, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(prompt_inputs)
|
||||
prompt = parser.format(prompt_inputs)
|
||||
else:
|
||||
prompt = raw_prompt
|
||||
prompt_inputs = inputs
|
||||
@@ -116,9 +115,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
@@ -131,35 +131,38 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_template: list[ChatModelMessage],
|
||||
inputs: dict,
|
||||
query: Optional[str],
|
||||
files: list[FileVar],
|
||||
files: Sequence[File],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
query_prompt_template: Optional[str] = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Get chat model prompt messages.
|
||||
"""
|
||||
raw_prompt_list = prompt_template
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
for prompt_item in raw_prompt_list:
|
||||
for prompt_item in prompt_template:
|
||||
raw_prompt = prompt_item.text
|
||||
|
||||
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(prompt_inputs)
|
||||
if self.with_variable_tmpl:
|
||||
vp = VariablePool()
|
||||
for k, v in inputs.items():
|
||||
if k.startswith("#"):
|
||||
vp.add(k[1:-1].split("."), v)
|
||||
raw_prompt = raw_prompt.replace("{{#context#}}", context or "")
|
||||
prompt = vp.convert_template(raw_prompt).text
|
||||
else:
|
||||
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
prompt_inputs = self._set_context_variable(
|
||||
context=context, parser=parser, prompt_inputs=prompt_inputs
|
||||
)
|
||||
prompt = parser.format(prompt_inputs)
|
||||
elif prompt_item.edition_type == "jinja2":
|
||||
prompt = raw_prompt
|
||||
prompt_inputs = inputs
|
||||
|
||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||
prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs)
|
||||
else:
|
||||
raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
|
||||
|
||||
@@ -170,25 +173,25 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
||||
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
||||
|
||||
if query and query_prompt_template:
|
||||
prompt_template = PromptTemplateParser(
|
||||
template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
|
||||
if query and memory_config and memory_config.query_prompt_template:
|
||||
parser = PromptTemplateParser(
|
||||
template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
|
||||
)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
prompt_inputs["#sys.query#"] = query
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
|
||||
|
||||
query = prompt_template.format(prompt_inputs)
|
||||
query = parser.format(prompt_inputs)
|
||||
|
||||
if memory and memory_config:
|
||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
if files and query is not None:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
@@ -200,19 +203,19 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
# get last user message content and add files
|
||||
prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
|
||||
last_message.content = prompt_message_contents
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
elif query:
|
||||
@@ -220,8 +223,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||
if "#context#" in prompt_template.variable_keys:
|
||||
def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||
if "#context#" in parser.variable_keys:
|
||||
if context:
|
||||
prompt_inputs["#context#"] = context
|
||||
else:
|
||||
@@ -229,8 +232,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
return prompt_inputs
|
||||
|
||||
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||
if "#query#" in prompt_template.variable_keys:
|
||||
def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||
if "#query#" in parser.variable_keys:
|
||||
if query:
|
||||
prompt_inputs["#query#"] = query
|
||||
else:
|
||||
@@ -244,16 +247,16 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config: MemoryConfig,
|
||||
raw_prompt: str,
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
prompt_template: PromptTemplateParser,
|
||||
parser: PromptTemplateParser,
|
||||
prompt_inputs: dict,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> dict:
|
||||
if "#histories#" in prompt_template.variable_keys:
|
||||
if "#histories#" in parser.variable_keys:
|
||||
if memory:
|
||||
inputs = {"#histories#": "", **prompt_inputs}
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs))
|
||||
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
|
||||
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
|
||||
|
||||
@@ -5,9 +5,11 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from core.app.app_config.entities import PromptTemplateEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import file_manager
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
@@ -18,10 +20,10 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from models.model import AppMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class ModelMode(enum.Enum):
|
||||
class ModelMode(str, enum.Enum):
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
@@ -53,7 +55,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: list["FileVar"],
|
||||
files: list["File"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
@@ -169,7 +171,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list["FileVar"],
|
||||
files: list["File"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
@@ -214,7 +216,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list["FileVar"],
|
||||
files: list["File"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
@@ -261,11 +263,12 @@ class SimplePromptTransform(PromptTransform):
|
||||
|
||||
return [self.get_last_user_message(prompt, files)], stops
|
||||
|
||||
def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
|
||||
def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage:
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
|
||||
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
from constants import UUID_NIL
|
||||
|
||||
|
||||
def extract_thread_messages(messages: list[dict]) -> list[dict]:
|
||||
def extract_thread_messages(messages: list[Any]):
|
||||
thread_messages = []
|
||||
next_message = None
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import cast
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
@@ -21,7 +22,7 @@ class PromptMessageUtil:
|
||||
:return:
|
||||
"""
|
||||
prompts = []
|
||||
if model_mode == ModelMode.CHAT.value:
|
||||
if model_mode == ModelMode.CHAT:
|
||||
tool_calls = []
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
@@ -51,11 +52,9 @@ class PromptMessageUtil:
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
files.append(
|
||||
{
|
||||
"type": "image",
|
||||
@@ -63,6 +62,14 @@ class PromptMessageUtil:
|
||||
"detail": content.detail.value,
|
||||
}
|
||||
)
|
||||
elif isinstance(content, AudioPromptMessageContent):
|
||||
files.append(
|
||||
{
|
||||
"type": "audio",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"format": content.format,
|
||||
}
|
||||
)
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
|
||||
Reference in New Issue
Block a user