mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-11 20:06:54 +08:00
feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from .node import LLMNode
|
||||
|
||||
__all__ = [
|
||||
"LLMNode",
|
||||
"LLMNodeChatModelMessage",
|
||||
"LLMNodeCompletionModelPromptTemplate",
|
||||
"LLMNodeData",
|
||||
"ModelConfig",
|
||||
"VisionConfig",
|
||||
]
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
@@ -19,62 +17,36 @@ class ModelConfig(BaseModel):
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
"""
|
||||
Context Config.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
variable_selector: Optional[list[str]] = None
|
||||
|
||||
|
||||
class VisionConfigOptions(BaseModel):
|
||||
variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"])
|
||||
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
|
||||
class VisionConfig(BaseModel):
|
||||
"""
|
||||
Vision Config.
|
||||
"""
|
||||
|
||||
class Configs(BaseModel):
|
||||
"""
|
||||
Configs.
|
||||
"""
|
||||
|
||||
detail: Literal["low", "high"]
|
||||
|
||||
enabled: bool
|
||||
configs: Optional[Configs] = None
|
||||
enabled: bool = False
|
||||
configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions)
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""
|
||||
Prompt Config.
|
||||
"""
|
||||
|
||||
jinja2_variables: Optional[list[VariableSelector]] = None
|
||||
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
"""
|
||||
LLM Node Chat Model Message.
|
||||
"""
|
||||
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
|
||||
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
"""
|
||||
LLM Node Chat Model Prompt Template.
|
||||
"""
|
||||
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
"""
|
||||
LLM Node Data.
|
||||
"""
|
||||
|
||||
model: ModelConfig
|
||||
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: Optional[PromptConfig] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@@ -1,39 +1,40 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
NodeEvent,
|
||||
RunCompletedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunStreamChunkEvent,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
@@ -41,44 +42,34 @@ from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class ModelInvokeCompleted(BaseModel):
|
||||
"""
|
||||
Model invoke completed
|
||||
"""
|
||||
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
class LLMNode(BaseNode[LLMNodeData]):
|
||||
_node_data_cls = LLMNodeData
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(LLMNodeData, deepcopy(self.node_data))
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
node_inputs = None
|
||||
process_data = None
|
||||
|
||||
try:
|
||||
# init messages template
|
||||
node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
|
||||
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data, variable_pool)
|
||||
inputs = self._fetch_inputs(node_data=self.node_data)
|
||||
|
||||
# fetch jinja2 inputs
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
|
||||
|
||||
# merge inputs
|
||||
inputs.update(jinja_inputs)
|
||||
@@ -86,13 +77,17 @@ class LLMNode(BaseNode):
|
||||
node_inputs = {}
|
||||
|
||||
# fetch files
|
||||
files = self._fetch_files(node_data, variable_pool)
|
||||
files = (
|
||||
self._fetch_files(selector=self.node_data.vision.configs.variable_selector)
|
||||
if self.node_data.vision.enabled
|
||||
else []
|
||||
)
|
||||
|
||||
if files:
|
||||
node_inputs["#files#"] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
generator = self._fetch_context(node_data, variable_pool)
|
||||
generator = self._fetch_context(node_data=self.node_data)
|
||||
context = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunRetrieverResourceEvent):
|
||||
@@ -103,21 +98,30 @@ class LLMNode(BaseNode):
|
||||
node_inputs["#context#"] = context # type: ignore
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
model_instance, model_config = self._fetch_model_config(self.node_data.model)
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance)
|
||||
|
||||
# fetch prompt messages
|
||||
if self.node_data.memory:
|
||||
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
if not query:
|
||||
raise ValueError("Query not found")
|
||||
query = query.text
|
||||
else:
|
||||
query = None
|
||||
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None,
|
||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||
system_query=query,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
)
|
||||
|
||||
process_data = {
|
||||
@@ -131,7 +135,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
node_data_model=self.node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
@@ -143,7 +147,7 @@ class LLMNode(BaseNode):
|
||||
for event in generator:
|
||||
if isinstance(event, RunStreamChunkEvent):
|
||||
yield event
|
||||
elif isinstance(event, ModelInvokeCompleted):
|
||||
elif isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
@@ -182,15 +186,7 @@ class LLMNode(BaseNode):
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
@@ -207,20 +203,13 @@ class LLMNode(BaseNode):
|
||||
usage = LLMUsage.empty_usage()
|
||||
for event in generator:
|
||||
yield event
|
||||
if isinstance(event, ModelInvokeCompleted):
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
usage = event.usage
|
||||
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
def _handle_invoke_result(
|
||||
self, invoke_result: LLMResult | Generator
|
||||
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:return:
|
||||
"""
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
return
|
||||
|
||||
@@ -250,18 +239,11 @@ class LLMNode(BaseNode):
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason)
|
||||
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
"""
|
||||
Transform chat messages
|
||||
|
||||
:param messages: chat messages
|
||||
:return:
|
||||
"""
|
||||
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
|
||||
if messages.edition_type == "jinja2" and messages.jinja2_text:
|
||||
messages.text = messages.jinja2_text
|
||||
@@ -274,13 +256,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
return messages
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Fetch jinja inputs
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
variables = {}
|
||||
|
||||
if not node_data.prompt_config:
|
||||
@@ -288,7 +264,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
|
||||
def parse_dict(d: dict) -> str:
|
||||
"""
|
||||
@@ -330,13 +306,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
return variables
|
||||
|
||||
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Fetch inputs
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
inputs = {}
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
@@ -350,7 +320,7 @@ class LLMNode(BaseNode):
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
for variable_selector in variable_selectors:
|
||||
variable_value = variable_pool.get_any(variable_selector.value_selector)
|
||||
variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
@@ -362,7 +332,7 @@ class LLMNode(BaseNode):
|
||||
template=memory.query_prompt_template
|
||||
).extract_variable_selectors()
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_value = variable_pool.get_any(variable_selector.value_selector)
|
||||
variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
@@ -370,36 +340,28 @@ class LLMNode(BaseNode):
|
||||
|
||||
return inputs
|
||||
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
|
||||
"""
|
||||
Fetch files
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data.vision.enabled:
|
||||
def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if variable is None:
|
||||
return []
|
||||
|
||||
files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value])
|
||||
if not files:
|
||||
if isinstance(variable, FileSegment):
|
||||
return [variable.value]
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
return variable.value
|
||||
# FIXME: Temporary fix for empty array,
|
||||
# all variables added to variable pool should be a Segment instance.
|
||||
if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0:
|
||||
return []
|
||||
raise ValueError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
return files
|
||||
|
||||
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
|
||||
"""
|
||||
Fetch context
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
def _fetch_context(self, node_data: LLMNodeData):
|
||||
if not node_data.context.enabled:
|
||||
return
|
||||
|
||||
if not node_data.context.variable_selector:
|
||||
return
|
||||
|
||||
context_value = variable_pool.get_any(node_data.context.variable_selector)
|
||||
context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
if isinstance(context_value, str):
|
||||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
|
||||
@@ -424,11 +386,6 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
||||
"""
|
||||
Convert to original retriever resource, temp.
|
||||
:param context_dict: context dict
|
||||
:return:
|
||||
"""
|
||||
if (
|
||||
"metadata" in context_dict
|
||||
and "_source" in context_dict["metadata"]
|
||||
@@ -451,6 +408,7 @@ class LLMNode(BaseNode):
|
||||
"segment_position": metadata.get("segment_position"),
|
||||
"index_node_hash": metadata.get("segment_index_node_hash"),
|
||||
"content": context_dict.get("content"),
|
||||
"page": metadata.get("page"),
|
||||
}
|
||||
|
||||
return source
|
||||
@@ -460,11 +418,6 @@ class LLMNode(BaseNode):
|
||||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data_model: node data model
|
||||
:return:
|
||||
"""
|
||||
model_name = node_data_model.name
|
||||
provider_name = node_data_model.provider
|
||||
|
||||
@@ -523,19 +476,15 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
|
||||
def _fetch_memory(
|
||||
self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance
|
||||
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
|
||||
) -> Optional[TokenBufferMemory]:
|
||||
"""
|
||||
Fetch memory
|
||||
:param node_data_memory: node data memory
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value])
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get_any(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||
)
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
@@ -555,43 +504,31 @@ class LLMNode(BaseNode):
|
||||
|
||||
def _fetch_prompt_messages(
|
||||
self,
|
||||
node_data: LLMNodeData,
|
||||
query: Optional[str],
|
||||
query_prompt_template: Optional[str],
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
*,
|
||||
system_query: str | None = None,
|
||||
inputs: dict[str, str] | None = None,
|
||||
files: Sequence["File"],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Fetch prompt messages
|
||||
:param node_data: node data
|
||||
:param query: query
|
||||
:param query_prompt_template: query prompt template
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param context: context
|
||||
:param memory: memory
|
||||
:param model_config: model config
|
||||
:return:
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=node_data.prompt_template,
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query or "",
|
||||
query=system_query or "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
query_prompt_template=query_prompt_template,
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
vision_enabled = node_data.vision.enabled
|
||||
vision_detail = node_data.vision.configs.detail if node_data.vision.configs else None
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.is_empty():
|
||||
@@ -599,17 +536,13 @@ class LLMNode(BaseNode):
|
||||
|
||||
if not isinstance(prompt_message.content, str):
|
||||
prompt_message_content = []
|
||||
for content_item in prompt_message.content:
|
||||
if (
|
||||
vision_enabled
|
||||
and content_item.type == PromptMessageContentType.IMAGE
|
||||
and isinstance(content_item, ImagePromptMessageContent)
|
||||
):
|
||||
# Override vision config if LLM node has vision config
|
||||
if vision_detail:
|
||||
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
|
||||
for content_item in prompt_message.content or []:
|
||||
if isinstance(content_item, ImagePromptMessageContent):
|
||||
# Override vision config if LLM node has vision config,
|
||||
# cuz vision detail is related to the configuration from FileUpload feature.
|
||||
content_item.detail = vision_detail
|
||||
prompt_message_content.append(content_item)
|
||||
elif content_item.type == PromptMessageContentType.TEXT:
|
||||
elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent):
|
||||
prompt_message_content.append(content_item)
|
||||
|
||||
if len(prompt_message_content) > 1:
|
||||
@@ -631,13 +564,6 @@ class LLMNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
"""
|
||||
Deduct LLM quota
|
||||
:param tenant_id: tenant id
|
||||
:param model_instance: model instance
|
||||
:param usage: usage
|
||||
:return:
|
||||
"""
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
@@ -668,7 +594,7 @@ class LLMNode(BaseNode):
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None:
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_instance.provider,
|
||||
@@ -680,27 +606,28 @@ class LLMNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LLMNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
if isinstance(prompt_template, list) and all(
|
||||
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
|
||||
):
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type != "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
else:
|
||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
if prompt_template.edition_type != "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
else:
|
||||
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}")
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
@@ -745,11 +672,6 @@ class LLMNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"type": "llm",
|
||||
"config": {
|
||||
Reference in New Issue
Block a user