feat: add ops trace (#5483)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Joe
2024-06-26 17:33:29 +08:00
committed by GitHub
parent 31a061ebaa
commit 4e2de638af
58 changed files with 3553 additions and 622 deletions

View File

@@ -44,6 +44,7 @@ from core.model_runtime.entities.message_entities import (
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from events.message_event import message_was_created
@@ -100,7 +101,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._conversation_name_generate_thread = None
def process(self) -> Union[
def process(
self,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
@@ -120,7 +123,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._application_generate_entity.query
)
generator = self._process_stream_response()
generator = self._process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
return self._to_stream_response(generator)
else:
@@ -197,7 +202,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
stream_response=stream_response
)
def _process_stream_response(self) -> Generator[StreamResponse, None, None]:
def _process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
@@ -224,7 +231,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message()
self._save_message(trace_manager)
yield self._message_end_to_stream_response()
elif isinstance(event, QueueRetrieverResourcesEvent):
@@ -269,7 +276,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self) -> None:
def _save_message(
self, trace_manager: Optional[TraceQueueManager] = None
) -> None:
"""
Save message.
:return:
@@ -300,6 +309,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
db.session.commit()
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE,
conversation_id=self._conversation.id,
message_id=self._message.id
)
)
message_was_created.send(
self._message,
application_generate_entity=self._application_generate_entity,

View File

@@ -22,6 +22,7 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
@@ -94,11 +95,15 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_run
def _workflow_run_success(self, workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Optional[str] = None) -> WorkflowRun:
def _workflow_run_success(
self, workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Optional[str] = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run: workflow run
@@ -106,6 +111,7 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
@@ -119,14 +125,27 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
db.session.refresh(workflow_run)
db.session.close()
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
)
)
return workflow_run
def _workflow_run_failed(self, workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
error: str) -> WorkflowRun:
def _workflow_run_failed(
self, workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
error: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> WorkflowRun:
"""
Workflow run failed
:param workflow_run: workflow run
@@ -148,6 +167,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
db.session.refresh(workflow_run)
db.session.close()
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
)
)
return workflow_run
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
@@ -180,7 +207,8 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
title=node_title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by
created_by=workflow_run.created_by,
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
db.session.add(workflow_node_execution)
@@ -440,9 +468,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
if self._iteration_state and self._iteration_state.current_iterations:
if not execution_metadata:
execution_metadata = {}
@@ -470,7 +498,7 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
self._task_state.total_tokens += (
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
if self._iteration_state:
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
@@ -496,13 +524,18 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_node_execution
def _handle_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
-> Optional[WorkflowRun]:
def _handle_workflow_finished(
self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> Optional[WorkflowRun]:
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == self._task_state.workflow_run_id).first()
if not workflow_run:
return None
if conversation_id is None:
conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id')
if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
@@ -510,7 +543,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.STOPPED,
error='Workflow stopped.'
error='Workflow stopped.',
conversation_id=conversation_id,
trace_manager=trace_manager
)
latest_node_execution_info = self._task_state.latest_node_execution_info
@@ -531,7 +566,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.FAILED,
error=event.error
error=event.error,
conversation_id=conversation_id,
trace_manager=trace_manager
)
else:
if self._task_state.latest_node_execution_info:
@@ -546,7 +583,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
outputs=outputs
outputs=outputs,
conversation_id=conversation_id,
trace_manager=trace_manager
)
self._task_state.workflow_run_id = workflow_run.id

View File

@@ -1,6 +1,7 @@
import json
import time
from collections.abc import Generator
from datetime import datetime, timezone
from typing import Optional, Union
from core.app.entities.queue_entities import (
@@ -131,7 +132,8 @@ class WorkflowIterationCycleManage(WorkflowCycleStateManager):
'started_run_index': node_run_index + 1,
'current_index': 0,
'steps_boundary': [],
})
}),
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
db.session.add(workflow_node_execution)