chart_generator.py 101 KB


  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. ##################################################################
  4. #
  5. # Copyright (c) 2023 CICV, Inc. All Rights Reserved
  6. #
  7. ##################################################################
  8. """
  9. @Authors: zhanghaiwen(zhanghaiwen@china-icv.cn)
  10. @Data: 2023/06/25
  11. @Last Modified: 2025/05/20
  12. @Summary: Chart generation utilities for metrics visualization
  13. """
  14. """
  15. 主要功能模块:
  16. 函数指标绘图(generate_function_chart_data)
  17. 舒适性指标绘图(generate_comfort_chart_data)
  18. 安全性指标绘图(generate_safety_chart_data)
  19. 交通指标绘图(generate_traffic_chart_data,待实现)
  20. 新增的急加速指标绘图:
  21. 在generate_comfort_chart_data中增加了slamaccelerate指标的支持
  22. 实现了完整的generate_slam_accelerate_chart函数,用于绘制急加速事件
  23. 该函数包含:
  24. 数据获取与预处理
  25. CSV数据保存与读取
  26. 纵向加速度与速度的双子图绘制
  27. 急加速事件的橙色背景标记
  28. 阈值线绘制
  29. 高质量图表输出(300dpi PNG)
  30. 其他关键特性:
  31. 统一的日志记录系统
  32. 阈值获取函数get_metric_thresholds
  33. 错误处理和异常捕获
  34. 时间戳管理
  35. 数据验证机制
  36. 详细的日志输出
  37. 辅助函数:
  38. calculate_distance 和 calculate_relative_speed(简化实现)
  39. scenario_sign_dict 场景签名字典(简化实现)
  40. 此代码实现了完整的指标可视化工具,特别针对急加速指标slamAccelerate提供了详细的绘图功能,能够清晰展示急加速事件的发生时间和相关数据变化。
  41. """
  42. import matplotlib
  43. matplotlib.use('Agg') # 使用非图形界面的后端
  44. import matplotlib.pyplot as plt
  45. import os
  46. import numpy as np
  47. import pandas as pd
  48. from typing import Optional, Dict, List, Any, Union
  49. from pathlib import Path
  50. from modules.lib.log_manager import LogManager
  51. def generate_function_chart_data(function_calculator, metric_name: str, output_dir: Optional[str] = None) -> Optional[
  52. str]:
  53. """
  54. Generate chart data for function metrics
  55. Args:
  56. function_calculator: FunctionCalculator instance
  57. metric_name: Metric name
  58. output_dir: Output directory
  59. Returns:
  60. str: Chart file path, or None if generation fails
  61. """
  62. logger = LogManager().get_logger()
  63. try:
  64. # 确保输出目录存在
  65. if output_dir:
  66. os.makedirs(output_dir, exist_ok=True)
  67. else:
  68. output_dir = os.path.join(os.getcwd(), 'data')
  69. # 根据指标名称选择不同的图表生成方法
  70. if metric_name.lower() == 'latestwarningdistance_ttc_lst':
  71. return generate_latest_warning_ttc_chart(function_calculator, output_dir)
  72. elif metric_name.lower() == 'earliestwarningdistance_ttc_lst':
  73. return generate_earliest_warning_distance_ttc_chart(function_calculator, output_dir)
  74. elif metric_name.lower() == 'earliestwarningdistance_lst':
  75. return generate_earliest_warning_distance_chart(function_calculator, output_dir)
  76. elif metric_name.lower() == 'latestwarningdistance_lst':
  77. return generate_latest_warning_distance_chart(function_calculator, output_dir)
  78. elif metric_name.lower() == 'latestwarningdistance_ttc_pgvil':
  79. return generate_latest_warning_ttc_pgvil_chart(function_calculator, output_dir)
  80. elif metric_name.lower() == 'earliestwarningdistance_ttc_pgvil':
  81. return generate_earliest_warning_distance_ttc_pgvil_chart(function_calculator, output_dir)
  82. elif metric_name.lower() == 'earliestwarningdistance_pgvil':
  83. return generate_earliest_warning_distance_pgvil_chart(function_calculator, output_dir)
  84. elif metric_name.lower() == 'latestwarningdistance_pgvil':
  85. return generate_latest_warning_distance_pgvil_chart(function_calculator, output_dir)
  86. elif metric_name.lower() == 'limitspeed_lst':
  87. return generate_limit_speed_chart(function_calculator, output_dir)
  88. elif metric_name.lower() == 'limitspeedpastlimitsign_lst':
  89. return generate_limit_speed_past_sign_chart(function_calculator, output_dir)
  90. elif metric_name.lower() == 'maxlongitudedist_lst':
  91. return generate_max_longitude_dist_chart(function_calculator, output_dir)
  92. else:
  93. logger.warning(f"Chart generation not implemented for metric [{metric_name}]")
  94. return None
  95. except Exception as e:
  96. logger.error(f"Failed to generate chart data: {str(e)}", exc_info=True)
  97. return None
  98. def generate_earliest_warning_distance_chart(function_calculator, output_dir: str) -> Optional[str]:
  99. """
  100. Generate warning distance chart with data visualization.
  101. This function creates charts for earliestWarningDistance_LST and latestWarningDistance_LST metrics.
  102. Args:
  103. function_calculator: FunctionCalculator instance
  104. output_dir: Output directory
  105. Returns:
  106. str: Chart file path, or None if generation fails
  107. """
  108. logger = LogManager().get_logger()
  109. try:
  110. # Get data
  111. ego_df = function_calculator.ego_data.copy()
  112. # Check if correctwarning is already calculated
  113. correctwarning = getattr(function_calculator, 'correctwarning', None)
  114. # Get configured thresholds
  115. thresholds = get_metric_thresholds(function_calculator, 'earliestWarningDistance_LST')
  116. max_threshold = thresholds["max"]
  117. min_threshold = thresholds["min"]
  118. # Get calculated warning distance and speed
  119. warning_dist = getattr(function_calculator, 'warning_dist', None)
  120. if warning_dist.empty:
  121. logger.warning(f"Cannot generate {"earliestWarningDistance_LST"} chart: empty data")
  122. return None
  123. # Calculate metric value
  124. metric_value = float(warning_dist.iloc[0]) if len(warning_dist) >= 0.0 else max_threshold
  125. # Save CSV data
  126. csv_filename = os.path.join(output_dir, f"earliestWarningDistance_LST_data.csv")
  127. df_csv = pd.DataFrame({
  128. 'simTime': ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]['simTime'],
  129. 'warning_distance': warning_dist,
  130. 'min_threshold': min_threshold,
  131. 'max_threshold': max_threshold,
  132. })
  133. df_csv.to_csv(csv_filename, index=False)
  134. logger.info(f"earliestWarningDistance_LST data saved to: {csv_filename}")
  135. # Read data from CSV
  136. df = pd.read_csv(csv_filename)
  137. # Create single chart for warning distance
  138. plt.figure(figsize=(12, 6), constrained_layout=True) # Adjusted height for single chart
  139. # Plot warning distance
  140. plt.plot(df['simTime'], df['warning_distance'], 'b-', label='Warning Distance')
  141. # Add threshold lines
  142. plt.axhline(y=max_threshold, color='r', linestyle='--', label=f'Max Threshold ({max_threshold}m)')
  143. plt.axhline(y=min_threshold, color='g', linestyle='--', label=f'Min Threshold ({min_threshold}m)')
  144. # Mark metric value
  145. if len(df) > 0:
  146. label_text = 'Earliest Warning Distance'
  147. plt.scatter(df['simTime'].iloc[0], df['warning_distance'].iloc[0],
  148. color='red', s=100, zorder=5,
  149. label=f'{label_text}: {metric_value:.2f}m')
  150. # Set y-axis range
  151. plt.ylim(bottom=-1, top=max(max_threshold * 1.1, df['warning_distance'].max() * 1.1))
  152. plt.xlabel('Time (s)')
  153. plt.ylabel('Distance (m)')
  154. plt.title(f'earliestWarningDistance_LST - Warning Distance Over Time')
  155. plt.grid(True)
  156. plt.legend()
  157. # Save image
  158. chart_filename = os.path.join(output_dir, f"earliestWarningDistance_LST_chart.png")
  159. plt.savefig(chart_filename, dpi=300)
  160. plt.close()
  161. logger.info(f"earliestWarningDistance_LST chart saved to: {chart_filename}")
  162. return chart_filename
  163. except Exception as e:
  164. logger.error(f"Failed to generate earliestWarningDistance_LST chart: {str(e)}", exc_info=True)
  165. return None
  166. def generate_earliest_warning_distance_pgvil_chart(function_calculator, output_dir: str) -> Optional[str]:
  167. """
  168. Generate warning distance chart with data visualization.
  169. This function creates charts for earliestWarningDistance_PGVIL and latestWarningDistance_PGVIL metrics.
  170. Args:
  171. function_calculator: FunctionCalculator instance
  172. output_dir: Output directory
  173. Returns:
  174. str: Chart file path, or None if generation fails
  175. """
  176. logger = LogManager().get_logger()
  177. try:
  178. # Get data
  179. ego_df = function_calculator.ego_data.copy()
  180. # Check if correctwarning is already calculated
  181. correctwarning = getattr(function_calculator, 'correctwarning', None)
  182. # Get configured thresholds
  183. thresholds = get_metric_thresholds(function_calculator, 'earliestWarningDistance_PGVIL')
  184. max_threshold = thresholds["max"]
  185. min_threshold = thresholds["min"]
  186. # Get calculated warning distance and speed
  187. warning_dist = getattr(function_calculator, 'warning_dist', None)
  188. warning_time = getattr(function_calculator, 'warning_time', None)
  189. if len(warning_dist) == 0:
  190. logger.warning(f"Cannot generate {"earliestWarningDistance_LST"} chart: empty data")
  191. return None
  192. # Calculate metric value
  193. metric_value = float(warning_dist[0]) if len(warning_dist) >= 0.0 else max_threshold
  194. # Save CSV data
  195. csv_filename = os.path.join(output_dir, f"earliestWarningDistance_PGVIL_data.csv")
  196. df_csv = pd.DataFrame({
  197. 'simTime': warning_time,
  198. 'warning_distance': warning_dist,
  199. 'min_threshold': min_threshold,
  200. 'max_threshold': max_threshold
  201. })
  202. df_csv.to_csv(csv_filename, index=False)
  203. logger.info(f"earliestWarningDistance_PGVIL data saved to: {csv_filename}")
  204. # Read data from CSV
  205. df = pd.read_csv(csv_filename)
  206. # Create single chart for warning distance
  207. plt.figure(figsize=(12, 6), constrained_layout=True) # Adjusted height for single chart
  208. # Plot warning distance
  209. plt.plot(df['simTime'], df['warning_distance'], 'b-', label='Warning Distance')
  210. # Add threshold lines
  211. plt.axhline(y=max_threshold, color='r', linestyle='--', label=f'Max Threshold ({max_threshold}m)')
  212. plt.axhline(y=min_threshold, color='g', linestyle='--', label=f'Min Threshold ({min_threshold}m)')
  213. # Mark metric value
  214. if len(df) > 0:
  215. label_text = 'Earliest Warning Distance'
  216. plt.scatter(df['simTime'].iloc[0], df['warning_distance'].iloc[0],
  217. color='red', s=100, zorder=5,
  218. label=f'{label_text}: {metric_value:.2f}m')
  219. # Set y-axis range
  220. plt.ylim(bottom=-1, top=max(max_threshold * 1.1, df['warning_distance'].max() * 1.1))
  221. plt.xlabel('Time (s)')
  222. plt.ylabel('Distance (m)')
  223. plt.title(f'earliestWarningDistance_PGVIL - Warning Distance Over Time')
  224. plt.grid(True)
  225. plt.legend()
  226. # Save image
  227. chart_filename = os.path.join(output_dir, f"earliestWarningDistance_PGVIL_chart.png")
  228. plt.savefig(chart_filename, dpi=300)
  229. plt.close()
  230. logger.info(f"earliestWarningDistance_PGVIL chart saved to: {chart_filename}")
  231. return chart_filename
  232. except Exception as e:
  233. logger.error(f"Failed to generate earliestWarningDistance_PGVIL chart: {str(e)}", exc_info=True)
  234. return None
  235. # # 使用function.py中已实现的find_nested_name函数
  236. # from modules.metric.function import find_nested_name
  237. def generate_latest_warning_ttc_pgvil_chart(function_calculator, output_dir: str) -> Optional[str]:
  238. """
  239. Generate TTC warning chart with data visualization.
  240. This version first saves data to CSV, then uses the CSV to generate the chart.
  241. Args:
  242. function_calculator: FunctionCalculator instance
  243. output_dir: Output directory
  244. Returns:
  245. str: Chart file path, or None if generation fails
  246. """
  247. logger = LogManager().get_logger()
  248. try:
  249. # 获取数据
  250. ego_df = function_calculator.ego_data.copy()
  251. correctwarning = getattr(function_calculator, 'correctwarning', None)
  252. # 获取配置的阈值
  253. thresholds = get_metric_thresholds(function_calculator, 'latestWarningDistance_TTC_PGVIL')
  254. max_threshold = thresholds["max"]
  255. min_threshold = thresholds["min"]
  256. warning_dist = getattr(function_calculator, 'warning_dist', None)
  257. warning_speed = getattr(function_calculator, 'warning_speed', None)
  258. warning_time = getattr(function_calculator, 'warning_time', None)
  259. ttc = getattr(function_calculator, 'ttc', None)
  260. if len(warning_dist) == 0:
  261. logger.warning("Cannot generate TTC warning chart: empty data")
  262. return None
  263. # 生成时间戳
  264. # 保存 CSV 数据
  265. csv_filename = os.path.join(output_dir, f"latestwarningdistance_ttc_pgvil_data.csv")
  266. df_csv = pd.DataFrame({
  267. 'simTime': warning_time,
  268. 'warning_distance': warning_dist,
  269. 'warning_speed': warning_speed,
  270. 'ttc': ttc,
  271. 'min_threshold': min_threshold,
  272. 'max_threshold': max_threshold,
  273. })
  274. df_csv.to_csv(csv_filename, index=False)
  275. logger.info(f"latestwarningdistance_ttc_pgvil data saved to: {csv_filename}")
  276. # 从 CSV 读取数据
  277. df = pd.read_csv(csv_filename)
  278. # 创建图表
  279. plt.figure(figsize=(12, 8), constrained_layout=True)
  280. # 图 1:预警距离
  281. ax1 = plt.subplot(3, 1, 1)
  282. ax1.plot(df['simTime'], df['warning_distance'], 'b-', label='Warning Distance')
  283. ax1.set_xlabel('Time (s)')
  284. ax1.set_ylabel('Distance (m)')
  285. ax1.set_title('Warning Distance Over Time')
  286. ax1.grid(True)
  287. ax1.legend()
  288. # 图 2:相对速度
  289. ax2 = plt.subplot(3, 1, 2)
  290. ax2.plot(df['simTime'], df['warning_speed'], 'g-', label='Relative Speed')
  291. ax2.set_xlabel('Time (s)')
  292. ax2.set_ylabel('Speed (m/s)')
  293. ax2.set_title('Relative Speed Over Time')
  294. ax2.grid(True)
  295. ax2.legend()
  296. # 图 3:TTC
  297. ax3 = plt.subplot(3, 1, 3)
  298. ax3.plot(df['simTime'], df['ttc'], 'r-', label='TTC')
  299. # Add threshold lines
  300. ax3.axhline(y=max_threshold, color='r', linestyle='--', label=f'Max Threshold ({max_threshold}s)')
  301. ax3.axhline(y=min_threshold, color='g', linestyle='--', label=f'Min Threshold ({min_threshold}s)')
  302. # Calculate metric value (latest TTC)
  303. metric_value = float(ttc[-1]) if len(ttc) > 0 else max_threshold
  304. # Mark latest TTC value
  305. if len(df) > 0:
  306. ax3.scatter(df['simTime'].iloc[-1], df['ttc'].iloc[-1],
  307. color='red', s=100, zorder=5,
  308. label=f'Latest TTC: {metric_value:.2f}s')
  309. ax3.set_xlabel('Time (s)')
  310. ax3.set_ylabel('TTC (s)')
  311. ax3.set_title('Time To Collision (TTC) Over Time')
  312. ax3.grid(True)
  313. ax3.legend()
  314. # 保存图像
  315. chart_filename = os.path.join(output_dir, f"latestwarningdistance_ttc_pgvil_chart.png")
  316. plt.savefig(chart_filename, dpi=300)
  317. plt.close()
  318. logger.info(f"latestwarningdistance_ttc_pgvil chart saved to: {chart_filename}")
  319. return chart_filename
  320. except Exception as e:
  321. logger.error(f"Failed to generate latestwarningdistance_ttc_pgvil chart: {str(e)}", exc_info=True)
  322. return None
  323. def generate_latest_warning_ttc_chart(function_calculator, output_dir: str) -> Optional[str]:
  324. """
  325. Generate TTC warning chart with data visualization.
  326. This version first saves data to CSV, then uses the CSV to generate the chart.
  327. Args:
  328. function_calculator: FunctionCalculator instance
  329. output_dir: Output directory
  330. Returns:
  331. str: Chart file path, or None if generation fails
  332. """
  333. logger = LogManager().get_logger()
  334. try:
  335. # 获取数据
  336. ego_df = function_calculator.ego_data.copy()
  337. correctwarning = getattr(function_calculator, 'correctwarning', None)
  338. # 获取配置的阈值
  339. thresholds = get_metric_thresholds(function_calculator, 'latestWarningDistance_TTC_LST')
  340. max_threshold = thresholds["max"]
  341. min_threshold = thresholds["min"]
  342. warning_dist = getattr(function_calculator, 'warning_dist', None)
  343. warning_speed = getattr(function_calculator, 'warning_speed', None)
  344. ttc = getattr(function_calculator, 'ttc', None)
  345. if warning_dist.empty:
  346. logger.warning("Cannot generate TTC warning chart: empty data")
  347. return None
  348. # 生成时间戳
  349. # 保存 CSV 数据
  350. csv_filename = os.path.join(output_dir, f"latestwarningdistance_ttc_lst_data.csv")
  351. df_csv = pd.DataFrame({
  352. 'simTime': ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]['simTime'],
  353. 'warning_distance': warning_dist,
  354. 'warning_speed': warning_speed,
  355. 'ttc': ttc,
  356. 'min_threshold': min_threshold,
  357. 'max_threshold': max_threshold,
  358. })
  359. df_csv.to_csv(csv_filename, index=False)
  360. logger.info(f"latestwarningdistance_ttc_lst data saved to: {csv_filename}")
  361. # 从 CSV 读取数据
  362. df = pd.read_csv(csv_filename)
  363. # 创建图表
  364. plt.figure(figsize=(12, 8), constrained_layout=True)
  365. # 图 1:预警距离
  366. ax1 = plt.subplot(3, 1, 1)
  367. ax1.plot(df['simTime'], df['warning_distance'], 'b-', label='Warning Distance')
  368. ax1.set_xlabel('Time (s)')
  369. ax1.set_ylabel('Distance (m)')
  370. ax1.set_title('Warning Distance Over Time')
  371. ax1.grid(True)
  372. ax1.legend()
  373. # 图 2:相对速度
  374. ax2 = plt.subplot(3, 1, 2)
  375. ax2.plot(df['simTime'], df['warning_speed'], 'g-', label='Relative Speed')
  376. ax2.set_xlabel('Time (s)')
  377. ax2.set_ylabel('Speed (m/s)')
  378. ax2.set_title('Relative Speed Over Time')
  379. ax2.grid(True)
  380. ax2.legend()
  381. # 图 3:TTC
  382. ax3 = plt.subplot(3, 1, 3)
  383. ax3.plot(df['simTime'], df['ttc'], 'r-', label='TTC')
  384. # Add threshold lines
  385. ax3.axhline(y=max_threshold, color='r', linestyle='--', label=f'Max Threshold ({max_threshold}s)')
  386. ax3.axhline(y=min_threshold, color='g', linestyle='--', label=f'Min Threshold ({min_threshold}s)')
  387. # Calculate metric value (latest TTC)
  388. metric_value = float(ttc[-1]) if len(ttc) > 0 else max_threshold
  389. # Mark latest TTC value
  390. if len(df) > 0:
  391. ax3.scatter(df['simTime'].iloc[-1], df['ttc'].iloc[-1],
  392. color='red', s=100, zorder=5,
  393. label=f'Latest TTC: {metric_value:.2f}s')
  394. ax3.set_xlabel('Time (s)')
  395. ax3.set_ylabel('TTC (s)')
  396. ax3.set_title('Time To Collision (TTC) Over Time')
  397. ax3.grid(True)
  398. ax3.legend()
  399. # 保存图像
  400. chart_filename = os.path.join(output_dir, f"latestwarningdistance_ttc_lst_chart.png")
  401. plt.savefig(chart_filename, dpi=300)
  402. plt.close()
  403. logger.info(f"latestwarningdistance_ttc_lst chart saved to: {chart_filename}")
  404. return chart_filename
  405. except Exception as e:
  406. logger.error(f"Failed to generate latestwarningdistance_ttc_lst chart: {str(e)}", exc_info=True)
  407. return None
  408. def generate_latest_warning_distance_chart(function_calculator, output_dir: str) -> Optional[str]:
  409. """
  410. Generate warning distance chart with data visualization.
  411. This function creates charts for latestWarningDistance_LST metric.
  412. Args:
  413. function_calculator: FunctionCalculator instance
  414. metric_name: Metric name (latestWarningDistance_LST)
  415. output_dir: Output directory
  416. Returns:
  417. str: Chart file path, or None if generation fails
  418. """
  419. logger = LogManager().get_logger()
  420. try:
  421. # Get data
  422. ego_df = function_calculator.ego_data.copy()
  423. # Check if correctwarning is already calculated
  424. correctwarning = getattr(function_calculator, 'correctwarning', None)
  425. # Get configured thresholds
  426. thresholds = get_metric_thresholds(function_calculator, 'latestWarningDistance_LST')
  427. max_threshold = thresholds["max"]
  428. min_threshold = thresholds["min"]
  429. # Get calculated warning distance and speed
  430. warning_dist = getattr(function_calculator, 'warning_dist', None)
  431. if warning_dist.empty:
  432. logger.warning(f"Cannot generate latestWarningDistance_LST chart: empty data")
  433. return None
  434. # Calculate metric value
  435. metric_value = float(warning_dist.iloc[-1]) if len(warning_dist) > 0 else max_threshold
  436. # Save CSV data
  437. csv_filename = os.path.join(output_dir, f"latestWarningDistance_LST_data.csv")
  438. df_csv = pd.DataFrame({
  439. 'simTime': ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]['simTime'],
  440. 'warning_distance': warning_dist,
  441. 'min_threshold': min_threshold,
  442. 'max_threshold': max_threshold
  443. })
  444. df_csv.to_csv(csv_filename, index=False)
  445. logger.info(f"latestWarningDistance_LST data saved to: {csv_filename}")
  446. # Read data from CSV
  447. df = pd.read_csv(csv_filename)
  448. # Create single chart for warning distance
  449. plt.figure(figsize=(12, 6), constrained_layout=True) # Adjusted height for single chart
  450. # Plot warning distance
  451. plt.plot(df['simTime'], df['warning_distance'], 'b-', label='Warning Distance')
  452. # Add threshold lines
  453. plt.axhline(y=max_threshold, color='r', linestyle='--', label=f'Max Threshold ({max_threshold}m)')
  454. plt.axhline(y=min_threshold, color='g', linestyle='--', label=f'Min Threshold ({min_threshold}m)')
  455. # Mark metric value
  456. if len(df) > 0:
  457. label_text = 'Latest Warning Distance'
  458. plt.scatter(df['simTime'].iloc[-1], df['warning_distance'].iloc[-1],
  459. color='red', s=100, zorder=5,
  460. label=f'{label_text}: {metric_value:.2f}m')
  461. # Set y-axis range
  462. plt.ylim(bottom=-1, top=max(max_threshold * 1.1, df['warning_distance'].max() * 1.1))
  463. plt.xlabel('Time (s)')
  464. plt.ylabel('Distance (m)')
  465. plt.title(f'latestWarningDistance_LST - Warning Distance Over Time')
  466. plt.grid(True)
  467. plt.legend()
  468. # Save image
  469. chart_filename = os.path.join(output_dir, f"latestWarningDistance_LST_chart.png")
  470. plt.savefig(chart_filename, dpi=300)
  471. plt.close()
  472. logger.info(f"latestWarningDistance_LST chart saved to: {chart_filename}")
  473. return chart_filename
  474. except Exception as e:
  475. logger.error(f"Failed to generate latestWarningDistance_LST chart: {str(e)}", exc_info=True)
  476. return None
  477. def generate_latest_warning_distance_pgvil_chart(function_calculator, output_dir: str) -> Optional[str]:
  478. """
  479. Generate warning distance chart with data visualization.
  480. This function creates charts for latestWarningDistance_LST metric.
  481. Args:
  482. function_calculator: FunctionCalculator instance
  483. metric_name: Metric name (latestWarningDistance_LST)
  484. output_dir: Output directory
  485. Returns:
  486. str: Chart file path, or None if generation fails
  487. """
  488. logger = LogManager().get_logger()
  489. try:
  490. # Get data
  491. ego_df = function_calculator.ego_data.copy()
  492. # Check if correctwarning is already calculated
  493. correctwarning = getattr(function_calculator, 'correctwarning', None)
  494. # Get configured thresholds
  495. thresholds = get_metric_thresholds(function_calculator, 'latestWarningDistance_PGVIL')
  496. max_threshold = thresholds["max"]
  497. min_threshold = thresholds["min"]
  498. # Get calculated warning distance and speed
  499. warning_dist = getattr(function_calculator, 'warning_dist', None)
  500. warning_time = getattr(function_calculator, 'warning_time', None)
  501. if len(warning_dist) == 0:
  502. logger.warning(f"Cannot generate latestWarningDistance_PGVIL chart: empty data")
  503. return None
  504. # Calculate metric value
  505. metric_value = float(warning_dist[-1]) if len(warning_dist) > 0 else max_threshold
  506. # Save CSV data
  507. csv_filename = os.path.join(output_dir, f"latestWarningDistance_PGVIL_data.csv")
  508. df_csv = pd.DataFrame({
  509. 'simTime': warning_time,
  510. 'warning_distance': warning_dist,
  511. 'min_threshold': min_threshold,
  512. 'max_threshold': max_threshold
  513. })
  514. df_csv.to_csv(csv_filename, index=False)
  515. logger.info(f"latestWarningDistance_PGVIL data saved to: {csv_filename}")
  516. # Read data from CSV
  517. df = pd.read_csv(csv_filename)
  518. # Create single chart for warning distance
  519. plt.figure(figsize=(12, 6), constrained_layout=True) # Adjusted height for single chart
  520. # Plot warning distance
  521. plt.plot(df['simTime'], df['warning_distance'], 'b-', label='Warning Distance')
  522. # Add threshold lines
  523. plt.axhline(y=max_threshold, color='r', linestyle='--', label=f'Max Threshold ({max_threshold}m)')
  524. plt.axhline(y=min_threshold, color='g', linestyle='--', label=f'Min Threshold ({min_threshold}m)')
  525. # Mark metric value
  526. if len(df) > 0:
  527. label_text = 'Latest Warning Distance'
  528. plt.scatter(df['simTime'].iloc[-1], df['warning_distance'].iloc[-1],
  529. color='red', s=100, zorder=5,
  530. label=f'{label_text}: {metric_value:.2f}m')
  531. # Set y-axis range
  532. plt.ylim(bottom=-1, top=max(max_threshold * 1.1, df['warning_distance'].max() * 1.1))
  533. plt.xlabel('Time (s)')
  534. plt.ylabel('Distance (m)')
  535. plt.title(f'latestWarningDistance_PGVIL - Warning Distance Over Time')
  536. plt.grid(True)
  537. plt.legend()
  538. # Save image
  539. chart_filename = os.path.join(output_dir, f"latestWarningDistance_PGVIL_chart.png")
  540. plt.savefig(chart_filename, dpi=300)
  541. plt.close()
  542. logger.info(f"latestWarningDistance_PGVIL chart saved to: {chart_filename}")
  543. return chart_filename
  544. except Exception as e:
  545. logger.error(f"Failed to generate latestWarningDistance_PGVIL chart: {str(e)}", exc_info=True)
  546. return None
  547. def generate_earliest_warning_distance_ttc_chart(function_calculator, output_dir: str) -> Optional[str]:
  548. """
  549. Generate TTC warning chart with data visualization for earliestWarningDistance_TTC_LST metric.
  550. Args:
  551. function_calculator: FunctionCalculator instance
  552. output_dir: Output directory
  553. Returns:
  554. str: Chart file path, or None if generation fails
  555. """
  556. logger = LogManager().get_logger()
  557. metric_name = 'earliestWarningDistance_TTC_LST'
  558. try:
  559. # Get data
  560. ego_df = function_calculator.ego_data.copy()
  561. # Check if correctwarning is already calculated
  562. correctwarning = getattr(function_calculator, 'correctwarning', None)
  563. # Get configured thresholds
  564. thresholds = get_metric_thresholds(function_calculator, metric_name)
  565. max_threshold = thresholds["max"]
  566. min_threshold = thresholds["min"]
  567. # Get calculated warning distance and speed
  568. warning_dist = getattr(function_calculator, 'correctwarning', None)
  569. warning_speed = getattr(function_calculator, 'warning_speed', None)
  570. ttc = getattr(function_calculator, 'ttc', None)
  571. # Calculate metric value
  572. metric_value = float(ttc[0]) if len(ttc) > 0 else max_threshold
  573. # Save CSV data
  574. csv_filename = os.path.join(output_dir, f"{metric_name.lower()}_data.csv")
  575. df_csv = pd.DataFrame({
  576. 'simTime': ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]['simTime'],
  577. 'warning_distance': warning_dist,
  578. 'warning_speed': warning_speed,
  579. 'ttc': ttc,
  580. 'min_threshold': min_threshold,
  581. 'max_threshold': max_threshold
  582. })
  583. df_csv.to_csv(csv_filename, index=False)
  584. logger.info(f"{metric_name} data saved to: {csv_filename}")
  585. # Read data from CSV
  586. df = pd.read_csv(csv_filename)
  587. # Create chart
  588. plt.figure(figsize=(12, 8), constrained_layout=True)
  589. # 图 1:预警距离
  590. ax1 = plt.subplot(3, 1, 1)
  591. ax1.plot(df['simTime'], df['warning_distance'], 'b-', label='Warning Distance')
  592. ax1.set_xlabel('Time (s)')
  593. ax1.set_ylabel('Distance (m)')
  594. ax1.set_title('Warning Distance Over Time')
  595. ax1.grid(True)
  596. ax1.legend()
  597. # 图 2:相对速度
  598. ax2 = plt.subplot(3, 1, 2)
  599. ax2.plot(df['simTime'], df['warning_speed'], 'g-', label='Relative Speed')
  600. ax2.set_xlabel('Time (s)')
  601. ax2.set_ylabel('Speed (m/s)')
  602. ax2.set_title('Relative Speed Over Time')
  603. ax2.grid(True)
  604. ax2.legend()
  605. # 图 3:TTC
  606. ax3 = plt.subplot(3, 1, 3)
  607. ax3.plot(df['simTime'], df['ttc'], 'r-', label='TTC')
  608. # Add threshold lines
  609. ax3.axhline(y=max_threshold, color='r', linestyle='--', label=f'Max Threshold ({max_threshold}s)')
  610. ax3.axhline(y=min_threshold, color='g', linestyle='--', label=f'Min Threshold ({min_threshold}s)')
  611. # Mark earliest TTC value
  612. if len(df) > 0:
  613. ax3.scatter(df['simTime'].iloc[0], df['ttc'].iloc[0],
  614. color='red', s=100, zorder=5,
  615. label=f'Earliest TTC: {metric_value:.2f}s')
  616. ax3.set_xlabel('Time (s)')
  617. ax3.set_ylabel('TTC (s)')
  618. ax3.set_title('Time To Collision (TTC) Over Time')
  619. ax3.grid(True)
  620. ax3.legend()
  621. # Save image
  622. chart_filename = os.path.join(output_dir, f"earliestwarningdistance_ttc_lst_chart.png")
  623. plt.savefig(chart_filename, dpi=300)
  624. plt.close()
  625. logger.info(f"{metric_name} chart saved to: {chart_filename}")
  626. return chart_filename
  627. except Exception as e:
  628. logger.error(f"Failed to generate earliestwarningdistance_ttc_lst chart: {str(e)}", exc_info=True)
  629. return None
  630. def generate_earliest_warning_distance_ttc_pgvil_chart(function_calculator, output_dir: str) -> Optional[str]:
  631. """
  632. Generate TTC warning chart with data visualization for earliestWarningDistance_TTC_PGVIL metric.
  633. Args:
  634. function_calculator: FunctionCalculator instance
  635. output_dir: Output directory
  636. Returns:
  637. str: Chart file path, or None if generation fails
  638. """
  639. logger = LogManager().get_logger()
  640. metric_name = 'earliestWarningDistance_TTC_PGVIL'
  641. try:
  642. # Get data
  643. ego_df = function_calculator.ego_data.copy()
  644. # Check if correctwarning is already calculated
  645. correctwarning = getattr(function_calculator, 'correctwarning', None)
  646. # Get configured thresholds
  647. thresholds = get_metric_thresholds(function_calculator, metric_name)
  648. max_threshold = thresholds["max"]
  649. min_threshold = thresholds["min"]
  650. # Get calculated warning distance and speed
  651. warning_dist = getattr(function_calculator, 'warning_dist', None)
  652. warning_speed = getattr(function_calculator, 'warning_speed', None)
  653. ttc = getattr(function_calculator, 'ttc', None)
  654. warning_time = getattr(function_calculator, 'warning_time', None)
  655. # Calculate metric value
  656. metric_value = float(ttc[0]) if len(ttc) > 0 else max_threshold
  657. # Save CSV data
  658. csv_filename = os.path.join(output_dir, f"{metric_name.lower()}_data.csv")
  659. df_csv = pd.DataFrame({
  660. 'simTime': warning_time,
  661. 'warning_distance': warning_dist,
  662. 'warning_speed': warning_speed,
  663. 'ttc': ttc,
  664. 'min_threshold': min_threshold,
  665. 'max_threshold': max_threshold
  666. })
  667. df_csv.to_csv(csv_filename, index=False)
  668. logger.info(f"{metric_name} data saved to: {csv_filename}")
  669. # Read data from CSV
  670. df = pd.read_csv(csv_filename)
  671. # Create chart
  672. plt.figure(figsize=(12, 8), constrained_layout=True)
  673. # 图 1:预警距离
  674. ax1 = plt.subplot(3, 1, 1)
  675. ax1.plot(df['simTime'], df['warning_distance'], 'b-', label='Warning Distance')
  676. ax1.set_xlabel('Time (s)')
  677. ax1.set_ylabel('Distance (m)')
  678. ax1.set_title('Warning Distance Over Time')
  679. ax1.grid(True)
  680. ax1.legend()
  681. # 图 2:相对速度
  682. ax2 = plt.subplot(3, 1, 2)
  683. ax2.plot(df['simTime'], df['warning_speed'], 'g-', label='Relative Speed')
  684. ax2.set_xlabel('Time (s)')
  685. ax2.set_ylabel('Speed (m/s)')
  686. ax2.set_title('Relative Speed Over Time')
  687. ax2.grid(True)
  688. ax2.legend()
  689. # 图 3:TTC
  690. ax3 = plt.subplot(3, 1, 3)
  691. ax3.plot(df['simTime'], df['ttc'], 'r-', label='TTC')
  692. # Add threshold lines
  693. ax3.axhline(y=max_threshold, color='r', linestyle='--', label=f'Max Threshold ({max_threshold}s)')
  694. ax3.axhline(y=min_threshold, color='g', linestyle='--', label=f'Min Threshold ({min_threshold}s)')
  695. # Mark earliest TTC value
  696. if len(df) > 0:
  697. ax3.scatter(df['simTime'].iloc[0], df['ttc'].iloc[0],
  698. color='red', s=100, zorder=5,
  699. label=f'Earliest TTC: {metric_value:.2f}s')
  700. ax3.set_xlabel('Time (s)')
  701. ax3.set_ylabel('TTC (s)')
  702. ax3.set_title('Time To Collision (TTC) Over Time')
  703. ax3.grid(True)
  704. ax3.legend()
  705. # Save image
  706. chart_filename = os.path.join(output_dir, f"earliestwarningdistance_ttc_pgvil_chart.png")
  707. plt.savefig(chart_filename, dpi=300)
  708. plt.close()
  709. logger.info(f"{metric_name} chart saved to: {chart_filename}")
  710. return chart_filename
  711. except Exception as e:
  712. logger.error(f"Failed to generate earliestwarningdistance_ttc_pgvil chart: {str(e)}", exc_info=True)
  713. return None
  714. def generate_limit_speed_chart(function_calculator, output_dir: str) -> Optional[str]:
  715. """
  716. Generate limit speed chart with data visualization for limitSpeed_LST metric.
  717. Args:
  718. function_calculator: FunctionCalculator instance
  719. output_dir: Output directory
  720. Returns:
  721. str: Chart file path, or None if generation fails
  722. """
  723. logger = LogManager().get_logger()
  724. metric_name = 'limitSpeed_LST'
  725. try:
  726. # Get data
  727. ego_df = function_calculator.ego_data.copy()
  728. # Get configured thresholds
  729. thresholds = get_metric_thresholds(function_calculator, metric_name)
  730. max_threshold = thresholds["max"]
  731. min_threshold = thresholds["min"]
  732. if ego_df.empty:
  733. logger.warning(f"Cannot generate {metric_name} chart: empty data")
  734. return None
  735. # Save CSV data
  736. csv_filename = os.path.join(output_dir, f"{metric_name.lower()}_data.csv")
  737. df_csv = pd.DataFrame({
  738. 'simTime': ego_df['simTime'],
  739. 'speed': ego_df['v'],
  740. 'speed_limit': ego_df.get('speed_limit', pd.Series([max_threshold] * len(ego_df)))
  741. })
  742. df_csv.to_csv(csv_filename, index=False)
  743. logger.info(f"{metric_name} data saved to: {csv_filename}")
  744. # Read data from CSV
  745. df = pd.read_csv(csv_filename)
  746. # Create chart
  747. plt.figure(figsize=(12, 6), constrained_layout=True)
  748. # Plot speed
  749. plt.plot(df['simTime'], df['speed'], 'b-', label='Vehicle Speed')
  750. plt.plot(df['simTime'], df['speed_limit'], 'r--', label='Speed Limit')
  751. # Set y-axis range
  752. plt.ylim(bottom=0, top=max(max_threshold * 1.1, df['speed'].max() * 1.1))
  753. plt.xlabel('Time (s)')
  754. plt.ylabel('Speed (m/s)')
  755. plt.title(f'{metric_name} - Vehicle Speed vs Speed Limit')
  756. plt.grid(True)
  757. plt.legend()
  758. # Save image
  759. chart_filename = os.path.join(output_dir, f"{metric_name.lower()}_chart.png")
  760. plt.savefig(chart_filename, dpi=300)
  761. plt.close()
  762. logger.info(f"{metric_name} chart saved to: {chart_filename}")
  763. return chart_filename
  764. except Exception as e:
  765. logger.error(f"Failed to generate {metric_name} chart: {str(e)}", exc_info=True)
  766. return None
  767. def generate_limit_speed_past_sign_chart(function_calculator, output_dir: str) -> Optional[str]:
  768. """
  769. Generate limit speed past sign chart with data visualization for limitSpeedPastLimitSign_LST metric.
  770. Args:
  771. function_calculator: FunctionCalculator instance
  772. output_dir: Output directory
  773. Returns:
  774. str: Chart file path, or None if generation fails
  775. """
  776. logger = LogManager().get_logger()
  777. metric_name = 'limitSpeedPastLimitSign_LST'
  778. try:
  779. # Get data
  780. ego_df = function_calculator.ego_data.copy()
  781. # Get configured thresholds
  782. thresholds = get_metric_thresholds(function_calculator, metric_name)
  783. max_threshold = thresholds["max"]
  784. min_threshold = thresholds["min"]
  785. if ego_df.empty:
  786. logger.warning(f"Cannot generate {metric_name} chart: empty data")
  787. return None
  788. # Get sign passing time if available
  789. sign_time = getattr(function_calculator, 'sign_pass_time', None)
  790. if sign_time is None:
  791. # Try to estimate sign passing time (middle of the simulation)
  792. sign_time = ego_df['simTime'].iloc[len(ego_df) // 2]
  793. # Save CSV data
  794. csv_filename = os.path.join(output_dir, f"{metric_name.lower()}_data.csv")
  795. df_csv = pd.DataFrame({
  796. 'simTime': ego_df['simTime'],
  797. 'speed': ego_df['v'],
  798. 'speed_limit': ego_df.get('speed_limit', pd.Series([max_threshold] * len(ego_df))),
  799. 'sign_pass_time': sign_time
  800. })
  801. df_csv.to_csv(csv_filename, index=False)
  802. logger.info(f"{metric_name} data saved to: {csv_filename}")
  803. # Read data from CSV
  804. df = pd.read_csv(csv_filename)
  805. # Create chart
  806. plt.figure(figsize=(12, 6), constrained_layout=True)
  807. # Plot speed
  808. plt.plot(df['simTime'], df['speed'], 'b-', label='Vehicle Speed')
  809. plt.plot(df['simTime'], df['speed_limit'], 'r--', label='Speed Limit')
  810. # Mark sign passing time
  811. plt.axvline(x=sign_time, color='g', linestyle='--', label='Speed Limit Sign')
  812. # Set y-axis range
  813. plt.ylim(bottom=0, top=max(max_threshold * 1.1, df['speed'].max() * 1.1))
  814. plt.xlabel('Time (s)')
  815. plt.ylabel('Speed (m/s)')
  816. plt.title(f'{metric_name} - Vehicle Speed vs Speed Limit')
  817. plt.grid(True)
  818. plt.legend()
  819. # Save image
  820. chart_filename = os.path.join(output_dir, f"{metric_name.lower()}_chart.png")
  821. plt.savefig(chart_filename, dpi=300)
  822. plt.close()
  823. logger.info(f"{metric_name} chart saved to: {chart_filename}")
  824. return chart_filename
  825. except Exception as e:
  826. logger.error(f"Failed to generate {metric_name} chart: {str(e)}", exc_info=True)
  827. return None
  828. def generate_max_longitude_dist_chart(function_calculator, output_dir: str) -> Optional[str]:
  829. """
  830. Generate maximum longitudinal distance chart with data visualization for maxLongitudeDist_LST metric.
  831. Args:
  832. function_calculator: FunctionCalculator instance
  833. output_dir: Output directory
  834. Returns:
  835. str: Chart file path, or None if generation fails
  836. """
  837. logger = LogManager().get_logger()
  838. metric_name = 'maxLongitudeDist_LST'
  839. try:
  840. # Get data
  841. ego_df = function_calculator.ego_data.copy()
  842. # Get configured thresholds
  843. thresholds = get_metric_thresholds(function_calculator, metric_name)
  844. max_threshold = thresholds["max"]
  845. min_threshold = thresholds["min"]
  846. # Get longitudinal distance data
  847. longitude_dist = ego_df['longitude_dist'] if 'longitude_dist' in ego_df.columns else None
  848. stop_time = ego_df['stop_time'] if 'stop_time' in ego_df.columns else None
  849. if longitude_dist is None or longitude_dist.empty:
  850. logger.warning(f"Cannot generate {metric_name} chart: missing longitudinal distance data")
  851. return None
  852. # Calculate metric value
  853. metric_value = longitude_dist.max()
  854. max_distance_time = ego_df.loc[longitude_dist.idxmax(), 'simTime']
  855. # Save CSV data
  856. csv_filename = os.path.join(output_dir, f"{metric_name.lower()}_data.csv")
  857. df_csv = pd.DataFrame({
  858. 'simTime': ego_df['simTime'],
  859. 'x_relative_dist': ego_df['x_relative_dist'],
  860. 'stop_time': stop_time,
  861. 'longitude_dist': longitude_dist
  862. })
  863. df_csv.to_csv(csv_filename, index=False)
  864. logger.info(f"{metric_name} data saved to: {csv_filename}")
  865. # Read data from CSV
  866. df = pd.read_csv(csv_filename)
  867. # Create chart
  868. plt.figure(figsize=(12, 6), constrained_layout=True)
  869. # Plot longitudinal distance
  870. plt.plot(df['simTime'], df['x_relative_dist'], 'b-', label='Longitudinal Distance')
  871. # Add threshold lines
  872. plt.axhline(y=max_threshold, color='r', linestyle='--', label=f'Max Threshold ({max_threshold}m)')
  873. plt.axhline(y=min_threshold, color='g', linestyle='--', label=f'Min Threshold ({min_threshold}m)')
  874. # Mark maximum longitudinal distance
  875. plt.scatter(max_distance_time, metric_value,
  876. color='red', s=100, zorder=5,
  877. label=f'Maximum Longitudinal Distance: {metric_value:.2f}m')
  878. # Set y-axis range
  879. plt.ylim(bottom=min(0, min_threshold * 0.9), top=max(max_threshold * 1.1, df['longitude_dist'].max() * 1.1))
  880. plt.xlabel('Time (s)')
  881. plt.ylabel('Longitudinal Distance (m)')
  882. plt.title(f'{metric_name} - Longitudinal Distance Over Time')
  883. plt.grid(True)
  884. plt.legend()
  885. # Save image
  886. chart_filename = os.path.join(output_dir, f"{metric_name.lower()}_chart.png")
  887. plt.savefig(chart_filename, dpi=300)
  888. plt.close()
  889. logger.info(f"{metric_name} chart saved to: {chart_filename}")
  890. return chart_filename
  891. except Exception as e:
  892. logger.error(f"Failed to generate {metric_name} chart: {str(e)}", exc_info=True)
  893. return None
  894. def generate_warning_delay_time_chart(function_calculator, output_dir: str) -> Optional[str]:
  895. """
  896. Generate warning delay time chart with data visualization for warningDelayTime_LST metric.
  897. Args:
  898. function_calculator: FunctionCalculator instance
  899. output_dir: Output directory
  900. Returns:
  901. str: Chart file path, or None if generation fails
  902. """
  903. logger = LogManager().get_logger()
  904. metric_name = 'warningDelayTime_LST'
  905. try:
  906. # Get data
  907. ego_df = function_calculator.ego_data.copy()
  908. # Get configured thresholds
  909. thresholds = get_metric_thresholds(function_calculator, metric_name)
  910. max_threshold = thresholds["max"]
  911. min_threshold = thresholds["min"]
  912. # Check if correctwarning is already calculated
  913. correctwarning = getattr(function_calculator, 'correctwarning', None)
  914. if correctwarning is None:
  915. logger.warning(f"Cannot generate {metric_name} chart: missing correctwarning value")
  916. return None
  917. # Get HMI warning time and rosbag warning time
  918. HMI_warning_rows = ego_df[(ego_df['ifwarning'] == correctwarning)]['simTime'].tolist()
  919. simTime_HMI = HMI_warning_rows[0] if len(HMI_warning_rows) > 0 else None
  920. rosbag_warning_rows = ego_df[(ego_df['event_Type'].notna()) & ((ego_df['event_Type'] != np.nan))][
  921. 'simTime'].tolist()
  922. simTime_rosbag = rosbag_warning_rows[0] if len(rosbag_warning_rows) > 0 else None
  923. if (simTime_HMI is None) or (simTime_rosbag is None):
  924. logger.warning(f"Cannot generate {metric_name} chart: missing warning time data")
  925. return None
  926. # Calculate delay time
  927. delay_time = abs(simTime_HMI - simTime_rosbag)
  928. # Save CSV data
  929. csv_filename = os.path.join(output_dir, f"{metric_name.lower()}_data.csv")
  930. df_csv = pd.DataFrame({
  931. 'HMI_warning_time': [simTime_HMI],
  932. 'rosbag_warning_time': [simTime_rosbag],
  933. 'delay_time': [delay_time],
  934. 'min_threshold': [min_threshold],
  935. 'max_threshold': [max_threshold]
  936. })
  937. df_csv.to_csv(csv_filename, index=False)
  938. logger.info(f"{metric_name} data saved to: {csv_filename}")
  939. # Create chart - bar chart for delay time
  940. plt.figure(figsize=(10, 6), constrained_layout=True)
  941. # Plot delay time as bar
  942. plt.bar(['Warning Delay Time'], [delay_time], color='blue', width=0.4)
  943. # Add threshold lines
  944. plt.axhline(y=max_threshold, color='r', linestyle='--', label=f'Max Threshold ({max_threshold}s)')
  945. plt.axhline(y=min_threshold, color='g', linestyle='--', label=f'Min Threshold ({min_threshold}s)')
  946. # Add value label
  947. plt.text(0, delay_time + 0.05, f'{delay_time:.3f}s', ha='center', va='bottom', fontweight='bold')
  948. # Set y-axis range
  949. plt.ylim(bottom=0, top=max(max_threshold * 1.2, delay_time * 1.2))
  950. plt.ylabel('Delay Time (s)')
  951. plt.title(f'{metric_name} - Warning Delay Time')
  952. plt.grid(True, axis='y')
  953. plt.legend()
  954. # Save image
  955. chart_filename = os.path.join(output_dir, f"{metric_name.lower()}_chart.png")
  956. plt.savefig(chart_filename, dpi=300)
  957. plt.close()
  958. logger.info(f"{metric_name} chart saved to: {chart_filename}")
  959. return chart_filename
  960. except Exception as e:
  961. logger.error(f"Failed to generate {metric_name} chart: {str(e)}", exc_info=True)
  962. return None
  963. def generate_comfort_chart_data(comfort_calculator, metric_name: str, output_dir: Optional[str] = None) -> Optional[
  964. str]:
  965. """
  966. Generate chart data for comfort metrics
  967. Args:
  968. comfort_calculator: ComfortCalculator instance
  969. metric_name: Metric name
  970. output_dir: Output directory
  971. Returns:
  972. str: Chart file path, or None if generation fails
  973. """
  974. logger = LogManager().get_logger()
  975. try:
  976. # 确保输出目录存在
  977. if output_dir:
  978. os.makedirs(output_dir, exist_ok=True)
  979. else:
  980. output_dir = os.getcwd()
  981. # 根据指标名称选择不同的图表生成方法
  982. if metric_name.lower() == 'shake':
  983. return generate_shake_chart(comfort_calculator, output_dir)
  984. elif metric_name.lower() == 'zigzag':
  985. return generate_zigzag_chart(comfort_calculator, output_dir)
  986. elif metric_name.lower() == 'cadence':
  987. return generate_cadence_chart(comfort_calculator, output_dir)
  988. elif metric_name.lower() == 'slambrake':
  989. return generate_slam_brake_chart(comfort_calculator, output_dir)
  990. elif metric_name.lower() == 'slamaccelerate':
  991. return generate_slam_accelerate_chart(comfort_calculator, output_dir)
  992. else:
  993. logger.warning(f"Chart generation not implemented for metric [{metric_name}]")
  994. return None
  995. except Exception as e:
  996. logger.error(f"Failed to generate chart data: {str(e)}", exc_info=True)
  997. return None
  998. def generate_shake_chart(comfort_calculator, output_dir: str) -> Optional[str]:
  999. """
  1000. Generate shake metric chart with orange background for shake events.
  1001. This version first saves data to CSV, then uses the CSV to generate the chart.
  1002. Args:
  1003. comfort_calculator: ComfortCalculator instance
  1004. output_dir: Output directory
  1005. Returns:
  1006. str: Chart file path, or None if generation fails
  1007. """
  1008. logger = LogManager().get_logger()
  1009. try:
  1010. # 获取数据
  1011. df = comfort_calculator.ego_df.copy()
  1012. shake_events = comfort_calculator.shake_events
  1013. if df.empty:
  1014. logger.warning("Cannot generate shake chart: empty data")
  1015. return None
  1016. # 生成时间戳
  1017. # 保存 CSV 数据(第一步)
  1018. csv_filename = os.path.join(output_dir, f"shake_data.csv")
  1019. df_csv = pd.DataFrame({
  1020. 'simTime': df['simTime'],
  1021. 'lat_acc': df['lat_acc'],
  1022. 'lat_acc_rate': df['lat_acc_rate'],
  1023. 'speedH_std': df['speedH_std'],
  1024. 'lat_acc_threshold': df.get('lat_acc_threshold', pd.Series([None] * len(df))),
  1025. 'lat_acc_rate_threshold': 0.5,
  1026. 'speedH_std_threshold': df.get('speedH_threshold', pd.Series([None] * len(df))),
  1027. })
  1028. df_csv.to_csv(csv_filename, index=False)
  1029. logger.info(f"Shake data saved to: {csv_filename}")
  1030. # 第二步:从 CSV 读取(可验证保存数据无误)
  1031. df = pd.read_csv(csv_filename)
  1032. # 创建图表(第三步)
  1033. import matplotlib.pyplot as plt
  1034. plt.figure(figsize=(12, 8), constrained_layout=True)
  1035. # 图 1:横向加速度
  1036. ax1 = plt.subplot(3, 1, 1)
  1037. ax1.plot(df['simTime'], df['lat_acc'], 'b-', label='Lateral Acceleration')
  1038. if 'lat_acc_threshold' in df.columns:
  1039. ax1.plot(df['simTime'], df['lat_acc_threshold'], 'r--', label='lat_acc_threshold')
  1040. for idx, event in enumerate(shake_events):
  1041. label = 'Shake Event' if idx == 0 else None
  1042. ax1.axvspan(event['start_time'], event['end_time'], alpha=0.3, color='orange', label=label)
  1043. ax1.set_xlabel('Time (s)')
  1044. ax1.set_ylabel('Lateral Acceleration (m/s²)')
  1045. ax1.set_title('Shake Event Detection - Lateral Acceleration')
  1046. ax1.grid(True)
  1047. ax1.legend()
  1048. # 图 2:lat_acc_rate
  1049. ax2 = plt.subplot(3, 1, 2)
  1050. ax2.plot(df['simTime'], df['lat_acc_rate'], 'g-', label='lat_acc_rate')
  1051. ax2.axhline(
  1052. y=0.5, color='orange', linestyle='--', linewidth=1.2, label='lat_acc_rate_threshold'
  1053. )
  1054. for idx, event in enumerate(shake_events):
  1055. label = 'Shake Event' if idx == 0 else None
  1056. ax2.axvspan(event['start_time'], event['end_time'], alpha=0.3, color='orange', label=label)
  1057. ax2.set_xlabel('Time (s)')
  1058. ax2.set_ylabel('Angular Velocity (m/s³)')
  1059. ax2.set_title('Shake Event Detection - lat_acc_rate')
  1060. ax2.grid(True)
  1061. ax2.legend()
  1062. # 图 3:speedH_std
  1063. ax3 = plt.subplot(3, 1, 3)
  1064. ax3.plot(df['simTime'], df['speedH_std'], 'b-', label='speedH_std')
  1065. if 'speedH_std_threshold' in df.columns:
  1066. ax3.plot(df['simTime'], df['speedH_std_threshold'], 'r--', label='speedH_threshold')
  1067. for idx, event in enumerate(shake_events):
  1068. label = 'Shake Event' if idx == 0 else None
  1069. ax3.axvspan(event['start_time'], event['end_time'], alpha=0.3, color='orange', label=label)
  1070. ax3.set_xlabel('Time (s)')
  1071. ax3.set_ylabel('Angular Velocity (deg/s)')
  1072. ax3.set_title('Shake Event Detection - speedH_std')
  1073. ax3.grid(True)
  1074. ax3.legend()
  1075. # 保存图像
  1076. chart_filename = os.path.join(output_dir, f"shake_chart.png")
  1077. plt.savefig(chart_filename, dpi=300)
  1078. plt.close()
  1079. logger.info(f"Shake chart saved to: {chart_filename}")
  1080. return chart_filename
  1081. except Exception as e:
  1082. logger.error(f"Failed to generate shake chart: {str(e)}", exc_info=True)
  1083. return None
  1084. def generate_zigzag_chart(comfort_calculator, output_dir: str) -> Optional[str]:
  1085. """
  1086. Generate zigzag metric chart with orange background for zigzag events.
  1087. This version first saves data to CSV, then uses the CSV to generate the chart.
  1088. Args:
  1089. comfort_calculator: ComfortCalculator instance
  1090. output_dir: Output directory
  1091. Returns:
  1092. str: Chart file path, or None if generation fails
  1093. """
  1094. logger = LogManager().get_logger()
  1095. try:
  1096. # 获取数据
  1097. df = comfort_calculator.ego_df.copy()
  1098. zigzag_events = comfort_calculator.discomfort_df[
  1099. comfort_calculator.discomfort_df['type'] == 'zigzag'
  1100. ].copy()
  1101. if df.empty:
  1102. logger.warning("Cannot generate zigzag chart: empty data")
  1103. return None
  1104. # 生成时间戳
  1105. # 保存 CSV 数据(第一步)
  1106. csv_filename = os.path.join(output_dir, f"zigzag_data.csv")
  1107. df_csv = pd.DataFrame({
  1108. 'simTime': df['simTime'],
  1109. 'speedH': df['speedH'],
  1110. 'posH': df['posH'],
  1111. 'min_speedH_threshold': -2.3, # 可替换为动态阈值
  1112. 'max_speedH_threshold': 2.3
  1113. })
  1114. df_csv.to_csv(csv_filename, index=False)
  1115. logger.info(f"Zigzag data saved to: {csv_filename}")
  1116. # 第二步:从 CSV 读取(可验证保存数据无误)
  1117. df = pd.read_csv(csv_filename)
  1118. # 创建图表(第三步)
  1119. import matplotlib.pyplot as plt
  1120. plt.figure(figsize=(12, 8), constrained_layout=True)
  1121. # ===== 子图1:Yaw Rate =====
  1122. ax1 = plt.subplot(2, 1, 1)
  1123. ax1.plot(df['simTime'], df['speedH'], 'g-', label='Yaw Rate')
  1124. # 添加 speedH 上下限阈值线
  1125. ax1.axhline(y=2.3, color='m', linestyle='--', linewidth=1.2, label='Max Threshold (+2.3)')
  1126. ax1.axhline(y=-2.3, color='r', linestyle='--', linewidth=1.2, label='Min Threshold (-2.3)')
  1127. # 添加橙色背景:Zigzag Events
  1128. for idx, event in zigzag_events.iterrows():
  1129. label = 'Zigzag Event' if idx == 0 else None
  1130. ax1.axvspan(event['start_time'], event['end_time'],
  1131. alpha=0.3, color='orange', label=label)
  1132. ax1.set_xlabel('Time (s)')
  1133. ax1.set_ylabel('Angular Velocity (deg/s)')
  1134. ax1.set_title('Zigzag Event Detection - Yaw Rate')
  1135. ax1.grid(True)
  1136. ax1.legend(loc='upper left')
  1137. # ===== 子图2:Yaw Angle =====
  1138. ax2 = plt.subplot(2, 1, 2)
  1139. ax2.plot(df['simTime'], df['posH'], 'b-', label='Yaw')
  1140. # 添加橙色背景:Zigzag Events
  1141. for idx, event in zigzag_events.iterrows():
  1142. label = 'Zigzag Event' if idx == 0 else None
  1143. ax2.axvspan(event['start_time'], event['end_time'],
  1144. alpha=0.3, color='orange', label=label)
  1145. ax2.set_xlabel('Time (s)')
  1146. ax2.set_ylabel('Yaw (deg)')
  1147. ax2.set_title('Zigzag Event Detection - Yaw Angle')
  1148. ax2.grid(True)
  1149. ax2.legend(loc='upper left')
  1150. # 保存图像
  1151. chart_filename = os.path.join(output_dir, f"zigzag_chart.png")
  1152. plt.savefig(chart_filename, dpi=300)
  1153. plt.close()
  1154. logger.info(f"Zigzag chart saved to: {chart_filename}")
  1155. return csv_filename
  1156. except Exception as e:
  1157. logger.error(f"Failed to generate zigzag chart: {str(e)}", exc_info=True)
  1158. return None
  1159. def generate_cadence_chart(comfort_calculator, output_dir: str) -> Optional[str]:
  1160. """
  1161. Generate cadence metric chart with orange background for cadence events.
  1162. This version first saves data to CSV, then uses the CSV to generate the chart.
  1163. Args:
  1164. comfort_calculator: ComfortCalculator instance
  1165. output_dir: Output directory
  1166. Returns:
  1167. str: Chart file path, or None if generation fails
  1168. """
  1169. logger = LogManager().get_logger()
  1170. try:
  1171. # 获取数据
  1172. df = comfort_calculator.ego_df.copy()
  1173. cadence_events = comfort_calculator.discomfort_df[comfort_calculator.discomfort_df['type'] == 'cadence'].copy()
  1174. if df.empty:
  1175. logger.warning("Cannot generate cadence chart: empty data")
  1176. return None
  1177. # 生成时间戳
  1178. # 保存 CSV 数据(第一步)
  1179. csv_filename = os.path.join(output_dir, f"cadence_data.csv")
  1180. df_csv = pd.DataFrame({
  1181. 'simTime': df['simTime'],
  1182. 'lon_acc': df['lon_acc'],
  1183. 'v': df['v'],
  1184. 'ip_acc': df.get('ip_acc', pd.Series([None] * len(df))),
  1185. 'ip_dec': df.get('ip_dec', pd.Series([None] * len(df)))
  1186. })
  1187. df_csv.to_csv(csv_filename, index=False)
  1188. logger.info(f"Cadence data saved to: {csv_filename}")
  1189. # 第二步:从 CSV 读取(可验证保存数据无误)
  1190. df = pd.read_csv(csv_filename)
  1191. # 创建图表(第三步)
  1192. import matplotlib.pyplot as plt
  1193. plt.figure(figsize=(12, 8), constrained_layout=True)
  1194. # 图 1:纵向加速度
  1195. ax1 = plt.subplot(2, 1, 1)
  1196. ax1.plot(df['simTime'], df['lon_acc'], 'b-', label='Longitudinal Acceleration')
  1197. if 'ip_acc' in df.columns and 'ip_dec' in df.columns:
  1198. ax1.plot(df['simTime'], df['ip_acc'], 'r--', label='Acceleration Threshold')
  1199. ax1.plot(df['simTime'], df['ip_dec'], 'g--', label='Deceleration Threshold')
  1200. # 添加橙色背景标识顿挫事件
  1201. for idx, event in cadence_events.iterrows():
  1202. label = 'Cadence Event' if idx == 0 else None
  1203. ax1.axvspan(event['start_time'], event['end_time'],
  1204. alpha=0.3, color='orange', label=label)
  1205. ax1.set_xlabel('Time (s)')
  1206. ax1.set_ylabel('Longitudinal Acceleration (m/s²)')
  1207. ax1.set_title('Cadence Event Detection - Longitudinal Acceleration')
  1208. ax1.grid(True)
  1209. ax1.legend()
  1210. # 图 2:速度
  1211. ax2 = plt.subplot(2, 1, 2)
  1212. ax2.plot(df['simTime'], df['v'], 'g-', label='Velocity')
  1213. # 添加橙色背景标识顿挫事件
  1214. for idx, event in cadence_events.iterrows():
  1215. label = 'Cadence Event' if idx == 0 else None
  1216. ax2.axvspan(event['start_time'], event['end_time'],
  1217. alpha=0.3, color='orange', label=label)
  1218. ax2.set_xlabel('Time (s)')
  1219. ax2.set_ylabel('Velocity (m/s)')
  1220. ax2.set_title('Cadence Event Detection - Vehicle Speed')
  1221. ax2.grid(True)
  1222. ax2.legend()
  1223. # 保存图像
  1224. chart_filename = os.path.join(output_dir, f"cadence_chart.png")
  1225. plt.savefig(chart_filename, dpi=300)
  1226. plt.close()
  1227. logger.info(f"Cadence chart saved to: {chart_filename}")
  1228. return chart_filename
  1229. except Exception as e:
  1230. logger.error(f"Failed to generate cadence chart: {str(e)}", exc_info=True)
  1231. return None
  1232. def generate_slam_brake_chart(comfort_calculator, output_dir: str) -> Optional[str]:
  1233. """
  1234. Generate slam brake metric chart with orange background for slam brake events.
  1235. This version first saves data to CSV, then uses the CSV to generate the chart.
  1236. Args:
  1237. comfort_calculator: ComfortCalculator instance
  1238. output_dir: Output directory
  1239. Returns:
  1240. str: Chart file path, or None if generation fails
  1241. """
  1242. logger = LogManager().get_logger()
  1243. try:
  1244. # 获取数据
  1245. df = comfort_calculator.ego_df.copy()
  1246. slam_brake_events = comfort_calculator.discomfort_df[
  1247. comfort_calculator.discomfort_df['type'] == 'slam_brake'].copy()
  1248. if df.empty:
  1249. logger.warning("Cannot generate slam brake chart: empty data")
  1250. return None
  1251. # 生成时间戳
  1252. # 保存 CSV 数据(第一步)
  1253. csv_filename = os.path.join(output_dir, f"slam_brake_data.csv")
  1254. df_csv = pd.DataFrame({
  1255. 'simTime': df['simTime'],
  1256. 'lon_acc': df['lon_acc'],
  1257. 'v': df['v'],
  1258. 'min_threshold': df.get('ip_dec', pd.Series([None] * len(df))),
  1259. 'max_threshold': 0.0
  1260. })
  1261. df_csv.to_csv(csv_filename, index=False)
  1262. logger.info(f"Slam brake data saved to: {csv_filename}")
  1263. # 第二步:从 CSV 读取(可验证保存数据无误)
  1264. df = pd.read_csv(csv_filename)
  1265. # 创建图表(第三步)
  1266. plt.figure(figsize=(12, 8), constrained_layout=True)
  1267. # 图 1:纵向加速度
  1268. ax1 = plt.subplot(2, 1, 1)
  1269. ax1.plot(df['simTime'], df['lon_acc'], 'b-', label='Longitudinal Acceleration')
  1270. if 'min_threshold' in df.columns:
  1271. ax1.plot(df['simTime'], df['min_threshold'], 'r--', label='Deceleration Threshold')
  1272. # 添加橙色背景标识急刹车事件
  1273. for idx, event in slam_brake_events.iterrows():
  1274. label = 'Slam Brake Event' if idx == 0 else None
  1275. ax1.axvspan(event['start_time'], event['end_time'],
  1276. alpha=0.3, color='orange', label=label)
  1277. ax1.set_xlabel('Time (s)')
  1278. ax1.set_ylabel('Longitudinal Acceleration (m/s²)')
  1279. ax1.set_title('Slam Brake Event Detection - Longitudinal Acceleration')
  1280. ax1.grid(True)
  1281. ax1.legend()
  1282. # 图 2:速度
  1283. ax2 = plt.subplot(2, 1, 2)
  1284. ax2.plot(df['simTime'], df['v'], 'g-', label='Velocity')
  1285. # 添加橙色背景标识急刹车事件
  1286. for idx, event in slam_brake_events.iterrows():
  1287. label = 'Slam Brake Event' if idx == 0 else None
  1288. ax2.axvspan(event['start_time'], event['end_time'],
  1289. alpha=0.3, color='orange', label=label)
  1290. ax2.set_xlabel('Time (s)')
  1291. ax2.set_ylabel('Velocity (m/s)')
  1292. ax2.set_title('Slam Brake Event Detection - Vehicle Speed')
  1293. ax2.grid(True)
  1294. ax2.legend()
  1295. # 保存图像
  1296. chart_filename = os.path.join(output_dir, f"slam_brake_chart.png")
  1297. plt.savefig(chart_filename, dpi=300)
  1298. plt.close()
  1299. logger.info(f"Slam brake chart saved to: {chart_filename}")
  1300. return chart_filename
  1301. except Exception as e:
  1302. logger.error(f"Failed to generate slam brake chart: {str(e)}", exc_info=True)
  1303. return None
  1304. def generate_slam_accelerate_chart(comfort_calculator, output_dir: str) -> Optional[str]:
  1305. """
  1306. Generate slam accelerate metric chart with orange background for slam accelerate events.
  1307. This version first saves data to CSV, then uses the CSV to generate the chart.
  1308. Args:
  1309. comfort_calculator: ComfortCalculator instance
  1310. output_dir: Output directory
  1311. Returns:
  1312. str: Chart file path, or None if generation fails
  1313. """
  1314. logger = LogManager().get_logger()
  1315. try:
  1316. # 获取数据
  1317. df = comfort_calculator.ego_df.copy()
  1318. slam_accel_events = comfort_calculator.discomfort_df[
  1319. (comfort_calculator.discomfort_df['type'] == 'slam_accel')
  1320. ].copy()
  1321. if df.empty:
  1322. logger.warning("Cannot generate slam accelerate chart: empty data")
  1323. return None
  1324. # 生成时间戳
  1325. # 保存 CSV 数据(第一步)
  1326. csv_filename = os.path.join(output_dir, f"slam_accel_data.csv")
  1327. # 获取加速度阈值(如果存在)
  1328. accel_threshold = df.get('ip_acc', pd.Series([None] * len(df)))
  1329. df_csv = pd.DataFrame({
  1330. 'simTime': df['simTime'],
  1331. 'lon_acc': df['lon_acc'],
  1332. 'v': df['v'],
  1333. 'min_threshold': 0.0, # 加速度最小阈值设为0
  1334. 'max_threshold': accel_threshold # 急加速阈值
  1335. })
  1336. df_csv.to_csv(csv_filename, index=False)
  1337. logger.info(f"Slam accelerate data saved to: {csv_filename}")
  1338. # 第二步:从 CSV 读取(可验证保存数据无误)
  1339. df = pd.read_csv(csv_filename)
  1340. # 创建图表(第三步)
  1341. plt.figure(figsize=(12, 8), constrained_layout=True)
  1342. # 图 1:纵向加速度
  1343. ax1 = plt.subplot(2, 1, 1)
  1344. ax1.plot(df['simTime'], df['lon_acc'], 'b-', label='Longitudinal Acceleration')
  1345. # 添加加速度阈值线
  1346. if 'max_threshold' in df.columns and not df['max_threshold'].isnull().all():
  1347. ax1.plot(df['simTime'], df['max_threshold'], 'r--', label='Acceleration Threshold')
  1348. # 添加橙色背景标识急加速事件
  1349. for idx, event in slam_accel_events.iterrows():
  1350. label = 'Slam Accelerate Event' if idx == 0 else None
  1351. ax1.axvspan(event['start_time'], event['end_time'],
  1352. alpha=0.3, color='orange', label=label)
  1353. ax1.set_xlabel('Time (s)')
  1354. ax1.set_ylabel('Acceleration (m/s²)')
  1355. ax1.set_title('Slam Accelerate Event Detection - Longitudinal Acceleration')
  1356. ax1.grid(True)
  1357. ax1.legend()
  1358. # 图 2:速度
  1359. ax2 = plt.subplot(2, 1, 2)
  1360. ax2.plot(df['simTime'], df['v'], 'g-', label='Velocity')
  1361. # 添加橙色背景标识急加速事件
  1362. for idx, event in slam_accel_events.iterrows():
  1363. label = 'Slam Accelerate Event' if idx == 0 else None
  1364. ax2.axvspan(event['start_time'], event['end_time'],
  1365. alpha=0.3, color='orange', label=label)
  1366. ax2.set_xlabel('Time (s)')
  1367. ax2.set_ylabel('Velocity (m/s)')
  1368. ax2.set_title('Slam Accelerate Event Detection - Vehicle Speed')
  1369. ax2.grid(True)
  1370. ax2.legend()
  1371. # 保存图像
  1372. chart_filename = os.path.join(output_dir, f"slam_accel_chart.png")
  1373. plt.savefig(chart_filename, dpi=300)
  1374. plt.close()
  1375. logger.info(f"Slam accelerate chart saved to: {chart_filename}")
  1376. return chart_filename
  1377. except Exception as e:
  1378. logger.error(f"Failed to generate slam accelerate chart: {str(e)}", exc_info=True)
  1379. return None
  1380. def get_metric_thresholds(calculator, metric_name: str) -> dict:
  1381. """
  1382. 从配置文件中获取指标的阈值
  1383. Args:
  1384. calculator: Calculator instance (FunctionCalculator, SafetyCalculator, ComfortCalculator, EfficientCalculator, TrafficCalculator)
  1385. metric_name: 指标名称
  1386. Returns:
  1387. dict: 包含min和max阈值的字典
  1388. """
  1389. logger = LogManager().get_logger()
  1390. thresholds = {"min": None, "max": None}
  1391. try:
  1392. # 根据计算器类型获取配置
  1393. if hasattr(calculator, 'data_processed'):
  1394. # 检查安全性指标配置
  1395. if hasattr(calculator.data_processed,
  1396. 'safety_config') and 'safety' in calculator.data_processed.safety_config:
  1397. config = calculator.data_processed.safety_config['safety']
  1398. metric_type = 'safety'
  1399. # 检查舒适性指标配置
  1400. elif hasattr(calculator.data_processed,
  1401. 'comfort_config') and 'comfort' in calculator.data_processed.comfort_config:
  1402. config = calculator.data_processed.comfort_config['comfort']
  1403. metric_type = 'comfort'
  1404. # 检查功能性指标配置
  1405. elif hasattr(calculator.data_processed,
  1406. 'function_config') and 'function' in calculator.data_processed.function_config:
  1407. config = calculator.data_processed.function_config['function']
  1408. metric_type = 'function'
  1409. # 检查高效性指标配置
  1410. elif hasattr(calculator.data_processed,
  1411. 'efficient_config') and 'efficient' in calculator.data_processed.efficient_config:
  1412. config = calculator.data_processed.efficient_config['efficient']
  1413. metric_type = 'efficient'
  1414. # 检查交通性指标配置
  1415. elif hasattr(calculator.data_processed,
  1416. 'traffic_config') and 'traffic' in calculator.data_processed.traffic_config:
  1417. config = calculator.data_processed.traffic_config['traffic']
  1418. metric_type = 'traffic'
  1419. else:
  1420. # 直接检查calculator是否有function_config属性(针对FunctionCalculator)
  1421. if hasattr(calculator, 'function_config') and 'function' in calculator.function_config:
  1422. config = calculator.function_config['function']
  1423. metric_type = 'function'
  1424. else:
  1425. logger.warning(f"无法找到{metric_name}的配置信息")
  1426. return thresholds
  1427. else:
  1428. # 直接检查calculator是否有function_config属性(针对FunctionCalculator)
  1429. if hasattr(calculator, 'function_config') and 'function' in calculator.function_config:
  1430. config = calculator.function_config['function']
  1431. metric_type = 'function'
  1432. else:
  1433. logger.warning(f"计算器没有data_processed属性或function_config属性")
  1434. return thresholds
  1435. # 递归查找指标配置
  1436. def find_metric_config(node, target_name):
  1437. if isinstance(node, dict):
  1438. if 'name' in node and node['name'].lower() == target_name.lower() and 'min' in node and 'max' in node:
  1439. return node
  1440. for key, value in node.items():
  1441. result = find_metric_config(value, target_name)
  1442. if result:
  1443. return result
  1444. return None
  1445. # 查找指标配置
  1446. metric_config = find_metric_config(config, metric_name)
  1447. if metric_config:
  1448. thresholds["min"] = metric_config.get("min")
  1449. thresholds["max"] = metric_config.get("max")
  1450. logger.info(f"找到{metric_name}的阈值: min={thresholds['min']}, max={thresholds['max']}")
  1451. else:
  1452. logger.warning(f"在{metric_type}配置中未找到{metric_name}的阈值信息")
  1453. except Exception as e:
  1454. logger.error(f"获取{metric_name}阈值时出错: {str(e)}", exc_info=True)
  1455. return thresholds
  1456. def generate_safety_chart_data(safety_calculator, metric_name: str, output_dir: Optional[str] = None) -> Optional[str]:
  1457. """
  1458. Generate chart data for safety metrics
  1459. Args:
  1460. safety_calculator: SafetyCalculator instance
  1461. metric_name: Metric name
  1462. output_dir: Output directory
  1463. Returns:
  1464. str: Chart file path, or None if generation fails
  1465. """
  1466. logger = LogManager().get_logger()
  1467. try:
  1468. # 确保输出目录存在
  1469. if output_dir:
  1470. os.makedirs(output_dir, exist_ok=True)
  1471. else:
  1472. output_dir = os.getcwd()
  1473. # 根据指标名称选择不同的图表生成方法
  1474. if metric_name.lower() == 'ttc':
  1475. return generate_ttc_chart(safety_calculator, output_dir)
  1476. elif metric_name.lower() == 'mttc':
  1477. return generate_mttc_chart(safety_calculator, output_dir)
  1478. elif metric_name.lower() == 'thw':
  1479. return generate_thw_chart(safety_calculator, output_dir)
  1480. elif metric_name.lower() == 'lonsd':
  1481. return generate_lonsd_chart(safety_calculator, output_dir)
  1482. elif metric_name.lower() == 'latsd':
  1483. return generate_latsd_chart(safety_calculator, output_dir)
  1484. elif metric_name.lower() == 'btn':
  1485. return generate_btn_chart(safety_calculator, output_dir)
  1486. elif metric_name.lower() == 'collisionrisk':
  1487. return generate_collision_risk_chart(safety_calculator, output_dir)
  1488. elif metric_name.lower() == 'collisionseverity':
  1489. return generate_collision_severity_chart(safety_calculator, output_dir)
  1490. else:
  1491. logger.warning(f"Chart generation not implemented for metric [{metric_name}]")
  1492. return None
  1493. except Exception as e:
  1494. logger.error(f"Failed to generate chart data: {str(e)}", exc_info=True)
  1495. return None
  1496. def generate_ttc_chart(safety_calculator, output_dir: str) -> Optional[str]:
  1497. """
  1498. Generate TTC metric chart with orange background for unsafe events.
  1499. This version first saves data to CSV, then uses the CSV to generate the chart.
  1500. Args:
  1501. safety_calculator: SafetyCalculator instance
  1502. output_dir: Output directory
  1503. Returns:
  1504. str: Chart file path, or None if generation fails
  1505. """
  1506. logger = LogManager().get_logger()
  1507. try:
  1508. # 获取数据
  1509. ttc_data = safety_calculator.ttc_data
  1510. if not ttc_data:
  1511. logger.warning("Cannot generate TTC chart: empty data")
  1512. return None
  1513. # 创建DataFrame
  1514. df = pd.DataFrame(ttc_data)
  1515. # 获取阈值
  1516. thresholds = get_metric_thresholds(safety_calculator, 'TTC')
  1517. min_threshold = thresholds.get('min')
  1518. max_threshold = thresholds.get('max')
  1519. # 生成时间戳
  1520. # 保存 CSV 数据(第一步)
  1521. csv_filename = os.path.join(output_dir, f"ttc_data.csv")
  1522. df_csv = pd.DataFrame({
  1523. 'simTime': df['simTime'],
  1524. 'simFrame': df['simFrame'],
  1525. 'TTC': df['TTC'],
  1526. 'min_threshold': min_threshold,
  1527. 'max_threshold': max_threshold
  1528. })
  1529. df_csv.to_csv(csv_filename, index=False)
  1530. logger.info(f"TTC data saved to: {csv_filename}")
  1531. # 第二步:从 CSV 读取(可验证保存数据无误)
  1532. df = pd.read_csv(csv_filename)
  1533. # 检测超阈值事件
  1534. unsafe_events = []
  1535. if min_threshold is not None:
  1536. # 对于TTC,小于最小阈值视为不安全
  1537. unsafe_condition = df['TTC'] < min_threshold
  1538. event_groups = (unsafe_condition != unsafe_condition.shift()).cumsum()
  1539. for _, group in df[unsafe_condition].groupby(event_groups):
  1540. if len(group) >= 2: # 至少2帧才算一次事件
  1541. start_time = group['simTime'].iloc[0]
  1542. end_time = group['simTime'].iloc[-1]
  1543. duration = end_time - start_time
  1544. if duration >= 0.1: # 只记录持续时间超过0.1秒的事件
  1545. unsafe_events.append({
  1546. 'start_time': start_time,
  1547. 'end_time': end_time,
  1548. 'start_frame': group['simFrame'].iloc[0],
  1549. 'end_frame': group['simFrame'].iloc[-1],
  1550. 'duration': duration,
  1551. 'min_ttc': group['TTC'].min()
  1552. })
  1553. # 创建图表(第三步)
  1554. plt.figure(figsize=(12, 8))
  1555. plt.plot(df['simTime'], df['TTC'], 'b-', label='TTC')
  1556. # 添加阈值线
  1557. if min_threshold is not None:
  1558. plt.axhline(y=min_threshold, color='r', linestyle='--', label=f'Min Threshold ({min_threshold}s)')
  1559. if max_threshold is not None:
  1560. plt.axhline(y=max_threshold, color='g', linestyle='--', label=f'Max Threshold ({max_threshold})')
  1561. # 添加橙色背景标识不安全事件
  1562. for idx, event in enumerate(unsafe_events):
  1563. label = 'Unsafe TTC Event' if idx == 0 else None
  1564. plt.axvspan(event['start_time'], event['end_time'],
  1565. alpha=0.3, color='orange', label=label)
  1566. plt.xlabel('Time (s)')
  1567. plt.ylabel('TTC (s)')
  1568. plt.title('Time To Collision (TTC) Trend')
  1569. plt.grid(True)
  1570. plt.legend()
  1571. # 保存图像
  1572. chart_filename = os.path.join(output_dir, f"ttc_chart.png")
  1573. plt.savefig(chart_filename, dpi=300)
  1574. plt.close()
  1575. # 记录不安全事件信息
  1576. if unsafe_events:
  1577. logger.info(f"检测到 {len(unsafe_events)} 个TTC不安全事件")
  1578. for i, event in enumerate(unsafe_events):
  1579. logger.info(
  1580. f"TTC不安全事件 #{i + 1}: 开始时间={event['start_time']:.2f}s, 结束时间={event['end_time']:.2f}s, 持续时间={event['duration']:.2f}s, 最小TTC={event['min_ttc']:.2f}s")
  1581. logger.info(f"TTC chart saved to: {chart_filename}")
  1582. return chart_filename
  1583. except Exception as e:
  1584. logger.error(f"Failed to generate TTC chart: {str(e)}", exc_info=True)
  1585. return None
  1586. def generate_mttc_chart(safety_calculator, output_dir: str) -> Optional[str]:
  1587. """
  1588. Generate MTTC metric chart with orange background for unsafe events
  1589. Args:
  1590. safety_calculator: SafetyCalculator instance
  1591. output_dir: Output directory
  1592. Returns:
  1593. str: Chart file path, or None if generation fails
  1594. """
  1595. logger = LogManager().get_logger()
  1596. try:
  1597. # 获取数据
  1598. mttc_data = safety_calculator.mttc_data
  1599. if not mttc_data:
  1600. logger.warning("Cannot generate MTTC chart: empty data")
  1601. return None
  1602. # 创建DataFrame
  1603. df = pd.DataFrame(mttc_data)
  1604. # 获取阈值
  1605. thresholds = get_metric_thresholds(safety_calculator, 'MTTC')
  1606. min_threshold = thresholds.get('min')
  1607. max_threshold = thresholds.get('max')
  1608. # 检测超阈值事件
  1609. unsafe_events = []
  1610. if min_threshold is not None:
  1611. # 对于MTTC,小于最小阈值视为不安全
  1612. unsafe_condition = df['MTTC'] < min_threshold
  1613. event_groups = (unsafe_condition != unsafe_condition.shift()).cumsum()
  1614. for _, group in df[unsafe_condition].groupby(event_groups):
  1615. if len(group) >= 2: # 至少2帧才算一次事件
  1616. start_time = group['simTime'].iloc[0]
  1617. end_time = group['simTime'].iloc[-1]
  1618. duration = end_time - start_time
  1619. if duration >= 0.1: # 只记录持续时间超过0.1秒的事件
  1620. unsafe_events.append({
  1621. 'start_time': start_time,
  1622. 'end_time': end_time,
  1623. 'start_frame': group['simFrame'].iloc[0],
  1624. 'end_frame': group['simFrame'].iloc[-1],
  1625. 'duration': duration,
  1626. 'min_mttc': group['MTTC'].min()
  1627. })
  1628. # 创建图表
  1629. plt.figure(figsize=(12, 6))
  1630. plt.plot(df['simTime'], df['MTTC'], 'g-', label='MTTC')
  1631. # 添加阈值线
  1632. if min_threshold is not None:
  1633. plt.axhline(y=min_threshold, color='r', linestyle='--', label=f'Min Threshold ({min_threshold}s)')
  1634. if max_threshold is not None:
  1635. plt.axhline(y=max_threshold, color='g', linestyle='--', label=f'Max Threshold ({max_threshold})')
  1636. # 添加橙色背景标识不安全事件
  1637. for idx, event in enumerate(unsafe_events):
  1638. label = 'Unsafe MTTC Event' if idx == 0 else None
  1639. plt.axvspan(event['start_time'], event['end_time'],
  1640. alpha=0.3, color='orange', label=label)
  1641. plt.xlabel('Time (s)')
  1642. plt.ylabel('MTTC (s)')
  1643. plt.title('Modified Time To Collision (MTTC) Trend')
  1644. plt.grid(True)
  1645. plt.legend()
  1646. # 保存图表
  1647. chart_filename = os.path.join(output_dir, f"mttc_chart.png")
  1648. plt.savefig(chart_filename, dpi=300)
  1649. plt.close()
  1650. # 保存CSV数据,包含阈值信息
  1651. csv_filename = os.path.join(output_dir, f"mttc_data.csv")
  1652. df_csv = df.copy()
  1653. df_csv['min_threshold'] = min_threshold
  1654. df_csv['max_threshold'] = max_threshold
  1655. df_csv.to_csv(csv_filename, index=False)
  1656. # 记录不安全事件信息
  1657. if unsafe_events:
  1658. logger.info(f"检测到 {len(unsafe_events)} 个MTTC不安全事件")
  1659. for i, event in enumerate(unsafe_events):
  1660. logger.info(
  1661. f"MTTC不安全事件 #{i + 1}: 开始时间={event['start_time']:.2f}s, 结束时间={event['end_time']:.2f}s, 持续时间={event['duration']:.2f}s, 最小MTTC={event['min_mttc']:.2f}s")
  1662. logger.info(f"MTTC chart saved to: {chart_filename}")
  1663. logger.info(f"MTTC data saved to: {csv_filename}")
  1664. return chart_filename
  1665. except Exception as e:
  1666. logger.error(f"Failed to generate MTTC chart: {str(e)}", exc_info=True)
  1667. return None
  1668. def generate_thw_chart(safety_calculator, output_dir: str) -> Optional[str]:
  1669. """
  1670. Generate THW metric chart with orange background for unsafe events
  1671. Args:
  1672. safety_calculator: SafetyCalculator instance
  1673. output_dir: Output directory
  1674. Returns:
  1675. str: Chart file path, or None if generation fails
  1676. """
  1677. logger = LogManager().get_logger()
  1678. try:
  1679. # 获取数据
  1680. thw_data = safety_calculator.thw_data
  1681. if not thw_data:
  1682. logger.warning("Cannot generate THW chart: empty data")
  1683. return None
  1684. # 创建DataFrame
  1685. df = pd.DataFrame(thw_data)
  1686. # 获取阈值
  1687. thresholds = get_metric_thresholds(safety_calculator, 'THW')
  1688. min_threshold = thresholds.get('min')
  1689. max_threshold = thresholds.get('max')
  1690. # 检测超阈值事件
  1691. unsafe_events = []
  1692. if min_threshold is not None:
  1693. # 对于THW,小于最小阈值视为不安全
  1694. unsafe_condition = df['THW'] < min_threshold
  1695. event_groups = (unsafe_condition != unsafe_condition.shift()).cumsum()
  1696. for _, group in df[unsafe_condition].groupby(event_groups):
  1697. if len(group) >= 2: # 至少2帧才算一次事件
  1698. start_time = group['simTime'].iloc[0]
  1699. end_time = group['simTime'].iloc[-1]
  1700. duration = end_time - start_time
  1701. if duration >= 0.1: # 只记录持续时间超过0.1秒的事件
  1702. unsafe_events.append({
  1703. 'start_time': start_time,
  1704. 'end_time': end_time,
  1705. 'start_frame': group['simFrame'].iloc[0],
  1706. 'end_frame': group['simFrame'].iloc[-1],
  1707. 'duration': duration,
  1708. 'min_thw': group['THW'].min()
  1709. })
  1710. # 创建图表
  1711. plt.figure(figsize=(12, 10))
  1712. plt.plot(df['simTime'], df['THW'], 'c-', label='THW')
  1713. # 添加阈值线
  1714. if min_threshold is not None:
  1715. plt.axhline(y=min_threshold, color='r', linestyle='--', label=f'Min Threshold ({min_threshold}s)')
  1716. if max_threshold is not None:
  1717. plt.axhline(y=max_threshold, color='g', linestyle='--', label=f'Max Threshold ({max_threshold})')
  1718. # 添加橙色背景标识不安全事件
  1719. for idx, event in enumerate(unsafe_events):
  1720. label = 'Unsafe THW Event' if idx == 0 else None
  1721. plt.axvspan(event['start_time'], event['end_time'],
  1722. alpha=0.3, color='orange', label=label)
  1723. plt.xlabel('Time (s)')
  1724. plt.ylabel('THW (s)')
  1725. plt.title('Time Headway (THW) Trend')
  1726. plt.grid(True)
  1727. plt.legend()
  1728. # 保存图表
  1729. chart_filename = os.path.join(output_dir, f"thw_chart.png")
  1730. plt.savefig(chart_filename, dpi=300)
  1731. plt.close()
  1732. # 保存CSV数据,包含阈值信息
  1733. csv_filename = os.path.join(output_dir, f"thw_data.csv")
  1734. df_csv = df.copy()
  1735. df_csv['min_threshold'] = min_threshold
  1736. df_csv['max_threshold'] = max_threshold
  1737. df_csv.to_csv(csv_filename, index=False)
  1738. # 记录不安全事件信息
  1739. if unsafe_events:
  1740. logger.info(f"检测到 {len(unsafe_events)} 个THW不安全事件")
  1741. for i, event in enumerate(unsafe_events):
  1742. logger.info(
  1743. f"THW不安全事件 #{i + 1}: 开始时间={event['start_time']:.2f}s, 结束时间={event['end_time']:.2f}s, 持续时间={event['duration']:.2f}s, 最小THW={event['min_thw']:.2f}s")
  1744. logger.info(f"THW chart saved to: {chart_filename}")
  1745. logger.info(f"THW data saved to: {csv_filename}")
  1746. return chart_filename
  1747. except Exception as e:
  1748. logger.error(f"Failed to generate THW chart: {str(e)}", exc_info=True)
  1749. return None
  1750. def generate_lonsd_chart(safety_calculator, output_dir: str) -> Optional[str]:
  1751. """
  1752. Generate Longitudinal Safe Distance metric chart
  1753. Args:
  1754. safety_calculator: SafetyCalculator instance
  1755. output_dir: Output directory
  1756. Returns:
  1757. str: Chart file path, or None if generation fails
  1758. """
  1759. logger = LogManager().get_logger()
  1760. try:
  1761. # 获取数据
  1762. lonsd_data = safety_calculator.lonsd_data
  1763. if not lonsd_data:
  1764. logger.warning("Cannot generate Longitudinal Safe Distance chart: empty data")
  1765. return None
  1766. # 创建DataFrame
  1767. df = pd.DataFrame(lonsd_data)
  1768. # 获取阈值
  1769. thresholds = get_metric_thresholds(safety_calculator, 'LonSD')
  1770. min_threshold = thresholds.get('min')
  1771. max_threshold = thresholds.get('max')
  1772. # 创建图表
  1773. plt.figure(figsize=(12, 6))
  1774. plt.plot(df['simTime'], df['LonSD'], 'm-', label='Longitudinal Safe Distance')
  1775. # 添加阈值线
  1776. if min_threshold is not None:
  1777. plt.axhline(y=min_threshold, color='r', linestyle='--', label=f'Min Threshold ({min_threshold}m)')
  1778. if max_threshold is not None:
  1779. plt.axhline(y=max_threshold, color='g', linestyle='--', label=f'Max Threshold ({max_threshold}m)')
  1780. plt.xlabel('Time (s)')
  1781. plt.ylabel('Distance (m)')
  1782. plt.title('Longitudinal Safe Distance (LonSD) Trend')
  1783. plt.grid(True)
  1784. plt.legend()
  1785. # 保存图表
  1786. chart_filename = os.path.join(output_dir, f"lonsd_chart.png")
  1787. plt.savefig(chart_filename, dpi=300)
  1788. plt.close()
  1789. # 保存CSV数据,包含阈值信息
  1790. csv_filename = os.path.join(output_dir, f"lonsd_data.csv")
  1791. df_csv = df.copy()
  1792. df_csv['min_threshold'] = min_threshold
  1793. df_csv['max_threshold'] = max_threshold
  1794. df_csv.to_csv(csv_filename, index=False)
  1795. logger.info(f"Longitudinal Safe Distance chart saved to: {chart_filename}")
  1796. logger.info(f"Longitudinal Safe Distance data saved to: {csv_filename}")
  1797. return chart_filename
  1798. except Exception as e:
  1799. logger.error(f"Failed to generate Longitudinal Safe Distance chart: {str(e)}", exc_info=True)
  1800. return None
  1801. def generate_latsd_chart(safety_calculator, output_dir: str) -> Optional[str]:
  1802. """
  1803. Generate Lateral Safe Distance metric chart with orange background for unsafe events
  1804. Args:
  1805. safety_calculator: SafetyCalculator instance
  1806. output_dir: Output directory
  1807. Returns:
  1808. str: Chart file path, or None if generation fails
  1809. """
  1810. logger = LogManager().get_logger()
  1811. try:
  1812. # 获取数据
  1813. latsd_data = safety_calculator.latsd_data
  1814. if not latsd_data:
  1815. logger.warning("Cannot generate Lateral Safe Distance chart: empty data")
  1816. return None
  1817. # 创建DataFrame
  1818. df = pd.DataFrame(latsd_data)
  1819. # 获取阈值
  1820. thresholds = get_metric_thresholds(safety_calculator, 'LatSD')
  1821. min_threshold = thresholds.get('min')
  1822. max_threshold = thresholds.get('max')
  1823. # 检测超阈值事件
  1824. unsafe_events = []
  1825. if min_threshold is not None:
  1826. # 对于LatSD,小于最小阈值视为不安全
  1827. unsafe_condition = df['LatSD'] < min_threshold
  1828. event_groups = (unsafe_condition != unsafe_condition.shift()).cumsum()
  1829. for _, group in df[unsafe_condition].groupby(event_groups):
  1830. if len(group) >= 2: # 至少2帧才算一次事件
  1831. start_time = group['simTime'].iloc[0]
  1832. end_time = group['simTime'].iloc[-1]
  1833. duration = end_time - start_time
  1834. if duration >= 0.1: # 只记录持续时间超过0.1秒的事件
  1835. unsafe_events.append({
  1836. 'start_time': start_time,
  1837. 'end_time': end_time,
  1838. 'start_frame': group['simFrame'].iloc[0],
  1839. 'end_frame': group['simFrame'].iloc[-1],
  1840. 'duration': duration,
  1841. 'min_latsd': group['LatSD'].min()
  1842. })
  1843. # 创建图表
  1844. plt.figure(figsize=(12, 6))
  1845. plt.plot(df['simTime'], df['LatSD'], 'y-', label='Lateral Safe Distance')
  1846. # 添加阈值线
  1847. if min_threshold is not None:
  1848. plt.axhline(y=min_threshold, color='r', linestyle='--', label=f'Min Threshold ({min_threshold}m)')
  1849. if max_threshold is not None:
  1850. plt.axhline(y=max_threshold, color='g', linestyle='--', label=f'Max Threshold ({max_threshold}m)')
  1851. # 添加橙色背景标识不安全事件
  1852. for idx, event in enumerate(unsafe_events):
  1853. label = 'Unsafe LatSD Event' if idx == 0 else None
  1854. plt.axvspan(event['start_time'], event['end_time'],
  1855. alpha=0.3, color='orange', label=label)
  1856. plt.xlabel('Time (s)')
  1857. plt.ylabel('Distance (m)')
  1858. plt.title('Lateral Safe Distance (LatSD) Trend')
  1859. plt.grid(True)
  1860. plt.legend()
  1861. # 保存图表
  1862. chart_filename = os.path.join(output_dir, f"latsd_chart.png")
  1863. plt.savefig(chart_filename, dpi=300)
  1864. plt.close()
  1865. # 保存CSV数据,包含阈值信息
  1866. csv_filename = os.path.join(output_dir, f"latsd_data.csv")
  1867. df_csv = df.copy()
  1868. df_csv['min_threshold'] = min_threshold
  1869. df_csv['max_threshold'] = max_threshold
  1870. df_csv.to_csv(csv_filename, index=False)
  1871. # 记录不安全事件信息
  1872. if unsafe_events:
  1873. logger.info(f"检测到 {len(unsafe_events)} 个LatSD不安全事件")
  1874. for i, event in enumerate(unsafe_events):
  1875. logger.info(
  1876. f"LatSD不安全事件 #{i + 1}: 开始时间={event['start_time']:.2f}s, 结束时间={event['end_time']:.2f}s, 持续时间={event['duration']:.2f}s, 最小LatSD={event['min_latsd']:.2f}m")
  1877. logger.info(f"Lateral Safe Distance chart saved to: {chart_filename}")
  1878. logger.info(f"Lateral Safe Distance data saved to: {csv_filename}")
  1879. return chart_filename
  1880. except Exception as e:
  1881. logger.error(f"Failed to generate Lateral Safe Distance chart: {str(e)}", exc_info=True)
  1882. return None
  1883. def generate_btn_chart(safety_calculator, output_dir: str) -> Optional[str]:
  1884. """
  1885. Generate Brake Threat Number metric chart with orange background for unsafe events
  1886. Args:
  1887. safety_calculator: SafetyCalculator instance
  1888. output_dir: Output directory
  1889. Returns:
  1890. str: Chart file path, or None if generation fails
  1891. """
  1892. logger = LogManager().get_logger()
  1893. try:
  1894. # 获取数据
  1895. btn_data = safety_calculator.btn_data
  1896. if not btn_data:
  1897. logger.warning("Cannot generate Brake Threat Number chart: empty data")
  1898. return None
  1899. # 创建DataFrame
  1900. df = pd.DataFrame(btn_data)
  1901. # 获取阈值
  1902. thresholds = get_metric_thresholds(safety_calculator, 'BTN')
  1903. min_threshold = thresholds.get('min')
  1904. max_threshold = thresholds.get('max')
  1905. # 检测超阈值事件
  1906. unsafe_events = []
  1907. if max_threshold is not None:
  1908. # 对于BTN,大于最大阈值视为不安全
  1909. unsafe_condition = df['BTN'] > max_threshold
  1910. event_groups = (unsafe_condition != unsafe_condition.shift()).cumsum()
  1911. for _, group in df[unsafe_condition].groupby(event_groups):
  1912. if len(group) >= 2: # 至少2帧才算一次事件
  1913. start_time = group['simTime'].iloc[0]
  1914. end_time = group['simTime'].iloc[-1]
  1915. duration = end_time - start_time
  1916. if duration >= 0.1: # 只记录持续时间超过0.1秒的事件
  1917. unsafe_events.append({
  1918. 'start_time': start_time,
  1919. 'end_time': end_time,
  1920. 'start_frame': group['simFrame'].iloc[0],
  1921. 'end_frame': group['simFrame'].iloc[-1],
  1922. 'duration': duration,
  1923. 'max_btn': group['BTN'].max()
  1924. })
  1925. # 创建图表
  1926. plt.figure(figsize=(12, 6))
  1927. plt.plot(df['simTime'], df['BTN'], 'r-', label='Brake Threat Number')
  1928. # 添加阈值线
  1929. if min_threshold is not None:
  1930. plt.axhline(y=min_threshold, color='r', linestyle='--', label=f'Min Threshold ({min_threshold})')
  1931. if max_threshold is not None:
  1932. plt.axhline(y=max_threshold, color='g', linestyle='--', label=f'Max Threshold ({max_threshold})')
  1933. # 添加橙色背景标识不安全事件
  1934. for idx, event in enumerate(unsafe_events):
  1935. label = 'Unsafe BTN Event' if idx == 0 else None
  1936. plt.axvspan(event['start_time'], event['end_time'],
  1937. alpha=0.3, color='orange', label=label)
  1938. plt.xlabel('Time (s)')
  1939. plt.ylabel('BTN')
  1940. plt.title('Brake Threat Number (BTN) Trend')
  1941. plt.grid(True)
  1942. plt.legend()
  1943. # 保存图表
  1944. chart_filename = os.path.join(output_dir, f"btn_chart.png")
  1945. plt.savefig(chart_filename, dpi=300)
  1946. plt.close()
  1947. # 保存CSV数据,包含阈值信息
  1948. csv_filename = os.path.join(output_dir, f"btn_data.csv")
  1949. df_csv = df.copy()
  1950. df_csv['min_threshold'] = min_threshold
  1951. df_csv['max_threshold'] = max_threshold
  1952. df_csv.to_csv(csv_filename, index=False)
  1953. # 记录不安全事件信息
  1954. if unsafe_events:
  1955. logger.info(f"检测到 {len(unsafe_events)} 个BTN不安全事件")
  1956. for i, event in enumerate(unsafe_events):
  1957. logger.info(
  1958. f"BTN不安全事件 #{i + 1}: 开始时间={event['start_time']:.2f}s, 结束时间={event['end_time']:.2f}s, 持续时间={event['duration']:.2f}s, 最大BTN={event['max_btn']:.2f}")
  1959. logger.info(f"Brake Threat Number chart saved to: {chart_filename}")
  1960. logger.info(f"Brake Threat Number data saved to: {csv_filename}")
  1961. return chart_filename
  1962. except Exception as e:
  1963. logger.error(f"Failed to generate Brake Threat Number chart: {str(e)}", exc_info=True)
  1964. return None
  1965. def generate_collision_risk_chart(safety_calculator, output_dir: str) -> Optional[str]:
  1966. """
  1967. Generate Collision Risk metric chart
  1968. Args:
  1969. safety_calculator: SafetyCalculator instance
  1970. output_dir: Output directory
  1971. Returns:
  1972. str: Chart file path, or None if generation fails
  1973. """
  1974. logger = LogManager().get_logger()
  1975. try:
  1976. # 获取数据
  1977. risk_data = safety_calculator.collision_risk_data
  1978. if not risk_data:
  1979. logger.warning("Cannot generate Collision Risk chart: empty data")
  1980. return None
  1981. # 创建DataFrame
  1982. df = pd.DataFrame(risk_data)
  1983. # 获取阈值
  1984. thresholds = get_metric_thresholds(safety_calculator, 'collisionRisk')
  1985. min_threshold = thresholds.get('min')
  1986. max_threshold = thresholds.get('max')
  1987. # 创建图表
  1988. plt.figure(figsize=(12, 6))
  1989. plt.plot(df['simTime'], df['collisionRisk'], 'r-', label='Collision Risk')
  1990. # 添加阈值线
  1991. if min_threshold is not None:
  1992. plt.axhline(y=min_threshold, color='r', linestyle='--', label=f'Min Threshold ({min_threshold}%)')
  1993. if max_threshold is not None:
  1994. plt.axhline(y=max_threshold, color='g', linestyle='--', label=f'Max Threshold ({max_threshold}%)')
  1995. plt.xlabel('Time (s)')
  1996. plt.ylabel('Risk Value (%)')
  1997. plt.title('Collision Risk (collisionRisk) Trend')
  1998. plt.grid(True)
  1999. plt.legend()
  2000. # 保存图表
  2001. chart_filename = os.path.join(output_dir, f"collision_risk_chart.png")
  2002. plt.savefig(chart_filename, dpi=300)
  2003. plt.close()
  2004. # 保存CSV数据,包含阈值信息
  2005. csv_filename = os.path.join(output_dir, f"collisionrisk_data.csv")
  2006. df_csv = df.copy()
  2007. df_csv['min_threshold'] = min_threshold
  2008. df_csv['max_threshold'] = max_threshold
  2009. df_csv.to_csv(csv_filename, index=False)
  2010. logger.info(f"Collision Risk chart saved to: {chart_filename}")
  2011. logger.info(f"Collision Risk data saved to: {csv_filename}")
  2012. return chart_filename
  2013. except Exception as e:
  2014. logger.error(f"Failed to generate Collision Risk chart: {str(e)}", exc_info=True)
  2015. return None
  2016. def generate_collision_severity_chart(safety_calculator, output_dir: str) -> Optional[str]:
  2017. """
  2018. Generate Collision Severity metric chart
  2019. Args:
  2020. safety_calculator: SafetyCalculator instance
  2021. output_dir: Output directory
  2022. Returns:
  2023. str: Chart file path, or None if generation fails
  2024. """
  2025. logger = LogManager().get_logger()
  2026. try:
  2027. # 获取数据
  2028. severity_data = safety_calculator.collision_severity_data
  2029. if not severity_data:
  2030. logger.warning("Cannot generate Collision Severity chart: empty data")
  2031. return None
  2032. # 创建DataFrame
  2033. df = pd.DataFrame(severity_data)
  2034. # 获取阈值
  2035. thresholds = get_metric_thresholds(safety_calculator, 'collisionSeverity')
  2036. min_threshold = thresholds.get('min')
  2037. max_threshold = thresholds.get('max')
  2038. # 创建图表
  2039. plt.figure(figsize=(12, 6))
  2040. plt.plot(df['simTime'], df['collisionSeverity'], 'r-', label='Collision Severity')
  2041. # 添加阈值线
  2042. if min_threshold is not None:
  2043. plt.axhline(y=min_threshold, color='r', linestyle='--', label=f'Min Threshold ({min_threshold}%)')
  2044. if max_threshold is not None:
  2045. plt.axhline(y=max_threshold, color='g', linestyle='--', label=f'Max Threshold ({max_threshold}%)')
  2046. plt.xlabel('Time (s)')
  2047. plt.ylabel('Severity (%)')
  2048. plt.title('Collision Severity (collisionSeverity) Trend')
  2049. plt.grid(True)
  2050. plt.legend()
  2051. # 保存图表
  2052. chart_filename = os.path.join(output_dir, f"collision_severity_chart.png")
  2053. plt.savefig(chart_filename, dpi=300)
  2054. plt.close()
  2055. # 保存CSV数据,包含阈值信息
  2056. csv_filename = os.path.join(output_dir, f"collisionseverity_data.csv")
  2057. df_csv = df.copy()
  2058. df_csv['min_threshold'] = min_threshold
  2059. df_csv['max_threshold'] = max_threshold
  2060. df_csv.to_csv(csv_filename, index=False)
  2061. logger.info(f"Collision Severity chart saved to: {chart_filename}")
  2062. logger.info(f"Collision Severity data saved to: {csv_filename}")
  2063. return chart_filename
  2064. except Exception as e:
  2065. logger.error(f"Failed to generate Collision Severity chart: {str(e)}", exc_info=True)
  2066. return None
  2067. def generate_traffic_chart_data(traffic_calculator, metric_name: str, output_dir: Optional[str] = None) -> Optional[
  2068. str]:
  2069. """Generate chart data for traffic metrics"""
  2070. # 待实现
  2071. return None
  2072. def calculate_distance(ego_df, correctwarning):
  2073. """计算预警距离"""
  2074. dist = ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]['relative_dist']
  2075. return dist
  2076. def calculate_relative_speed(ego_df, correctwarning):
  2077. """计算相对速度"""
  2078. return ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]['composite_v']
  2079. # 使用function.py中已实现的scenario_sign_dict
  2080. from modules.metric.function import scenario_sign_dict
  2081. if __name__ == "__main__":
  2082. # 测试代码
  2083. print("Metrics visualization utilities loaded.")