mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 10:13:01 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -49,7 +49,6 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
|
||||
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
|
||||
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
||||
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
|
||||
from langsmith import Client
|
||||
from langsmith.schemas import RunBase
|
||||
@@ -63,6 +64,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.message_id or trace_info.workflow_run_id
|
||||
if trace_info.start_time is None:
|
||||
trace_info.start_time = datetime.now()
|
||||
message_dotted_order = (
|
||||
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
|
||||
)
|
||||
@@ -78,8 +81,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
message_run = LangSmithRunModel(
|
||||
id=trace_info.message_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
inputs=trace_info.workflow_run_inputs,
|
||||
outputs=trace_info.workflow_run_outputs,
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
run_type=LangSmithRunType.chain,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
@@ -90,6 +93,15 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
error=trace_info.error,
|
||||
trace_id=trace_id,
|
||||
dotted_order=message_dotted_order,
|
||||
file_list=[],
|
||||
serialized=None,
|
||||
parent_run_id=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
)
|
||||
self.add_run(message_run)
|
||||
|
||||
@@ -98,11 +110,11 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=trace_info.workflow_run_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
inputs=trace_info.workflow_run_inputs,
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
run_type=LangSmithRunType.tool,
|
||||
start_time=trace_info.workflow_data.created_at,
|
||||
end_time=trace_info.workflow_data.finished_at,
|
||||
outputs=trace_info.workflow_run_outputs,
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
extra={
|
||||
"metadata": metadata,
|
||||
},
|
||||
@@ -111,6 +123,13 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
parent_run_id=trace_info.message_id or None,
|
||||
trace_id=trace_id,
|
||||
dotted_order=workflow_dotted_order,
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
@@ -211,25 +230,35 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
id=node_execution_id,
|
||||
trace_id=trace_id,
|
||||
dotted_order=node_dotted_order,
|
||||
error="",
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = trace_info.file_list
|
||||
message_file_data: MessageFile = trace_info.message_file_data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
metadata = trace_info.metadata
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
message_id = message_data.id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
metadata["user_id"] = user_id
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser = (
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
@@ -247,12 +276,20 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
outputs=message_data.answer,
|
||||
extra={
|
||||
"metadata": metadata,
|
||||
},
|
||||
extra={"metadata": metadata},
|
||||
tags=["message", str(trace_info.conversation_mode)],
|
||||
error=trace_info.error,
|
||||
file_list=file_list,
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
dotted_order=None,
|
||||
parent_run_id=None,
|
||||
)
|
||||
self.add_run(message_run)
|
||||
|
||||
@@ -267,17 +304,27 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
outputs=message_data.answer,
|
||||
extra={
|
||||
"metadata": metadata,
|
||||
},
|
||||
extra={"metadata": metadata},
|
||||
parent_run_id=message_id,
|
||||
tags=["llm", str(trace_info.conversation_mode)],
|
||||
error=trace_info.error,
|
||||
file_list=file_list,
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
dotted_order=None,
|
||||
id=str(uuid.uuid4()),
|
||||
)
|
||||
self.add_run(llm_run)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
langsmith_run = LangSmithRunModel(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
inputs=trace_info.inputs,
|
||||
@@ -288,48 +335,82 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
run_type=LangSmithRunType.tool,
|
||||
extra={
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
extra={"metadata": trace_info.metadata},
|
||||
tags=["moderation"],
|
||||
parent_run_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
||||
end_time=trace_info.end_time or trace_info.message_data.updated_at,
|
||||
id=str(uuid.uuid4()),
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
dotted_order=None,
|
||||
error="",
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
suggested_question_run = LangSmithRunModel(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.suggested_question,
|
||||
run_type=LangSmithRunType.tool,
|
||||
extra={
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
extra={"metadata": trace_info.metadata},
|
||||
tags=["suggested_question"],
|
||||
parent_run_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time or message_data.created_at,
|
||||
end_time=trace_info.end_time or message_data.updated_at,
|
||||
id=str(uuid.uuid4()),
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
dotted_order=None,
|
||||
error="",
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.add_run(suggested_question_run)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
dataset_retrieval_run = LangSmithRunModel(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
inputs=trace_info.inputs,
|
||||
outputs={"documents": trace_info.documents},
|
||||
run_type=LangSmithRunType.retriever,
|
||||
extra={
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
extra={"metadata": trace_info.metadata},
|
||||
tags=["dataset_retrieval"],
|
||||
parent_run_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
||||
end_time=trace_info.end_time or trace_info.message_data.updated_at,
|
||||
id=str(uuid.uuid4()),
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
dotted_order=None,
|
||||
error="",
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.add_run(dataset_retrieval_run)
|
||||
@@ -347,7 +428,18 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
parent_run_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
file_list=[trace_info.file_url],
|
||||
file_list=[cast(str, trace_info.file_url)],
|
||||
id=str(uuid.uuid4()),
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
dotted_order=None,
|
||||
error=trace_info.error or "",
|
||||
)
|
||||
|
||||
self.add_run(tool_run)
|
||||
@@ -358,12 +450,23 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
run_type=LangSmithRunType.tool,
|
||||
extra={
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
extra={"metadata": trace_info.metadata},
|
||||
tags=["generate_name"],
|
||||
start_time=trace_info.start_time or datetime.now(),
|
||||
end_time=trace_info.end_time or datetime.now(),
|
||||
id=str(uuid.uuid4()),
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
dotted_order=None,
|
||||
error="",
|
||||
file_list=[],
|
||||
parent_run_id=None,
|
||||
)
|
||||
|
||||
self.add_run(name_run)
|
||||
|
||||
Reference in New Issue
Block a user