feat: Parallel Execution of Nodes in Workflows (#8192)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
takatost
2024-09-10 15:23:16 +08:00
committed by GitHub
parent 5da0182800
commit dabfd74622
156 changed files with 11158 additions and 5605 deletions

View File

@@ -1,16 +1,17 @@
import json
from collections.abc import Generator
from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
@@ -25,7 +26,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
@@ -43,17 +46,26 @@ if TYPE_CHECKING:
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
finish_reason: Optional[str] = None
class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = cast(LLMNodeData, deepcopy(self.node_data))
variable_pool = self.graph_runtime_state.variable_pool
node_inputs = None
process_data = None
@@ -80,10 +92,15 @@ class LLMNode(BaseNode):
node_inputs['#files#'] = [file.to_dict() for file in files]
# fetch context value
context = self._fetch_context(node_data, variable_pool)
generator = self._fetch_context(node_data, variable_pool)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
context = event.context
yield event
if context:
node_inputs['#context#'] = context
node_inputs['#context#'] = context # type: ignore
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
@@ -115,19 +132,34 @@ class LLMNode(BaseNode):
}
# handle invoke result
result_text, usage, finish_reason = self._invoke_llm(
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
)
result_text = ''
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
if isinstance(event, RunStreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompleted):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data
)
)
return
outputs = {
'text': result_text,
@@ -135,22 +167,26 @@ class LLMNode(BaseNode):
'finish_reason': finish_reason
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
outputs=outputs,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
outputs=outputs,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
},
llm_usage=usage
)
)
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str]) -> tuple[str, LLMUsage]:
stop: Optional[list[str]] = None) \
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Invoke large language model
:param node_data_model: node data model
@@ -170,23 +206,31 @@ class LLMNode(BaseNode):
)
# handle invoke result
text, usage, finish_reason = self._handle_invoke_result(
generator = self._handle_invoke_result(
invoke_result=invoke_result
)
usage = LLMUsage.empty_usage()
for event in generator:
yield event
if isinstance(event, ModelInvokeCompleted):
usage = event.usage
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage, finish_reason
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
if isinstance(invoke_result, LLMResult):
return
model = None
prompt_messages = []
prompt_messages: list[PromptMessage] = []
full_text = ''
usage = None
finish_reason = None
@@ -194,7 +238,10 @@ class LLMNode(BaseNode):
text = result.delta.message.content
full_text += text
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
yield RunStreamChunkEvent(
chunk_content=text,
from_variable_selector=[self.node_id, 'text']
)
if not model:
model = result.model
@@ -211,11 +258,15 @@ class LLMNode(BaseNode):
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage, finish_reason
yield ModelInvokeCompleted(
text=full_text,
usage=usage,
finish_reason=finish_reason
)
def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
Transform chat messages
@@ -224,13 +275,13 @@ class LLMNode(BaseNode):
"""
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == 'jinja2':
if messages.edition_type == 'jinja2' and messages.jinja2_text:
messages.text = messages.jinja2_text
return messages
for message in messages:
if message.edition_type == 'jinja2':
if message.edition_type == 'jinja2' and message.jinja2_text:
message.text = message.jinja2_text
return messages
@@ -348,7 +399,7 @@ class LLMNode(BaseNode):
return files
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
"""
Fetch context
:param node_data: node data
@@ -356,15 +407,18 @@ class LLMNode(BaseNode):
:return:
"""
if not node_data.context.enabled:
return None
return
if not node_data.context.variable_selector:
return None
return
context_value = variable_pool.get_any(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
return context_value
yield RunRetrieverResourceEvent(
retriever_resources=[],
context=context_value
)
elif isinstance(context_value, list):
context_str = ''
original_retriever_resource = []
@@ -381,17 +435,10 @@ class LLMNode(BaseNode):
if retriever_resource:
original_retriever_resource.append(retriever_resource)
if self.callbacks and original_retriever_resource:
for callback in self.callbacks:
callback.on_event(
event=QueueRetrieverResourcesEvent(
retriever_resources=original_retriever_resource
)
)
return context_str.strip()
return None
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource,
context=context_str.strip()
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
"""
@@ -574,7 +621,8 @@ class LLMNode(BaseNode):
if not isinstance(prompt_message.content, str):
prompt_message_content = []
for content_item in prompt_message.content:
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent):
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(
content_item, ImagePromptMessageContent):
# Override vision config if LLM node has vision config
if vision_detail:
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
@@ -646,13 +694,19 @@ class LLMNode(BaseNode):
db.session.commit()
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
prompt_template = node_data.prompt_template
variable_selectors = []
@@ -702,6 +756,10 @@ class LLMNode(BaseNode):
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {
node_id + '.' + key: value for key, value in variable_mapping.items()
}
return variable_mapping
@classmethod