feat/enhance the multi-modal support (#8818)

This commit is contained in:
-LAN-
2024-10-21 10:43:49 +08:00
committed by GitHub
parent 7a1d6fe509
commit e61752bd3a
267 changed files with 6263 additions and 3523 deletions

View File

@@ -0,0 +1,17 @@
from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from .node import LLMNode
__all__ = [
"LLMNode",
"LLMNodeChatModelMessage",
"LLMNodeCompletionModelPromptTemplate",
"LLMNodeData",
"ModelConfig",
"VisionConfig",
]

View File

@@ -1,17 +1,15 @@
from typing import Any, Literal, Optional, Union
from collections.abc import Sequence
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.model_runtime.entities import ImagePromptMessageContent
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
@@ -19,62 +17,36 @@ class ModelConfig(BaseModel):
class ContextConfig(BaseModel):
"""
Context Config.
"""
enabled: bool
variable_selector: Optional[list[str]] = None
class VisionConfigOptions(BaseModel):
variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"])
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH
class VisionConfig(BaseModel):
"""
Vision Config.
"""
class Configs(BaseModel):
"""
Configs.
"""
detail: Literal["low", "high"]
enabled: bool
configs: Optional[Configs] = None
enabled: bool = False
configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions)
class PromptConfig(BaseModel):
"""
Prompt Config.
"""
jinja2_variables: Optional[list[VariableSelector]] = None
class LLMNodeChatModelMessage(ChatModelMessage):
"""
LLM Node Chat Model Message.
"""
jinja2_text: Optional[str] = None
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
"""
LLM Node Chat Model Prompt Template.
"""
jinja2_text: Optional[str] = None
class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
model: ModelConfig
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: Optional[PromptConfig] = None
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig
vision: VisionConfig = Field(default_factory=VisionConfig)

View File

@@ -1,39 +1,40 @@
import json
from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional, cast
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities import (
AudioPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
NodeEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunStreamChunkEvent,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
@@ -41,44 +42,34 @@ from models.model import Conversation
from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus
from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
)
if TYPE_CHECKING:
from core.file.file_obj import FileVar
from core.file.models import File
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
finish_reason: Optional[str] = None
class LLMNode(BaseNode):
class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:return:
"""
node_data = cast(LLMNodeData, deepcopy(self.node_data))
variable_pool = self.graph_runtime_state.variable_pool
def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs = None
process_data = None
try:
# init messages template
node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool)
inputs = self._fetch_inputs(node_data=self.node_data)
# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
# merge inputs
inputs.update(jinja_inputs)
@@ -86,13 +77,17 @@ class LLMNode(BaseNode):
node_inputs = {}
# fetch files
files = self._fetch_files(node_data, variable_pool)
files = (
self._fetch_files(selector=self.node_data.vision.configs.variable_selector)
if self.node_data.vision.enabled
else []
)
if files:
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
generator = self._fetch_context(node_data, variable_pool)
generator = self._fetch_context(node_data=self.node_data)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
@@ -103,21 +98,30 @@ class LLMNode(BaseNode):
node_inputs["#context#"] = context # type: ignore
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
model_instance, model_config = self._fetch_model_config(self.node_data.model)
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance)
# fetch prompt messages
if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not query:
raise ValueError("Query not found")
query = query.text
else:
query = None
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
system_query=query,
inputs=inputs,
files=files,
context=context,
memory=memory,
model_config=model_config,
vision_detail=self.node_data.vision.configs.detail,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
)
process_data = {
@@ -131,7 +135,7 @@ class LLMNode(BaseNode):
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
node_data_model=self.node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
@@ -143,7 +147,7 @@ class LLMNode(BaseNode):
for event in generator:
if isinstance(event, RunStreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompleted):
elif isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
@@ -182,15 +186,7 @@ class LLMNode(BaseNode):
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Invoke large language model
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
) -> Generator[NodeEvent, None, None]:
db.session.close()
invoke_result = model_instance.invoke_llm(
@@ -207,20 +203,13 @@ class LLMNode(BaseNode):
usage = LLMUsage.empty_usage()
for event in generator:
yield event
if isinstance(event, ModelInvokeCompleted):
if isinstance(event, ModelInvokeCompletedEvent):
usage = event.usage
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
if isinstance(invoke_result, LLMResult):
return
@@ -250,18 +239,11 @@ class LLMNode(BaseNode):
if not usage:
usage = LLMUsage.empty_usage()
yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason)
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
def _transform_chat_messages(
self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
Transform chat messages
:param messages: chat messages
:return:
"""
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == "jinja2" and messages.jinja2_text:
messages.text = messages.jinja2_text
@@ -274,13 +256,7 @@ class LLMNode(BaseNode):
return messages
def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch jinja inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables = {}
if not node_data.prompt_config:
@@ -288,7 +264,7 @@ class LLMNode(BaseNode):
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable = variable_selector.variable
value = variable_pool.get_any(variable_selector.value_selector)
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
def parse_dict(d: dict) -> str:
"""
@@ -330,13 +306,7 @@ class LLMNode(BaseNode):
return variables
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
inputs = {}
prompt_template = node_data.prompt_template
@@ -350,7 +320,7 @@ class LLMNode(BaseNode):
variable_selectors = variable_template_parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable_value = variable_pool.get_any(variable_selector.value_selector)
variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
@@ -362,7 +332,7 @@ class LLMNode(BaseNode):
template=memory.query_prompt_template
).extract_variable_selectors()
for variable_selector in query_variable_selectors:
variable_value = variable_pool.get_any(variable_selector.value_selector)
variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
@@ -370,36 +340,28 @@ class LLMNode(BaseNode):
return inputs
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
"""
Fetch files
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if not node_data.vision.enabled:
def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is None:
return []
files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value])
if not files:
if isinstance(variable, FileSegment):
return [variable.value]
if isinstance(variable, ArrayFileSegment):
return variable.value
# FIXME: Temporary fix for empty array,
# all variables added to variable pool should be a Segment instance.
if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0:
return []
raise ValueError(f"Invalid variable type: {type(variable)}")
return files
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
"""
Fetch context
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled:
return
if not node_data.context.variable_selector:
return
context_value = variable_pool.get_any(node_data.context.variable_selector)
context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
@@ -424,11 +386,6 @@ class LLMNode(BaseNode):
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
"""
Convert to original retriever resource, temp.
:param context_dict: context dict
:return:
"""
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
@@ -451,6 +408,7 @@ class LLMNode(BaseNode):
"segment_position": metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
"page": metadata.get("page"),
}
return source
@@ -460,11 +418,6 @@ class LLMNode(BaseNode):
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data_model: node data model
:return:
"""
model_name = node_data_model.name
provider_name = node_data_model.provider
@@ -523,19 +476,15 @@ class LLMNode(BaseNode):
)
def _fetch_memory(
self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
"""
Fetch memory
:param node_data_memory: node data memory
:param variable_pool: variable pool
:return:
"""
if not node_data_memory:
return None
# get conversation id
conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value])
conversation_id = self.graph_runtime_state.variable_pool.get_any(
["sys", SystemVariableKey.CONVERSATION_ID.value]
)
if conversation_id is None:
return None
@@ -555,43 +504,31 @@ class LLMNode(BaseNode):
def _fetch_prompt_messages(
self,
node_data: LLMNodeData,
query: Optional[str],
query_prompt_template: Optional[str],
inputs: dict[str, str],
files: list["FileVar"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
*,
system_query: str | None = None,
inputs: dict[str, str] | None = None,
files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
memory_config: MemoryConfig | None = None,
vision_detail: ImagePromptMessageContent.DETAIL,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt messages
:param node_data: node data
:param query: query
:param query_prompt_template: query prompt template
:param inputs: inputs
:param files: files
:param context: context
:param memory: memory
:param model_config: model config
:return:
"""
inputs = inputs or {}
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_messages = prompt_transform.get_prompt(
prompt_template=node_data.prompt_template,
prompt_template=prompt_template,
inputs=inputs,
query=query or "",
query=system_query or "",
files=files,
context=context,
memory_config=node_data.memory,
memory_config=memory_config,
memory=memory,
model_config=model_config,
query_prompt_template=query_prompt_template,
)
stop = model_config.stop
vision_enabled = node_data.vision.enabled
vision_detail = node_data.vision.configs.detail if node_data.vision.configs else None
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if prompt_message.is_empty():
@@ -599,17 +536,13 @@ class LLMNode(BaseNode):
if not isinstance(prompt_message.content, str):
prompt_message_content = []
for content_item in prompt_message.content:
if (
vision_enabled
and content_item.type == PromptMessageContentType.IMAGE
and isinstance(content_item, ImagePromptMessageContent)
):
# Override vision config if LLM node has vision config
if vision_detail:
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
for content_item in prompt_message.content or []:
if isinstance(content_item, ImagePromptMessageContent):
# Override vision config if LLM node has vision config,
# cuz vision detail is related to the configuration from FileUpload feature.
content_item.detail = vision_detail
prompt_message_content.append(content_item)
elif content_item.type == PromptMessageContentType.TEXT:
elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent):
prompt_message_content.append(content_item)
if len(prompt_message_content) > 1:
@@ -631,13 +564,6 @@ class LLMNode(BaseNode):
@classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
"""
Deduct LLM quota
:param tenant_id: tenant id
:param model_instance: model instance
:param usage: usage
:return:
"""
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
@@ -668,7 +594,7 @@ class LLMNode(BaseNode):
else:
used_quota = 1
if used_quota is not None:
if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_instance.provider,
@@ -680,27 +606,28 @@ class LLMNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
if isinstance(prompt_template, list) and all(
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
):
for prompt in prompt_template:
if prompt.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
else:
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
else:
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}")
variable_mapping = {}
for variable_selector in variable_selectors:
@@ -745,11 +672,6 @@ class LLMNode(BaseNode):
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {
"type": "llm",
"config": {