Browse Source

开放保存画图csv的路径接口

XGJ_zhaoyuan 19 hours ago
parent
commit
ad2d91a42a

+ 8 - 11
modules/lib/chart_generator.py

@@ -95,10 +95,9 @@ def generate_function_chart_data(function_calculator, metric_name: str, output_d
 
     try:
         # 确保输出目录存在
-        if output_dir:
-            os.makedirs(output_dir, exist_ok=True)
-        else:
+        if not output_dir:
             output_dir = os.path.join(os.getcwd(), 'data')
+            os.makedirs(output_dir, exist_ok=True)
 
         # 根据指标名称选择不同的图表生成方法
         if metric_name.lower() == 'latestwarningdistance_ttc_lst':
@@ -427,6 +426,7 @@ def generate_latest_warning_ttc_pgvil_chart(function_calculator, output_dir: str
         logger.error(f"Failed to generate latestwarningdistance_ttc_pgvil chart: {str(e)}", exc_info=True)
         return None
 
+
 def generate_latest_warning_ttc_chart(function_calculator, output_dir: str) -> Optional[str]:
     """
     Generate TTC warning chart with data visualization.
@@ -534,6 +534,7 @@ def generate_latest_warning_ttc_chart(function_calculator, output_dir: str) -> O
         logger.error(f"Failed to generate latestwarningdistance_ttc_lst chart: {str(e)}", exc_info=True)
         return None
 
+
 def generate_latest_warning_distance_chart(function_calculator, output_dir: str) -> Optional[str]:
     """
     Generate warning distance chart with data visualization.
@@ -623,6 +624,7 @@ def generate_latest_warning_distance_chart(function_calculator, output_dir: str)
         logger.error(f"Failed to generate latestWarningDistance_LST chart: {str(e)}", exc_info=True)
         return None
 
+
 def generate_latest_warning_distance_pgvil_chart(function_calculator, output_dir: str) -> Optional[str]:
     """
     Generate warning distance chart with data visualization.
@@ -814,6 +816,7 @@ def generate_earliest_warning_distance_ttc_chart(function_calculator, output_dir
         logger.error(f"Failed to generate earliestwarningdistance_ttc_lst chart: {str(e)}", exc_info=True)
         return None
 
+
 def generate_earliest_warning_distance_ttc_pgvil_chart(function_calculator, output_dir: str) -> Optional[str]:
     """
     Generate TTC warning chart with data visualization for earliestWarningDistance_TTC_PGVIL metric.
@@ -917,6 +920,7 @@ def generate_earliest_warning_distance_ttc_pgvil_chart(function_calculator, outp
         logger.error(f"Failed to generate earliestwarningdistance_ttc_pgvil chart: {str(e)}", exc_info=True)
         return None
 
+
 def generate_limit_speed_chart(function_calculator, output_dir: str) -> Optional[str]:
     """
     Generate limit speed chart with data visualization for limitSpeed_LST metric.
@@ -2113,7 +2117,6 @@ def generate_mttc_chart(safety_calculator, output_dir: str) -> Optional[str]:
                 logger.info(
                     f"MTTC不安全事件 #{i + 1}: 开始时间={event['start_time']:.2f}s, 结束时间={event['end_time']:.2f}s, 持续时间={event['duration']:.2f}s, 最小MTTC={event['min_mttc']:.2f}s")
 
-        
         logger.info(f"MTTC data saved to: {csv_filename}")
         return csv_filename
 
@@ -2217,7 +2220,6 @@ def generate_thw_chart(safety_calculator, output_dir: str) -> Optional[str]:
                 logger.info(
                     f"THW不安全事件 #{i + 1}: 开始时间={event['start_time']:.2f}s, 结束时间={event['end_time']:.2f}s, 持续时间={event['duration']:.2f}s, 最小THW={event['min_thw']:.2f}s")
 
-        
         logger.info(f"THW data saved to: {csv_filename}")
         return csv_filename
 
@@ -2285,7 +2287,6 @@ def generate_lonsd_chart(safety_calculator, output_dir: str) -> Optional[str]:
         df_csv['max_threshold'] = max_threshold
         df_csv.to_csv(csv_filename, index=False)
 
-        
         logger.info(f"Longitudinal Safe Distance data saved to: {csv_filename}")
         return csv_filename
 
@@ -2389,7 +2390,6 @@ def generate_latsd_chart(safety_calculator, output_dir: str) -> Optional[str]:
                 logger.info(
                     f"LatSD不安全事件 #{i + 1}: 开始时间={event['start_time']:.2f}s, 结束时间={event['end_time']:.2f}s, 持续时间={event['duration']:.2f}s, 最小LatSD={event['min_latsd']:.2f}m")
 
-        
         logger.info(f"Lateral Safe Distance data saved to: {csv_filename}")
         return csv_filename
 
@@ -2493,7 +2493,6 @@ def generate_btn_chart(safety_calculator, output_dir: str) -> Optional[str]:
                 logger.info(
                     f"BTN不安全事件 #{i + 1}: 开始时间={event['start_time']:.2f}s, 结束时间={event['end_time']:.2f}s, 持续时间={event['duration']:.2f}s, 最大BTN={event['max_btn']:.2f}")
 
-        
         logger.info(f"Brake Threat Number data saved to: {csv_filename}")
         return csv_filename
 
@@ -2561,7 +2560,6 @@ def generate_collision_risk_chart(safety_calculator, output_dir: str) -> Optiona
         df_csv['max_threshold'] = max_threshold
         df_csv.to_csv(csv_filename, index=False)
 
-        
         logger.info(f"Collision Risk data saved to: {csv_filename}")
         return csv_filename
 
@@ -2620,7 +2618,7 @@ def generate_collision_severity_chart(safety_calculator, output_dir: str) -> Opt
         # chart_filename = os.path.join(output_dir, f"collision_severity_chart.png")
         # plt.savefig(chart_filename, dpi=300)
         # plt.close()
-        #logger.info(f"Collision Severity chart saved to: {chart_filename}")
+        # logger.info(f"Collision Severity chart saved to: {chart_filename}")
 
         # 保存CSV数据,包含阈值信息
         csv_filename = os.path.join(output_dir, f"collisionseverity_data.csv")
@@ -2629,7 +2627,6 @@ def generate_collision_severity_chart(safety_calculator, output_dir: str) -> Opt
         df_csv['max_threshold'] = max_threshold
         df_csv.to_csv(csv_filename, index=False)
 
-        
         logger.info(f"Collision Severity data saved to: {csv_filename}")
         return csv_filename
 

+ 60 - 60
modules/metric/comfort.py

@@ -157,98 +157,98 @@ COMFORT_INFO = [
 # 独立指标计算函数
 # ----------------------
 # 更新指标计算函数,返回事件次数而非指标值
-def calculate_motioncomfortindex(data_processed) -> dict:
+def calculate_motioncomfortindex(data_processed, plot_path) -> dict:
     """计算运动舒适度指数事件次数"""
     comfort = ComfortCalculator(data_processed)
     # 计算舒适度指数并检测事件
-    comfort.calculate_motion_comfort_index()
+    comfort.calculate_motion_comfort_index(plot_path)
     # 统计事件类型为'motionComfortIndex'的事件次数
     count = len(comfort.discomfort_df[comfort.discomfort_df['type'] == 'motionComfortIndex'])
     return {"motionComfortIndex": float(count)}
 
 
-def calculate_ridequalityscore(data_processed) -> dict:
+def calculate_ridequalityscore(data_processed, plot_path) -> dict:
     """计算乘坐质量评分事件次数"""
     comfort = ComfortCalculator(data_processed)
     # 计算乘坐质量评分并检测事件
-    comfort.calculate_ride_quality_score()
+    comfort.calculate_ride_quality_score(plot_path)
     # 统计事件类型为'rideQualityScore'的事件次数
     count = len(comfort.discomfort_df[comfort.discomfort_df['type'] == 'rideQualityScore'])
     return {"rideQualityScore": float(count)}
 
 
-def calculate_motionsickness(data_processed) -> dict:
+def calculate_motionsickness(data_processed, plot_path) -> dict:
     """计算晕车概率事件次数"""
     comfort = ComfortCalculator(data_processed)
     # 计算晕车概率并检测事件
-    comfort.calculate_motion_sickness_probability()
+    comfort.calculate_motion_sickness_probability(plot_path)
     # 统计事件类型为'motionSickness'的事件次数
     count = len(comfort.discomfort_df[comfort.discomfort_df['type'] == 'motionSickness'])
     return {"motionSickness": float(count)}
 
 
-def calculate_vdv(data_processed) -> dict:
+def calculate_vdv(data_processed, plot_path) -> dict:
     """计算振动剂量值(VDV)事件次数"""
     comfort = ComfortCalculator(data_processed)
     # 计算VDV并检测事件
-    comfort.calculate_vdv()
+    comfort.calculate_vdv(plot_path)
     # 统计事件类型为'vdv'的事件次数
     count = len(comfort.discomfort_df[comfort.discomfort_df['type'] == 'vdv'])
     return {"vdv": float(count)}
 
 
-def calculate_ava_vav(data_processed) -> dict:
+def calculate_ava_vav(data_processed, plot_path) -> dict:
     """计算多维度综合加权加速度事件次数"""
     comfort = ComfortCalculator(data_processed)
     # 计算AVA/VAV并检测事件
-    comfort.calculate_ava_vav()
+    comfort.calculate_ava_vav(plot_path)
     # 统计事件类型为'ava_vav'的事件次数
     count = len(comfort.discomfort_df[comfort.discomfort_df['type'] == 'ava_vav'])
     return {"ava_vav": float(count)}
 
 
-def calculate_msdv(data_processed) -> dict:
+def calculate_msdv(data_processed, plot_path) -> dict:
     """计算晕动剂量值(MSDV)事件次数"""
     comfort = ComfortCalculator(data_processed)
     # 计算MSDV并检测事件
-    comfort.calculate_msdv()
+    comfort.calculate_msdv(plot_path)
     # 统计事件类型为'msdv'的事件次数
     count = len(comfort.discomfort_df[comfort.discomfort_df['type'] == 'msdv'])
     return {"msdv": float(count)}
 
 
-def calculate_zigzag(data_processed) -> dict:
+def calculate_zigzag(data_processed, plot_path) -> dict:
     """计算蛇行指标"""
     comfort = ComfortCalculator(data_processed)
-    zigzag_count = comfort.calculate_zigzag_count()
+    zigzag_count = comfort.calculate_zigzag_count(plot_path)
     return {"zigzag": float(zigzag_count)}
 
 
-def calculate_shake(data_processed) -> dict:
+def calculate_shake(data_processed, plot_path) -> dict:
     """计算晃动指标"""
     comfort = ComfortCalculator(data_processed)
-    shake_count = comfort.calculate_shake_count()
+    shake_count = comfort.calculate_shake_count(plot_path)
     return {"shake": float(shake_count)}
 
 
-def calculate_cadence(data_processed) -> dict:
+def calculate_cadence(data_processed, plot_path) -> dict:
     """计算顿挫指标"""
     comfort = ComfortCalculator(data_processed)
-    cadence_count = comfort.calculate_cadence_count()
+    cadence_count = comfort.calculate_cadence_count(plot_path)
     return {"cadence": float(cadence_count)}
 
 
-def calculate_slambrake(data_processed) -> dict:
+def calculate_slambrake(data_processed, plot_path) -> dict:
     """计算急刹车指标"""
     comfort = ComfortCalculator(data_processed)
-    slam_brake_count = comfort.calculate_slam_brake_count()
+    slam_brake_count = comfort.calculate_slam_brake_count(plot_path)
     return {"slamBrake": float(slam_brake_count)}
 
 
-def calculate_slamaccelerate(data_processed) -> dict:
+def calculate_slamaccelerate(data_processed, plot_path) -> dict:
     """计算急加速指标"""
     comfort = ComfortCalculator(data_processed)
-    slam_accel_count = comfort.calculate_slam_accel_count()
+    slam_accel_count = comfort.calculate_slam_accel_count(plot_path)
     return {"slamAccelerate": float(slam_accel_count)}
 
 
@@ -284,11 +284,11 @@ def peak_valley_decorator(method):
 class ComfortRegistry:
     """舒适性指标注册器"""
 
-    def __init__(self, data_processed):
+    def __init__(self, data_processed, plot_path):
         self.logger = LogManager().get_logger()  # 获取全局日志实例
         self.data = data_processed
-        self.output_dir = None  # 图表数据输出目录
-        
+        self.output_dir = plot_path  # 图表数据输出目录
+
         # 检查comfort_config是否为空
         if not hasattr(data_processed, 'comfort_config') or not data_processed.comfort_config:
             self.logger.warning("舒适性配置为空,跳过舒适性指标计算")
@@ -296,7 +296,7 @@ class ComfortRegistry:
             self.metrics = []
             self._registry = {}
             return
-            
+
         self.comfort_config = data_processed.comfort_config.get("comfort", {})
         self.metrics = self._extract_metrics(self.comfort_config)
         self._registry = self._build_registry()
@@ -332,7 +332,7 @@ class ComfortRegistry:
         results = {}
         for name, func in self._registry.items():
             try:
-                result = func(self.data)
+                result = func(self.data, self.output_dir)
                 results.update(result)
                 # 新增:将每个指标的结果写入日志
                 self.logger.info(f'舒适性指标[{name}]计算结果: {result}')
@@ -344,15 +344,16 @@ class ComfortRegistry:
 class ComfortManager:
     """舒适性指标计算主类"""
 
-    def __init__(self, data_processed):
+    def __init__(self, data_processed, plot_path):
         self.data = data_processed
         self.logger = LogManager().get_logger()
+        self.plot_path = plot_path
         # 检查comfort_config是否为空
         if not hasattr(data_processed, 'comfort_config') or not data_processed.comfort_config:
             self.logger.warning("舒适性配置为空,跳过舒适性指标计算初始化")
             self.registry = None
         else:
-            self.registry = ComfortRegistry(self.data)
+            self.registry = ComfortRegistry(self.data, self.plot_path)
 
     def report_statistic(self):
         """生成舒适性评分报告"""
@@ -360,7 +361,7 @@ class ComfortManager:
         if self.registry is None:
             self.logger.info("舒适性指标管理器未初始化,返回空结果")
             return {}
-            
+
         comfort_result = self.registry.batch_execute()
 
         return comfort_result
@@ -368,7 +369,7 @@ class ComfortManager:
 class ComfortCalculator:
     """舒适性指标计算类 - 提供核心计算功能"""
 
-    def generate_metric_chart(self, metric_name: str) -> None:
+    def generate_metric_chart(self, metric_name: str, plot_path: Path) -> None:
         """
         生成指标图表
 
@@ -376,12 +377,12 @@ class ComfortCalculator:
             metric_name: 指标名称
         """
         # 设置输出目录
-        if not hasattr(self, 'output_dir') or not self.output_dir:
-            self.output_dir = os.path.join(os.getcwd(), 'data')
-            os.makedirs(self.output_dir, exist_ok=True)
+        if not plot_path:
+            plot_path = os.path.join(os.getcwd(), 'data')
+            os.makedirs(plot_path, exist_ok=True)
 
         # 调用chart_generator中的函数生成图表
-        chart_path = generate_comfort_chart_data(self, metric_name, self.output_dir)
+        chart_path = generate_comfort_chart_data(self, metric_name, plot_path)
         if chart_path:
             self.logger.info(f"{metric_name}图表已生成: {chart_path}")
 
@@ -547,7 +548,7 @@ class ComfortCalculator:
 
         return pd.Series(filtered_data, index=acceleration_data.index)
 
-    def calculate_motion_comfort_index(self):
+    def calculate_motion_comfort_index(self, plot_path):
         """
         计算运动舒适度指数(Motion Comfort Index)并检测低舒适度事件
 
@@ -631,7 +632,7 @@ class ComfortCalculator:
 
         return comfort_index
 
-    def calculate_ride_quality_score(self):
+    def calculate_ride_quality_score(self, plot_path):
         """
         计算乘坐质量评分(Ride Quality Score)并检测低质量事件
 
@@ -641,7 +642,7 @@ class ComfortCalculator:
         同时检测评分低于阈值(60)的事件
         """
         # 实际计算乘坐质量评分
-        ride_quality_score = self._calculate_ride_quality_score()
+        ride_quality_score = self._calculate_ride_quality_score(plot_path)
 
         # 直接设置阈值
         self._detect_threshold_events(
@@ -653,7 +654,7 @@ class ComfortCalculator:
 
         return ride_quality_score
 
-    def _calculate_ride_quality_score(self):
+    def _calculate_ride_quality_score(self, plot_path):
         """实际计算乘坐质量评分"""
         df = self.ego_df.copy()
 
@@ -712,7 +713,7 @@ class ComfortCalculator:
 
         return ride_quality_score
 
-    def calculate_motion_sickness_probability(self):
+    def calculate_motion_sickness_probability(self, plot_path):
         """计算晕车概率指标并检测高概率事件"""
         # 实际计算晕车概率
         motion_sickness_prob = self._calculate_motion_sickness_probability()
@@ -794,10 +795,10 @@ class ComfortCalculator:
 
         return probability
 
-    def calculate_vdv(self):
+    def calculate_vdv(self, plot_path):
         """计算振动剂量值(Vibration Dose Value, VDV)指标并检测高VDV事件"""
         # 实际计算VDV
-        vdv_value = self._calculate_vdv()
+        vdv_value = self._calculate_vdv(plot_path)
 
         # 直接设置阈值
         self._detect_threshold_events(
@@ -806,11 +807,11 @@ class ComfortCalculator:
             min_threshold=0.0,
             max_threshold=8.0  # 硬编码阈值
         )
-        # self.generate_metric_chart('vdv')
+        # self.generate_metric_chart('vdv', plot_path)
 
         return vdv_value
 
-    def _calculate_vdv(self):
+    def _calculate_vdv(self, plot_path):
         """实际计算振动剂量值"""
         # 获取数据
         df = self.ego_df.copy()
@@ -870,11 +871,11 @@ class ComfortCalculator:
         self.logger.info(f"X方向VDV: {vdv_x}, Y方向VDV: {vdv_y}, Z方向VDV: {vdv_z}")
 
         # 生成VDV指标图表
-        # self.generate_metric_chart('vdv')
+        # self.generate_metric_chart('vdv', plot_path)
 
         return vdv
 
-    def calculate_ava_vav(self):
+    def calculate_ava_vav(self, plot_path):
         """计算多维度综合加权加速度并检测高值事件"""
         # 实际计算AVA/VAV
         ava_vav_value = self._calculate_ava_vav()
@@ -887,7 +888,7 @@ class ComfortCalculator:
             min_threshold=0.0,
             max_threshold=0.63  # 硬编码阈值
         )
-        # self.generate_metric_chart('ava_vav')
+        # self.generate_metric_chart('ava_vav', plot_path)
         return ava_vav_value
 
     def _calculate_ava_vav(self):
@@ -961,7 +962,7 @@ class ComfortCalculator:
 
         return ava_vav
 
-    def calculate_msdv(self):
+    def calculate_msdv(self, plot_path):
         """计算晕动剂量值(Motion Sickness Dose Value, MSDV)指标并检测高值事件"""
         # 实际计算MSDV
         msdv_value = self._calculate_msdv()
@@ -974,7 +975,7 @@ class ComfortCalculator:
             min_threshold=0.0,
             max_threshold=6.0  # 硬编码阈值
         )
-        # self.generate_metric_chart('msdv')
+        # self.generate_metric_chart('msdv', plot_path)
         return msdv_value
 
     def _calculate_msdv(self):
@@ -1038,14 +1039,14 @@ class ComfortCalculator:
 
         return msdv
 
-    def calculate_zigzag_count(self):
+    def calculate_zigzag_count(self, plot_path):
         """计算蛇行指标并检测事件"""
         # 原有的计算逻辑
         self._zigzag_detector()
 
         # 检测蛇行事件
         zigzag_events = self._detect_zigzag_events()
-        self.generate_metric_chart('zigzag')
+        self.generate_metric_chart('zigzag', plot_path)
 
         # 返回事件次数
         return len(zigzag_events)
@@ -1077,14 +1078,14 @@ class ComfortCalculator:
 
         return events
 
-    def calculate_shake_count(self):
+    def calculate_shake_count(self, plot_path):
         """计算晃动指标并检测事件"""
         # 原有的计算逻辑
         shake_events = self._shake_detector()
 
         # 检测晃动事件
         # shake_events = self._detect_shake_events()
-        self.generate_metric_chart('shake')
+        self.generate_metric_chart('shake', plot_path)
 
         # 返回事件次数
         return len(shake_events)
@@ -1118,11 +1119,11 @@ class ComfortCalculator:
 
         return events
 
-    def calculate_cadence_count(self):
+    def calculate_cadence_count(self, plot_path):
         """计算顿挫指标并检测事件"""
         # 原有的计算逻辑
         cadence_events = self._cadence_detector()
-        self.generate_metric_chart('cadence')
+        self.generate_metric_chart('cadence', plot_path)
         # 返回事件次数
         return len(cadence_events)
 
@@ -1222,14 +1223,14 @@ class ComfortCalculator:
 
         return cadence_time_ranges
 
-    def calculate_slam_brake_count(self):
+    def calculate_slam_brake_count(self, plot_path):
         """计算急刹车指标并检测事件"""
         # 原有的计算逻辑
         self._slam_brake_detector()
 
         # 返回事件次数
         # 生成急刹车指标图表
-        self.generate_metric_chart('slamBrake')
+        self.generate_metric_chart('slamBrake', plot_path)
         return self.slam_brake_count
 
     def _slam_brake_detector(self):
@@ -1323,11 +1324,11 @@ class ComfortCalculator:
         self.slam_brake_count = len(slam_brake_events)
         self.logger.info(f"检测到 {self.slam_brake_count} 次急刹车事件")
 
-    def calculate_slam_accel_count(self):
+    def calculate_slam_accel_count(self, plot_path):
         """计算急加速指标并检测事件"""
         # 原有的计算逻辑
         self._slam_accel_detector()
-        self.generate_metric_chart('slamaccelerate')
+        self.generate_metric_chart('slamaccelerate', plot_path)
         # 返回事件次数
         return self.slam_accel_count
 
@@ -1882,4 +1883,3 @@ class ComfortCalculator:
 
 
 
-

+ 73 - 64
modules/metric/efficient.py

@@ -21,10 +21,10 @@ import pandas as pd
 
 class Efficient:
     """高效性指标计算类"""
-    
+
     def __init__(self, data_processed):
         """初始化高效性指标计算类
-        
+
         Args:
             data_processed: 预处理后的数据对象
         """
@@ -32,17 +32,17 @@ class Efficient:
         self.data_processed = data_processed
         self.df = data_processed.object_df.copy()  # 浅拷贝
         self.ego_df = data_processed.ego_data.copy()  # 浅拷贝
-        
+
         # 配置参数
         self.STOP_SPEED_THRESHOLD = 0.05  # 停车速度阈值 (m/s)
-        self.STOP_TIME_THRESHOLD = 0.5    # 停车时间阈值 (秒)
-        self.FRAME_RANGE = 13             # 停车帧数阈值
-        
+        self.STOP_TIME_THRESHOLD = 0.5  # 停车时间阈值 (秒)
+        self.FRAME_RANGE = 13  # 停车帧数阈值
+
         # 初始化结果变量
-        self.stop_count = 0     # 停车次数
+        self.stop_count = 0  # 停车次数
         self.stop_duration = 0  # 平均停车时长
-        self.average_v = 0      # 平均速度
-        
+        self.average_v = 0  # 平均速度
+
         # 统计指标结果字典
         self.calculated_value = {
             'maxSpeed': 0,
@@ -52,10 +52,10 @@ class Efficient:
             'speedUtilizationRatio': 0,
             'accelerationSmoothness': 0  # 添加新指标的默认值
         }
-        
+
     def _max_speed(self):
         """计算最大速度
-        
+
         Returns:
             float: 最大速度 (m/s)
         """
@@ -65,7 +65,7 @@ class Efficient:
 
     def _deviation_speed(self):
         """计算速度方差
-        
+
         Returns:
             float: 速度方差
         """
@@ -75,20 +75,21 @@ class Efficient:
 
     def average_velocity(self):
         """计算平均速度
-        
+
         Returns:
             float: 平均速度 (km/h)
         """
         self.average_v = self.ego_df['v'].mean() * 3.6  # 转换为 km/h
         self.calculated_value['averagedSpeed'] = self.average_v
         return self.average_v
+
     def acceleration_smoothness(self):
         """计算加速度平稳度
-        
+
         加速度平稳度用以衡量车辆加减速过程的平滑程度,
         通过计算加速度序列的波动程度(标准差)来评估。
         平稳度指标定义为 1-σ_a/a_max(归一化后靠近1代表加速度更稳定)。
-        
+
         Returns:
             float: 加速度平稳度 (0-1之间的比率,越接近1表示越平稳)
         """
@@ -98,13 +99,13 @@ class Efficient:
             # 使用车辆坐标系下的加速度计算合成加速度
             lon_acc = self.ego_df['lon_acc_vehicle'].values
             lat_acc = self.ego_df['lat_acc_vehicle'].values
-            accel_magnitude = np.sqrt(lon_acc**2 + lat_acc**2)
+            accel_magnitude = np.sqrt(lon_acc ** 2 + lat_acc ** 2)
             self.logger.info("使用车辆坐标系下的加速度计算合成加速度")
         elif 'accelX' in self.ego_df.columns and 'accelY' in self.ego_df.columns:
             # 计算合成加速度(考虑X和Y方向)
             accel_x = self.ego_df['accelX'].values
             accel_y = self.ego_df['accelY'].values
-            accel_magnitude = np.sqrt(accel_x**2 + accel_y**2)
+            accel_magnitude = np.sqrt(accel_x ** 2 + accel_y ** 2)
             self.logger.info("使用accelX和accelY计算合成加速度")
         else:
             # 从速度差分计算加速度
@@ -114,47 +115,48 @@ class Efficient:
             time_diff[time_diff == 0] = 1e-6
             accel_magnitude = np.abs(np.diff(velocity, prepend=velocity[0]) / time_diff)
             self.logger.info("从速度差分计算加速度")
-        
+
         # 过滤掉异常值(可选)
         # 使用3倍标准差作为阈值
         mean_accel = np.mean(accel_magnitude)
         std_accel = np.std(accel_magnitude)
         threshold = mean_accel + 3 * std_accel
         filtered_accel = accel_magnitude[accel_magnitude <= threshold]
-        
+
         # 如果过滤后数据太少,则使用原始数据
         if len(filtered_accel) < len(accel_magnitude) * 0.8:
             filtered_accel = accel_magnitude
             self.logger.info("过滤后数据太少,使用原始加速度数据")
         else:
             self.logger.info(f"过滤掉 {len(accel_magnitude) - len(filtered_accel)} 个异常加速度值")
-        
+
         # 计算加速度标准差
         accel_std = np.std(filtered_accel)
-        
+
         # 计算最大加速度(使用95百分位数以避免极端值影响)
         accel_max = np.percentile(filtered_accel, 95)
-        
+
         # 防止除以零
         if accel_max < 0.001:
             accel_max = 0.001
-        
+
         # 计算平稳度指标: 1 - σ_a/a_max
         smoothness = 1.0 - (accel_std / accel_max)
-        
+
         # 限制在0-1范围内
         smoothness = np.clip(smoothness, 0.0, 1.0)
-        
+
         self.calculated_value['accelerationSmoothness'] = smoothness
-        
+
         self.logger.info(f"加速度标准差: {accel_std:.4f} m/s²")
         self.logger.info(f"加速度最大值(95百分位): {accel_max:.4f} m/s²")
         self.logger.info(f"加速度平稳度(Acceleration Smoothness): {smoothness:.4f}")
-        
+
         return smoothness
+
     def stop_duration_and_count(self):
         """计算停车次数和平均停车时长
-        
+
         Returns:
             float: 平均停车时长 (秒)
         """
@@ -163,18 +165,18 @@ class Efficient:
         if not any(stop_mask):
             self.calculated_value['stopDuration'] = 0
             return 0  # 如果没有停车,直接返回0
-            
+
         stop_time_list = self.ego_df.loc[stop_mask, 'simTime'].values.tolist()
         stop_frame_list = self.ego_df.loc[stop_mask, 'simFrame'].values.tolist()
-        
+
         if not stop_frame_list:
             return 0  # 防止空列表导致的索引错误
-            
+
         stop_frame_group = []
         stop_time_group = []
         sum_stop_time = 0
         f1, t1 = stop_frame_list[0], stop_time_list[0]
-        
+
         # 检测停车段
         for i in range(1, len(stop_frame_list)):
             if stop_frame_list[i] - stop_frame_list[i - 1] != 1:  # 帧不连续
@@ -187,7 +189,7 @@ class Efficient:
                     self.stop_count += 1
                 # 更新起始点
                 f1, t1 = stop_frame_list[i], stop_time_list[i]
-        
+
         # 检查最后一段停车
         if len(stop_frame_list) > 0:
             f2, t2 = stop_frame_list[-1], stop_time_list[-1]
@@ -198,31 +200,31 @@ class Efficient:
                 stop_time_group.append((t1, t2))
                 sum_stop_time += (t2 - t1)
                 self.stop_count += 1
-        
+
         # 计算平均停车时长
         self.stop_duration = sum_stop_time / self.stop_count if self.stop_count > 0 else 0
         self.calculated_value['stopDuration'] = self.stop_duration
-        
+
         self.logger.info(f"检测到停车次数: {self.stop_count}, 平均停车时长: {self.stop_duration:.2f}秒")
         return self.stop_duration
 
     def speed_utilization_ratio(self, default_speed_limit=60.0):
         """计算速度利用率
-        
+
         速度利用率度量车辆实际速度与道路限速之间的比率,
         反映车辆对道路速度资源的利用程度。
-        
+
         计算公式: R_v = v_actual / v_limit
-        
+
         Args:
             default_speed_limit: 默认道路限速 (km/h),当无法获取实际限速时使用
-            
+
         Returns:
             float: 速度利用率 (0-1之间的比率)
         """
         # 获取车辆速度数据 (m/s)
         speeds = self.ego_df['v'].values
-        
+
         # 尝试从数据中获取道路限速信息
         # 首先检查road_speed_max列,其次检查speedLimit列,最后使用默认值
         if 'road_speed_max' in self.ego_df.columns:
@@ -236,34 +238,37 @@ class Efficient:
             default_limit_ms = default_speed_limit / 3.6
             speed_limits = np.full_like(speeds, default_limit_ms)
             self.logger.info(f"未找到道路限速信息,使用默认限速: {default_speed_limit} km/h")
-        
+
         # 确保限速值为m/s单位,如果数据是km/h需要转换
         # 假设如果限速值大于30,则认为是km/h单位,需要转换为m/s
         if np.mean(speed_limits) > 30:
             speed_limits = speed_limits / 3.6
             self.logger.info("将限速单位从km/h转换为m/s")
-        
+
         # 计算每一帧的速度利用率
-        ratios = np.divide(speeds, speed_limits, 
-                          out=np.zeros_like(speeds), 
-                          where=speed_limits!=0)
-        
+        ratios = np.divide(speeds, speed_limits,
+                           out=np.zeros_like(speeds),
+                           where=speed_limits != 0)
+
         # 限制比率不超过1(超速按1计算)
         ratios = np.minimum(ratios, 1.0)
-        
+
         # 计算平均速度利用率
         avg_ratio = np.mean(ratios)
         self.calculated_value['speedUtilizationRatio'] = avg_ratio
-        
+
         self.logger.info(f"速度利用率(Speed Utilization Ratio): {avg_ratio:.4f}")
         return avg_ratio
 
+
 class EfficientManager:
-    """高效性指标管理类"""  
-    def __init__(self, data_processed):
+    """高效性指标管理类"""
+
+    def __init__(self, data_processed, plot_path):
         self.data = data_processed
         self.efficient = EfficientRegistry(self.data)
-    
+        self.plot_path = plot_path
+
     def report_statistic(self):
         """Generate the statistics and report the results."""
         # 使用注册表批量执行指标计算
@@ -271,9 +276,6 @@ class EfficientManager:
         return efficient_result
 
 
-    
-
-    
 # ----------------------
 # 基础指标计算函数
 # ----------------------
@@ -283,30 +285,35 @@ def maxSpeed(data_processed) -> dict:
     max_speed = efficient._max_speed()
     return {"maxSpeed": float(max_speed)}
 
+
 def deviationSpeed(data_processed) -> dict:
     """计算速度方差"""
     efficient = Efficient(data_processed)
     deviation = efficient._deviation_speed()
     return {"deviationSpeed": float(deviation)}
 
+
 def averagedSpeed(data_processed) -> dict:
     """计算平均速度"""
     efficient = Efficient(data_processed)
     avg_speed = efficient.average_velocity()
     return {"averagedSpeed": float(avg_speed)}
 
+
 def stopDuration(data_processed) -> dict:
     """计算停车持续时间和次数"""
     efficient = Efficient(data_processed)
     stop_duration = efficient.stop_duration_and_count()
     return {"stopDuration": float(stop_duration)}
 
+
 def speedUtilizationRatio(data_processed) -> dict:
     """计算速度利用率"""
     efficient = Efficient(data_processed)
     ratio = efficient.speed_utilization_ratio()
     return {"speedUtilizationRatio": float(ratio)}
 
+
 def accelerationSmoothness(data_processed) -> dict:
     """计算加速度平稳度"""
     efficient = Efficient(data_processed)
@@ -316,31 +323,32 @@ def accelerationSmoothness(data_processed) -> dict:
 
 class EfficientManager:
     """高效性指标管理类"""
-    
-    def __init__(self, data_processed):
+
+    def __init__(self, data_processed, plot_path):
         self.data = data_processed
         self.logger = LogManager().get_logger()
+        self.plot_path = plot_path
         # 检查efficient_config是否为空
         if not hasattr(data_processed, 'efficient_config') or not data_processed.efficient_config:
             self.logger.warning("高效性配置为空,跳过高效性指标计算初始化")
             self.registry = None
         else:
             self.registry = EfficientRegistry(self.data)
-    
+
     def report_statistic(self):
         """计算并报告高效性指标结果"""
         # 如果registry为None,直接返回空字典
         if self.registry is None:
             self.logger.info("高效性指标管理器未初始化,返回空结果")
             return {}
-            
+
         efficient_result = self.registry.batch_execute()
         return efficient_result
 
 
 class EfficientRegistry:
     """高效性指标注册器"""
-    
+
     def __init__(self, data_processed):
         self.logger = LogManager().get_logger()  # 获取全局日志实例
         self.data = data_processed
@@ -354,20 +362,22 @@ class EfficientRegistry:
         self.eff_config = data_processed.efficient_config.get("efficient", {})
         self.metrics = self._extract_metrics(self.eff_config)
         self._registry = self._build_registry()
-    
+
     def _extract_metrics(self, config_node: dict) -> list:
         """DFS遍历提取指标"""
         metrics = []
+
         def _recurse(node):
             if isinstance(node, dict):
                 if 'name' in node and not any(isinstance(v, dict) for v in node.values()):
                     metrics.append(node['name'])
                 for v in node.values():
                     _recurse(v)
+
         _recurse(config_node)
         self.logger.info(f'评比的高效性指标列表:{metrics}')
         return metrics
-    
+
     def _build_registry(self) -> dict:
         """自动注册指标函数"""
         registry = {}
@@ -377,7 +387,7 @@ class EfficientRegistry:
             except KeyError:
                 self.logger.error(f"未实现指标函数: {metric_name}")
         return registry
-    
+
     def batch_execute(self) -> dict:
         """批量执行指标计算"""
         results = {}
@@ -385,7 +395,7 @@ class EfficientRegistry:
         if not hasattr(self, 'eff_config') or not self.eff_config or not self._registry:
             self.logger.info("高效性配置为空或无注册指标,返回空结果")
             return results
-            
+
         for name, func in self._registry.items():
             try:
                 result = func(self.data)
@@ -399,4 +409,3 @@ class EfficientRegistry:
         return results
 
 
-        

+ 131 - 126
modules/metric/function.py

@@ -40,24 +40,24 @@ scenario_sign_dict = {
     "LeftTurnAssist": 206,
     "HazardousLocationW": 207,
     "RedLightViolationW": 208,
-    "AbnormalVehicleW":209,
-    "NsightVulnerableRoadUserCollisionW":210,
-    "LitterW":211,
-    "ForwardCollisionW":212,
-    "VisibilityW":213,
-    "EmergencyBrakeW":214,
-    "IntersectionCollisionW":215,
-    "BlindSpotW":216,
-    "DoNotPassW":217,
-    "ControlLossW":218,
-    "FrontTrafficJamW":219,
-    "EmergencyVehicleW":220,
-    "CooperativeVehicleMerge":221,
-    "CooperativeLaneChange":223,
-    "VulnerableRoadUserCollisionW":224,
-    "CooperativeIntersectionPassing":225,
-    "RampMerge":226,
-    "DrivingLaneRecommendation":227,
+    "AbnormalVehicleW": 209,
+    "NsightVulnerableRoadUserCollisionW": 210,
+    "LitterW": 211,
+    "ForwardCollisionW": 212,
+    "VisibilityW": 213,
+    "EmergencyBrakeW": 214,
+    "IntersectionCollisionW": 215,
+    "BlindSpotW": 216,
+    "DoNotPassW": 217,
+    "ControlLossW": 218,
+    "FrontTrafficJamW": 219,
+    "EmergencyVehicleW": 220,
+    "CooperativeVehicleMerge": 221,
+    "CooperativeLaneChange": 223,
+    "VulnerableRoadUserCollisionW": 224,
+    "CooperativeIntersectionPassing": 225,
+    "RampMerge": 226,
+    "DrivingLaneRecommendation": 227,
     "TrafficJamW": 228,
     "DynamicSpeedLimitingInformation": 229,
     "EmergencyVehiclesPriority": 230,
@@ -67,6 +67,7 @@ scenario_sign_dict = {
     "GreenLightOptimalSpeedAdvisory": 234,
 }
 
+
 def _is_pedestrian_in_crosswalk(polygon, test_point) -> bool:
     polygon = Polygon(polygon)
     point = Point(test_point)
@@ -182,6 +183,7 @@ def get_first_warning(data_processed) -> Optional[pd.DataFrame]:
     first_time = warning_times.iloc[0]
     return obj_df[obj_df['simTime'] == first_time]
 
+
 def getAxis(heading):
     AxisL = [0, 0]
     AxisW = [0, 0]
@@ -211,15 +213,15 @@ def getProjectionRadius(AxisL, AxisW, baseAxis, halfLength, halfWidth):
 
 
 def isCollision(
-    firstAxisL,
-    firstAxisW,
-    firstHalfLength,
-    firstHalfWidth,
-    secondAxisL,
-    secondAxisW,
-    secondHalfLength,
-    secondHalfWidth,
-    disVector,
+        firstAxisL,
+        firstAxisW,
+        firstHalfLength,
+        firstHalfWidth,
+        secondAxisL,
+        secondAxisW,
+        secondHalfLength,
+        secondHalfWidth,
+        disVector,
 ):
     isCollision = True
     axes = [firstAxisL, firstAxisW, secondAxisL, secondAxisW]
@@ -241,20 +243,20 @@ def isCollision(
 
 
 def funcIsCollision(
-    firstDimX,
-    firstDimY,
-    firstOffX,
-    firstOffY,
-    firstX,
-    firstY,
-    firstHeading,
-    secondDimX,
-    secondDimY,
-    secondOffX,
-    secondOffY,
-    secondX,
-    secondY,
-    secondHeading,
+        firstDimX,
+        firstDimY,
+        firstOffX,
+        firstOffY,
+        firstX,
+        firstY,
+        firstHeading,
+        secondDimX,
+        secondDimY,
+        secondOffX,
+        secondOffY,
+        secondX,
+        secondY,
+        secondHeading,
 ):
     firstAxisL = getAxis(firstHeading)[0]
     firstAxisW = getAxis(firstHeading)[1]
@@ -292,10 +294,11 @@ def funcIsCollision(
 
     return varIsCollision
 
+
 # ----------------------
 # 核心计算功能函数
 # ----------------------
-def latestWarningDistance_LST(data) -> dict:
+def latestWarningDistance_LST(data, plot_path) -> dict:
     """预警距离计算流水线"""
     scenario_name = find_nested_name(data.function_config["function"])
     value = data.function_config["function"][scenario_name]["latestWarningDistance_LST"]["max"]
@@ -313,12 +316,12 @@ def latestWarningDistance_LST(data) -> dict:
         return {"latestWarningDistance_LST": 0.0}
 
     # 生成图表数据
-    generate_function_chart_data(data, 'latestWarningDistance_LST')
+    generate_function_chart_data(data, 'latestWarningDistance_LST', plot_path)
 
     return {"latestWarningDistance_LST": float(warning_dist.iloc[-1]) if len(warning_dist) > 0 else value}
 
 
-def earliestWarningDistance_LST(data) -> dict:
+def earliestWarningDistance_LST(data, plot_path) -> dict:
     """预警距离计算流水线"""
     scenario_name = find_nested_name(data.function_config["function"])
     value = data.function_config["function"][scenario_name]["earliestWarningDistance_LST"]["max"]
@@ -337,12 +340,12 @@ def earliestWarningDistance_LST(data) -> dict:
 
     # 生成图表数据
 
-    generate_function_chart_data(data, 'earliestWarningDistance_LST')
+    generate_function_chart_data(data, 'earliestWarningDistance_LST', plot_path)
 
     return {"earliestWarningDistance_LST": float(warning_dist.iloc[0]) if len(warning_dist) > 0 else value}
 
 
-def latestWarningDistance_TTC_LST(data) -> dict:
+def latestWarningDistance_TTC_LST(data, plot_path) -> dict:
     """TTC计算流水线"""
     scenario_name = find_nested_name(data.function_config["function"])
     value = data.function_config["function"][scenario_name]["latestWarningDistance_TTC_LST"]["max"]
@@ -369,12 +372,12 @@ def latestWarningDistance_TTC_LST(data) -> dict:
     data.ttc = ttc
     # 生成图表数据
     # from modules.lib.chart_generator import generate_function_chart_data
-    generate_function_chart_data(data, 'latestWarningDistance_TTC_LST')
+    generate_function_chart_data(data, 'latestWarningDistance_TTC_LST', plot_path)
 
     return {"latestWarningDistance_TTC_LST": float(ttc[-1]) if len(ttc) > 0 else value}
 
 
-def earliestWarningDistance_TTC_LST(data) -> dict:
+def earliestWarningDistance_TTC_LST(data, plot_path) -> dict:
     """TTC计算流水线"""
     scenario_name = find_nested_name(data.function_config["function"])
     value = data.function_config["function"][scenario_name]["earliestWarningDistance_TTC_LST"]["max"]
@@ -403,12 +406,12 @@ def earliestWarningDistance_TTC_LST(data) -> dict:
     data.correctwarning = correctwarning
 
     # 生成图表数据
-    generate_function_chart_data(data, 'earliestWarningDistance_TTC_LST')
+    generate_function_chart_data(data, 'earliestWarningDistance_TTC_LST', plot_path)
 
     return {"earliestWarningDistance_TTC_LST": float(ttc[0]) if len(ttc) > 0 else value}
 
 
-def warningDelayTime_LST(data):
+def warningDelayTime_LST(data, plot_path):
     scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
     # 将correctwarning保存到data对象中,供图表生成使用
@@ -427,7 +430,7 @@ def warningDelayTime_LST(data):
     return {"warningDelayTime_LST": delay_time}
 
 
-def warningDelayTimeofReachDecel_LST(data):
+def warningDelayTimeofReachDecel_LST(data, plot_path):
     scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
     # 将correctwarning保存到data对象中,供图表生成使用
@@ -451,7 +454,7 @@ def warningDelayTimeofReachDecel_LST(data):
         return {"warningDelayTimeofReachDecel_LST": warning_simTime[0] - obj_speed_simtime[0]}
 
 
-def rightWarningSignal_LST(data):
+def rightWarningSignal_LST(data, plot_path):
     scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
     # 将correctwarning保存到data对象中,供图表生成使用
@@ -467,7 +470,7 @@ def rightWarningSignal_LST(data):
         return {"rightWarningSignal_LST": 1}
 
 
-def ifCrossingRedLight_LST(data):
+def ifCrossingRedLight_LST(data, plot_path):
     scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
     # 将correctwarning保存到data对象中,供图表生成使用
@@ -482,7 +485,7 @@ def ifCrossingRedLight_LST(data):
         return {"ifCrossingRedLight_LST": 1}
 
 
-def ifStopgreenWaveSpeedGuidance_LST(data):
+def ifStopgreenWaveSpeedGuidance_LST(data, plot_path):
     scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
     # 将correctwarning保存到data对象中,供图表生成使用
@@ -497,7 +500,7 @@ def ifStopgreenWaveSpeedGuidance_LST(data):
 
 
 # ------ 单车智能指标 ------
-def limitSpeed_LST(data):
+def limitSpeed_LST(data, plot_path):
     ego_df = data.ego_data
     scenario_name = find_nested_name(data.function_config["function"])
     limit_speed = data.function_config["function"][scenario_name]["limitSpeed_LST"]["max"]
@@ -506,11 +509,11 @@ def limitSpeed_LST(data):
         return {"speedLimit_LST": -1}
     max_speed = max(speed_limit)
     data.speedLimit = limit_speed
-    generate_function_chart_data(data, 'limitspeed_LST')
+    generate_function_chart_data(data, 'limitspeed_LST', plot_path)
     return {"speedLimit_LST": max_speed}
 
 
-def limitSpeedPastLimitSign_LST(data):
+def limitSpeedPastLimitSign_LST(data, plot_path):
     ego_df = data.ego_data
     scenario_name = find_nested_name(data.function_config["function"])
     limit_speed = data.function_config["function"][scenario_name]["limitSpeed_LST"]["max"]
@@ -519,13 +522,13 @@ def limitSpeedPastLimitSign_LST(data):
     ego_time = ego_df[ego_df['x_relative_dist'] <= -100 - car_length]['simTime'].tolist()
     data.speedLimit = limit_speed
     data.speedPastLimitSign_LST = ego_time[0] if len(ego_time) > 0 else None
-    generate_function_chart_data(data, 'limitSpeedPastLimitSign_LST')
+    generate_function_chart_data(data, 'limitSpeedPastLimitSign_LST', plot_path)
     if len(ego_speed) == 0:
         return {"speedPastLimitSign_LST": -1}
     return {"speedPastLimitSign_LST": ego_speed[0]}
 
 
-def leastDistance_LST(data):
+def leastDistance_LST(data, plot_path):
     ego_df = data.ego_data
     dist_row = ego_df[ego_df['v'] == 0]['relative_dist'].tolist()
     if len(dist_row) == 0:
@@ -535,7 +538,7 @@ def leastDistance_LST(data):
         return {"leastDistance_LST": min_dist}
 
 
-def launchTimeinStopLine_LST(data):
+def launchTimeinStopLine_LST(data, plot_path):
     ego_df = data.ego_data
     simtime_row = ego_df[ego_df['v'] == 0]['simTime'].tolist()
     if len(simtime_row) == 0:
@@ -545,7 +548,7 @@ def launchTimeinStopLine_LST(data):
         return {"launchTimeinStopLine_LST": delta_t}
 
 
-def launchTimewhenFollowingCar_LST(data):
+def launchTimewhenFollowingCar_LST(data, plot_path):
     ego_df = data.ego_data
     t_interval = ego_df['simTime'].tolist()[1] - ego_df['simTime'].tolist()[0]
     simtime_row = ego_df[ego_df['v'] == 0]['simTime'].tolist()
@@ -559,7 +562,7 @@ def launchTimewhenFollowingCar_LST(data):
         return {"launchTimewhenFollowingCar_LST": max(delta_t)}
 
 
-def noStop_LST(data):
+def noStop_LST(data, plot_path):
     ego_df_ini = data.ego_data
     min_time = ego_df_ini['simTime'].min() + 5
     max_time = ego_df_ini['simTime'].max() - 5
@@ -571,7 +574,7 @@ def noStop_LST(data):
         return {"noStop_LST": 1}
 
 
-def launchTimeinTrafficLight_LST(data):
+def launchTimeinTrafficLight_LST(data, plot_path):
     '''
     待修改:
     红灯的状态值:1
@@ -588,7 +591,7 @@ def launchTimeinTrafficLight_LST(data):
     return {"timeInterval_LST": simtime_of_launch[-1] - simtime_of_launch[0]}
 
 
-def crossJunctionToTargetLane_LST(data):
+def crossJunctionToTargetLane_LST(data, plot_path):
     ego_df = data.ego_data
     lane_in_leftturn = set(ego_df['lane_id'].tolist())
     scenario_name = find_nested_name(data.function_config["function"])
@@ -599,15 +602,15 @@ def crossJunctionToTargetLane_LST(data):
         return {"crossJunctionToTargetLane_LST": target_lane_id}
 
 
-def keepInLane_LST(data):
+def keepInLane_LST(data, plot_path):
     ego_df = data.ego_data
-    notkeepinlane = ego_df[ego_df['laneOffset'] > ego_df['lane_width']/2].tolist()
+    notkeepinlane = ego_df[ego_df['laneOffset'] > ego_df['lane_width'] / 2].tolist()
     if len(notkeepinlane):
         return {"keepInLane_LST": -1}
     return {"keepInLane_LST": 1}
 
 
-def leastLateralDistance_LST(data):
+def leastLateralDistance_LST(data, plot_path):
     ego_df = data.ego_data
     lane_width = ego_df[ego_df['x_relative_dist'] == 0]['lane_width']
     if lane_width.empty():
@@ -620,7 +623,7 @@ def leastLateralDistance_LST(data):
             return {"leastLateralDistance_LST": -1}
 
 
-def waitTimeAtCrosswalkwithPedestrian_LST(data):
+def waitTimeAtCrosswalkwithPedestrian_LST(data, plot_path):
     ego_df = data.ego_data
     object_df = data.object_data
     data['in_crosswalk'] = []
@@ -640,7 +643,7 @@ def waitTimeAtCrosswalkwithPedestrian_LST(data):
         car_wait_pedestrian) > 0 else 0}
 
 
-def launchTimewhenPedestrianLeave_LST(data):
+def launchTimewhenPedestrianLeave_LST(data, plot_path):
     ego_df = data.ego_data
     car_stop_time = ego_df[ego_df['v'] == 0]["simTime"]
     if car_stop_time.empty():
@@ -652,7 +655,7 @@ def launchTimewhenPedestrianLeave_LST(data):
         return {"launchTimewhenPedestrianLeave_LST": legal_stop_time[-1] - legal_stop_time[0]}
 
 
-def noCollision_LST(data):
+def noCollision_LST(data, plot_path):
     ego_df = data.ego_data
     if ego_df['relative_dist'].any() == 0:
         return {"noCollision_LST": -1}
@@ -660,7 +663,7 @@ def noCollision_LST(data):
         return {"noCollision_LST": 1}
 
 
-def noReverse_LST(data):
+def noReverse_LST(data, plot_path):
     ego_df = data.ego_data
     if (ego_df["lon_v_vehicle"] * ego_df["posH"]).any() < 0:
         return {"noReverse_LST": -1}
@@ -668,7 +671,7 @@ def noReverse_LST(data):
         return {"noReverse_LST": 1}
 
 
-def turnAround_LST(data):
+def turnAround_LST(data, plot_path):
     ego_df = data.ego_data
     if (ego_df["lon_v_vehicle"].tolist()[0] * ego_df["lon_v_vehicle"].tolist()[-1] < 0) and (
             ego_df["lon_v_vehicle"] * ego_df["posH"].all() > 0):
@@ -677,7 +680,7 @@ def turnAround_LST(data):
         return {"turnAround_LST": -1}
 
 
-def laneOffset_LST(data):
+def laneOffset_LST(data, plot_path):
     car_width = data.function_config['vehicle']['CAR_WIDTH']
     ego_df_ini = data.ego_data
     min_time = ego_df_ini['simTime'].min() + 5
@@ -687,7 +690,7 @@ def laneOffset_LST(data):
     return {"laneOffset_LST": max(laneoffset)}
 
 
-def maxLongitudeDist_LST(data):
+def maxLongitudeDist_LST(data, plot_path):
     ego_df = data.ego_data
     longitude_dist = abs(ego_df[ego_df['v'] == 0]['x_relative_dist'].tolist())
     data.longitude_dist = min(abs(ego_df[ego_df['v'] == 0]['x_relative_dist'].tolist()))
@@ -695,11 +698,11 @@ def maxLongitudeDist_LST(data):
     data.stop_time = min(stop_time)
     if len(longitude_dist) == 0:
         return {"maxLongitudeDist_LST": -1}
-    generate_function_chart_data(data, 'maxLongitudeDist_LST')
+    generate_function_chart_data(data, 'maxLongitudeDist_LST', plot_path)
     return {"maxLongDist_LST": min(longitude_dist)}
 
 
-def noEmergencyBraking_LST(data):
+def noEmergencyBraking_LST(data, plot_path):
     ego_df = data.ego_data
     ego_df['ip_dec'] = ego_df['v'].apply(
         get_interpolation, point1=[18, -5], point2=[72, -3.5])
@@ -711,7 +714,7 @@ def noEmergencyBraking_LST(data):
         return {"noEmergencyBraking_LST": -1}
 
 
-def rightWarningSignal_PGVIL(data_processed) -> dict:
+def rightWarningSignal_PGVIL(data_processed, plot_path) -> dict:
     """判断是否发出正确预警信号"""
 
     ego_df = data_processed.ego_data
@@ -732,7 +735,7 @@ def rightWarningSignal_PGVIL(data_processed) -> dict:
         return {"rightWarningSignal_PGVIL": 1}
 
 
-def latestWarningDistance_PGVIL(data_processed) -> dict:
+def latestWarningDistance_PGVIL(data_processed, plot_path) -> dict:
     """预警距离计算流水线"""
     ego_df = data_processed.ego_data
     obj_df = data_processed.object_df
@@ -751,11 +754,11 @@ def latestWarningDistance_PGVIL(data_processed) -> dict:
     if distances.size == 0:
         print("没有找到数据!")
         return {"latestWarningDistance_PGVIL": 15}  # 或返回其他默认值,如0.0
-    generate_function_chart_data(data_processed, 'latestWarningDistance_PGVIL')
+    generate_function_chart_data(data_processed, 'latestWarningDistance_PGVIL', plot_path)
     return {"latestWarningDistance_PGVIL": float(np.min(distances))}
 
 
-def latestWarningDistance_TTC_PGVIL(data_processed) -> dict:
+def latestWarningDistance_TTC_PGVIL(data_processed, plot_path) -> dict:
     """TTC计算流水线"""
     ego_df = data_processed.ego_data
     obj_df = data_processed.object_df
@@ -784,11 +787,11 @@ def latestWarningDistance_TTC_PGVIL(data_processed) -> dict:
         print("没有找到数据!")
         return {"latestWarningDistance_TTC_PGVIL": 2}  # 或返回其他默认值,如0.0
     data_processed.ttc = ttc
-    generate_function_chart_data(data_processed, 'latestWarningDistance_TTC_PGVIL')
+    generate_function_chart_data(data_processed, 'latestWarningDistance_TTC_PGVIL', plot_path)
     return {"latestWarningDistance_TTC_PGVIL": float(np.nanmin(ttc))}
 
 
-def earliestWarningDistance_PGVIL(data_processed) -> dict:
+def earliestWarningDistance_PGVIL(data_processed, plot_path) -> dict:
     """预警距离计算流水线"""
     ego_df = data_processed.ego_data
     obj_df = data_processed.object_df
@@ -807,11 +810,11 @@ def earliestWarningDistance_PGVIL(data_processed) -> dict:
     if distances.size == 0:
         print("没有找到数据!")
         return {"earliestWarningDistance_PGVIL": 15}  # 或返回其他默认值,如0.0
-    generate_function_chart_data(data_processed, 'earliestWarningDistance_PGVIL')
+    generate_function_chart_data(data_processed, 'earliestWarningDistance_PGVIL', plot_path)
     return {"earliestWarningDistance": float(np.min(distances))}
 
 
-def earliestWarningDistance_TTC_PGVIL(data_processed) -> dict:
+def earliestWarningDistance_TTC_PGVIL(data_processed, plot_path) -> dict:
     """TTC计算流水线"""
     ego_df = data_processed.ego_data
     obj_df = data_processed.object_df
@@ -841,7 +844,7 @@ def earliestWarningDistance_TTC_PGVIL(data_processed) -> dict:
         print("没有找到数据!")
         return {"earliestWarningDistance_TTC_PGVIL": 2}  # 或返回其他默认值,如0.0
     data_processed.ttc = ttc
-    generate_function_chart_data(data_processed, 'earliestWarningDistance_TTC_PGVIL')
+    generate_function_chart_data(data_processed, 'earliestWarningDistance_TTC_PGVIL', plot_path)
     return {"earliestWarningDistance_TTC_PGVIL": float(np.nanmin(ttc))}
 
 
@@ -876,7 +879,7 @@ def earliestWarningDistance_TTC_PGVIL(data_processed) -> dict:
 #             return {"delayOfEmergencyBrakeWarning": -1}
 
 
-def warningDelayTime_PGVIL(data_processed) -> dict:
+def warningDelayTime_PGVIL(data_processed, plot_path) -> dict:
     """车端接收到预警到HMI发出预警的时延"""
     ego_df = data_processed.ego_data
     # #打印ego_df的列名
@@ -933,7 +936,7 @@ def get_car_to_stop_line_distance(ego, car_point, stop_line_points):
     return carhead_distance_to_foot
 
 
-def ifCrossingRedLight_PGVIL(data_processed) -> dict:
+def ifCrossingRedLight_PGVIL(data_processed, plot_path) -> dict:
     # 判断车辆是否闯红灯
 
     stop_line_points = np.array([(276.555, -35.575), (279.751, -33.683)])
@@ -1007,7 +1010,7 @@ def ifCrossingRedLight_PGVIL(data_processed) -> dict:
 #     mindisStopline = np.min(distance_to_stoplines) - distance_carpoint_carhead
 #     return {"mindisStopline": mindisStopline}
 
-def limitSpeed_PGVIL(data):
+def limitSpeed_PGVIL(data, plot_path):
     ego_df = data.ego_data
     max_speed = max(ego_df["v"])
     if len(ego_df["v"]) == 0:
@@ -1016,7 +1019,7 @@ def limitSpeed_PGVIL(data):
         return {"speedLimit_PGVIL": max_speed}
 
 
-def leastDistance_PGVIL(data):
+def leastDistance_PGVIL(data, plot_path):
     exclude_seconds = 2.0
     ego_df = data.ego_data
     max_sim_time = ego_df["simTime"].max()
@@ -1036,12 +1039,12 @@ def leastDistance_PGVIL(data):
                 ego, car_point, stop_line_points
             )
 
-            return {"leastDistance_PGVIL": distance_to_stopline}
+            return {"minimumDistance_PGVIL": distance_to_stopline}
 
-    return {"leastDistance_PGVIL": -1}
+    return {"minimumDistance_PGVIL": -1}
 
 
-def launchTimeinStopLine_PGVIL(data):
+def launchTimeinStopLine_PGVIL(data, plot_path):
     ego_df = data.ego_data
     in_stop = False
     start_time = None
@@ -1074,13 +1077,13 @@ def launchTimeinStopLine_PGVIL(data):
         return {"launchTimeinStopLine_PGVIL": float(max_duration)}
 
 
-def launchTimeinTrafficLight_PGVIL(data):
+def launchTimeinTrafficLight_PGVIL(data, plot_path):
     GREEN = 0x100000
     RED = 0x10000000
     ego_df = data.ego_data
     # 找到第一次红灯 to 绿灯的切换时刻
     is_transition = (ego_df["stateMask"] == GREEN) & (
-        ego_df["stateMask"].shift(1) == RED
+            ego_df["stateMask"].shift(1) == RED
     )
     transition_times = ego_df.loc[is_transition, "simTime"]
     if transition_times.empty:
@@ -1097,7 +1100,7 @@ def launchTimeinTrafficLight_PGVIL(data):
     return {"timeInterval_PGVIL": time_move - time_red2green}
 
 
-def crossJunctionToTargetLane_PGVIL(data):
+def crossJunctionToTargetLane_PGVIL(data, plot_path):
     ego_df = data.ego_data
     lane_ids = set(ego_df["lane_id"].dropna())
 
@@ -1113,7 +1116,7 @@ def crossJunctionToTargetLane_PGVIL(data):
     return {"crossJunctionToTargetLane_PGVIL": result}
 
 
-def noStop_PGVIL(data):
+def noStop_PGVIL(data, plot_path):
     exclude_end_seconds = 5.0
     exclude_start_seconds = 5.0
     ego_df = data.ego_data
@@ -1123,7 +1126,7 @@ def noStop_PGVIL(data):
     end_threshold = max_sim_time - exclude_end_seconds
     filtered_df = ego_df[
         (ego_df["simTime"] >= start_threshold) & (ego_df["simTime"] <= end_threshold)
-    ]
+        ]
 
     if (filtered_df["v"] == 0).any():
         return {"noStop_PGVIL": -1}
@@ -1131,7 +1134,7 @@ def noStop_PGVIL(data):
         return {"noStop_PGVIL": 1}
 
 
-def noEmergencyBraking_PGVIL(data):
+def noEmergencyBraking_PGVIL(data, plot_path):
     ego_df = data.ego_data
     ego_df["ip_dec"] = ego_df["v"].apply(
         get_interpolation, point1=[18, -5], point2=[72, -3.5]
@@ -1145,7 +1148,7 @@ def noEmergencyBraking_PGVIL(data):
         return {"noEmergencyBraking_PGVIL": -1}
 
 
-def noReverse_PGVIL(data):
+def noReverse_PGVIL(data, plot_path):
     ego_df = data.ego_data.copy()
     heading_x = np.cos(ego_df["posH"])
     reverse_flag = (ego_df["speedX"] * heading_x) < 0
@@ -1156,7 +1159,7 @@ def noReverse_PGVIL(data):
         return {"noReverse_PGVIL": 1}
 
 
-def laneOffset_PGVIL(data):
+def laneOffset_PGVIL(data, plot_path):
     car_width = data.function_config["vehicle"]["CAR_WIDTH"]
     ego_df = data.ego_data
     is_zero = ego_df["v"] == 0
@@ -1169,12 +1172,12 @@ def laneOffset_PGVIL(data):
 
     # 距离右侧车道线
     edge_dist = (last_stop["lane_width"] / 2 + last_stop["laneOffset"]) - (
-        car_width / 2
+            car_width / 2
     )
     return {"laneOffset_PGVIL": edge_dist.max()}
 
 
-def maxLongitudelDistance_PGVIL(data):
+def maxLongitudelDistance_PGVIL(data, plot_path):
     scenario_name = find_nested_name(data.function_config["function"])
     stopX_pos = data.function_config["function"][scenario_name][
         "maxLongitudelDistance_PGVIL"
@@ -1184,7 +1187,7 @@ def maxLongitudelDistance_PGVIL(data):
     ]["min"]
 
 
-def keepInLane_PGVIL(data):
+def keepInLane_PGVIL(data, plot_path):
     ego_df = data.ego_data.copy()
     ego_df = ego_df.dropna(subset=["laneOffset", "lane_width"])
     if ego_df.empty:
@@ -1196,7 +1199,7 @@ def keepInLane_PGVIL(data):
         return {"keepInLane_PGVIL": 1}
 
 
-def launchTimewhenPedestrianLeave_PGVIL(data):
+def launchTimewhenPedestrianLeave_PGVIL(data, plot_path):
     ego_df = data.ego_data
     ped_df = data.object_data.loc[
         data.object_data["playerId"] == 2, ["simTime", "posX", "posY"]
@@ -1242,7 +1245,7 @@ def launchTimewhenPedestrianLeave_PGVIL(data):
     return {"launchTimewhenPedestrianLeave_PGVIL": t_launch - t_stop}
 
 
-def waitTimeAtCrosswalkwithPedestrian_PGVIL(data):
+def waitTimeAtCrosswalkwithPedestrian_PGVIL(data, plot_path):
     ego_df = data.ego_data
     ped_df = data.object_data.loc[
         data.object_data["playerId"] == 2, ["simTime", "posX", "posY"]
@@ -1266,25 +1269,25 @@ def waitTimeAtCrosswalkwithPedestrian_PGVIL(data):
 
     stops_df = merged.loc[
         (merged["simTime"] >= t0_launch) & (merged["v"] == 0) & valid_times
-    ].sort_values("simTime")
+        ].sort_values("simTime")
     if stops_df.empty:
         return {"waitTimeAtCrosswalkwithPedestrian_PGVIL": -1}
     wait_time = stops_df["simTime"].iloc[-1] - stops_df["simTime"].iloc[0]
     return {"waitTimeAtCrosswalkwithPedestrian_PGVIL": wait_time}
 
 
-def noCollision_PGVIL(data):
+def noCollision_PGVIL(data, plot_path):
     ego = data.ego_data[["simTime", "posX", "posY", "posH", "dimX", "dimY", "offX", "offY"]]
     tar = data.object_data.loc[data.object_data["playerId"] == 2,
-        ["simTime", "posX", "posY", "posH", "dimX", "dimY", "offX", "offY"]]
+    ["simTime", "posX", "posY", "posH", "dimX", "dimY", "offX", "offY"]]
     df = ego.merge(tar, on="simTime", suffixes=("", "_tar"))
 
     df["posH_tar_rad"] = np.deg2rad(df["posH_tar"])
     df["posH_rad"] = np.deg2rad(df["posH"])
     df["collision"] = df.apply(lambda row: funcIsCollision(
-            row["dimX"],row["dimY"],row["offX"],row["offY"],row["posX"],row["posY"],
-            row["posH_rad"],row["dimX_tar"],row["dimY_tar"],row["offX_tar"],row["offY_tar"],
-            row["posX_tar"],row["posY_tar"],row["posH_tar_rad"],),axis=1,)
+        row["dimX"], row["dimY"], row["offX"], row["offY"], row["posX"], row["posY"],
+        row["posH_rad"], row["dimX_tar"], row["dimY_tar"], row["offX_tar"], row["offY_tar"],
+        row["posX_tar"], row["posY_tar"], row["posH_tar_rad"], ), axis=1, )
 
     if df["collision"].any():
         return {"noCollision_PGVIL": -1}
@@ -1292,40 +1295,41 @@ def noCollision_PGVIL(data):
         return {"noCollision_PGVIL": 1}
 
 
-def leastLateralDistance_PGVIL(data):
+def leastLateralDistance_PGVIL(data, plot_path):
     ego_df = data.ego_data
     cones = data.object_data.loc[data.object_data['playerId'] != 1,
-        ['simTime','playerId','posX','posY']].copy()
-    df = ego.merge(cones, on='simTime', how='inner', suffixes=('','_cone'))
+    ['simTime', 'playerId', 'posX', 'posY']].copy()
+    df = ego_df.merge(cones, on='simTime', how='inner', suffixes=('', '_cone'))
     yaw = np.deg2rad(df['posH'])
     x_c = df['posX'] + df['offX'] * np.cos(yaw)
     y_c = df['posY'] + df['offX'] * np.sin(yaw)
-    dx = df['posX']   - x_c
-    dy = df['posY']   - y_c
+    dx = df['posX'] - x_c
+    dy = df['posY'] - y_c
     dx = df['posX_cone'] - x_c
     dy = df['posY_cone'] - y_c
-    local_x =  np.cos(yaw) * dx + np.sin(yaw) * dy
+    local_x = np.cos(yaw) * dx + np.sin(yaw) * dy
     local_y = -np.sin(yaw) * dx + np.cos(yaw) * dy
     half_length = df['dimX'] / 2
-    half_width  = df['dimY'] / 2
+    half_width = df['dimY'] / 2
 
     inside = (np.abs(local_x) <= half_length) & (np.abs(local_y) <= half_width)
-    collisions = df.loc[inside, ['simTime','playerId']]
+    collisions = df.loc[inside, ['simTime', 'playerId']]
 
     if collisions.empty:
         return {"noConeRectCollision_PGVIL": 1}
     else:
         collision_times = collisions['simTime'].unique().tolist()
-        collision_ids   = collisions['playerId'].unique().tolist()
+        collision_ids = collisions['playerId'].unique().tolist()
         return {"noConeRectCollision_PGVIL": -1}
 
 
 class FunctionRegistry:
     """动态函数注册器(支持参数验证)"""
 
-    def __init__(self, data_processed):
+    def __init__(self, data_processed, plot_path):
         self.logger = LogManager().get_logger()  # 获取全局日志实例
         self.data = data_processed
+        self.plot_path = plot_path
         # 检查function_config是否为空
         if not hasattr(data_processed, 'function_config') or not data_processed.function_config:
             self.logger.warning("功能配置为空,跳过功能指标计算")
@@ -1373,10 +1377,10 @@ class FunctionRegistry:
         if not hasattr(self, 'fun_config') or not self.fun_config or not self._registry:
             self.logger.info("功能配置为空或无注册指标,返回空结果")
             return results
-            
+
         for name, func in self._registry.items():
             try:
-                result = func(self.data)  # 统一传递数据上下文
+                result = func(self.data, self.plot_path)  # 统一传递数据上下文
                 results.update(result)
             except Exception as e:
                 print(f"{name} 执行失败: {str(e)}")
@@ -1389,15 +1393,16 @@ class FunctionRegistry:
 class FunctionManager:
     """管理功能指标计算的类"""
 
-    def __init__(self, data_processed):
+    def __init__(self, data_processed, plot_path):
         self.data = data_processed
         self.logger = LogManager().get_logger()
+        self.plot_path = plot_path
         # 检查function_config是否为空
         if not hasattr(data_processed, 'function_config') or not data_processed.function_config:
             self.logger.warning("功能配置为空,跳过功能指标计算初始化")
             self.function = None
         else:
-            self.function = FunctionRegistry(self.data)
+            self.function = FunctionRegistry(self.data, self.plot_path)
 
     def report_statistic(self):
         """
@@ -1408,7 +1413,7 @@ class FunctionManager:
         if self.function is None:
             self.logger.info("功能指标管理器未初始化,返回空结果")
             return {}
-            
+
         function_result = self.function.batch_execute()
 
         print("\n[功能性表现及评价结果]")

+ 39 - 34
modules/metric/safety.py

@@ -38,7 +38,7 @@ SAFETY_INFO = [
 # ----------------------
 # 独立指标计算函数
 # ----------------------
-def calculate_ttc(data_processed) -> dict:
+def calculate_ttc(data_processed, plot_path) -> dict:
     """计算TTC (Time To Collision)"""
     if data_processed is None or not hasattr(data_processed, 'object_df'):
         return {"TTC": None}
@@ -47,7 +47,7 @@ def calculate_ttc(data_processed) -> dict:
         ttc_value = safety.get_ttc_value()
         # 只生成图表,数据导出由chart_generator处理
         if safety.ttc_data:
-            safety.generate_metric_chart('TTC')
+            safety.generate_metric_chart('TTC', plot_path)
         LogManager().get_logger().info(f"安全指标[TTC]计算结果: {ttc_value}")
         return {"TTC": ttc_value}
     except Exception as e:
@@ -55,7 +55,7 @@ def calculate_ttc(data_processed) -> dict:
         return {"TTC": None}
 
 
-def calculate_mttc(data_processed) -> dict:
+def calculate_mttc(data_processed, plot_path) -> dict:
     """计算MTTC (Modified Time To Collision)"""
     if data_processed is None or not hasattr(data_processed, 'object_df'):
         return {"MTTC": None}
@@ -63,7 +63,7 @@ def calculate_mttc(data_processed) -> dict:
         safety = SafetyCalculator(data_processed)
         mttc_value = safety.get_mttc_value()
         if safety.mttc_data:
-            safety.generate_metric_chart('MTTC')
+            safety.generate_metric_chart('MTTC', plot_path)
         LogManager().get_logger().info(f"安全指标[MTTC]计算结果: {mttc_value}")
         return {"MTTC": mttc_value}
     except Exception as e:
@@ -71,7 +71,7 @@ def calculate_mttc(data_processed) -> dict:
         return {"MTTC": None}
 
 
-def calculate_thw(data_processed) -> dict:
+def calculate_thw(data_processed, plot_path) -> dict:
     """计算THW (Time Headway)"""
     if data_processed is None or not hasattr(data_processed, 'object_df'):
         return {"THW": None}
@@ -79,7 +79,7 @@ def calculate_thw(data_processed) -> dict:
         safety = SafetyCalculator(data_processed)
         thw_value = safety.get_thw_value()
         if safety.thw_data:
-            safety.generate_metric_chart('THW')
+            safety.generate_metric_chart('THW', plot_path)
         LogManager().get_logger().info(f"安全指标[THW]计算结果: {thw_value}")
         return {"THW": thw_value}
     except Exception as e:
@@ -87,7 +87,7 @@ def calculate_thw(data_processed) -> dict:
         return {"THW": None}
 
 
-def calculate_tlc(data_processed) -> dict:
+def calculate_tlc(data_processed, plot_path) -> dict:
     """计算TLC (Time to Line Crossing)"""
     if data_processed is None or not hasattr(data_processed, 'object_df'):
         return {"TLC": None}
@@ -95,7 +95,7 @@ def calculate_tlc(data_processed) -> dict:
         safety = SafetyCalculator(data_processed)
         tlc_value = safety.get_tlc_value()
         if safety.tlc_data:
-            safety.generate_metric_chart('TLC')
+            safety.generate_metric_chart('TLC', plot_path)
         LogManager().get_logger().info(f"安全指标[TLC]计算结果: {tlc_value}")
         return {"TLC": tlc_value}
     except Exception as e:
@@ -103,7 +103,7 @@ def calculate_tlc(data_processed) -> dict:
         return {"TLC": None}
 
 
-def calculate_ttb(data_processed) -> dict:
+def calculate_ttb(data_processed, plot_path) -> dict:
     """计算TTB (Time to Brake)"""
     if data_processed is None or not hasattr(data_processed, 'object_df'):
         return {"TTB": None}
@@ -111,7 +111,7 @@ def calculate_ttb(data_processed) -> dict:
         safety = SafetyCalculator(data_processed)
         ttb_value = safety.get_ttb_value()
         if safety.ttb_data:
-            safety.generate_metric_chart('TTB')
+            safety.generate_metric_chart('TTB', plot_path)
         LogManager().get_logger().info(f"安全指标[TTB]计算结果: {ttb_value}")
         return {"TTB": ttb_value}
     except Exception as e:
@@ -119,7 +119,7 @@ def calculate_ttb(data_processed) -> dict:
         return {"TTB": None}
 
 
-def calculate_tm(data_processed) -> dict:
+def calculate_tm(data_processed, plot_path) -> dict:
     """计算TM (Time Margin)"""
     if data_processed is None or not hasattr(data_processed, 'object_df'):
         return {"TM": None}
@@ -127,7 +127,7 @@ def calculate_tm(data_processed) -> dict:
         safety = SafetyCalculator(data_processed)
         tm_value = safety.get_tm_value()
         if safety.tm_data:
-            safety.generate_metric_chart('TM')
+            safety.generate_metric_chart('TM', plot_path)
         LogManager().get_logger().info(f"安全指标[TM]计算结果: {tm_value}")
         return {"TM": tm_value}
     except Exception as e:
@@ -135,7 +135,7 @@ def calculate_tm(data_processed) -> dict:
         return {"TM": None}
 
 
-def calculate_dtc(data_processed) -> dict:
+def calculate_dtc(data_processed, plot_path) -> dict:
     """计算DTC (Distance to Collision)"""
     if data_processed is None or not hasattr(data_processed, 'object_df'):
         return {"DTC": None}
@@ -149,7 +149,7 @@ def calculate_dtc(data_processed) -> dict:
         return {"DTC": None}
 
 
-def calculate_psd(data_processed) -> dict:
+def calculate_psd(data_processed, plot_path) -> dict:
     """计算PSD (Potential Safety Distance)"""
     if data_processed is None or not hasattr(data_processed, 'object_df'):
         return {"PSD": None}
@@ -163,7 +163,7 @@ def calculate_psd(data_processed) -> dict:
         return {"PSD": None}
 
 
-def calculate_collisionrisk(data_processed) -> dict:
+def calculate_collisionrisk(data_processed, plot_path) -> dict:
     """计算碰撞风险"""
     if data_processed is None or not hasattr(data_processed, 'object_df'):
         return {"collisionRisk": None}
@@ -171,7 +171,7 @@ def calculate_collisionrisk(data_processed) -> dict:
         safety = SafetyCalculator(data_processed)
         collision_risk_value = safety.get_collision_risk_value()
         if safety.collision_risk_data:
-            safety.generate_metric_chart('collisionRisk')
+            safety.generate_metric_chart('collisionRisk', plot_path)
         LogManager().get_logger().info(f"安全指标[collisionRisk]计算结果: {collision_risk_value}")
         return {"collisionRisk": collision_risk_value}
     except Exception as e:
@@ -179,17 +179,17 @@ def calculate_collisionrisk(data_processed) -> dict:
         return {"collisionRisk": None}
 
 
-def calculate_lonsd(data_processed) -> dict:
+def calculate_lonsd(data_processed, plot_path) -> dict:
     """计算纵向安全距离"""
     safety = SafetyCalculator(data_processed)
     lonsd_value = safety.get_lonsd_value()
     if safety.lonsd_data:
-        safety.generate_metric_chart('LonSD')
+        safety.generate_metric_chart('LonSD', plot_path)
     LogManager().get_logger().info(f"安全指标[LonSD]计算结果: {lonsd_value}")
     return {"LonSD": lonsd_value}
 
 
-def calculate_latsd(data_processed) -> dict:
+def calculate_latsd(data_processed, plot_path) -> dict:
     """计算横向安全距离"""
     if data_processed is None or not hasattr(data_processed, 'object_df'):
         return {"LatSD": None}
@@ -198,7 +198,7 @@ def calculate_latsd(data_processed) -> dict:
         latsd_value = safety.get_latsd_value()
         if safety.latsd_data:
             # 只生成图表,数据导出由chart_generator处理
-            safety.generate_metric_chart('LatSD')
+            safety.generate_metric_chart('LatSD', plot_path)
         LogManager().get_logger().info(f"安全指标[LatSD]计算结果: {latsd_value}")
         return {"LatSD": latsd_value}
     except Exception as e:
@@ -206,7 +206,7 @@ def calculate_latsd(data_processed) -> dict:
         return {"LatSD": None}
 
 
-def calculate_btn(data_processed) -> dict:
+def calculate_btn(data_processed, plot_path) -> dict:
     """计算制动威胁数"""
     if data_processed is None or not hasattr(data_processed, 'object_df'):
         return {"BTN": None}
@@ -215,7 +215,7 @@ def calculate_btn(data_processed) -> dict:
         btn_value = safety.get_btn_value()
         if safety.btn_data:
             # 只生成图表,数据导出由chart_generator处理
-            safety.generate_metric_chart('BTN')
+            safety.generate_metric_chart('BTN', plot_path)
         LogManager().get_logger().info(f"安全指标[BTN]计算结果: {btn_value}")
         return {"BTN": btn_value}
     except Exception as e:
@@ -223,7 +223,7 @@ def calculate_btn(data_processed) -> dict:
         return {"BTN": None}
 
 
-def calculate_collisionseverity(data_processed) -> dict:
+def calculate_collisionseverity(data_processed, plot_path) -> dict:
     """计算碰撞严重性"""
     if data_processed is None or not hasattr(data_processed, 'object_df'):
         return {"collisionSeverity": None}
@@ -232,7 +232,7 @@ def calculate_collisionseverity(data_processed) -> dict:
         collision_severity_value = safety.get_collision_severity_value()
         if safety.collision_severity_data:
             # 只生成图表,数据导出由chart_generator处理
-            safety.generate_metric_chart('collisionSeverity')
+            safety.generate_metric_chart('collisionSeverity', plot_path)
         LogManager().get_logger().info(f"安全指标[collisionSeverity]计算结果: {collision_severity_value}")
         return {"collisionSeverity": collision_severity_value}
     except Exception as e:
@@ -243,9 +243,10 @@ def calculate_collisionseverity(data_processed) -> dict:
 class SafetyRegistry:
     """安全指标注册器"""
 
-    def __init__(self, data_processed):
+    def __init__(self, data_processed, plot_path):
         self.logger = LogManager().get_logger()
         self.data = data_processed
+        self.plot_path = plot_path
         # 检查safety_config是否为空
         if not hasattr(data_processed, 'safety_config') or not data_processed.safety_config:
             self.logger.warning("安全配置为空,跳过安全指标计算")
@@ -290,10 +291,10 @@ class SafetyRegistry:
         if not hasattr(self, 'safety_config') or not self.safety_config or not self._registry:
             self.logger.info("安全配置为空或无注册指标,返回空结果")
             return results
-            
+
         for name, func in self._registry.items():
             try:
-                result = func(self.data)
+                result = func(self.data, self.plot_path)
                 results.update(result)
             except Exception as e:
                 self.logger.error(f"{name} 执行失败: {str(e)}", exc_info=True)
@@ -305,15 +306,16 @@ class SafetyRegistry:
 class SafeManager:
     """安全指标管理类"""
 
-    def __init__(self, data_processed):
+    def __init__(self, data_processed, plot_path):
         self.data = data_processed
         self.logger = LogManager().get_logger()
+        self.plot_path = plot_path
         # 检查safety_config是否为空
         if not hasattr(data_processed, 'safety_config') or not data_processed.safety_config:
             self.logger.warning("安全配置为空,跳过安全指标计算初始化")
             self.registry = None
         else:
-            self.registry = SafetyRegistry(self.data)
+            self.registry = SafetyRegistry(self.data, self.plot_path)
 
     def report_statistic(self):
         """计算并报告安全指标结果"""
@@ -321,7 +323,7 @@ class SafeManager:
         if self.registry is None:
             self.logger.info("安全指标管理器未初始化,返回空结果")
             return {}
-            
+
         safety_result = self.registry.batch_execute()
         return safety_result
 
@@ -336,7 +338,8 @@ class SafetyCalculator:
         self.df = data_processed.object_df.copy()
         self.ego_df = data_processed.ego_data.copy()  # 使用copy()避免修改原始数据
         self.obj_id_list = data_processed.obj_id_list
-        self.obj_df = self.df[self.df['playerId'] == 2].copy().reset_index(drop=True) if len(self.obj_id_list) > 1 else pd.DataFrame(columns = self.ego_df.columns) # 使用copy()避免修改原始数据
+        self.obj_df = self.df[self.df['playerId'] == 2].copy().reset_index(drop=True) if len(
+            self.obj_id_list) > 1 else pd.DataFrame(columns=self.ego_df.columns)  # 使用copy()避免修改原始数据
         self.metric_list = [
             'TTC', 'MTTC', 'THW', 'TLC', 'TTB', 'TM', 'DTC', 'PSD', 'LonSD', 'LatSD', 'BTN', 'collisionRisk',
             'collisionSeverity'
@@ -555,7 +558,7 @@ class SafetyCalculator:
                 lat_v = -v_x1 * math.sin(h1_rad) + v_y1 * math.cos(h1_rad)
 
                 obj_dict[frame_num][playerId]['lat_v_rel'] = lat_v - (
-                            -v_x2 * math.sin(h1_rad) + v_y2 * math.cos(h1_rad))
+                        -v_x2 * math.sin(h1_rad) + v_y2 * math.cos(h1_rad))
                 obj_dict[frame_num][playerId]['lon_v_rel'] = lon_v - (v_x2 * math.cos(h1_rad) + v_y2 * math.sin(h1_rad))
 
                 TTC = None if (TTC is None or TTC < 0) else TTC
@@ -715,7 +718,7 @@ class SafetyCalculator:
         dist = np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
         return dist
 
-    def generate_metric_chart(self, metric_name: str) -> None:
+    def generate_metric_chart(self, metric_name: str, plot_path: Path) -> None:
         """生成指标图表
 
         Args:
@@ -723,9 +726,11 @@ class SafetyCalculator:
         """
         try:
             # 确定输出目录
-            if self.output_dir is None:
+            if plot_path is None:
                 self.output_dir = os.path.join(os.getcwd(), 'data')
                 os.makedirs(self.output_dir, exist_ok=True)
+            else:
+                self.output_dir = plot_path
 
             # 调用图表生成函数
             chart_path = generate_safety_chart_data(self, metric_name, self.output_dir)

+ 106 - 97
modules/metric/traffic.py

@@ -1,6 +1,6 @@
 # ... 保留原有导入和常量定义 ...
 import math
-import operator 
+import operator
 import copy
 import numpy as np
 import pandas as pd
@@ -10,7 +10,6 @@ from modules.lib.score import Score
 from modules.lib.log_manager import LogManager
 from modules.lib import data_process
 
-
 OVERTAKE_INFO = [
     "simTime",
     "simFrame",
@@ -51,7 +50,7 @@ TURNAROUND_INFO = [
 
 TRFFICSIGN_INFO = [
     "simTime",
-    "simFrame",  
+    "simFrame",
     "playerId",
     "speedX",
     "speedY",
@@ -63,6 +62,8 @@ TRFFICSIGN_INFO = [
     "sign_x",
     "sign_y",
 ]
+
+
 # 修改指标函数名称为 calculate_xxx 格式
 def calculate_overtake_when_passing_car(data_processed):
     """计算会车时超车指标"""
@@ -266,6 +267,7 @@ def calculate_urbanexpresswayorhighwayridelanedivider(data_processed):
     urbanExpresswayOrHighwayRideLaneDivider_count = warningviolation.calculate_urbanExpresswayOrHighwayRideLaneDivider_count()
     return {"urbanExpresswayOrHighwayRideLaneDivider": urbanExpresswayOrHighwayRideLaneDivider_count}
 
+
 def calculate_nostraightthrough(data_processed):
     """计算禁止直行标志牌处直行指标"""
     trafficsignviolation = TrafficSignViolation(data_processed)
@@ -290,7 +292,7 @@ def calculate_minimumspeedlimitviolation(data_processed):
 # 修改 TrafficRegistry 类的 _build_registry 方法
 class TrafficRegistry:
     """交通违规指标注册器"""
-    
+
     def __init__(self, data_processed):
         self.logger = LogManager().get_logger()
         self.data = data_processed
@@ -304,20 +306,22 @@ class TrafficRegistry:
         self.traffic_config = data_processed.traffic_config.get("traffic", {})
         self.metrics = self._extract_metrics(self.traffic_config)
         self._registry = self._build_registry()
-    
+
     def _extract_metrics(self, config_node: dict) -> list:
         """从配置中提取指标名称"""
         metrics = []
+
         def _recurse(node):
             if isinstance(node, dict):
                 if 'name' in node and not any(isinstance(v, dict) for v in node.values()):
                     metrics.append(node['name'])
                 for v in node.values():
                     _recurse(v)
+
         _recurse(config_node)
         self.logger.info(f'评比的交通违规指标列表:{metrics}')
         return metrics
-    
+
     def _build_registry(self) -> dict:
         """构建指标函数注册表"""
         registry = {}
@@ -328,7 +332,7 @@ class TrafficRegistry:
             except KeyError:
                 self.logger.error(f"未实现交通违规指标函数: {func_name}")
         return registry
-    
+
     def batch_execute(self) -> dict:
         """批量执行指标计算"""
         results = {}
@@ -336,7 +340,7 @@ class TrafficRegistry:
         if not hasattr(self, 'traffic_config') or not self.traffic_config or not self._registry:
             self.logger.info("交通违规配置为空或无注册指标,返回空结果")
             return results
-            
+
         for name, func in self._registry.items():
             try:
                 result = func(self.data)
@@ -352,24 +356,25 @@ class TrafficRegistry:
 
 class TrafficManager:
     """交通违规指标管理类"""
-    
-    def __init__(self, data_processed):
+
+    def __init__(self, data_processed, plot_path):
         self.data = data_processed
         self.logger = LogManager().get_logger()
+        self.plot_path = plot_path
         # 检查traffic_config是否为空
         if not hasattr(data_processed, 'traffic_config') or not data_processed.traffic_config:
             self.logger.warning("交通违规配置为空,跳过交通违规指标计算初始化")
             self.registry = None
         else:
             self.registry = TrafficRegistry(self.data)
-    
+
     def report_statistic(self):
         """计算并报告交通违规指标结果"""
         # 如果registry为None,直接返回空字典
         if self.registry is None:
             self.logger.info("交通违规指标管理器未初始化,返回空结果")
             return {}
-            
+
         traffic_result = self.registry.batch_execute()
         return traffic_result
 
@@ -383,24 +388,24 @@ class OvertakingViolation(object):
 
         # 存储原始数据引用,不进行拷贝
         self._raw_data = df_data.obj_data[1]  # 自车数据
-        
+
         # 安全获取其他车辆数据
         self._data_obj = None
         self._other_obj_data1 = None
-        
+
         # 检查是否存在ID为2的对象数据
         if 2 in df_data.obj_id_list:
             self._data_obj = df_data.obj_data[2]
-        
+
         # 检查是否存在ID为3的对象数据
         if 3 in df_data.obj_id_list:
             self._other_obj_data1 = df_data.obj_data[3]
-        
+
         # 初始化属性,但不立即创建数据副本
         self._ego_data = None
         self._obj_data = None
         self._other_obj_data = None
-        
+
         # 使用字典统一管理违规计数器
         self.violation_counts = {
             "overtake_on_right": 0,
@@ -413,7 +418,7 @@ class OvertakingViolation(object):
             "overtake_on_decelerate_lane": 0,
             "overtake_in_different_senerios": 0
         }
-        
+
         # 标记计算状态
         self._calculated = {
             "illegal_overtake": False,
@@ -424,14 +429,14 @@ class OvertakingViolation(object):
             "decelerate_lane": False,
             "different_senerios": False
         }
-     
+
     @property
     def ego_data(self):
         """懒加载方式获取ego数据,只在首次访问时创建副本"""
         if self._ego_data is None:
             self._ego_data = self._raw_data[OVERTAKE_INFO].copy().reset_index(drop=True)
         return self._ego_data
-    
+
     @property
     def obj_data(self):
         """懒加载方式获取obj数据"""
@@ -442,7 +447,7 @@ class OvertakingViolation(object):
                 # 如果没有数据,创建一个空的DataFrame,列名与ego_data相同
                 self._obj_data = pd.DataFrame(columns=OVERTAKE_INFO)
         return self._obj_data
-    
+
     @property
     def other_obj_data(self):
         """懒加载方式获取other_obj数据"""
@@ -495,7 +500,7 @@ class OvertakingViolation(object):
         # 如果已经计算过,直接返回
         if self._calculated["illegal_overtake"]:
             return
-            
+
         # 如果没有其他车辆数据,直接返回,保持默认值0
         if self.obj_data.empty:
             print("没有其他车辆数据,无法检测超车违规,默认为0")
@@ -515,32 +520,32 @@ class OvertakingViolation(object):
             ego_data_frames = self.ego_data[
                 self.ego_data["simFrame"].isin(simframe_window)
             ]
-            
+
             # 确保有足够的数据进行处理
             if len(ego_data_frames) == 0:
                 start_frame_id += 1
                 continue
-                
+
             obj_data_frames = self.obj_data[
                 self.obj_data["simFrame"].isin(simframe_window)
             ]
-            
+
             # 如果没有其他车辆数据,跳过当前窗口
             if len(obj_data_frames) == 0:
                 start_frame_id += 1
                 continue
-                
+
             other_data_frames = self.other_obj_data[
                 self.other_obj_data["simFrame"].isin(simframe_window)
             ]
-            
+
             # 读取前后的laneId
             lane_id = ego_data_frames["lane_id"].tolist()
-            
+
             # 读取前后方向盘转角steeringWheel
             driverctrl_start_state = ego_data_frames["posH"].iloc[0]
             driverctrl_end_state = ego_data_frames["posH"].iloc[-1]
-            
+
             # 读取车辆前后的位置信息
             dx, dy = self._is_dxy_of_car(ego_data_frames, obj_data_frames)
             ego_speedx = ego_data_frames["speedX"].tolist()
@@ -554,7 +559,7 @@ class OvertakingViolation(object):
             else:
                 obj_speedx = []
                 obj_speedy = []
-                
+
             # 检查会车时超车
             if len(other_data_frames) > 0:
                 other_start_speedx = other_data_frames["speedX"].iloc[0]
@@ -569,7 +574,7 @@ class OvertakingViolation(object):
                     )
                     start_frame_id += window_width
                     continue
-            
+
             # 检查右侧超车
             if driverctrl_start_state > 0 and driverctrl_end_state < 0:
                 self.violation_counts["overtake_on_right"] += self._is_overtake(
@@ -577,7 +582,7 @@ class OvertakingViolation(object):
                 )
                 start_frame_id += window_width
                 continue
-            
+
             # 检查掉头时超车
             if obj_speedx and obj_speedy:  # 确保列表不为空
                 if ego_speedx[0] * obj_speedx[0] + ego_speedy[0] * obj_speedy[0] < 0:
@@ -586,10 +591,10 @@ class OvertakingViolation(object):
                     )
                     start_frame_id += window_width
                     continue
-            
+
             # 如果没有检测到任何违规,移动窗口
             start_frame_id += 1
-            
+
         self._calculated["illegal_overtake"] = True
 
     # 借道超车场景
@@ -598,13 +603,13 @@ class OvertakingViolation(object):
         # 如果已经计算过,直接返回
         if self._calculated["forbid_lane"]:
             return
-            
+
         # 如果没有其他车辆数据,直接返回,保持默认值0
         if self.obj_data.empty:
             print("没有其他车辆数据,无法检测借道超车违规,默认为0")
             self._calculated["forbid_lane"] = True
             return
-            
+
         simTime = self.obj_data["simTime"].tolist()
         simtime_devide = self.different_road_area_simtime(simTime)
         for simtime in simtime_devide:
@@ -617,7 +622,7 @@ class OvertakingViolation(object):
                     self.violation_counts["overtake_in_forbid_lane"] += 1
             except Exception as e:
                 print("数据缺少lane_type信息")
-                
+
         self._calculated["forbid_lane"] = True
 
     # 在匝道超车
@@ -626,13 +631,13 @@ class OvertakingViolation(object):
         # 如果已经计算过,直接返回
         if self._calculated["ramp_area"]:
             return
-            
+
         # 如果没有其他车辆数据,直接返回,保持默认值0
         if self.obj_data.empty:
             print("没有其他车辆数据,无法检测匝道超车违规,默认为0")
             self._calculated["ramp_area"] = True
             return
-            
+
         ramp_simtime_list = self.ego_data[(self.ego_data["road_type"] == 19)][
             "simTime"
         ].tolist()
@@ -652,7 +657,7 @@ class OvertakingViolation(object):
                 )
             else:
                 continue
-                
+
         self._calculated["ramp_area"] = True
 
     def overtake_in_tunnel_area_detector(self):
@@ -660,13 +665,13 @@ class OvertakingViolation(object):
         # 如果已经计算过,直接返回
         if self._calculated["tunnel_area"]:
             return
-            
+
         # 如果没有其他车辆数据,直接返回,保持默认值0
         if self.obj_data.empty:
             print("没有其他车辆数据,无法检测隧道超车违规,默认为0")
             self._calculated["tunnel_area"] = True
             return
-            
+
         tunnel_simtime_list = self.ego_data[(self.ego_data["road_type"] == 15)][
             "simTime"
         ].tolist()
@@ -686,7 +691,7 @@ class OvertakingViolation(object):
                 )
             else:
                 continue
-                
+
         self._calculated["tunnel_area"] = True
 
     # 加速车道超车
@@ -695,13 +700,13 @@ class OvertakingViolation(object):
         # 如果已经计算过,直接返回
         if self._calculated["accelerate_lane"]:
             return
-            
+
         # 如果没有其他车辆数据,直接返回,保持默认值0
         if self.obj_data.empty:
             print("没有其他车辆数据,无法检测加速车道超车违规,默认为0")
             self._calculated["accelerate_lane"] = True
             return
-            
+
         accelerate_simtime_list = self.ego_data[self.ego_data["lane_type"] == 2][
             "simTime"
         ].tolist()
@@ -723,7 +728,7 @@ class OvertakingViolation(object):
             self.violation_counts["overtake_on_accelerate_lane"] += self._is_overtake(
                 lane_id, dx, dy, ego_speedx, ego_speedy
             )
-            
+
         self._calculated["accelerate_lane"] = True
 
     # 减速车道超车
@@ -732,13 +737,13 @@ class OvertakingViolation(object):
         # 如果已经计算过,直接返回
         if self._calculated["decelerate_lane"]:
             return
-            
+
         # 如果没有其他车辆数据,直接返回,保持默认值0
         if self.obj_data.empty:
             print("没有其他车辆数据,无法检测减速车道超车违规,默认为0")
             self._calculated["decelerate_lane"] = True
             return
-            
+
         decelerate_simtime_list = self.ego_data[(self.ego_data["lane_type"] == 3)][
             "simTime"
         ].tolist()
@@ -760,7 +765,7 @@ class OvertakingViolation(object):
             self.violation_counts["overtake_on_decelerate_lane"] += self._is_overtake(
                 lane_id, dx, dy, ego_speedx, ego_speedy
             )
-            
+
         self._calculated["decelerate_lane"] = True
 
     # 在交叉路口
@@ -769,13 +774,13 @@ class OvertakingViolation(object):
         # 如果已经计算过,直接返回
         if self._calculated["different_senerios"]:
             return
-            
+
         # 如果没有其他车辆数据,直接返回,保持默认值0
         if self.obj_data.empty:
             print("没有其他车辆数据,无法检测不同场景超车违规,默认为0")
             self._calculated["different_senerios"] = True
             return
-            
+
         crossroad_simTime = self.ego_data[self.ego_data["interid"] != 10000][
             "simTime"
         ].tolist()  # 判断是路口或者隧道区域
@@ -800,7 +805,7 @@ class OvertakingViolation(object):
             self.violation_counts["overtake_in_different_senerios"] += self._is_overtake(
                 lane_id, dx, dy, ego_speedx, ego_speedy
             )
-            
+
         self._calculated["different_senerios"] = True
 
     def calculate_overtake_when_passing_car_count(self):
@@ -855,33 +860,33 @@ class SlowdownViolation(object):
     def __init__(self, df_data):
         print("减速让行违规类-------------------------")
         self.traffic_violations_type = "减速让行违规类"
-        
+
         # 存储原始数据引用
         self._raw_data = df_data.obj_data[1]
         self.object_items = set(df_data.object_df.type.tolist())
-        
+
         # 存储行人数据引用
         self._pedestrian_df = None
         if 13 in self.object_items:  # 行人的type是13
             self._pedestrian_df = df_data.object_df[df_data.object_df.type == 13]
-        
+
         # 初始化属性,但不立即创建数据副本
         self._ego_data = None
         self._pedestrian_data = None
-        
+
         # 初始化计数器
         self.slow_down_in_crosswalk_count = 0
         self.avoid_pedestrian_in_crosswalk_count = 0
         self.avoid_pedestrian_in_the_road_count = 0
         self.aviod_pedestrian_when_turning_count = 0
-    
+
     @property
     def ego_data(self):
         """懒加载方式获取ego数据"""
         if self._ego_data is None:
             self._ego_data = self._raw_data[SLOWDOWN_INFO].copy().reset_index(drop=True)
         return self._ego_data
-    
+
     @property
     def pedestrian_data(self):
         """懒加载方式获取行人数据"""
@@ -1074,28 +1079,28 @@ class TurnaroundViolation(object):
         # 存储原始数据引用
         self._raw_data = df_data.obj_data[1]
         self.object_items = set(df_data.object_df.type.tolist())
-        
+
         # 存储行人数据引用
         self._pedestrian_df = None
         if 13 in self.object_items:  # 行人的type是13
             self._pedestrian_df = df_data.object_df[df_data.object_df.type == 13]
-        
+
         # 初始化属性,但不立即创建数据副本
         self._ego_data = None
         self._pedestrian_data = None
-        
+
         # 初始化计数器
         self.turning_in_forbiden_turn_back_sign_count = 0
         self.turning_in_forbiden_turn_left_sign_count = 0
         self.avoid_pedestrian_when_turn_back_count = 0
-    
+
     @property
     def ego_data(self):
         """懒加载方式获取ego数据"""
         if self._ego_data is None:
             self._ego_data = self._raw_data[TURNAROUND_INFO].copy().reset_index(drop=True)
         return self._ego_data
-    
+
     @property
     def pedestrian_data(self):
         """懒加载方式获取行人数据"""
@@ -1235,26 +1240,27 @@ class TurnaroundViolation(object):
         self.avoid_pedestrian_when_turn_back_detector()
         return self.avoid_pedestrian_when_turn_back_count
 
+
 class WrongWayViolation(object):
     """停车违规类"""
 
     def __init__(self, df_data):
         print("停车违规类初始化中...")
         self.traffic_violations_type = "停车违规类"
-        
+
         # 存储原始数据引用
         self._raw_data = df_data.obj_data[1]
-        
+
         # 初始化属性,但不立即创建数据副本
         self._data = None
-        
+
         # 初始化违规统计
         self.violation_count = {
             "urbanExpresswayOrHighwayDrivingLaneStopped": 0,
             "urbanExpresswayOrHighwayEmergencyLaneStopped": 0,
             "urbanExpresswayEmergencyLaneDriving": 0,
         }
-    
+
     @property
     def data(self):
         """懒加载方式获取数据"""
@@ -1321,6 +1327,7 @@ class WrongWayViolation(object):
         self.process_violations()
         return self.violation_count["urbanExpresswayEmergencyLaneDriving"]
 
+
 class SpeedingViolation(object):
     """超速违规类"""
 
@@ -1331,7 +1338,7 @@ class SpeedingViolation(object):
         self.traffic_violations_type = "超速违规类"
         # 存储原始数据引用
         self._raw_data = df_data.obj_data[1]
-        
+
         # 初始化属性,但不立即创建数据副本
         self._data = None
         # 初始化违规统计
@@ -1439,6 +1446,7 @@ class SpeedingViolation(object):
         return self.violation_counts["generalRoadSpeedOverLimit20to50"] if self.violation_counts.get(
             "generalRoadSpeedOverLimit20to50") else 0
 
+
 class TrafficLightViolation(object):
     """违反交通灯类"""
 
@@ -1600,26 +1608,27 @@ class TrafficLightViolation(object):
         self.process_violations()
         return self.violation_counts["illegalDrivingOrParkingAtCrossroads"]
 
+
 class WarningViolation(object):
     """警告性违规类"""
 
     def __init__(self, df_data):
         print("警告性违规类初始化中...")
         self.traffic_violations_type = "警告性违规类"
-        
+
         # 存储原始数据引用
         self.config = df_data.vehicle_config
         self._raw_data = df_data.obj_data[1]
-        
+
         # 初始化属性,但不立即创建数据副本
         self._data = None
-        
+
         # 初始化违规计数器
         self.violation_counts = {
             "generalRoadIrregularLaneUse": 0,  # 驾驶机动车在高速公路、城市快速路以外的道路上不按规定车道行驶
             "urbanExpresswayOrHighwayRideLaneDivider": 0,  # 机动车在高速公路或者城市快速路上骑、轧车行道分界线
         }
-    
+
     @property
     def data(self):
         """懒加载方式获取数据"""
@@ -1632,7 +1641,7 @@ class WarningViolation(object):
         """处理所有违规类型"""
         # 处理普通道路不按规定车道行驶违规
         self._process_irregular_lane_use()
-        
+
         # 处理骑、轧车行道分界线违规
         self._process_lane_divider_violation()
 
@@ -1640,21 +1649,21 @@ class WarningViolation(object):
         """处理普通道路不按规定车道行驶违规"""
         # 定义道路和车道类型
         general_road = {3}  # 普通道路
-        lane_type = {11}    # 非机动车道
-        
+        lane_type = {11}  # 非机动车道
+
         # 使用布尔索引来筛选满足条件的行
         condition = (self.data["road_fc"].isin(general_road)) & (
             self.data["lane_type"].isin(lane_type)
         )
-        
+
         # 创建一个新的列,并根据条件设置值
         self.data["is_violation"] = condition
-        
+
         # 统计满足条件的连续时间段
         violation_segments = self.count_continuous_violations(
             self.data["is_violation"], self.data["simTime"]
         )
-        
+
         # 更新违规计数
         self.violation_counts["generalRoadIrregularLaneUse"] = len(violation_segments)
 
@@ -1663,18 +1672,18 @@ class WarningViolation(object):
         # 获取车辆和车道宽度
         car_width = self.config["CAR_WIDTH"]
         lane_width = self.data["lane_width"]
-        
+
         # 计算阈值
         threshold = (lane_width - car_width) / 2
-        
+
         # 找到满足条件的行
         self.data["is_violation"] = self.data["laneOffset"] > threshold
-        
+
         # 统计满足条件的连续时间段
         violation_segments = self.count_continuous_violations(
             self.data["is_violation"], self.data["simTime"]
         )
-        
+
         # 更新违规计数
         self.violation_counts["urbanExpresswayOrHighwayRideLaneDivider"] = len(
             violation_segments
@@ -1682,11 +1691,11 @@ class WarningViolation(object):
 
     def count_continuous_violations(self, violation_series, time_series):
         """统计连续违规的时间段数量
-        
+
         Args:
             violation_series: 表示是否违规的布尔序列
             time_series: 对应的时间序列
-            
+
         Returns:
             list: 连续违规时间段列表
         """
@@ -1722,9 +1731,10 @@ class WarningViolation(object):
         self._process_lane_divider_violation()
         return self.violation_counts["urbanExpresswayOrHighwayRideLaneDivider"]
 
+
 class TrafficSignViolation:
     """交通标志违规类"""
-    
+
     PROHIBITED_STRAIGHT_THRESHOLD = 5
     SIGN_TYPE_STRAIGHT_PROHIBITED = 7
     SIGN_TYPE_SPEED_LIMIT = 12
@@ -1733,13 +1743,13 @@ class TrafficSignViolation:
     def __init__(self, df_data):
         print("交通标志违规类初始化中...")
         self.traffic_violations_type = "交通标志违规类"
-        
+
         # 存储原始数据引用
         self._raw_data = df_data.obj_data[1]
-        
+
         # 初始化属性,但不立即创建数据副本
         self._data = None
-        
+
         # 延迟计算标志
         self._calculated = False
         self._violation_counts = {
@@ -1747,7 +1757,7 @@ class TrafficSignViolation:
             "SpeedLimitViolation": 0,
             "MinimumSpeedLimitViolation": 0
         }
-    
+
     @property
     def data(self):
         """懒加载方式获取数据"""
@@ -1800,30 +1810,30 @@ class TrafficSignViolation:
     def _check_straight_violation(self):
         """检查禁止直行违规"""
         straight_df = self.data[self.data["sign_type1"] == self.SIGN_TYPE_STRAIGHT_PROHIBITED]
-        
+
         if not straight_df.empty:
             # 计算航向角变化并填充缺失值
             straight_df = straight_df.copy()
             straight_df['posH_diff'] = straight_df['posH'].diff().abs().fillna(0)
-            
+
             # 创建筛选条件
             mask = (
-                (straight_df['posH_diff'] <= self.PROHIBITED_STRAIGHT_THRESHOLD) &
-                (straight_df['v'] > 0)
+                    (straight_df['posH_diff'] <= self.PROHIBITED_STRAIGHT_THRESHOLD) &
+                    (straight_df['v'] > 0)
             )
-            
+
             self._violation_counts["NoStraightThrough"] = mask.sum()
 
     def _check_speed_violation(self, sign_type, compare_op, count_key):
         """通用速度违规检查方法
-        
+
         Args:
             sign_type: 标志类型
             compare_op: 比较操作符
             count_key: 违规计数键名
         """
         violation_df = self.data[self.data["sign_type1"] == sign_type]
-        
+
         if not violation_df.empty:
             mask = compare_op(violation_df['v'], violation_df['sign_speed'])
             self._violation_counts[count_key] = mask.sum()
@@ -1834,4 +1844,3 @@ class TrafficSignViolation:
 
 
 
- 

+ 149 - 124
scripts/evaluator_enhanced.py

@@ -17,7 +17,6 @@ import traceback
 import json
 import inspect
 
-
 # 常量定义
 DEFAULT_WORKERS = 4
 CUSTOM_METRIC_PREFIX = "metric_"
@@ -31,16 +30,17 @@ else:
 
 sys.path.insert(0, str(_ROOT_PATH))
 
+
 class ConfigManager:
     """配置管理组件"""
-    
+
     def __init__(self, logger: logging.Logger):
         self.logger = logger
         self.base_config: Dict[str, Any] = {}
         self.custom_config: Dict[str, Any] = {}
         self.merged_config: Dict[str, Any] = {}
         self._config_cache = {}
-    
+
     def split_configs(self, all_metrics_path: Path, builtin_metrics_path: Path, custom_metrics_path: Path) -> None:
         """从all_metrics_config.yaml拆分成内置和自定义配置"""
         # 检查是否已经存在提取的配置文件,如果存在则跳过拆分过程
@@ -48,36 +48,36 @@ class ConfigManager:
         if extracted_builtin_path.exists() and custom_metrics_path.exists():
             self.logger.info(f"使用已存在的拆分配置文件: {extracted_builtin_path}")
             return
-            
+
         try:
             # 使用缓存加载配置文件,避免重复读取
             all_metrics_dict = self._safe_load_config(all_metrics_path)
             builtin_metrics_dict = self._safe_load_config(builtin_metrics_path)
-            
+
             # 递归提取内置和自定义指标
             extracted_builtin_metrics, custom_metrics_dict = self._split_metrics_recursive(
                 all_metrics_dict, builtin_metrics_dict
             )
-            
+
             # 保存提取的内置指标到新文件
             with open(extracted_builtin_path, 'w', encoding='utf-8') as f:
                 yaml.dump(extracted_builtin_metrics, f, allow_unicode=True, sort_keys=False, indent=2)
             self.logger.info(f"拆分配置: 提取的内置指标已保存到 {extracted_builtin_path}")
-            
+
             if custom_metrics_dict:
                 with open(custom_metrics_path, 'w', encoding='utf-8') as f:
                     yaml.dump(custom_metrics_dict, f, allow_unicode=True, sort_keys=False, indent=2)
                 self.logger.info(f"拆分配置: 自定义指标已保存到 {custom_metrics_path}")
-                
+
         except Exception as err:
             self.logger.error(f"拆分配置失败: {str(err)}")
             raise
-    
+
     def _split_metrics_recursive(self, all_dict: Dict, builtin_dict: Dict) -> Tuple[Dict, Dict]:
         """递归拆分内置和自定义指标配置"""
         extracted_builtin = {}
         custom_metrics = {}
-        
+
         for key, value in all_dict.items():
             if key in builtin_dict:
                 # 如果是字典类型,继续递归
@@ -93,31 +93,32 @@ class ConfigManager:
             else:
                 # 如果键不在内置配置中,归类为自定义指标
                 custom_metrics[key] = value
-        
+
         return extracted_builtin, custom_metrics
-    
-    def load_configs(self, all_config_path: Optional[Path], builtin_metrics_path: Optional[Path], custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
+
+    def load_configs(self, all_config_path: Optional[Path], builtin_metrics_path: Optional[Path],
+                     custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
         """加载并合并配置"""
         # 如果已经加载过配置,直接返回缓存的结果
         cache_key = f"{all_config_path}_{builtin_metrics_path}_{custom_metrics_path}"
         if cache_key in self._config_cache:
             self.logger.info("使用缓存的配置数据")
             return self._config_cache[cache_key]
-            
+
         # 自动拆分配置
         extracted_builtin_path = None
-        
+
         if all_config_path and all_config_path.exists():
             # 生成提取的内置指标配置文件路径
             extracted_builtin_path = builtin_metrics_path.parent / f"{builtin_metrics_path.stem}_extracted{builtin_metrics_path.suffix}"
             self.split_configs(all_config_path, builtin_metrics_path, custom_metrics_path)
-            
+
         # 优先使用提取的内置指标配置
         if extracted_builtin_path and extracted_builtin_path.exists():
             self.base_config = self._safe_load_config(extracted_builtin_path)
         else:
             self.base_config = self._safe_load_config(builtin_metrics_path) if builtin_metrics_path else {}
-            
+
         self.custom_config = self._safe_load_config(custom_metrics_path) if custom_metrics_path else {}
         if all_config_path and all_config_path.exists():
             self.merged_config = self._safe_load_config(all_config_path)
@@ -125,7 +126,7 @@ class ConfigManager:
             self._config_cache[cache_key] = self.merged_config
             return self.merged_config
         return {}
-    
+
     @lru_cache(maxsize=16)
     def _safe_load_config(self, config_path: Path) -> Dict[str, Any]:
         """安全加载YAML配置,使用lru_cache减少重复读取"""
@@ -140,25 +141,26 @@ class ConfigManager:
         except Exception as err:
             self.logger.error(f"Failed to load config {config_path}: {str(err)}")
             return {}
-    
+
     def get_config(self) -> Dict[str, Any]:
         return self.merged_config
-    
+
     def get_base_config(self) -> Dict[str, Any]:
         return self.base_config
-    
+
     def get_custom_config(self) -> Dict[str, Any]:
         return self.custom_config
 
+
 class MetricLoader:
     """指标加载器组件"""
-    
+
     def __init__(self, logger: logging.Logger, config_manager: ConfigManager):
         self.logger = logger
         self.config_manager = config_manager
         self.metric_modules: Dict[str, Type] = {}
         self.custom_metric_modules: Dict[str, Any] = {}
-    
+
     def load_builtin_metrics(self) -> Dict[str, Type]:
         """加载内置指标模块"""
         module_mapping = {
@@ -168,15 +170,15 @@ class MetricLoader:
             "efficient": ("modules.metric.efficient", "EfficientManager"),
             "function": ("modules.metric.function", "FunctionManager"),
         }
-        
+
         self.metric_modules = {
             name: self._load_module(*info)
             for name, info in module_mapping.items()
         }
-        
+
         self.logger.info(f"Loaded builtin metrics: {', '.join(self.metric_modules.keys())}")
         return self.metric_modules
-    
+
     @lru_cache(maxsize=32)
     def _load_module(self, module_path: str, class_name: str) -> Type:
         """动态加载Python模块"""
@@ -186,7 +188,7 @@ class MetricLoader:
         except (ImportError, AttributeError) as e:
             self.logger.error(f"Failed to load module: {module_path}.{class_name} - {str(e)}")
             raise
-    
+
     def load_custom_metrics(self, custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
         """加载自定义指标模块"""
         if not custom_metrics_path or not custom_metrics_path.is_dir():
@@ -194,10 +196,10 @@ class MetricLoader:
             return {}
 
         # 检查是否有新的自定义指标文件
-        current_files = set(f.name for f in custom_metrics_path.glob(CUSTOM_METRIC_FILE_PATTERN) 
-                          if f.name.startswith(CUSTOM_METRIC_PREFIX))
+        current_files = set(f.name for f in custom_metrics_path.glob(CUSTOM_METRIC_FILE_PATTERN)
+                            if f.name.startswith(CUSTOM_METRIC_PREFIX))
         loaded_files = set(self.custom_metric_modules.keys())
-        
+
         # 如果没有新文件且已有加载的模块,直接返回
         if self.custom_metric_modules and not (current_files - loaded_files):
             self.logger.info(f"No new custom metrics to load, using {len(self.custom_metric_modules)} cached modules")
@@ -208,30 +210,30 @@ class MetricLoader:
             if py_file.name.startswith(CUSTOM_METRIC_PREFIX):
                 if self._process_custom_metric_file(py_file):
                     loaded_count += 1
-        
+
         self.logger.info(f"Loaded {loaded_count} custom metric modules")
         return self.custom_metric_modules
-    
+
     def _process_custom_metric_file(self, file_path: Path) -> bool:
         """处理单个自定义指标文件"""
         try:
             metric_key = self._validate_metric_file(file_path)
-            
+
             module_name = f"custom_metric_{file_path.stem}"
             spec = importlib.util.spec_from_file_location(module_name, file_path)
             module = importlib.util.module_from_spec(spec)
             spec.loader.exec_module(module)
-            
+
             from modules.lib.metric_registry import BaseMetric
             metric_class = None
-            
+
             for name, obj in inspect.getmembers(module):
-                if (inspect.isclass(obj) and 
-                    issubclass(obj, BaseMetric) and 
-                    obj != BaseMetric):
+                if (inspect.isclass(obj) and
+                        issubclass(obj, BaseMetric) and
+                        obj != BaseMetric):
                     metric_class = obj
                     break
-            
+
             if metric_class:
                 self.custom_metric_modules[metric_key] = {
                     'type': 'class',
@@ -247,7 +249,7 @@ class MetricLoader:
                 self.logger.info(f"Loaded function-based custom metric: {metric_key}")
             else:
                 raise AttributeError(f"Missing evaluate() function or BaseMetric subclass: {file_path.name}")
-                
+
             return True
         except ValueError as e:
             self.logger.warning(str(e))
@@ -255,24 +257,25 @@ class MetricLoader:
         except Exception as e:
             self.logger.error(f"Failed to load custom metric {file_path}: {str(e)}")
             return False
-    
+
     def _validate_metric_file(self, file_path: Path) -> str:
         """验证自定义指标文件命名规范"""
         stem = file_path.stem[len(CUSTOM_METRIC_PREFIX):]
         parts = stem.split('_')
         if len(parts) < 3:
-            raise ValueError(f"Invalid custom metric filename: {file_path.name}, should be metric_<level1>_<level2>_<level3>.py")
+            raise ValueError(
+                f"Invalid custom metric filename: {file_path.name}, should be metric_<level1>_<level2>_<level3>.py")
 
         level1, level2, level3 = parts[:3]
         if not self._is_metric_configured(level1, level2, level3):
             raise ValueError(f"Unconfigured metric: {level1}.{level2}.{level3}")
         return f"{level1}.{level2}.{level3}"
-    
+
     def _is_metric_configured(self, level1: str, level2: str, level3: str) -> bool:
         """检查指标是否在配置中注册"""
         custom_config = self.config_manager.get_custom_config()
         try:
-            return (level1 in custom_config and 
+            return (level1 in custom_config and
                     isinstance(custom_config[level1], dict) and
                     level2 in custom_config[level1] and
                     isinstance(custom_config[level1][level2], dict) and
@@ -280,54 +283,58 @@ class MetricLoader:
                     isinstance(custom_config[level1][level2][level3], dict))
         except Exception:
             return False
-    
+
     def get_builtin_metrics(self) -> Dict[str, Type]:
         return self.metric_modules
-    
+
     def get_custom_metrics(self) -> Dict[str, Any]:
         return self.custom_metric_modules
 
+
 class EvaluationEngine:
     """评估引擎组件"""
-    
-    def __init__(self, logger: logging.Logger, config_manager: ConfigManager, metric_loader: MetricLoader):
+
+    def __init__(self, logger: logging.Logger, config_manager: ConfigManager, metric_loader: MetricLoader,
+                 plot_path: str):
         self.logger = logger
         self.config_manager = config_manager
         self.metric_loader = metric_loader
-    
+        self.plot_path = plot_path
+
     def evaluate(self, data: Any) -> Dict[str, Any]:
         """执行评估流程"""
         raw_results = self._collect_builtin_metrics(data)
         custom_results = self._collect_custom_metrics(data)
         return self._process_merged_results(raw_results, custom_results)
-    
+
     def _collect_builtin_metrics(self, data: Any) -> Dict[str, Any]:
         """收集内置指标结果"""
         metric_modules = self.metric_loader.get_builtin_metrics()
+        x = metric_modules.items()
         raw_results: Dict[str, Any] = {}
-        
+
         # 获取配置中实际存在的指标
         config = self.config_manager.get_config()
         available_metrics = {
             metric_name for metric_name in metric_modules.keys()
             if metric_name in config and isinstance(config[metric_name], dict)
         }
-        
+
         # 只处理配置中存在的指标
         filtered_modules = {
             name: module for name, module in metric_modules.items()
             if name in available_metrics
         }
-        
+
         # 优化线程池大小,避免创建过多线程
         max_workers = min(len(filtered_modules), DEFAULT_WORKERS)
-        
+
         with ThreadPoolExecutor(max_workers=max_workers) as executor:
             futures = {
-                executor.submit(self._run_module, module, data, module_name): module_name
-                for module_name, module in filtered_modules.items()
+                executor.submit(self._run_module, module, data, module_name, self.plot_path): module_name for
+                module_name, module in filtered_modules.items()
             }
-    
+
             for future in futures:
                 module_name = futures[future]
                 try:
@@ -343,27 +350,27 @@ class EvaluationEngine:
                         "message": str(e),
                         "timestamp": datetime.now().isoformat(),
                     }
-        
+
         return raw_results
-    
+
     def _collect_custom_metrics(self, data: Any) -> Dict[str, Dict]:
         """收集自定义指标结果"""
         custom_metrics = self.metric_loader.get_custom_metrics()
         if not custom_metrics:
             return {}
-            
+
         custom_results = {}
-        
+
         # 使用线程池并行处理自定义指标
         max_workers = min(len(custom_metrics), DEFAULT_WORKERS)
-        
+
         with ThreadPoolExecutor(max_workers=max_workers) as executor:
             futures = {}
-            
+
             # 提交所有自定义指标任务
             for metric_key, metric_info in custom_metrics.items():
                 futures[executor.submit(self._run_custom_metric, metric_key, metric_info, data)] = metric_key
-            
+
             # 收集结果
             for future in futures:
                 metric_key = futures[future]
@@ -373,14 +380,14 @@ class EvaluationEngine:
                         custom_results[level1] = result
                 except Exception as e:
                     self.logger.error(f"Custom metric {metric_key} execution failed: {str(e)}")
-        
+
         return custom_results
-        
+
     def _run_custom_metric(self, metric_key: str, metric_info: Dict, data: Any) -> Tuple[str, Dict]:
         """执行单个自定义指标"""
         try:
             level1, level2, level3 = metric_key.split('.')
-            
+
             if metric_info['type'] == 'class':
                 metric_class = metric_info['class']
                 metric_instance = metric_class(data)
@@ -388,10 +395,10 @@ class EvaluationEngine:
             else:
                 module = metric_info['module']
                 metric_result = module.evaluate(data)
-            
+
             self.logger.info(f"Calculated custom metric: {level1}.{level2}.{level3}")
             return level1, metric_result
-            
+
         except Exception as e:
             self.logger.error(f"Custom metric {metric_key} failed: {str(e)}")
             try:
@@ -403,7 +410,7 @@ class EvaluationEngine:
                 }
             except Exception:
                 return "", {}
-    
+
     def _process_merged_results(self, raw_results: Dict, custom_results: Dict) -> Dict:
         """处理合并后的评估结果"""
         from modules.lib.score import Score
@@ -429,30 +436,31 @@ class EvaluationEngine:
                     final_results[level1] = self._format_error(e)
 
         return final_results
-        
+
     def _format_error(self, e: Exception) -> Dict:
         return {
             "status": "error",
             "message": str(e),
             "timestamp": datetime.now().isoformat()
         }
-                
-    def _run_module(self, module_class: Any, data: Any, module_name: str) -> Dict[str, Any]:
+
+    def _run_module(self, module_class: Any, data: Any, module_name: str, plot_path: str) -> Dict[str, Any]:
         """执行单个评估模块"""
         try:
-            instance = module_class(data)
+            instance = module_class(data, plot_path)
             return {module_name: instance.report_statistic()}
         except Exception as e:
             self.logger.error(f"{module_name} execution error: {str(e)}", exc_info=True)
             return {module_name: {"error": str(e)}}
 
+
 class LoggingManager:
     """日志管理组件"""
-    
+
     def __init__(self, log_path: Path):
         self.log_path = log_path
         self.logger = self._init_logger()
-    
+
     def _init_logger(self) -> logging.Logger:
         """初始化日志系统"""
         try:
@@ -467,27 +475,28 @@ class LoggingManager:
             logger.addHandler(console_handler)
             logger.warning(f"Failed to init standard logger: {str(e)}, using fallback logger")
             return logger
-    
+
     def get_logger(self) -> logging.Logger:
         return self.logger
 
+
 class DataProcessor:
     """数据处理组件"""
-    
+
     def __init__(self, logger: logging.Logger, data_path: Path, config_path: Optional[Path] = None):
         self.logger = logger
         self.data_path = data_path
         self.config_path = config_path
         self.case_name = self.data_path.name
         self._processor = None
-    
+
     @property
     def processor(self) -> Any:
         """懒加载数据处理器,只在首次访问时创建"""
         if self._processor is None:
             self._processor = self._load_processor()
         return self._processor
-    
+
     def _load_processor(self) -> Any:
         """加载数据处理器"""
         try:
@@ -500,7 +509,7 @@ class DataProcessor:
         except ImportError as e:
             self.logger.error(f"Failed to load data processor: {str(e)}")
             raise RuntimeError(f"Failed to load data processor: {str(e)}") from e
-    
+
     def validate(self) -> None:
         """验证数据路径"""
         if not self.data_path.exists():
@@ -508,10 +517,12 @@ class DataProcessor:
         if not self.data_path.is_dir():
             raise NotADirectoryError(f"Invalid data directory: {self.data_path}")
 
+
 class EvaluationPipeline:
     """评估流水线控制器"""
-    
-    def __init__(self, all_config_path: str, base_config_path: str, log_path: str, data_path: str, report_path: str, 
+
+    def __init__(self, all_config_path: str, base_config_path: str, log_path: str, data_path: str, report_path: str,
+                 plot_path: str,
                  custom_metrics_path: Optional[str] = None, custom_config_path: Optional[str] = None):
         # 路径初始化
         self.all_config_path = Path(all_config_path) if all_config_path else None
@@ -519,8 +530,9 @@ class EvaluationPipeline:
         self.custom_config_path = Path(custom_config_path) if custom_config_path else None
         self.data_path = Path(data_path)
         self.report_path = Path(report_path)
+        self.plot_path = Path(plot_path)
         self.custom_metrics_path = Path(custom_metrics_path) if custom_metrics_path else None
-        
+
         # 日志
         self.logging_manager = LoggingManager(Path(log_path))
         self.logger = self.logging_manager.get_logger()
@@ -535,37 +547,38 @@ class EvaluationPipeline:
         self.metric_loader.load_custom_metrics(self.custom_metrics_path)
         # 数据处理
         self.data_processor = DataProcessor(self.logger, self.data_path, self.all_config_path)
-        self.evaluation_engine = EvaluationEngine(self.logger, self.config_manager, self.metric_loader)
-    
+        self.evaluation_engine = EvaluationEngine(self.logger, self.config_manager, self.metric_loader, self.plot_path)
+
     def execute(self) -> Dict[str, Any]:
         """执行评估流水线"""
         try:
             # 只在首次运行时验证数据路径
             self.data_processor.validate()
-            
+
             self.logger.info(f"Start evaluation: {self.data_path.name}")
             start_time = time.perf_counter()
-            
+
             # 性能分析日志
             config_start = time.perf_counter()
             results = self.evaluation_engine.evaluate(self.data_processor.processor)
             eval_time = time.perf_counter() - config_start
-            
+
             # 生成报告
             report_start = time.perf_counter()
             report = self._generate_report(self.data_processor.case_name, results)
             report_time = time.perf_counter() - report_start
-            
+
             # 总耗时
             elapsed_time = time.perf_counter() - start_time
-            self.logger.info(f"Evaluation completed, time: {elapsed_time:.2f}s (评估: {eval_time:.2f}s, 报告: {report_time:.2f}s)")
-            
+            self.logger.info(
+                f"Evaluation completed, time: {elapsed_time:.2f}s (评估: {eval_time:.2f}s, 报告: {report_time:.2f}s)")
+
             return report
-            
+
         except Exception as e:
             self.logger.critical(f"Evaluation failed: {str(e)}", exc_info=True)
             return {"error": str(e), "traceback": traceback.format_exc()}
-    
+
     def _add_overall_result(self, report: Dict[str, Any]) -> Dict[str, Any]:
         """处理评测报告并添加总体结果字段"""
         # 加载阈值参数
@@ -574,17 +587,17 @@ class EvaluationPipeline:
             "T1": self.config['T_threshold']['T1_threshold'],
             "T2": self.config['T_threshold']['T2_threshold']
         }
-        
+
         # 初始化计数器
         counters = {'p0': 0, 'p1': 0, 'p2': 0}
-        
+
         # 优化:一次性收集所有失败的指标
         failed_categories = [
             (category, category_data.get('priority'))
             for category, category_data in report.items()
             if isinstance(category_data, dict) and category != "metadata" and not category_data.get('result', True)
         ]
-        
+
         # 计数
         for _, priority in failed_categories:
             if priority == 0:
@@ -593,18 +606,18 @@ class EvaluationPipeline:
                 counters['p1'] += 1
             elif priority == 2:
                 counters['p2'] += 1
-        
+
         # 阈值判断逻辑
         overall_result = not (
-            counters['p0'] > thresholds['T0'] or
-            counters['p1'] > thresholds['T1'] or
-            counters['p2'] > thresholds['T2']
+                counters['p0'] > thresholds['T0'] or
+                counters['p1'] > thresholds['T1'] or
+                counters['p2'] > thresholds['T2']
         )
-        
+
         # 生成处理后的报告
         processed_report = report.copy()
         processed_report['overall_result'] = overall_result
-        
+
         # 添加统计信息
         processed_report['threshold_checks'] = {
             'T0_threshold': thresholds['T0'],
@@ -612,58 +625,62 @@ class EvaluationPipeline:
             'T2_threshold': thresholds['T2'],
             'actual_counts': counters
         }
-        
+
         self.logger.info(f"Added overall result: {overall_result}")
         return processed_report
-        
+
     def _generate_report(self, case_name: str, results: Dict[str, Any]) -> Dict[str, Any]:
         """生成评估报告"""
         from modules.lib.common import dict2json
-        
+
         self.report_path.mkdir(parents=True, exist_ok=True)
-        
+
         results["metadata"] = {
             "case_name": case_name,
             "timestamp": datetime.now().isoformat(),
             "version": "1.0",
         }
-        
+
         # 添加总体结果评估
         results = self._add_overall_result(results)
-        
+
         report_file = self.report_path / f"{case_name}_report.json"
         dict2json(results, report_file)
         self.logger.info(f"Report generated: {report_file}")
-        
+
         return results
 
+
 def main():
     """命令行入口"""
     parser = argparse.ArgumentParser(
         description="Autonomous Driving Evaluation System V3.1",
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
-    
+
     # 必要参数
     parser.add_argument(
         "--dataPath",
         type=str,
-        default=r"D:\Kevin\zhaoyuan\data\V2V_CSAE53-2020_ForwardCollision_LST_01-02",
+        # default=r"D:\Cicv\招远\V2V_CSAE53-2020_ForwardCollision_LST_01-02_new",
+        # default=r"D:\Cicv\招远\AD_GBT41798-2022_TrafficSignalRecognitionAndResponse_LST_01",
+        # default=r"/home/server/桌面/XGJ/zhaoyuan_DataPreProcess/output/AD_GBT41798-2022_TrafficSignalRecognitionAndResponse_LST_02",
+        default=r"/home/server/桌面/XGJ/zhaoyuan_DataPreProcess/output/V2I_CSAE53-2020_LeftTurnAssist_PGVIL_demo",
         help="Input data directory",
     )
-    
+
     # 配置参数
     config_group = parser.add_argument_group('Configuration')
     config_group.add_argument(
         "--allConfigPath",
         type=str,
-        default=r"config/all_metrics_config.yaml",
+        default=r"/home/server/anaconda3/envs/vitual_XGJ/zhaoyuan_0617/zhaoyuan/config/all_metrics_config.yaml",
         help="Full metrics config file path (built-in + custom)",
     )
     config_group.add_argument(
         "--baseConfigPath",
         type=str,
-        default=r"config/builtin_metrics_config.yaml",
+        default=r"/home/server/anaconda3/envs/vitual_XGJ/zhaoyuan_0617/zhaoyuan/config/all_metrics_config.yaml",
         help="Built-in metrics config file path",
     )
     config_group.add_argument(
@@ -672,7 +689,7 @@ def main():
         default=r"config/custom_metrics_config.yaml",
         help="Custom metrics config path (optional)",
     )
-    
+
     # 输出参数
     output_group = parser.add_argument_group('Output')
     output_group.add_argument(
@@ -687,7 +704,13 @@ def main():
         default="reports",
         help="Output report directory",
     )
-    
+    output_group.add_argument(
+        "--plotPath",
+        type=str,
+        default=r"/home/server/anaconda3/envs/vitual_XGJ/zhaoyuan_0617/zhaoyuan/scripts/reports/datas",
+        help="Output plot csv directory",
+    )
+
     # 扩展参数
     ext_group = parser.add_argument_group('Extensions')
     ext_group.add_argument(
@@ -696,20 +719,21 @@ def main():
         default="custom_metrics",
         help="Custom metrics scripts directory (optional)",
     )
-    
+
     args = parser.parse_args()
 
     try:
         pipeline = EvaluationPipeline(
             all_config_path=args.allConfigPath,
             base_config_path=args.baseConfigPath,
-            log_path=args.logPath, 
-            data_path=args.dataPath, 
-            report_path=args.reportPath, 
-            custom_metrics_path=args.customMetricsPath, 
+            log_path=args.logPath,
+            data_path=args.dataPath,
+            report_path=args.reportPath,
+            plot_path=args.plotPath,
+            custom_metrics_path=args.customMetricsPath,
             custom_config_path=args.customConfigPath
         )
-        
+
         start_time = time.perf_counter()
         result = pipeline.execute()
         elapsed_time = time.perf_counter() - start_time
@@ -720,7 +744,7 @@ def main():
 
         print(f"Evaluation completed, total time: {elapsed_time:.2f}s")
         print(f"Report path: {pipeline.report_path}")
-        
+
     except KeyboardInterrupt:
         print("\nUser interrupted")
         sys.exit(130)
@@ -729,6 +753,7 @@ def main():
         traceback.print_exc()
         sys.exit(1)
 
+
 if __name__ == "__main__":
     warnings.filterwarnings("ignore")
     main()