mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-11 03:46:52 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -33,11 +33,11 @@ from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.workflow import WorkflowAppLog, WorkflowRun
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
provider_config_map = {
|
||||
provider_config_map: dict[str, dict[str, Any]] = {
|
||||
TracingProviderEnum.LANGFUSE.value: {
|
||||
"config_class": LangfuseConfig,
|
||||
"secret_keys": ["public_key", "secret_key"],
|
||||
@@ -145,7 +145,7 @@ class OpsTraceManager:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config_data: TraceAppConfig = (
|
||||
trace_config_data: Optional[TraceAppConfig] = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
@@ -155,7 +155,11 @@ class OpsTraceManager:
|
||||
return None
|
||||
|
||||
# decrypt_token
|
||||
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
|
||||
app = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
|
||||
tenant_id = app.tenant_id
|
||||
decrypt_tracing_config = cls.decrypt_tracing_config(
|
||||
tenant_id, tracing_provider, trace_config_data.tracing_config
|
||||
)
|
||||
@@ -178,7 +182,7 @@ class OpsTraceManager:
|
||||
if app_id is None:
|
||||
return None
|
||||
|
||||
app: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
|
||||
|
||||
if app is None:
|
||||
return None
|
||||
@@ -209,8 +213,12 @@ class OpsTraceManager:
|
||||
def get_app_config_through_message_id(cls, message_id: str):
|
||||
app_model_config = None
|
||||
message_data = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
if not message_data:
|
||||
return None
|
||||
conversation_id = message_data.conversation_id
|
||||
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
if not conversation_data:
|
||||
return None
|
||||
|
||||
if conversation_data.app_model_config_id:
|
||||
app_model_config = (
|
||||
@@ -236,7 +244,9 @@ class OpsTraceManager:
|
||||
if tracing_provider not in provider_config_map and tracing_provider is not None:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
app_config: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_config:
|
||||
raise ValueError("App not found")
|
||||
app_config.tracing = json.dumps(
|
||||
{
|
||||
"enabled": enabled,
|
||||
@@ -252,7 +262,9 @@ class OpsTraceManager:
|
||||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
app: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
if not app.tracing:
|
||||
return {"enabled": False, "tracing_provider": None}
|
||||
app_trace_config = json.loads(app.tracing)
|
||||
@@ -483,6 +495,8 @@ class TraceTask:
|
||||
|
||||
def moderation_trace(self, message_id, timer, **kwargs):
|
||||
moderation_result = kwargs.get("moderation_result")
|
||||
if not moderation_result:
|
||||
return {}
|
||||
inputs = kwargs.get("inputs")
|
||||
message_data = get_message_data(message_id)
|
||||
if not message_data:
|
||||
@@ -518,7 +532,7 @@ class TraceTask:
|
||||
return moderation_trace_info
|
||||
|
||||
def suggested_question_trace(self, message_id, timer, **kwargs):
|
||||
suggested_question = kwargs.get("suggested_question")
|
||||
suggested_question = kwargs.get("suggested_question", [])
|
||||
message_data = get_message_data(message_id)
|
||||
if not message_data:
|
||||
return {}
|
||||
@@ -586,7 +600,7 @@ class TraceTask:
|
||||
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id=message_id,
|
||||
inputs=message_data.query or message_data.inputs,
|
||||
documents=[doc.model_dump() for doc in documents],
|
||||
documents=[doc.model_dump() for doc in documents] if documents else [],
|
||||
start_time=timer.get("start"),
|
||||
end_time=timer.get("end"),
|
||||
metadata=metadata,
|
||||
@@ -596,9 +610,9 @@ class TraceTask:
|
||||
return dataset_retrieval_trace_info
|
||||
|
||||
def tool_trace(self, message_id, timer, **kwargs):
|
||||
tool_name = kwargs.get("tool_name")
|
||||
tool_inputs = kwargs.get("tool_inputs")
|
||||
tool_outputs = kwargs.get("tool_outputs")
|
||||
tool_name = kwargs.get("tool_name", "")
|
||||
tool_inputs = kwargs.get("tool_inputs", {})
|
||||
tool_outputs = kwargs.get("tool_outputs", {})
|
||||
message_data = get_message_data(message_id)
|
||||
if not message_data:
|
||||
return {}
|
||||
@@ -608,7 +622,7 @@ class TraceTask:
|
||||
tool_parameters = {}
|
||||
created_time = message_data.created_at
|
||||
end_time = message_data.updated_at
|
||||
agent_thoughts: list[MessageAgentThought] = message_data.agent_thoughts
|
||||
agent_thoughts = message_data.agent_thoughts
|
||||
for agent_thought in agent_thoughts:
|
||||
if tool_name in agent_thought.tools:
|
||||
created_time = agent_thought.created_at
|
||||
@@ -672,6 +686,8 @@ class TraceTask:
|
||||
generate_conversation_name = kwargs.get("generate_conversation_name")
|
||||
inputs = kwargs.get("inputs")
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
if not tenant_id:
|
||||
return {}
|
||||
start_time = timer.get("start")
|
||||
end_time = timer.get("end")
|
||||
|
||||
@@ -693,8 +709,8 @@ class TraceTask:
|
||||
return generate_name_trace_info
|
||||
|
||||
|
||||
trace_manager_timer = None
|
||||
trace_manager_queue = queue.Queue()
|
||||
trace_manager_timer: Optional[threading.Timer] = None
|
||||
trace_manager_queue: queue.Queue = queue.Queue()
|
||||
trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
|
||||
trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
|
||||
|
||||
@@ -706,7 +722,7 @@ class TraceQueueManager:
|
||||
self.app_id = app_id
|
||||
self.user_id = user_id
|
||||
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
|
||||
self.flask_app = current_app._get_current_object()
|
||||
self.flask_app = current_app._get_current_object() # type: ignore
|
||||
if trace_manager_timer is None:
|
||||
self.start_timer()
|
||||
|
||||
@@ -723,7 +739,7 @@ class TraceQueueManager:
|
||||
|
||||
def collect_tasks(self):
|
||||
global trace_manager_queue
|
||||
tasks = []
|
||||
tasks: list[TraceTask] = []
|
||||
while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
|
||||
task = trace_manager_queue.get_nowait()
|
||||
tasks.append(task)
|
||||
@@ -749,6 +765,8 @@ class TraceQueueManager:
|
||||
def send_to_celery(self, tasks: list[TraceTask]):
|
||||
with self.flask_app.app_context():
|
||||
for task in tasks:
|
||||
if task.app_id is None:
|
||||
continue
|
||||
file_id = uuid4().hex
|
||||
trace_info = task.execute()
|
||||
task_data = TaskData(
|
||||
|
||||
Reference in New Issue
Block a user