mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 02:03:02 +08:00
feat: Parallel Execution of Nodes in Workflows (#8192)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -1,16 +1,17 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
@@ -25,7 +26,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
@@ -43,17 +46,26 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
|
||||
class ModelInvokeCompleted(BaseModel):
|
||||
"""
|
||||
Model invoke completed
|
||||
"""
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
_node_data_cls = LLMNodeData
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(LLMNodeData, deepcopy(self.node_data))
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
node_inputs = None
|
||||
process_data = None
|
||||
@@ -80,10 +92,15 @@ class LLMNode(BaseNode):
|
||||
node_inputs['#files#'] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
context = self._fetch_context(node_data, variable_pool)
|
||||
generator = self._fetch_context(node_data, variable_pool)
|
||||
context = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunRetrieverResourceEvent):
|
||||
context = event.context
|
||||
yield event
|
||||
|
||||
if context:
|
||||
node_inputs['#context#'] = context
|
||||
node_inputs['#context#'] = context # type: ignore
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
@@ -115,19 +132,34 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage, finish_reason = self._invoke_llm(
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
)
|
||||
|
||||
result_text = ''
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunStreamChunkEvent):
|
||||
yield event
|
||||
elif isinstance(event, ModelInvokeCompleted):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
break
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
outputs = {
|
||||
'text': result_text,
|
||||
@@ -135,22 +167,26 @@ class LLMNode(BaseNode):
|
||||
'finish_reason': finish_reason
|
||||
}
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
}
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
},
|
||||
llm_usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
def _invoke_llm(self, node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str]) -> tuple[str, LLMUsage]:
|
||||
stop: Optional[list[str]] = None) \
|
||||
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
@@ -170,23 +206,31 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage, finish_reason = self._handle_invoke_result(
|
||||
generator = self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
|
||||
usage = LLMUsage.empty_usage()
|
||||
for event in generator:
|
||||
yield event
|
||||
if isinstance(event, ModelInvokeCompleted):
|
||||
usage = event.usage
|
||||
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage, finish_reason
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
|
||||
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:return:
|
||||
"""
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
return
|
||||
|
||||
model = None
|
||||
prompt_messages = []
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
full_text = ''
|
||||
usage = None
|
||||
finish_reason = None
|
||||
@@ -194,7 +238,10 @@ class LLMNode(BaseNode):
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=text,
|
||||
from_variable_selector=[self.node_id, 'text']
|
||||
)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
@@ -211,11 +258,15 @@ class LLMNode(BaseNode):
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage, finish_reason
|
||||
yield ModelInvokeCompleted(
|
||||
text=full_text,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _transform_chat_messages(self,
|
||||
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
"""
|
||||
Transform chat messages
|
||||
|
||||
@@ -224,13 +275,13 @@ class LLMNode(BaseNode):
|
||||
"""
|
||||
|
||||
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
|
||||
if messages.edition_type == 'jinja2':
|
||||
if messages.edition_type == 'jinja2' and messages.jinja2_text:
|
||||
messages.text = messages.jinja2_text
|
||||
|
||||
return messages
|
||||
|
||||
for message in messages:
|
||||
if message.edition_type == 'jinja2':
|
||||
if message.edition_type == 'jinja2' and message.jinja2_text:
|
||||
message.text = message.jinja2_text
|
||||
|
||||
return messages
|
||||
@@ -348,7 +399,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
return files
|
||||
|
||||
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
|
||||
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
|
||||
"""
|
||||
Fetch context
|
||||
:param node_data: node data
|
||||
@@ -356,15 +407,18 @@ class LLMNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
if not node_data.context.enabled:
|
||||
return None
|
||||
return
|
||||
|
||||
if not node_data.context.variable_selector:
|
||||
return None
|
||||
return
|
||||
|
||||
context_value = variable_pool.get_any(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
if isinstance(context_value, str):
|
||||
return context_value
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=[],
|
||||
context=context_value
|
||||
)
|
||||
elif isinstance(context_value, list):
|
||||
context_str = ''
|
||||
original_retriever_resource = []
|
||||
@@ -381,17 +435,10 @@ class LLMNode(BaseNode):
|
||||
if retriever_resource:
|
||||
original_retriever_resource.append(retriever_resource)
|
||||
|
||||
if self.callbacks and original_retriever_resource:
|
||||
for callback in self.callbacks:
|
||||
callback.on_event(
|
||||
event=QueueRetrieverResourcesEvent(
|
||||
retriever_resources=original_retriever_resource
|
||||
)
|
||||
)
|
||||
|
||||
return context_str.strip()
|
||||
|
||||
return None
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=original_retriever_resource,
|
||||
context=context_str.strip()
|
||||
)
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
||||
"""
|
||||
@@ -574,7 +621,8 @@ class LLMNode(BaseNode):
|
||||
if not isinstance(prompt_message.content, str):
|
||||
prompt_message_content = []
|
||||
for content_item in prompt_message.content:
|
||||
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent):
|
||||
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(
|
||||
content_item, ImagePromptMessageContent):
|
||||
# Override vision config if LLM node has vision config
|
||||
if vision_detail:
|
||||
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
|
||||
@@ -646,13 +694,19 @@ class LLMNode(BaseNode):
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LLMNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
variable_selectors = []
|
||||
@@ -702,6 +756,10 @@ class LLMNode(BaseNode):
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {
|
||||
node_id + '.' + key: value for key, value in variable_mapping.items()
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user