chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -36,17 +36,17 @@ from tasks.ops_trace_task import process_trace_tasks
provider_config_map = {
TracingProviderEnum.LANGFUSE.value: {
'config_class': LangfuseConfig,
'secret_keys': ['public_key', 'secret_key'],
'other_keys': ['host', 'project_key'],
'trace_instance': LangFuseDataTrace
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
"other_keys": ["host", "project_key"],
"trace_instance": LangFuseDataTrace,
},
TracingProviderEnum.LANGSMITH.value: {
'config_class': LangSmithConfig,
'secret_keys': ['api_key'],
'other_keys': ['project', 'endpoint'],
'trace_instance': LangSmithDataTrace
}
"config_class": LangSmithConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
},
}
@@ -64,14 +64,17 @@ class OpsTraceManager:
:return: encrypted tracing configuration
"""
# Get the configuration class and the keys that require encryption
config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys']
config_class, secret_keys, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["secret_keys"],
provider_config_map[tracing_provider]["other_keys"],
)
new_config = {}
# Encrypt necessary keys
for key in secret_keys:
if key in tracing_config:
if '*' in tracing_config[key]:
if "*" in tracing_config[key]:
# If the key contains '*', retain the original value from the current config
new_config[key] = current_trace_config.get(key, tracing_config[key])
else:
@@ -94,8 +97,11 @@ class OpsTraceManager:
:param tracing_config: tracing config
:return:
"""
config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys']
config_class, secret_keys, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["secret_keys"],
provider_config_map[tracing_provider]["other_keys"],
)
new_config = {}
for key in secret_keys:
if key in tracing_config:
@@ -114,8 +120,11 @@ class OpsTraceManager:
:param decrypt_tracing_config: tracing config
:return:
"""
config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys']
config_class, secret_keys, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["secret_keys"],
provider_config_map[tracing_provider]["other_keys"],
)
new_config = {}
for key in secret_keys:
if key in decrypt_tracing_config:
@@ -133,9 +142,11 @@ class OpsTraceManager:
:param tracing_provider: tracing provider
:return:
"""
trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
).first()
trace_config_data: TraceAppConfig = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if not trace_config_data:
return None
@@ -164,21 +175,21 @@ class OpsTraceManager:
if app_id is None:
return None
app: App = db.session.query(App).filter(
App.id == app_id
).first()
app: App = db.session.query(App).filter(App.id == app_id).first()
app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
if app_ops_trace_config is not None:
tracing_provider = app_ops_trace_config.get('tracing_provider')
tracing_provider = app_ops_trace_config.get("tracing_provider")
else:
return None
# decrypt_token
decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
if app_ops_trace_config.get('enabled'):
trace_instance, config_class = provider_config_map[tracing_provider]['trace_instance'], \
provider_config_map[tracing_provider]['config_class']
if app_ops_trace_config.get("enabled"):
trace_instance, config_class = (
provider_config_map[tracing_provider]["trace_instance"],
provider_config_map[tracing_provider]["config_class"],
)
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
return tracing_instance
@@ -192,9 +203,11 @@ class OpsTraceManager:
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if conversation_data.app_model_config_id:
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation_data.app_model_config_id
).first()
app_model_config = (
db.session.query(AppModelConfig)
.filter(AppModelConfig.id == conversation_data.app_model_config_id)
.first()
)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
app_model_config = conversation_data.override_model_configs
@@ -231,10 +244,7 @@ class OpsTraceManager:
"""
app: App = db.session.query(App).filter(App.id == app_id).first()
if not app.tracing:
return {
"enabled": False,
"tracing_provider": None
}
return {"enabled": False, "tracing_provider": None}
app_trace_config = json.loads(app.tracing)
return app_trace_config
@@ -246,8 +256,10 @@ class OpsTraceManager:
:param tracing_provider: tracing provider
:return:
"""
config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['trace_instance']
config_type, trace_instance = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["trace_instance"],
)
tracing_config = config_type(**tracing_config)
return trace_instance(tracing_config).api_check()
@@ -259,8 +271,10 @@ class OpsTraceManager:
:param tracing_provider: tracing provider
:return:
"""
config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['trace_instance']
config_type, trace_instance = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["trace_instance"],
)
tracing_config = config_type(**tracing_config)
return trace_instance(tracing_config).get_project_key()
@@ -272,8 +286,10 @@ class OpsTraceManager:
:param tracing_provider: tracing provider
:return:
"""
config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \
provider_config_map[tracing_provider]['trace_instance']
config_type, trace_instance = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["trace_instance"],
)
tracing_config = config_type(**tracing_config)
return trace_instance(tracing_config).get_project_url()
@@ -287,7 +303,7 @@ class TraceTask:
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
timer: Optional[Any] = None,
**kwargs
**kwargs,
):
self.trace_type = trace_type
self.message_id = message_id
@@ -310,9 +326,7 @@ class TraceTask:
self.workflow_run, self.conversation_id, self.user_id
),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
self.message_id, self.timer, **self.kwargs
),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs),
TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
self.message_id, self.timer, **self.kwargs
),
@@ -337,12 +351,8 @@ class TraceTask:
workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status
workflow_run_inputs = (
json.loads(workflow_run.inputs) if workflow_run.inputs else {}
)
workflow_run_outputs = (
json.loads(workflow_run.outputs) if workflow_run.outputs else {}
)
workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {}
workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {}
workflow_run_version = workflow_run.version
error = workflow_run.error if workflow_run.error else ""
@@ -352,17 +362,18 @@ class TraceTask:
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
# get workflow_app_log_id
workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by(
tenant_id=tenant_id,
app_id=workflow_run.app_id,
workflow_run_id=workflow_run.id
).first()
workflow_app_log_data = (
db.session.query(WorkflowAppLog)
.filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id)
.first()
)
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
# get message_id
message_data = db.session.query(Message.id).filter_by(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id
).first()
message_data = (
db.session.query(Message.id)
.filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id)
.first()
)
message_id = str(message_data.id) if message_data else None
metadata = {
@@ -470,9 +481,9 @@ class TraceTask:
# get workflow_app_log_id
workflow_app_log_id = None
if message_data.workflow_run_id:
workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by(
workflow_run_id=message_data.workflow_run_id
).first()
workflow_app_log_data = (
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
)
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
moderation_trace_info = ModerationTraceInfo(
@@ -510,9 +521,9 @@ class TraceTask:
# get workflow_app_log_id
workflow_app_log_id = None
if message_data.workflow_run_id:
workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by(
workflow_run_id=message_data.workflow_run_id
).first()
workflow_app_log_data = (
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
)
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
suggested_question_trace_info = SuggestedQuestionTraceInfo(
@@ -569,9 +580,9 @@ class TraceTask:
return dataset_retrieval_trace_info
def tool_trace(self, message_id, timer, **kwargs):
tool_name = kwargs.get('tool_name')
tool_inputs = kwargs.get('tool_inputs')
tool_outputs = kwargs.get('tool_outputs')
tool_name = kwargs.get("tool_name")
tool_inputs = kwargs.get("tool_inputs")
tool_outputs = kwargs.get("tool_outputs")
message_data = get_message_data(message_id)
if not message_data:
return {}
@@ -586,11 +597,11 @@ class TraceTask:
if tool_name in agent_thought.tools:
created_time = agent_thought.created_at
tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
tool_config = tool_meta_data.get('tool_config', {})
time_cost = tool_meta_data.get('time_cost', 0)
tool_config = tool_meta_data.get("tool_config", {})
time_cost = tool_meta_data.get("time_cost", 0)
end_time = created_time + timedelta(seconds=time_cost)
error = tool_meta_data.get('error', "")
tool_parameters = tool_meta_data.get('tool_parameters', {})
error = tool_meta_data.get("error", "")
tool_parameters = tool_meta_data.get("tool_parameters", {})
metadata = {
"message_id": message_id,
"tool_name": tool_name,
@@ -715,9 +726,7 @@ class TraceQueueManager:
def start_timer(self):
global trace_manager_timer
if trace_manager_timer is None or not trace_manager_timer.is_alive():
trace_manager_timer = threading.Timer(
trace_manager_interval, self.run
)
trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
trace_manager_timer.daemon = False
trace_manager_timer.start()