Эх сурвалжийг харах

更新function.py,满足预警延迟指标的统计

XGJ_zhaoyuan 1 сар өмнө
parent
commit
3bccbe6fe3
1 өөрчлөгдсөн 71 нэмэгдсэн , 49 устгасан
  1. 71 49
      modules/metric/function.py

+ 71 - 49
modules/metric/function.py

@@ -14,6 +14,7 @@
 
 import sys
 from pathlib import Path
+
 # 添加项目根目录到系统路径
 root_path = Path(__file__).resolve().parent.parent
 sys.path.append(str(root_path))
@@ -25,24 +26,26 @@ from typing import Dict, Tuple, Optional, Callable, Any
 import pandas as pd
 import yaml
 
-
 # ----------------------
 # 基础工具函数 (Pure functions)
 # ----------------------
-scenario_sign_dict = {"LeftTurnAssist": 206, "HazardousLocationW": 207, "RedLightViolationW": 208, "CoorperativeIntersectionPassing": 225, "GreenLightOptimalSpeedAdvisory": 234,
+scenario_sign_dict = {"LeftTurnAssist": 206, "HazardousLocationW": 207, "RedLightViolationW": 208,
+                      "CoorperativeIntersectionPassing": 225, "GreenLightOptimalSpeedAdvisory": 234,
                       "ForwardCollision": 212}
 
+
 def calculate_distance_PGVIL(ego_pos: np.ndarray, obj_pos: np.ndarray) -> np.ndarray:
     """向量化距离计算"""
     return np.linalg.norm(ego_pos - obj_pos, axis=1)
 
 
 def calculate_relative_speed_PGVIL(
-    ego_speed: np.ndarray, obj_speed: np.ndarray
+        ego_speed: np.ndarray, obj_speed: np.ndarray
 ) -> np.ndarray:
     """向量化相对速度计算"""
     return np.linalg.norm(ego_speed - obj_speed, axis=1)
 
+
 def calculate_distance(ego_df: pd.DataFrame, correctwarning: int) -> np.ndarray:
     """向量化距离计算"""
     dist = ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]['relative_dist']
@@ -54,14 +57,13 @@ def calculate_relative_speed(ego_df: pd.DataFrame, correctwarning: int) -> np.nd
     return ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]['composite_v']
 
 
-
 def extract_ego_obj(data: pd.DataFrame) -> Tuple[pd.Series, pd.DataFrame]:
     """数据提取函数"""
     ego = data[data["playerId"] == 1].iloc[0]
     obj = data[data["playerId"] != 1]
     return ego, obj
 
-    
+
 def get_first_warning(data_processed) -> Optional[pd.DataFrame]:
     """带缓存的预警数据获取"""
     ego_df = data_processed.ego_data
@@ -74,7 +76,7 @@ def get_first_warning(data_processed) -> Optional[pd.DataFrame]:
         print("无法获取正确的预警信号标志位!")
         return None
     warning_rows = ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]
-    
+
     warning_times = warning_rows['simTime']
     if warning_times.empty:
         print("没有找到预警数据!")
@@ -108,7 +110,6 @@ def earliestWarningDistance_LST(data) -> dict:
     if warning_dist.empty:
         return {"earliestWarningDistance_LST": 0.0}
 
-
     return {"earliestWarningDistance_LST": float(warning_dist.iloc[0]) if len(warning_dist) > 0 else np.inf}
 
 
@@ -145,31 +146,29 @@ def earliestWarningDistance_TTC_LST(data) -> dict:
 
     return {"earliestWarningDistance_TTC_LST": float(ttc[0]) if len(ttc) > 0 else np.inf}
 
+
 def warningDelayTime_LST(data):
     scenario_name = data.function_config["function"]["scenario"]["name"]
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
     HMI_warning_rows = ego_df[(ego_df['ifwarning'] == correctwarning)]['simTime'].tolist()
     simTime_HMI = HMI_warning_rows[0] if len(HMI_warning_rows) > 0 else None
-    rosbag_warning_rows = ego_df[(ego_df['event_info.eventSource'].notna())]['simTime'].tolist()
+    rosbag_warning_rows = ego_df[(ego_df['event_Type'].notna()) & ((ego_df['event_Type'] != np.nan))][
+        'simTime'].tolist()
     simTime_rosbag = rosbag_warning_rows[0] if len(rosbag_warning_rows) > 0 else None
-    if (simTime_HMI is None) and (simTime_rosbag is None):
-        print("没有发出预警!")
+    if (simTime_HMI is None) or (simTime_rosbag is None):
+        print("预警出错!")
         delay_time = 100.0
-    elif (simTime_HMI is not None) and (simTime_rosbag is not None):
-        delay_time = abs(simTime_HMI - simTime_rosbag)
-    elif (simTime_HMI is not None) and (simTime_rosbag is None):
-        print("没有发出预警!")
-        delay_time = None
     else:
-        delay_time = simTime_rosbag
+        delay_time = abs(simTime_HMI - simTime_rosbag)
     return {"warningDelayTime_LST": delay_time}
 
+
 def warningDelayTimeOf4_LST(data):
     scenario_name = data.function_config["function"]["scenario"]["name"]
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
-    ego_speed_simtime = ego_df[ego_df['accel'] <= -4]['simTime'].tolist() # 单位m/s^2
+    ego_speed_simtime = ego_df[ego_df['accel'] <= -4]['simTime'].tolist()  # 单位m/s^2
     warning_simTime = ego_df[ego_df['ifwarning'] == correctwarning]['simTime'].tolist()
     if (len(warning_simTime) == 0) and (len(ego_speed_simtime) == 0):
         return {"warningDelayTimeOf4_LST": 0}
@@ -180,11 +179,12 @@ def warningDelayTimeOf4_LST(data):
     else:
         return {"warningDelayTimeOf4_LST": warning_simTime[0] - ego_speed_simtime[0]}
 
+
 def rightWarningSignal_LST(data):
     scenario_name = data.function_config["function"]["scenario"]["name"]
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
-    if correctwarning.empty:
+    if ego_df['ifwarning'].empty:
         print("无法获取正确预警信号标志位!")
         return
     warning_rows = ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]
@@ -193,26 +193,32 @@ def rightWarningSignal_LST(data):
     else:
         return {"rightWarningSignal_LST": 1}
 
+
 def ifCrossingRedLight_LST(data):
     scenario_name = data.function_config["function"]["scenario"]["name"]
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
-    redlight_simtime = ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['stateMask'] == 1) & (ego_df['relative_dist'] == 0) & (ego_df['v'] != 0)]['simTime']
+    redlight_simtime = ego_df[
+        (ego_df['ifwarning'] == correctwarning) & (ego_df['stateMask'] == 1) & (ego_df['relative_dist'] == 0) & (
+                    ego_df['v'] != 0)]['simTime']
     if redlight_simtime.empty:
-        return {"ifCrossingRedLight_LST": 0}
+        return {"ifCrossingRedLight_LST": -1}
     else:
         return {"ifCrossingRedLight_LST": 1}
 
+
 def ifStopgreenWaveSpeedGuidance_LST(data):
     scenario_name = data.function_config["function"]["scenario"]["name"]
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
-    greenlight_simtime = ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['stateMask'] == 0) & (ego_df['v'] == 0)]['simTime']
+    greenlight_simtime = \
+    ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['stateMask'] == 0) & (ego_df['v'] == 0)]['simTime']
     if greenlight_simtime.empty:
-        return {"ifStopgreenWaveSpeedGuidance_LST": 0}
+        return {"ifStopgreenWaveSpeedGuidance_LST": -1}
     else:
         return {"ifStopgreenWaveSpeedGuidance_LST": 1}
 
+
 def rightWarningSignal_PGVIL(data_processed) -> dict:
     """判断是否发出正确预警信号"""
 
@@ -224,7 +230,9 @@ def rightWarningSignal_PGVIL(data_processed) -> dict:
         print("无法获取正确的预警信号标志位!")
         return None
     # 找出本行 correctwarning 和 ifwarning 相等,且 correctwarning 不是 NaN 的行
-    warning_rows = ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]
+    warning_rows = ego_df[
+        (ego_df["ifwarning"] == correctwarning) & (ego_df["ifwarning"].notna())
+        ]
 
     if warning_rows.empty:
         return {"rightWarningSignal_PGVIL": -1}
@@ -278,6 +286,7 @@ def latestWarningDistance_TTC_PGVIL(data_processed) -> dict:
 
     return {"latestWarningDistance_TTC_PGVIL": float(np.nanmin(ttc))}
 
+
 def earliestWarningDistance_PGVIL(data_processed) -> dict:
     """预警距离计算流水线"""
     ego_df = data_processed.ego_data
@@ -297,6 +306,7 @@ def earliestWarningDistance_PGVIL(data_processed) -> dict:
 
     return {"earliestWarningDistance": float(np.min(distances))}
 
+
 def earliestWarningDistance_TTC_PGVIL(data_processed) -> dict:
     """TTC计算流水线"""
     ego_df = data_processed.ego_data
@@ -325,7 +335,7 @@ def earliestWarningDistance_TTC_PGVIL(data_processed) -> dict:
 
     return {"earliestWarningDistance_TTC_PGVIL": float(np.nanmin(ttc))}
 
-    
+
 # def delayOfEmergencyBrakeWarning(data_processed) -> dict:
 #     #预警时机相对背景车辆减速度达到-4m/s2后的时延
 #     ego_df = data_processed.ego_data
@@ -353,24 +363,30 @@ def earliestWarningDistance_TTC_PGVIL(data_processed) -> dict:
 #             return {"delayOfEmergencyBrakeWarning": float(delay_time)}
 
 #         else:
-#             print("没有达到预警减速度阈值:-4m/s^2")            
+#             print("没有达到预警减速度阈值:-4m/s^2")
 #             return {"delayOfEmergencyBrakeWarning": -1}
 
+
 def warningDelayTime_PGVIL(data_processed) -> dict:
-    """车端接收到预警到HMI发出预警的时延"""    
+    """车端接收到预警到HMI发出预警的时延"""
     ego_df = data_processed.ego_data
+    # #打印ego_df的列名
+    # print(ego_df.columns.tolist())
+
     warning_data = get_first_warning(data_processed)
 
     if warning_data is None:
         return {"warningDelayTime_PGVIL": -1}
     try:
         ego, obj = extract_ego_obj(warning_data)
-        rosbag_warning_rows = ego_df[(ego_df['event_Type'].notna())]
-        
+
+        # 找到event_Type不为空,且playerID为1的行
+        rosbag_warning_rows = ego_df[(ego_df["event_Type"].notna())]
+
         first_time = rosbag_warning_rows["simTime"].iloc[0]
         warning_time = warning_data[warning_data["playerId"] == 1]["simTime"].iloc[0]
         delay_time = warning_time - first_time
-        
+
         return {"warningDelayTime_PGVIL": float(delay_time)}
 
     except Exception as e:
@@ -383,7 +399,7 @@ def get_car_to_stop_line_distance(ego, car_point, stop_line_points):
     计算主车后轴中心点到停止线的距离
     :return 距离
     """
-    distance_carpoint_carhead = ego['dimX']/2 + ego['offX']
+    distance_carpoint_carhead = ego["dimX"] / 2 + ego["offX"]
     # 计算停止线的方向向量
     line_vector = np.array(
         [
@@ -392,27 +408,32 @@ def get_car_to_stop_line_distance(ego, car_point, stop_line_points):
         ]
     )
     direction_vector_norm = np.linalg.norm(line_vector)
-    direction_vector_unit = line_vector / direction_vector_norm if direction_vector_norm != 0 else np.array([0, 0])
+    direction_vector_unit = (
+        line_vector / direction_vector_norm
+        if direction_vector_norm != 0
+        else np.array([0, 0])
+    )
     # 计算主车后轴中心点到停止线投影的坐标(垂足)
     projection_length = np.dot(car_point - stop_line_points[0], direction_vector_unit)
     perpendicular_foot = stop_line_points[0] + projection_length * direction_vector_unit
-    
+
     # 计算主车后轴中心点到垂足的距离
     distance_to_foot = np.linalg.norm(car_point - perpendicular_foot)
     carhead_distance_to_foot = distance_to_foot - distance_carpoint_carhead
-    
+
     return carhead_distance_to_foot
 
+
 def ifCrossingRedLight_PGVIL(data_processed) -> dict:
-    #判断车辆是否闯红灯
+    # 判断车辆是否闯红灯
 
-    stop_line_points = np.array([(276.555,-35.575),(279.751,-33.683)])
+    stop_line_points = np.array([(276.555, -35.575), (279.751, -33.683)])
     X_OFFSET = 258109.4239876
     Y_OFFSET = 4149969.964821
     stop_line_points += np.array([[X_OFFSET, Y_OFFSET]])
     ego_df = data_processed.ego_data
-    
-    prev_distance = float('inf')  # 初始化为正无穷大
+
+    prev_distance = float("inf")  # 初始化为正无穷大
     """
     traffic_light_status
     0x100000为绿灯,1048576
@@ -420,23 +441,25 @@ def ifCrossingRedLight_PGVIL(data_processed) -> dict:
     0x10000000为红灯,268435456
     """
     red_light_violation = False
-    for index ,ego in ego_df.iterrows():
+    for index, ego in ego_df.iterrows():
         car_point = (ego["posX"], ego["posY"])
         stateMask = ego["stateMask"]
         simTime = ego["simTime"]
-        distance_to_stopline = get_car_to_stop_line_distance(ego, car_point, stop_line_points)
+        distance_to_stopline = get_car_to_stop_line_distance(
+            ego, car_point, stop_line_points
+        )
 
-        #主车车头跨越停止线时非绿灯,返回-1,闯红灯
+        # 主车车头跨越停止线时非绿灯,返回-1,闯红灯
         if prev_distance > 0 and distance_to_stopline < 0:
-            if stateMask != 1048576:
+            if stateMask is not None and stateMask != 1048576:
                 red_light_violation = True
             break
         prev_distance = distance_to_stopline
 
     if red_light_violation:
-        return {"ifCrossingRedLight_PGVIL": -1}#闯红灯
+        return {"ifCrossingRedLight_PGVIL": -1}  # 闯红灯
     else:
-        return {"ifCrossingRedLight_PGVIL": 1}#没有闯红灯
+        return {"ifCrossingRedLight_PGVIL": 1}  # 没有闯红灯
 
 
 # def ifStopgreenWaveSpeedGuidance(data_processed) -> dict:
@@ -459,7 +482,7 @@ def ifCrossingRedLight_PGVIL(data_processed) -> dict:
 #     ]["simTime"]
 #     if stop_giveway_simtime.empty:
 #         print("没有停车让行标志/标线")
-        
+
 #     ego_data = stop_giveway_data[stop_giveway_data['playerId'] == 1]
 #     distance_carpoint_carhead = ego_data['dimX'].iloc[0]/2 + ego_data['offX'].iloc[0]
 #     distance_to_stoplines = []
@@ -471,11 +494,10 @@ def ifCrossingRedLight_PGVIL(data_processed) -> dict:
 #         ]
 #         distance_to_stopline = get_car_to_stop_line_distance(ego_pos, stop_line_points)
 #         distance_to_stoplines.append(distance_to_stopline)
-    
+
 #     mindisStopline = np.min(distance_to_stoplines) - distance_carpoint_carhead
 #     return {"mindisStopline": mindisStopline}
-    
-   
+
 
 class FunctionRegistry:
     """动态函数注册器(支持参数验证)"""
@@ -495,7 +517,7 @@ class FunctionRegistry:
         def _recurse(node):
             if isinstance(node, dict):
                 if "name" in node and not any(
-                    isinstance(v, dict) for v in node.values()
+                        isinstance(v, dict) for v in node.values()
                 ):
                     metrics.append(node["name"])
                 for v in node.values():
@@ -555,4 +577,4 @@ class FunctionManager:
 # 使用示例
 if __name__ == "__main__":
     pass
-    # print("\n[功能类表现及得分情况]")
+    # print("\n[功能类表现及得分情况]")