Browse Source

新增DTC指标

XGJ_zhaoyuan 2 weeks ago
parent
commit
568536053a
2 changed files with 92 additions and 3 deletions
  1. 5 0
      config/all_metrics_config.yaml
  2. 87 3
      modules/metric/safety.py

+ 5 - 0
config/all_metrics_config.yaml

@@ -58,6 +58,11 @@ safety:
       priority: 0
       max: 2000.0
       min: 1.5
+    DTC:
+      name: DTC
+      priority: 0
+      max: 2000.0
+      min: 1.5
   safeDistance:
     name: safeDistance
     priority: 0

+ 87 - 3
modules/metric/safety.py

@@ -110,6 +110,58 @@ def calculate_tm(data_processed) -> dict:
         LogManager().get_logger().error(f"TM计算异常: {str(e)}", exc_info=True)
         return {"TM": None}
 
+# def calculate_MPrTTC(data_processed) -> dict:
+#     """计算MPrTTC (Model Predictive Time-to-Collision)"""
+#     if data_processed is None or not hasattr(data_processed, 'object_df'):
+#         return {"MPrTTC": None}
+#     try:
+#         safety = SafetyCalculator(data_processed)
+#         mprttc_value = safety.get_mprttc_value()
+#         LogManager().get_logger().info(f"安全指标[MPrTTC]计算结果: {mprttc_value}")
+#         return {"MPrTTC": mprttc_value}
+#     except Exception as e:
+#         LogManager().get_logger().error(f"MPrTTC计算异常: {str(e)}", exc_info=True)
+#         return {"MPrTTC": None}
+
+# def calculate_pet(data_processed) -> dict:
+#     # PET (Post Encroachment Time)
+#     if data_processed is None or not hasattr(data_processed, 'object_df'):
+#         return {"PET": None}
+#     try:
+#         safety = SafetyCalculator(data_processed)
+#         pet_value = safety.get_pet_value()
+#         LogManager().get_logger().info(f"安全指标[PET]计算结果: {pet_value}")
+#         return {"PET": pet_value}
+#     except Exception as e:
+#         LogManager().get_logger().error(f"PET计算异常: {str(e)}", exc_info=True)
+#         return {"PET": None}
+
+def calculate_dtc(data_processed) -> dict:
+    """计算DTC (Distance to Collision)"""
+    if data_processed is None or not hasattr(data_processed, 'object_df'):
+        return {"DTC": None}
+    try:
+        safety = SafetyCalculator(data_processed)
+        dtc_value = safety.get_dtc_value()
+        LogManager().get_logger().info(f"安全指标[DTC]计算结果: {dtc_value}")
+        return {"DTC": dtc_value}
+    except Exception as e:
+        LogManager().get_logger().error(f"DTC计算异常: {str(e)}", exc_info=True)
+        return {"DTC": None}
+
+def calculate_dtc(data_processed) -> dict:
+    """计算DTC (Distance to Collision)"""
+    if data_processed is None or not hasattr(data_processed, 'object_df'):
+        return {"DTC": None}
+    try:
+        safety = SafetyCalculator(data_processed)
+        dtc_value = safety.get_dtc_value()
+        LogManager().get_logger().info(f"安全指标[DTC]计算结果: {dtc_value}")
+        return {"DTC": dtc_value}
+    except Exception as e:
+        LogManager().get_logger().error(f"DTC计算异常: {str(e)}", exc_info=True)
+        return {"DTC": None}
+
 def calculate_collisionrisk(data_processed) -> dict:
     """计算碰撞风险"""
     safety = SafetyCalculator(data_processed)
@@ -218,8 +270,9 @@ class SafetyCalculator:
         self.df = data_processed.object_df.copy()
         self.ego_df = data_processed.ego_data.copy()  # 使用copy()避免修改原始数据
         self.obj_id_list = data_processed.obj_id_list
+
         self.metric_list = [
-            'TTC', 'MTTC', 'THW', 'LonSD', 'LatSD', 'BTN', 'collisionRisk', 'collisionSeverity'
+            'TTC', 'MTTC', 'THW', 'TLC', 'TTB', 'TM', 'DTC', 'LonSD', 'LatSD', 'BTN', 'collisionRisk', 'collisionSeverity'
         ]
 
         # 初始化默认值
@@ -230,6 +283,9 @@ class SafetyCalculator:
             "TLC": 10.0,
             "TTB": 10.0,
             "TM": 10.0,
+            # "MPrTTC": 10.0,
+            # "PET": 10.0,
+            "DTC": 10.0,
             "LatSD": 3.0,
             "BTN": 1.0,
             "collisionRisk": 0.0,
@@ -269,6 +325,7 @@ class SafetyCalculator:
         ego_decel_min = self.data_processed.vehicle_config["EGO_DECEL_MIN"]
         ego_decel_lon_max = self.data_processed.vehicle_config["EGO_DECEL_LON_MAX"]
         ego_decel_lat_max = self.data_processed.vehicle_config["EGO_DECEL_LAT_MAX"]
+        driver_reaction_time = self.data_processed.vehicle_config["RHO"]
         ego_decel_max = np.sqrt(ego_decel_lon_max ** 2 + ego_decel_lat_max ** 2)
         x_relative_start_dist = self.ego_df["x_relative_start_dist"]
 
@@ -358,6 +415,10 @@ class SafetyCalculator:
                 TLC = self._cal_TLC(v1, h1, laneOffset)
                 TTB = self._cal_TTB(x_relative_start_dist, relative_v, ego_decel_max)
                 TM = self._cal_TM(x_relative_start_dist, v2, a2, v1, a1)
+                DTC = self._cal_DTC(vrel_projection_in_dist, arel_projection_in_dist, driver_reaction_time)
+                # MPrTTC = self._cal_MPrTTC(x_relative_start_dist)
+                # PET = self._cal_PET(x_relative_start_dist, v2, a2, v1, a1)
+
 
                 LonSD = self._cal_longitudinal_safe_dist(v_ego_p, v_obj_p, rho, ego_accel_max, ego_decel_min, obj_decel_max)
 
@@ -394,6 +455,7 @@ class SafetyCalculator:
                 TLC = None if (TLC is None or TLC < 0) else TLC
                 TTB = None if (TTB is None or TTB < 0) else TTB
                 TM = None if (TM is None or TM < 0) else TM
+                DTC = None if (DTC is None or DTC < 0) else DTC
 
                 obj_dict[frame_num][playerId]['TTC'] = TTC
                 obj_dict[frame_num][playerId]['MTTC'] = MTTC
@@ -401,6 +463,7 @@ class SafetyCalculator:
                 obj_dict[frame_num][playerId]['TLC'] = TLC
                 obj_dict[frame_num][playerId]['TTB'] = TTB
                 obj_dict[frame_num][playerId]['TM'] = TM
+                obj_dict[frame_num][playerId]['DTC'] = DTC
                 obj_dict[frame_num][playerId]['LonSD'] = LonSD
                 obj_dict[frame_num][playerId]['LatSD'] = LatSD
                 obj_dict[frame_num][playerId]['BTN'] = abs(BTN)
@@ -418,7 +481,7 @@ class SafetyCalculator:
 
         df_safe = pd.concat(df_list)
         col_list = ['simTime', 'simFrame', 'playerId',
-                    'TTC', 'MTTC', 'THW', 'TLC', 'TTB', 'TM', 'LonSD', 'LatSD', 'BTN',
+                    'TTC', 'MTTC', 'THW', 'TLC', 'TTB', 'TM', 'DTC', 'LonSD', 'LatSD', 'BTN',
                     'collisionSeverity', 'pr_death', 'collisionRisk']
         self.df_safe = df_safe[col_list].reset_index(drop=True)
 
@@ -597,6 +660,20 @@ class SafetyCalculator:
         TM = (x_relative_start_dist0 + v2**2/(2*a2) - v1**2/(2*a1)) / v1
         return TM
 
+    # def _cal_MPrTTC(self, T=5, c = False, collision_dist = 5.99):
+    #     time_interval = self.ego_df['simTime'].tolist()[1] - self.ego_df['simTime'].tolist()[0]
+    #
+    #     for i in range(len(self.obj_id_list)):
+    #         for j in range(T):
+    #             MPrTTC = j * time_interval
+
+    def _cal_DTC(self, v_on_dist, a_on_dist, t):
+        if a_on_dist == 0:
+            return None
+        DTC = v_on_dist * t + v_on_dist ** 2 / a_on_dist
+        return DTC
+
+
     def velocity(self, v_x, v_y):
         v = math.sqrt(v_x ** 2 + v_y ** 2) * 3.6
         return v
@@ -700,7 +777,7 @@ class SafetyCalculator:
     
     def _safe_statistic_most_dangerous(self):
         min_list = ['TTC', 'MTTC', 'THW', 'TLC', 'TTB', 'LonSD', 'LatSD', 'TM']
-        max_list = ['BTN', 'collisionRisk', 'collisionSeverity']
+        max_list = ['DTC', 'BTN', 'collisionRisk', 'collisionSeverity']
         result = {}
         for metric in min_list:
             if metric in self.metric_list:
@@ -732,6 +809,7 @@ class SafetyCalculator:
             'TLC': 10.0,
             'TTB': 10.0,
             'TM': 10.0,
+            'DTC': 10.0,
             'LonSD': 10.0,
             'LatSD': 2.0,
             'BTN': 1.0,
@@ -786,6 +864,12 @@ class SafetyCalculator:
         tm_values = self.df_safe['TM'].dropna()
         return float(tm_values.min()) if not tm_values.empty else self._default_value('TM')
 
+    def get_dtc_value(self) -> float:
+        if self.empty_flag or self.df_safe is None:
+            return self._default_value('DTC')
+        dtc_values = self.df_safe['DTC'].dropna()
+        return float(dtc_values.min()) if not dtc_values.empty else self._default_value('DTC')
+
     def get_lonsd_value(self) -> float:
         if self.empty_flag or self.df_safe is None:
             return self._default_value('LonSD')