Преглед изворни кода

增加实车功能性指标画图功能

XGJ_zhaoyuan пре 14 часа
родитељ
комит
c2bd6a4b35

+ 34 - 34
config/all_metrics_config.yaml

@@ -139,39 +139,39 @@ comfort:
       priority: 0
       max: 0
       min: 0
-  # comforDynamic:
-  #   name: comforDynamic
-  #   priority: 0
-  #   rideQualityScore:
-  #     name: rideQualityScore
-  #     priority: 0
-  #     max: 100
-  #     min: 80
-  #   motionSickness:
-  #     name: motionSickness
-  #     priority: 0
-  #     max: 30.0
-  #     min: 0.0
-  #   motionComfortIndex:
-  #     name: motionComfortIndex
-  #     priority: 0
-  #     max: 10.0
-  #     min: 8.0
-  #   vdv:
-  #     name: vdv
-  #     priority: 0
-  #     max: 8.0
-  #     min: 0
-  #   ava_vav:
-  #     name: ava_vav
-  #     priority: 0
-  #     max: 0.63
-  #     min: 0.0
-  #   msdv:
-  #     name: msdv
-  #     priority: 0
-  #     max: 6.0
-  #     min: 0.0
+  comforDynamic:
+    name: comforDynamic
+    priority: 0
+    rideQualityScore:
+      name: rideQualityScore
+      priority: 0
+      max: 0
+      min: 0
+    motionSickness:
+      name: motionSickness
+      priority: 0
+      max: 0.0
+      min: 0.0
+    motionComfortIndex:
+      name: motionComfortIndex
+      priority: 0
+      max: 0.0
+      min: 0.0
+    vdv:
+      name: vdv
+      priority: 0
+      max: 0
+      min: 0
+    ava_vav:
+      name: ava_vav
+      priority: 0
+      max: 0
+      min: 0.0
+    msdv:
+      name: msdv
+      priority: 0
+      max: 0.0
+      min: 0.0
 
 efficient:
   name: efficient
@@ -216,7 +216,7 @@ efficient:
 function:
   name: function
   priority: 0
-  scenario:
+  ForwardCollision:
     name: ForwardCollision
     priority: 0
     latestWarningDistance_TTC_LST:

+ 2 - 2
config/custom_metrics_config.yaml

@@ -26,8 +26,8 @@ comfort:
   comfortLat:
     name: comfortLat
     priority: 0
-    Weaving:
-      name: Weaving
+    zigzag:
+      name: zigzag
       priority: 0
       max: 0
       min: 0

Разлика између датотеке није приказан због своје велике величине
+ 875 - 85
modules/lib/chart_generator.py


+ 84 - 55
modules/lib/data_process.py

@@ -19,10 +19,9 @@ import pandas as pd
 
 import yaml
 
-
-
 from modules.lib.log_manager import LogManager
 
+
 # from lib import log  # 确保这个路径是正确的,或者调整它
 # logger = None  # 初始化为 None
 
@@ -31,7 +30,7 @@ class DataPreprocessing:
     def __init__(self, data_path, config_path):
         # Initialize paths and data containers
         # self.logger = log.get_logger()
-        
+
         self.data_path = data_path
         self.case_name = os.path.basename(os.path.normpath(data_path))
 
@@ -79,30 +78,30 @@ class DataPreprocessing:
             full_config = yaml.safe_load(f)
 
         modules = ["vehicle", "T_threshold", "safety", "comfort", "efficient", "function", "traffic"]
-        
+
         # 1. 初始化 vehicle_config(不涉及 T_threshold 合并)
         self.vehicle_config = full_config[modules[0]]
-        
+
         # 2. 定义 T_threshold_config(封装为字典)
         T_threshold_config = {"T_threshold": full_config[modules[1]]}
-        
+
         # 3. 统一处理需要合并 T_threshold 的模块
         # 3.1 safety_config
         self.safety_config = {"safety": full_config[modules[2]]}
         self.safety_config.update(T_threshold_config)
-        
+
         # 3.2 comfort_config
         self.comfort_config = {"comfort": full_config[modules[3]]}
         self.comfort_config.update(T_threshold_config)
-        
+
         # 3.3 efficient_config
         self.efficient_config = {"efficient": full_config[modules[4]]}
         self.efficient_config.update(T_threshold_config)
-        
+
         # 3.4 function_config
         self.function_config = {"function": full_config[modules[5]]}
         self.function_config.update(T_threshold_config)
-        
+
         # 3.5 traffic_config
         self.traffic_config = {"traffic": full_config[modules[6]]}
         self.traffic_config.update(T_threshold_config)
@@ -110,26 +109,26 @@ class DataPreprocessing:
     @staticmethod
     def cal_velocity(lat_v, lon_v):
         """Calculate resultant velocity from lateral and longitudinal components."""
-        return np.sqrt(lat_v**2 + lon_v**2)
+        return np.sqrt(lat_v ** 2 + lon_v ** 2)
 
     def _real_process_object_df(self):
         """Process the object DataFrame."""
         try:
             # 读取 CSV 文件
             merged_csv_path = os.path.join(self.data_path, "merged_ObjState.csv")
-            
+
             # 检查文件是否存在
             if not os.path.exists(merged_csv_path):
                 logger = LogManager().get_logger()
                 logger.error(f"文件不存在: {merged_csv_path}")
                 raise FileNotFoundError(f"文件不存在: {merged_csv_path}")
-                
+
             self.object_df = pd.read_csv(
                 merged_csv_path,
                 dtype={"simTime": float},
                 engine="python",
                 on_bad_lines="skip",  # 自动跳过异常行
-                na_values=["","NA","null","NaN"]  # 明确处理缺失值
+                na_values=["", "NA", "null", "NaN"]  # 明确处理缺失值
             ).drop_duplicates(subset=["simTime", "simFrame", "playerId"])
             self.object_df.columns = [col.replace("+AF8-", "_") for col in self.object_df.columns]
 
@@ -139,14 +138,13 @@ class DataPreprocessing:
             data["lat_v"] = data["speedY"] * 1
             data["lon_v"] = data["speedX"] * 1
             # 使用向量化操作代替 apply
-            data["v"] = np.sqrt(data["lat_v"]**2 + data["lon_v"]**2)
+            data["v"] = np.sqrt(data["lat_v"] ** 2 + data["lon_v"] ** 2)
 
             # 计算加速度分量
             data["lat_acc"] = data["accelY"] * 1
             data["lon_acc"] = data["accelX"] * 1
             # 使用向量化操作代替 apply
-            data["accel"] = np.sqrt(data["lat_acc"]**2 + data["lon_acc"]**2)
-
+            data["accel"] = np.sqrt(data["lat_acc"] ** 2 + data["lon_acc"] ** 2)
 
             # Drop rows with missing 'type' and reset index
             data = data.dropna(subset=["type"])
@@ -173,12 +171,12 @@ class DataPreprocessing:
     def _calculate_object_parameters(self, obj_data):
         """Calculate additional parameters for a single object."""
         obj_data = obj_data.copy()
-        
+
         # 确保数据按时间排序
         obj_data = obj_data.sort_values(by="simTime").reset_index(drop=True)
-        
+
         obj_data["time_diff"] = obj_data["simTime"].diff()
-        
+
         # 处理可能的零时间差
         zero_time_diff = obj_data["time_diff"] == 0
         if zero_time_diff.any():
@@ -193,13 +191,13 @@ class DataPreprocessing:
         obj_data["yawrate_diff"] = obj_data["speedH"].diff()
 
         obj_data["lat_acc_roc"] = (
-            obj_data["lat_acc_diff"] / obj_data["time_diff"]
+                obj_data["lat_acc_diff"] / obj_data["time_diff"]
         ).replace([np.inf, -np.inf], [9999, -9999])
         obj_data["lon_acc_roc"] = (
-            obj_data["lon_acc_diff"] / obj_data["time_diff"]
+                obj_data["lon_acc_diff"] / obj_data["time_diff"]
         ).replace([np.inf, -np.inf], [9999, -9999])
         obj_data["accelH"] = (
-            obj_data["yawrate_diff"] / obj_data["time_diff"]
+                obj_data["yawrate_diff"] / obj_data["time_diff"]
         ).replace([np.inf, -np.inf], [9999, -9999])
 
         return obj_data
@@ -241,7 +239,7 @@ class DataPreprocessing:
         """Calculate mileage based on the driving data."""
         if len(df) < 2:
             return 0.0  # 数据不足,无法计算里程
-            
+
         if df["travelDist"].nunique() == 1:
             # 创建临时DataFrame进行计算,避免修改原始数据
             temp_df = df.copy()
@@ -262,38 +260,69 @@ class DataPreprocessing:
         return df["simTime"].iloc[-1] - df["simTime"].iloc[0]
 
     def process_ego_data(self, ego_data):
-        """处理自车数据,包括坐标系转换等"""
+        """处理自车数据:将东北天(ENU)坐标系下的速度/加速度转换为车辆坐标系(考虑yaw, pitch, roll)"""
+        '''
+        原字段	新字段名	描述
+        a_x_body	lon_acc_vehicle	车辆坐标系下的纵向加速度
+        a_y_body	lat_acc_vehicle	车辆坐标系下的横向加速度
+        a_z_body	acc_z_vehicle	车辆坐标系下的垂向加速度
+        v_x_body	lon_v_vehicle	车辆坐标系下的纵向速度
+        v_y_body	lat_v_vehicle	车辆坐标系下的横向速度
+        v_z_body	vel_z_vehicle	车辆坐标系下的垂向速度
+        posH	heading_rad	航向角(弧度)
+        pitch_rad	pitch_rad	俯仰角(弧度)
+        roll_rad	roll_rad	横滚角(弧度)
+        '''
+        logger = LogManager().get_logger()
+
         if ego_data is None or len(ego_data) == 0:
-            logger = LogManager().get_logger()
             logger.warning("自车数据为空,无法进行坐标系转换")
             return ego_data
-            
-        # 创建副本避免修改原始数据
+
         ego_data = ego_data.copy()
-        
-        # 添加坐标系转换:将东北天坐标系下的加速度和速度转换为车辆坐标系下的值
-        # 使用车辆航向角进行转换
-        # 注意:与safety.py保持一致,使用(90 - heading)作为与x轴的夹角
-        ego_data['heading_rad'] = np.deg2rad(90 - ego_data['posH'])  # 转换为与x轴的夹角
-        
-        # 使用向量化操作计算车辆坐标系下的纵向和横向加速度
-        ego_data['lon_acc_vehicle'] = ego_data['accelX'] * np.cos(ego_data['heading_rad']) + \
-                                     ego_data['accelY'] * np.sin(ego_data['heading_rad'])
-        ego_data['lat_acc_vehicle'] = -ego_data['accelX'] * np.sin(ego_data['heading_rad']) + \
-                                     ego_data['accelY'] * np.cos(ego_data['heading_rad'])
-         
-        # 使用向量化操作计算车辆坐标系下的纵向和横向速度
-        ego_data['lon_v_vehicle'] = ego_data['speedX'] * np.cos(ego_data['heading_rad']) + \
-                                   ego_data['speedY'] * np.sin(ego_data['heading_rad'])
-        ego_data['lat_v_vehicle'] = -ego_data['speedX'] * np.sin(ego_data['heading_rad']) + \
-                                   ego_data['speedY'] * np.cos(ego_data['heading_rad'])
-        
-        # 将原始的东北天坐标系加速度和速度保留,但在其他模块中可以直接使用车辆坐标系的值
-        ego_data['lon_acc'] = ego_data['lon_acc_vehicle']
-        ego_data['lat_acc'] = ego_data['lat_acc_vehicle']
-        
-        # 记录日志
-        logger = LogManager().get_logger()
-        logger.info("已将加速度和速度转换为车辆坐标系")
-        
-        return ego_data
+        for col in ['speedZ', 'accelZ']:
+            if col not in ego_data.columns:
+                ego_data[col] = 0.0
+                logger.warning(f"自车数据中缺少列 '{col}',已将其填充为 0.0")
+
+        # 角度转为弧度(修正 posH 表示正北为 0° => 车辆朝正东为 0°)
+        ego_data['yaw_rad'] = np.deg2rad(90 - ego_data['posH'])
+        ego_data['pitch_rad'] = np.deg2rad(ego_data.get('pitch', 0))
+        ego_data['roll_rad'] = np.deg2rad(ego_data.get('roll', 0))
+
+        # 预计算三角函数(向量化)
+        cy = np.cos(ego_data['yaw_rad'])
+        sy = np.sin(ego_data['yaw_rad'])
+        cp = np.cos(ego_data['pitch_rad'])
+        sp = np.sin(ego_data['pitch_rad'])
+        cr = np.cos(ego_data['roll_rad'])
+        sr = np.sin(ego_data['roll_rad'])
+
+        # === 加速度(ENU → 车辆坐标系) ===
+        ego_data['lon_acc_vehicle'] = (ego_data['accelX'] * (cy * cp) +
+                                       ego_data['accelY'] * (cy * sp * sr - sy * cr) +
+                                       ego_data['accelZ'] * (cy * sp * cr + sy * sr))
+
+        ego_data['lat_acc_vehicle'] = (ego_data['accelX'] * (sy * cp) +
+                                       ego_data['accelY'] * (sy * sp * sr + cy * cr) +
+                                       ego_data['accelZ'] * (sy * sp * cr - cy * sr))
+
+        ego_data['acc_z_vehicle'] = (ego_data['accelX'] * (-sp) +
+                                     ego_data['accelY'] * (cp * sr) +
+                                     ego_data['accelZ'] * (cp * cr))
+
+        # === 速度(ENU → 车辆坐标系) ===
+        ego_data['lon_v_vehicle'] = (ego_data['speedX'] * (cy * cp) +
+                                     ego_data['speedY'] * (cy * sp * sr - sy * cr) +
+                                     ego_data['speedZ'] * (cy * sp * cr + sy * sr))
+
+        ego_data['lat_v_vehicle'] = (ego_data['speedX'] * (sy * cp) +
+                                     ego_data['speedY'] * (sy * sp * sr + cy * cr) +
+                                     ego_data['speedZ'] * (sy * sp * cr - cy * sr))
+
+        ego_data['vel_z_vehicle'] = (ego_data['speedX'] * (-sp) +
+                                     ego_data['speedY'] * (cp * sr) +
+                                     ego_data['speedZ'] * (cp * cr))
+
+        logger.info("完成车辆坐标系转换(考虑yaw/pitch/roll)")
+        return ego_data

Разлика између датотеке није приказан због своје велике величине
+ 1264 - 579
modules/metric/comfort.py


+ 306 - 9
modules/metric/function.py

@@ -18,6 +18,7 @@ from pathlib import Path
 # 添加项目根目录到系统路径
 root_path = Path(__file__).resolve().parent.parent
 sys.path.append(str(root_path))
+print(root_path)
 
 from modules.lib.score import Score
 from modules.lib.log_manager import LogManager
@@ -25,6 +26,9 @@ import numpy as np
 from typing import Dict, Tuple, Optional, Callable, Any
 import pandas as pd
 import yaml
+from modules.lib.chart_generator import generate_function_chart_data
+from shapely.geometry import Point, Polygon
+from modules.lib.common import get_interpolation
 
 # ----------------------
 # 基础工具函数 (Pure functions)
@@ -33,6 +37,47 @@ scenario_sign_dict = {"LeftTurnAssist": 206, "HazardousLocationW": 207, "RedLigh
                       "CoorperativeIntersectionPassing": 225, "GreenLightOptimalSpeedAdvisory": 234,
                       "ForwardCollision": 212}
 
+
+def _is_pedestrian_in_crosswalk(polygon, test_point) -> bool:
+    polygon = Polygon(polygon)
+    point = Point(test_point)
+    return polygon.contains(point)
+
+
+def _is_segment_by_interval(time_list, expected_step) -> list:
+    """
+    根据时间戳之间的间隔进行分段。
+
+    参数:
+    time_list (list): 时间戳列表。
+    expected_step (float): 预期的固定步长。
+
+    返回:
+    list: 分段后的时间戳列表,每个元素是一个子列表。
+    """
+    if not time_list:
+        return []
+
+    segments = []
+    current_segment = [time_list[0]]
+
+    for i in range(1, len(time_list)):
+        actual_step = time_list[i] - time_list[i - 1]
+        if actual_step != expected_step:
+            # 如果间隔不符合预期,则开始一个新的段
+            segments.append(current_segment)
+            current_segment = [time_list[i]]
+        else:
+            # 否则,将当前时间戳添加到当前段中
+            current_segment.append(time_list[i])
+
+    # 添加最后一个段
+    if current_segment:
+        segments.append(current_segment)
+
+    return segments
+
+
 # 寻找二级指标的名称
 def find_nested_name(data):
     """
@@ -56,6 +101,7 @@ def find_nested_name(data):
                 return result
     return None
 
+
 def calculate_distance_PGVIL(ego_pos: np.ndarray, obj_pos: np.ndarray) -> np.ndarray:
     """向量化距离计算"""
     return np.linalg.norm(ego_pos - obj_pos, axis=1)
@@ -118,9 +164,19 @@ def latestWarningDistance_LST(data) -> dict:
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
     warning_dist = calculate_distance(ego_df, correctwarning)
+    warning_speed = calculate_relative_speed(ego_df, correctwarning)
+
+    # 将计算结果保存到data对象中,供图表生成使用
+    data.warning_dist = warning_dist
+    data.warning_speed = warning_speed
+    data.correctwarning = correctwarning
+
     if warning_dist.empty:
         return {"latestWarningDistance_LST": 0.0}
 
+    # 生成图表数据
+    generate_function_chart_data(data, 'latestWarningDistance_LST')
+
     return {"latestWarningDistance_LST": float(warning_dist.iloc[-1]) if len(warning_dist) > 0 else value}
 
 
@@ -131,9 +187,20 @@ def earliestWarningDistance_LST(data) -> dict:
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
     warning_dist = calculate_distance(ego_df, correctwarning)
+    warning_speed = calculate_relative_speed(ego_df, correctwarning)
+
+    # 将计算结果保存到data对象中,供图表生成使用
+    data.warning_dist = warning_dist
+    data.warning_speed = warning_speed
+    data.correctwarning = correctwarning
+
     if warning_dist.empty:
         return {"earliestWarningDistance_LST": 0.0}
 
+    # 生成图表数据
+
+    generate_function_chart_data(data, 'earliestWarningDistance_LST')
+
     return {"earliestWarningDistance_LST": float(warning_dist.iloc[0]) if len(warning_dist) > 0 else value}
 
 
@@ -147,19 +214,25 @@ def latestWarningDistance_TTC_LST(data) -> dict:
     if warning_dist.empty:
         return {"latestWarningDistance_TTC_LST": 0.0}
 
+    # 将correctwarning保存到data对象中,供图表生成使用
+    data.correctwarning = correctwarning
+
     warning_speed = calculate_relative_speed(ego_df, correctwarning)
 
     with np.errstate(divide='ignore', invalid='ignore'):
         ttc = np.where(warning_speed != 0, warning_dist / warning_speed, np.inf)
-    
+
     # 处理无效的TTC值
     for i in range(len(ttc)):
         ttc[i] = float(value) if (not ttc[i] or ttc[i] < 0) else ttc[i]
-    
+
+    data.warning_dist = warning_dist
+    data.warning_speed = warning_speed
+    data.ttc = ttc
     # 生成图表数据
-    from modules.lib.chart_generator import generate_function_chart_data
+    # from modules.lib.chart_generator import generate_function_chart_data
     generate_function_chart_data(data, 'latestWarningDistance_TTC_LST')
-        
+
     return {"latestWarningDistance_TTC_LST": float(ttc[-1]) if len(ttc) > 0 else value}
 
 
@@ -173,21 +246,35 @@ def earliestWarningDistance_TTC_LST(data) -> dict:
     if warning_dist.empty:
         return {"earliestWarningDistance_TTC_LST": 0.0}
 
+    # 将correctwarning保存到data对象中,供图表生成使用
+    data.correctwarning = correctwarning
+
     warning_speed = calculate_relative_speed(ego_df, correctwarning)
 
     with np.errstate(divide='ignore', invalid='ignore'):
         ttc = np.where(warning_speed != 0, warning_dist / warning_speed, np.inf)
-    
+
     # 处理无效的TTC值
     for i in range(len(ttc)):
         ttc[i] = float(value) if (not ttc[i] or ttc[i] < 0) else ttc[i]
-        
+
+    # 将计算结果保存到data对象中,供图表生成使用
+    data.warning_dist = warning_dist
+    data.warning_speed = warning_speed
+    data.ttc = ttc
+    data.correctwarning = correctwarning
+
+    # 生成图表数据
+    generate_function_chart_data(data, 'earliestWarningDistance_TTC_LST')
+
     return {"earliestWarningDistance_TTC_LST": float(ttc[0]) if len(ttc) > 0 else value}
 
 
 def warningDelayTime_LST(data):
     scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
+    # 将correctwarning保存到data对象中,供图表生成使用
+    data.correctwarning = correctwarning
     ego_df = data.ego_data
     HMI_warning_rows = ego_df[(ego_df['ifwarning'] == correctwarning)]['simTime'].tolist()
     simTime_HMI = HMI_warning_rows[0] if len(HMI_warning_rows) > 0 else None
@@ -205,6 +292,8 @@ def warningDelayTime_LST(data):
 def warningDelayTimeofReachDecel_LST(data):
     scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
+    # 将correctwarning保存到data对象中,供图表生成使用
+    data.correctwarning = correctwarning
     ego_df = data.ego_data
     ego_speed_simtime = ego_df[ego_df['accel'] <= -4]['simTime'].tolist()  # 单位m/s^2
     warning_simTime = ego_df[ego_df['ifwarning'] == correctwarning]['simTime'].tolist()
@@ -221,6 +310,8 @@ def warningDelayTimeofReachDecel_LST(data):
 def rightWarningSignal_LST(data):
     scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
+    # 将correctwarning保存到data对象中,供图表生成使用
+    data.correctwarning = correctwarning
     ego_df = data.ego_data
     if ego_df['ifwarning'].empty:
         print("无法获取正确预警信号标志位!")
@@ -235,10 +326,12 @@ def rightWarningSignal_LST(data):
 def ifCrossingRedLight_LST(data):
     scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
+    # 将correctwarning保存到data对象中,供图表生成使用
+    data.correctwarning = correctwarning
     ego_df = data.ego_data
     redlight_simtime = ego_df[
         (ego_df['ifwarning'] == correctwarning) & (ego_df['stateMask'] == 1) & (ego_df['relative_dist'] == 0) & (
-                    ego_df['v'] != 0)]['simTime']
+                ego_df['v'] != 0)]['simTime']
     if redlight_simtime.empty:
         return {"ifCrossingRedLight_LST": -1}
     else:
@@ -248,15 +341,220 @@ def ifCrossingRedLight_LST(data):
 def ifStopgreenWaveSpeedGuidance_LST(data):
     scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
+    # 将correctwarning保存到data对象中,供图表生成使用
+    data.correctwarning = correctwarning
     ego_df = data.ego_data
     greenlight_simtime = \
-    ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['stateMask'] == 0) & (ego_df['v'] == 0)]['simTime']
+        ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['stateMask'] == 0) & (ego_df['v'] == 0)]['simTime']
     if greenlight_simtime.empty:
         return {"ifStopgreenWaveSpeedGuidance_LST": -1}
     else:
         return {"ifStopgreenWaveSpeedGuidance_LST": 1}
 
 
+# ------ 单车智能指标 ------
+def limitSpeed_LST(data):
+    ego_df = data.ego_data
+
+    speed_limit = ego_df[ego_df['x_relative_dist'] == 0]['v'].tolist()
+    if len(speed_limit) == 0:
+        return {"speedLimit_LST": -1}
+    max_speed = max(speed_limit)
+    generate_function_chart_data(data, 'limitspeed_LST')
+    return {"speedLimit_LST": max_speed}
+
+
+def limitSpeedPastLimitSign_LST(data):
+    ego_df = data.ego_data
+    car_length = data.function_config["vehicle"]['CAR_LENGTH']
+    ego_speed = ego_df[ego_df['x_relative_dist'] == -100 - car_length]['v'].tolist()
+    if len(ego_speed) == 0:
+        return {"speedPastLimitSign_LST": -1}
+    generate_function_chart_data(data, 'limitSpeedPastLimitSign_LST')
+    return {"speedPastLimitSign_LST": ego_speed[0]}
+
+
+def leastDistance_LST(data):
+    ego_df = data.ego_data
+    dist_row = ego_df[ego_df['v'] == 0]['relative_dist'].tolist()
+    if len(dist_row) == 0:
+        return {"minimumDistance_LST": -1}
+    else:
+        min_dist = min(dist_row)
+        return {"minimumDistance_LST": min_dist}
+
+
+def launchTimeinStopLine_LST(data):
+    ego_df = data.ego_data
+    simtime_row = ego_df[ego_df['v'] == 0]['simTime'].tolist()
+    if len(simtime_row) == 0:
+        return {"launchTimeinStopLine_LST": -1}
+    else:
+        delta_t = simtime_row[-1] - simtime_row[0]
+        return {"launchTimeinStopLine_LST": delta_t}
+
+
+def launchTimewhenFollowingCar_LST(data):
+    ego_df = data.ego_data
+    t_interval = ego_df['simTime'].tolist()[1] - ego_df['simTime'].tolist()[0]
+    simtime_row = ego_df[ego_df['v'] == 0]['simTime'].tolist()
+    if len(simtime_row) == 0:
+        return {"launchTimewhenFollowingCar_LST": 0}
+    else:
+        time_interval = _is_segment_by_interval(simtime_row, t_interval)
+        delta_t = []
+        for t in time_interval:
+            delta_t.append(t[-1] - t[0])
+        return {"launchTimewhenFollowingCar_LST": max(delta_t)}
+
+
+def noStop_LST(data):
+    ego_df = data.ego_data
+    speed = ego_df['v'].tolist()
+    if (speed.any() == 0):
+        return {"noStop_LST": -1}
+    else:
+        return {"noStop_LST": 1}
+
+
+def launchTimeinTrafficLight_LST(data):
+    '''
+    待修改:
+    红灯的状态值:1
+    绿灯的状态值:0
+    '''
+    ego_df = data.ego_data
+    simtime_of_redlight = ego_df[ego_df['stateMask'] == 0]['simTime']
+    simtime_of_stop = ego_df[ego_df['v'] == 0]['simTime']
+    if simtime_of_stop.empty() or simtime_of_redlight.empty():
+        return {"timeInterval_LST": -1}
+    simtime_of_launch = simtime_of_redlight[simtime_of_redlight.isin(simtime_of_stop)].tolist()
+    if len(simtime_of_launch) == 0:
+        return {"timeInterval_LST": -1}
+    return {"timeInterval_LST": simtime_of_launch[-1] - simtime_of_launch[0]}
+
+
+def crossJunctionToTargetLane_LST(data):
+    ego_df = data.ego_data
+    lane_in_leftturn = set(ego_df['lane_id'].tolist())
+    target_lane_id = data.function_config["function"]["scenario"]["crossJunctionToTargetLane_LST"]['max']
+    if target_lane_id not in lane_in_leftturn:
+        return {"crossJunctionToTargetLane_LST": -1}
+    else:
+        return {"crossJunctionToTargetLane_LST": target_lane_id}
+
+
+def keepInLane_LST(data):
+    ego_df = data.ego_data
+    target_road_type = data.function_config["function"]["scenario"]["keepInLane_LST"]['max']
+    data_in_tunnel = ego_df[ego_df['road_type'] == target_road_type]
+    if data_in_tunnel.empty:
+        return {"keepInLane_LST": -1}
+    else:
+        tunnel_lane = data_in_tunnel['lane_id'].tolist()
+        if len(set(tunnel_lane)) >= 2:
+            return {"keepInLane_LST": -1}
+        else:
+            return {"keepInLane_LST": target_road_type}
+
+
+def leastLateralDistance_LST(data):
+    ego_df = data.ego_data
+    lane_width = ego_df[ego_df['x_relative_dist'] == 0]['lane_width']
+    if lane_width.empty():
+        return {"leastLateralDistance_LST": -1}
+    else:
+        y_relative_dist = ego_df[ego_df['x_relative_dist'] == 0]['y_relative_dist']
+        if (y_relative_dist <= lane_width / 2).all():
+            return {"leastLateralDistance_LST": 1}
+        else:
+            return {"leastLateralDistance_LST": -1}
+
+
+def waitTimeAtCrosswalkwithPedestrian_LST(data):
+    ego_df = data.ego_data
+    object_df = data.object_data
+    data['in_crosswalk'] = []
+    position_data = data.drop_duplicates(subset=['cross_id', 'cross_coords'], inplace=True)
+    for cross_id in position_data['cross_id'].tolist():
+        for posX, posY in object_df['posX', 'posY']:
+            pedestrian_pos = (posX, posY)
+            plogan_pos = position_data[position_data['cross_id'] == cross_id]['cross_coords'].tolist()[0]
+            if _is_pedestrian_in_crosswalk(plogan_pos, pedestrian_pos):
+                data[data['playerId'] == 2]['in_crosswalk'] = 1
+            else:
+                data[data['playerId'] == 2]['in_crosswalk'] = 0
+    car_stop_time = ego_df[ego_df['v'] == 0]['simTime']
+    pedestrian_in_crosswalk_time = data[(data['in_crosswalk'] == 1)]['simTime']
+    car_wait_pedestrian = car_stop_time.isin(pedestrian_in_crosswalk_time).tolist()
+    return {'waitTimeAtCrosswalkwithPedestrian_LST': car_wait_pedestrian[-1] - car_wait_pedestrian[0] if len(
+        car_wait_pedestrian) > 0 else 0}
+
+
+def launchTimewhenPedestrianLeave_LST(data):
+    ego_df = data.ego_data
+    car_stop_time = ego_df[ego_df['v'] == 0]["simTime"]
+    if car_stop_time.empty():
+        return {"launchTimewhenPedestrianLeave_LST": -1}
+    else:
+        lane_width = ego_df[ego_df['v'] == 0]['lane_width'].tolist()[0]
+        car_to_pedestrain = ego_df[ego_df['y_relative_dist'] <= lane_width / 2]["simTime"]
+        legal_stop_time = car_stop_time.isin(car_to_pedestrain).tolist()
+        return {"launchTimewhenPedestrianLeave_LST": legal_stop_time[-1] - legal_stop_time[0]}
+
+
+def noCollision_LST(data):
+    ego_df = data.ego_data
+    if ego_df['relative_dist'].any() == 0:
+        return {"noCollision_LST": -1}
+    else:
+        return {"noCollision_LST": 1}
+
+
+def noReverse_LST(data):
+    ego_df = data.ego_data
+    if ego_df["lon_v_vehicle"] * ego_df["posH"].any() < 0:
+        return {"noReverse_LST": -1}
+    else:
+        return {"noReverse_LST": 1}
+
+
+def turnAround_LST(data):
+    ego_df = data.ego_data
+    if (ego_df["lon_v_vehicle"].tolist()[0] * ego_df["lon_v_vehicle"].tolist()[-1] < 0) and (
+            ego_df["lon_v_vehicle"] * ego_df["posH"].all() > 0):
+        return {"turnAround_LST": 1}
+    else:
+        return {"turnAround_LST": -1}
+
+
+def laneOffset_LST(data):
+    car_width = data.function_config['vehicle']['CAR_WIDTH']
+    ego_df = data.ego_data
+    laneoffset = ego_df['y_relative_dist'] - car_width / 2
+    return {"laneOffset_LST": max(laneoffset)}
+
+
+def maxLongitudeDist_LST(data):
+    ego_df = data.ego_data
+    if len(ego_df['x_relative_dist']) == 0:
+        return {"maxLongitudeDist_LST": -1}
+    generate_function_chart_data(data, 'maxLongitudeDist_LST')
+    return {"maxLongDist_LST": max(ego_df['x_relative_dist'])}
+
+
+def noEmergencyBraking_LST(data):
+    ego_df = data.ego_data
+    ego_df['ip_dec'] = ego_df['v'].apply(
+        get_interpolation, point1=[18, -5], point2=[72, -3.5])
+    ego_df['slam_brake'] = (ego_df['accleX'] - ego_df['ip_dec']).apply(
+        lambda x: 1 if x < 0 else 0)
+    if sum(ego_df['slam_brake']) == 0:
+        return {"noEmergencyBraking_LST": 1}
+    else:
+        return {"noEmergencyBraking_LST": -1}
+
+
 def rightWarningSignal_PGVIL(data_processed) -> dict:
     """判断是否发出正确预警信号"""
 
@@ -605,7 +903,6 @@ class FunctionManager:
         """
         function_result = self.function.batch_execute()
 
-
         print("\n[功能性表现及评价结果]")
         return function_result
         # self.logger.info(f'Function Result: {function_result}')

+ 1 - 1
modules/metric/safety.py

@@ -271,7 +271,7 @@ class SafetyRegistry:
         registry = {}
         for metric_name in self.metrics:
             func_name = f"calculate_{metric_name.lower()}"
-            if func_name in globals():
+            if func_name in globals(): # global()会获取当前模块下所有全局变量、函数、类和其他对象的名称及其对应的值
                 registry[metric_name] = globals()[func_name]
             else:
                 self.logger.warning(f"未实现安全指标函数: {func_name}")

+ 104 - 96
scripts/evaluator_enhanced.py

@@ -17,7 +17,6 @@ import traceback
 import json
 import inspect
 
-
 # 常量定义
 DEFAULT_WORKERS = 4
 CUSTOM_METRIC_PREFIX = "metric_"
@@ -31,26 +30,27 @@ else:
 
 sys.path.insert(0, str(_ROOT_PATH))
 
+
 class ConfigManager:
     """配置管理组件"""
-    
+
     def __init__(self, logger: logging.Logger):
         self.logger = logger
         self.base_config: Dict[str, Any] = {}
         self.custom_config: Dict[str, Any] = {}
         self.merged_config: Dict[str, Any] = {}
-    
+
     def split_configs(self, all_config_path: Path, base_config_path: Path, custom_config_path: Path) -> None:
         """从all_metrics_config.yaml拆分成内置和自定义配置"""
         try:
             with open(all_config_path, 'r', encoding='utf-8') as f:
                 all_metrics = yaml.safe_load(f) or {}
-            
+
             with open(base_config_path, 'r', encoding='utf-8') as f:
                 builtin_metrics = yaml.safe_load(f) or {}
-            
+
             custom_metrics = self._find_custom_metrics(all_metrics, builtin_metrics)
-            
+
             if custom_metrics:
                 with open(custom_config_path, 'w', encoding='utf-8') as f:
                     yaml.dump(custom_metrics, f, allow_unicode=True, sort_keys=False, indent=2)
@@ -58,18 +58,18 @@ class ConfigManager:
         except Exception as e:
             self.logger.error(f"Failed to split configs: {str(e)}")
             raise
-    
+
     def _find_custom_metrics(self, all_metrics, builtin_metrics, current_path=""):
         """递归比较找出自定义指标"""
         custom_metrics = {}
-        
+
         if isinstance(all_metrics, dict) and isinstance(builtin_metrics, dict):
             for key in all_metrics:
                 if key not in builtin_metrics:
                     custom_metrics[key] = all_metrics[key]
                 else:
                     child_custom = self._find_custom_metrics(
-                        all_metrics[key], 
+                        all_metrics[key],
                         builtin_metrics[key],
                         f"{current_path}.{key}" if current_path else key
                     )
@@ -77,34 +77,34 @@ class ConfigManager:
                         custom_metrics[key] = child_custom
         elif all_metrics != builtin_metrics:
             return all_metrics
-        
+
         if custom_metrics:
             return self._ensure_structure(custom_metrics, all_metrics, current_path)
         return None
-    
+
     def _ensure_structure(self, metrics_dict, full_dict, path):
         """确保每级包含name和priority"""
         if not isinstance(metrics_dict, dict):
             return metrics_dict
-        
+
         current = full_dict
         for key in path.split('.'):
             if key in current:
                 current = current[key]
             else:
                 break
-        
+
         result = {}
         if isinstance(current, dict):
             if 'name' in current:
                 result['name'] = current['name']
             if 'priority' in current:
                 result['priority'] = current['priority']
-        
+
         for key, value in metrics_dict.items():
             if key not in ['name', 'priority']:
                 result[key] = self._ensure_structure(value, full_dict, f"{path}.{key}" if path else key)
-        
+
         return result
 
     def load_configs(self, base_config_path: Optional[Path], custom_config_path: Optional[Path]) -> Dict[str, Any]:
@@ -116,19 +116,19 @@ class ConfigManager:
                 target_custom_path = custom_config_path or (base_config_path.parent / "custom_metrics_config.yaml")
                 self.split_configs(all_config_path, base_config_path, target_custom_path)
                 custom_config_path = target_custom_path
-        
+
         self.base_config = self._safe_load_config(base_config_path) if base_config_path else {}
         self.custom_config = self._safe_load_config(custom_config_path) if custom_config_path else {}
         self.merged_config = self._merge_configs(self.base_config, self.custom_config)
         return self.merged_config
-    
+
     def _safe_load_config(self, config_path: Path) -> Dict[str, Any]:
         """安全加载YAML配置"""
         try:
             if not config_path.exists():
                 self.logger.warning(f"Config file not found: {config_path}")
                 return {}
-                
+
             with config_path.open('r', encoding='utf-8') as f:
                 config = yaml.safe_load(f) or {}
                 self.logger.info(f"Loaded config: {config_path}")
@@ -136,24 +136,24 @@ class ConfigManager:
         except Exception as e:
             self.logger.error(f"Failed to load config {config_path}: {str(e)}")
             return {}
-    
+
     def _merge_configs(self, base_config: Dict, custom_config: Dict) -> Dict:
         """智能合并配置"""
         merged = base_config.copy()
-        
+
         for level1_key, level1_value in custom_config.items():
             if not isinstance(level1_value, dict) or 'name' not in level1_value:
                 if level1_key not in merged:
                     merged[level1_key] = level1_value
                 continue
-                
+
             if level1_key not in merged:
                 merged[level1_key] = level1_value
             else:
                 for level2_key, level2_value in level1_value.items():
                     if level2_key in ['name', 'priority']:
                         continue
-                        
+
                     if isinstance(level2_value, dict):
                         if level2_key not in merged[level1_key]:
                             merged[level1_key][level2_key] = level2_value
@@ -161,31 +161,32 @@ class ConfigManager:
                             for level3_key, level3_value in level2_value.items():
                                 if level3_key in ['name', 'priority']:
                                     continue
-                                    
+
                                 if isinstance(level3_value, dict):
                                     if level3_key not in merged[level1_key][level2_key]:
                                         merged[level1_key][level2_key][level3_key] = level3_value
-        
+
         return merged
-    
+
     def get_config(self) -> Dict[str, Any]:
         return self.merged_config
-    
+
     def get_base_config(self) -> Dict[str, Any]:
         return self.base_config
-    
+
     def get_custom_config(self) -> Dict[str, Any]:
         return self.custom_config
 
+
 class MetricLoader:
     """指标加载器组件"""
-    
+
     def __init__(self, logger: logging.Logger, config_manager: ConfigManager):
         self.logger = logger
         self.config_manager = config_manager
         self.metric_modules: Dict[str, Type] = {}
         self.custom_metric_modules: Dict[str, Any] = {}
-    
+
     def load_builtin_metrics(self) -> Dict[str, Type]:
         """加载内置指标模块"""
         module_mapping = {
@@ -195,15 +196,15 @@ class MetricLoader:
             "efficient": ("modules.metric.efficient", "EfficientManager"),
             "function": ("modules.metric.function", "FunctionManager"),
         }
-        
+
         self.metric_modules = {
             name: self._load_module(*info)
             for name, info in module_mapping.items()
         }
-        
+
         self.logger.info(f"Loaded builtin metrics: {', '.join(self.metric_modules.keys())}")
         return self.metric_modules
-    
+
     @lru_cache(maxsize=32)
     def _load_module(self, module_path: str, class_name: str) -> Type:
         """动态加载Python模块"""
@@ -213,7 +214,7 @@ class MetricLoader:
         except (ImportError, AttributeError) as e:
             self.logger.error(f"Failed to load module: {module_path}.{class_name} - {str(e)}")
             raise
-    
+
     def load_custom_metrics(self, custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
         """加载自定义指标模块"""
         if not custom_metrics_path or not custom_metrics_path.is_dir():
@@ -225,30 +226,30 @@ class MetricLoader:
             if py_file.name.startswith(CUSTOM_METRIC_PREFIX):
                 if self._process_custom_metric_file(py_file):
                     loaded_count += 1
-        
+
         self.logger.info(f"Loaded {loaded_count} custom metric modules")
         return self.custom_metric_modules
-    
+
     def _process_custom_metric_file(self, file_path: Path) -> bool:
         """处理单个自定义指标文件"""
         try:
             metric_key = self._validate_metric_file(file_path)
-            
+
             module_name = f"custom_metric_{file_path.stem}"
             spec = importlib.util.spec_from_file_location(module_name, file_path)
             module = importlib.util.module_from_spec(spec)
             spec.loader.exec_module(module)
-            
+
             from modules.lib.metric_registry import BaseMetric
             metric_class = None
-            
+
             for name, obj in inspect.getmembers(module):
-                if (inspect.isclass(obj) and 
-                    issubclass(obj, BaseMetric) and 
-                    obj != BaseMetric):
+                if (inspect.isclass(obj) and
+                        issubclass(obj, BaseMetric) and
+                        obj != BaseMetric):
                     metric_class = obj
                     break
-            
+
             if metric_class:
                 self.custom_metric_modules[metric_key] = {
                     'type': 'class',
@@ -264,7 +265,7 @@ class MetricLoader:
                 self.logger.info(f"Loaded function-based custom metric: {metric_key}")
             else:
                 raise AttributeError(f"Missing evaluate() function or BaseMetric subclass: {file_path.name}")
-                
+
             return True
         except ValueError as e:
             self.logger.warning(str(e))
@@ -272,24 +273,25 @@ class MetricLoader:
         except Exception as e:
             self.logger.error(f"Failed to load custom metric {file_path}: {str(e)}")
             return False
-    
+
     def _validate_metric_file(self, file_path: Path) -> str:
         """验证自定义指标文件命名规范"""
         stem = file_path.stem[len(CUSTOM_METRIC_PREFIX):]
         parts = stem.split('_')
         if len(parts) < 3:
-            raise ValueError(f"Invalid custom metric filename: {file_path.name}, should be metric_<level1>_<level2>_<level3>.py")
+            raise ValueError(
+                f"Invalid custom metric filename: {file_path.name}, should be metric_<level1>_<level2>_<level3>.py")
 
         level1, level2, level3 = parts[:3]
         if not self._is_metric_configured(level1, level2, level3):
             raise ValueError(f"Unconfigured metric: {level1}.{level2}.{level3}")
         return f"{level1}.{level2}.{level3}"
-    
+
     def _is_metric_configured(self, level1: str, level2: str, level3: str) -> bool:
         """检查指标是否在配置中注册"""
         custom_config = self.config_manager.get_custom_config()
         try:
-            return (level1 in custom_config and 
+            return (level1 in custom_config and
                     isinstance(custom_config[level1], dict) and
                     level2 in custom_config[level1] and
                     isinstance(custom_config[level1][level2], dict) and
@@ -297,32 +299,33 @@ class MetricLoader:
                     isinstance(custom_config[level1][level2][level3], dict))
         except Exception:
             return False
-    
+
     def get_builtin_metrics(self) -> Dict[str, Type]:
         return self.metric_modules
-    
+
     def get_custom_metrics(self) -> Dict[str, Any]:
         return self.custom_metric_modules
 
+
 class EvaluationEngine:
     """评估引擎组件"""
-    
+
     def __init__(self, logger: logging.Logger, config_manager: ConfigManager, metric_loader: MetricLoader):
         self.logger = logger
         self.config_manager = config_manager
         self.metric_loader = metric_loader
-    
+
     def evaluate(self, data: Any) -> Dict[str, Any]:
         """执行评估流程"""
         raw_results = self._collect_builtin_metrics(data)
         custom_results = self._collect_custom_metrics(data)
         return self._process_merged_results(raw_results, custom_results)
-    
+
     def _collect_builtin_metrics(self, data: Any) -> Dict[str, Any]:
         """收集内置指标结果"""
         metric_modules = self.metric_loader.get_builtin_metrics()
         raw_results: Dict[str, Any] = {}
-        
+
         with ThreadPoolExecutor(max_workers=len(metric_modules)) as executor:
             futures = {
                 executor.submit(self._run_module, module, data, module_name): module_name
@@ -344,21 +347,21 @@ class EvaluationEngine:
                         "message": str(e),
                         "timestamp": datetime.now().isoformat(),
                     }
-        
+
         return raw_results
-    
+
     def _collect_custom_metrics(self, data: Any) -> Dict[str, Dict]:
         """收集自定义指标结果"""
         custom_metrics = self.metric_loader.get_custom_metrics()
         if not custom_metrics:
             return {}
-            
+
         custom_results = {}
-        
+
         for metric_key, metric_info in custom_metrics.items():
             try:
                 level1, level2, level3 = metric_key.split('.')
-                
+
                 if metric_info['type'] == 'class':
                     metric_class = metric_info['class']
                     metric_instance = metric_class(data)
@@ -366,22 +369,22 @@ class EvaluationEngine:
                 else:
                     module = metric_info['module']
                     metric_result = module.evaluate(data)
-                
+
                 if level1 not in custom_results:
                     custom_results[level1] = {}
                 custom_results[level1] = metric_result
-                
+
                 self.logger.info(f"Calculated custom metric: {level1}.{level2}.{level3}")
-                
+
             except Exception as e:
                 self.logger.error(f"Custom metric {metric_key} failed: {str(e)}")
-                
+
                 try:
                     level1, level2, level3 = metric_key.split('.')
-                    
+
                     if level1 not in custom_results:
                         custom_results[level1] = {}
-                        
+
                     custom_results[level1] = {
                         "status": "error",
                         "message": str(e),
@@ -389,9 +392,9 @@ class EvaluationEngine:
                     }
                 except Exception:
                     pass
-        
+
         return custom_results
-    
+
     def _process_merged_results(self, raw_results: Dict, custom_results: Dict) -> Dict:
         """处理合并后的评估结果"""
         from modules.lib.score import Score
@@ -417,14 +420,14 @@ class EvaluationEngine:
                     final_results[level1] = self._format_error(e)
 
         return final_results
-        
+
     def _format_error(self, e: Exception) -> Dict:
         return {
             "status": "error",
             "message": str(e),
             "timestamp": datetime.now().isoformat()
         }
-                
+
     def _run_module(self, module_class: Any, data: Any, module_name: str) -> Dict[str, Any]:
         """执行单个评估模块"""
         try:
@@ -434,13 +437,14 @@ class EvaluationEngine:
             self.logger.error(f"{module_name} execution error: {str(e)}", exc_info=True)
             return {module_name: {"error": str(e)}}
 
+
 class LoggingManager:
     """日志管理组件"""
-    
+
     def __init__(self, log_path: Path):
         self.log_path = log_path
         self.logger = self._init_logger()
-    
+
     def _init_logger(self) -> logging.Logger:
         """初始化日志系统"""
         try:
@@ -455,20 +459,21 @@ class LoggingManager:
             logger.addHandler(console_handler)
             logger.warning(f"Failed to init standard logger: {str(e)}, using fallback logger")
             return logger
-    
+
     def get_logger(self) -> logging.Logger:
         return self.logger
 
+
 class DataProcessor:
     """数据处理组件"""
-    
+
     def __init__(self, logger: logging.Logger, data_path: Path, config_path: Optional[Path] = None):
         self.logger = logger
         self.data_path = data_path
         self.config_path = config_path
         self.processor = self._load_processor()
         self.case_name = self.data_path.name
-    
+
     def _load_processor(self) -> Any:
         """加载数据处理器"""
         try:
@@ -477,7 +482,7 @@ class DataProcessor:
         except ImportError as e:
             self.logger.error(f"Failed to load data processor: {str(e)}")
             raise RuntimeError(f"Failed to load data processor: {str(e)}") from e
-    
+
     def validate(self) -> None:
         """验证数据路径"""
         if not self.data_path.exists():
@@ -485,10 +490,11 @@ class DataProcessor:
         if not self.data_path.is_dir():
             raise NotADirectoryError(f"Invalid data directory: {self.data_path}")
 
+
 class EvaluationPipeline:
     """评估流水线控制器"""
-    
-    def __init__(self, config_path: str, log_path: str, data_path: str, report_path: str, 
+
+    def __init__(self, config_path: str, log_path: str, data_path: str, report_path: str,
                  custom_metrics_path: Optional[str] = None, custom_config_path: Optional[str] = None):
         # 路径初始化
         self.config_path = Path(config_path) if config_path else None
@@ -496,7 +502,7 @@ class EvaluationPipeline:
         self.data_path = Path(data_path)
         self.report_path = Path(report_path)
         self.custom_metrics_path = Path(custom_metrics_path) if custom_metrics_path else None
-        
+
         # 组件初始化
         self.logging_manager = LoggingManager(Path(log_path))
         self.logger = self.logging_manager.get_logger()
@@ -507,50 +513,51 @@ class EvaluationPipeline:
         self.metric_loader.load_custom_metrics(self.custom_metrics_path)
         self.evaluation_engine = EvaluationEngine(self.logger, self.config_manager, self.metric_loader)
         self.data_processor = DataProcessor(self.logger, self.data_path, self.config_path)
-    
+
     def execute(self) -> Dict[str, Any]:
         """执行评估流水线"""
         try:
             self.data_processor.validate()
-            
+
             self.logger.info(f"Start evaluation: {self.data_path.name}")
             start_time = time.perf_counter()
             results = self.evaluation_engine.evaluate(self.data_processor.processor)
             elapsed_time = time.perf_counter() - start_time
             self.logger.info(f"Evaluation completed, time: {elapsed_time:.2f}s")
-            
+
             report = self._generate_report(self.data_processor.case_name, results)
             return report
-            
+
         except Exception as e:
             self.logger.critical(f"Evaluation failed: {str(e)}", exc_info=True)
             return {"error": str(e), "traceback": traceback.format_exc()}
-    
+
     def _generate_report(self, case_name: str, results: Dict[str, Any]) -> Dict[str, Any]:
         """生成评估报告"""
         from modules.lib.common import dict2json
-        
+
         self.report_path.mkdir(parents=True, exist_ok=True)
-        
+
         results["metadata"] = {
             "case_name": case_name,
             "timestamp": datetime.now().isoformat(),
             "version": "3.1.0",
         }
-        
+
         report_file = self.report_path / f"{case_name}_report.json"
         dict2json(results, report_file)
         self.logger.info(f"Report generated: {report_file}")
-        
+
         return results
 
+
 def main():
     """命令行入口"""
     parser = argparse.ArgumentParser(
         description="Autonomous Driving Evaluation System V3.1",
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
-    
+
     parser.add_argument(
         "--logPath",
         type=str,
@@ -560,7 +567,7 @@ def main():
     parser.add_argument(
         "--dataPath",
         type=str,
-        default=r"D:\Kevin\zhaoyuan\data\V2V_CSAE53-2020_ForwardCollisionW_LST_01-01",
+        default=r"D:\Kevin\zhaoyuan\data\V2V_CSAE53-2020_ForwardCollision_LST_01-02",
         help="Input data directory",
     )
     parser.add_argument(
@@ -587,19 +594,19 @@ def main():
         default="config/custom_metrics_config.yaml",
         help="Custom metrics config path (optional)",
     )
-    
+
     args = parser.parse_args()
 
     try:
         pipeline = EvaluationPipeline(
-            args.configPath, 
-            args.logPath, 
-            args.dataPath, 
-            args.reportPath, 
-            args.customMetricsPath, 
+            args.configPath,
+            args.logPath,
+            args.dataPath,
+            args.reportPath,
+            args.customMetricsPath,
             args.customConfigPath
         )
-        
+
         start_time = time.perf_counter()
         result = pipeline.execute()
         elapsed_time = time.perf_counter() - start_time
@@ -610,7 +617,7 @@ def main():
 
         print(f"Evaluation completed, total time: {elapsed_time:.2f}s")
         print(f"Report path: {pipeline.report_path}")
-        
+
     except KeyboardInterrupt:
         print("\nUser interrupted")
         sys.exit(130)
@@ -619,6 +626,7 @@ def main():
         traceback.print_exc()
         sys.exit(1)
 
+
 if __name__ == "__main__":
     warnings.filterwarnings("ignore")
     main()

+ 61 - 61
scripts/evaluator_optimized.py

@@ -14,7 +14,6 @@ from datetime import datetime
 # 强制导入所有可能动态加载的模块
 
 
-
 # 安全设置根目录路径(动态路径管理)
 # 判断是否处于编译模式
 if hasattr(sys, "_MEIPASS"):
@@ -40,21 +39,22 @@ class EvaluationCore:
             cls._instance._init(logPath, configPath, customConfigPath, customMetricsPath)
         return cls._instance
 
-    def _init(self, logPath: str = None, configPath: str = None, customConfigPath: str = None, customMetricsPath: str = None) -> None:
+    def _init(self, logPath: str = None, configPath: str = None, customConfigPath: str = None,
+              customMetricsPath: str = None) -> None:
         """初始化引擎组件"""
         self.log_path = logPath
         self.config_path = configPath
         self.custom_config_path = customConfigPath
         self.custom_metrics_path = customMetricsPath
-        
+
         # 加载配置
         self.metrics_config = {}
         self.custom_metrics_config = {}
         self.merged_config = {}  # 添加合并后的配置
-        
+
         # 自定义指标脚本模块
         self.custom_metrics_modules = {}
-        
+
         self._init_log_system()
         self._load_configs()  # 加载并合并配置
         self._init_metrics()
@@ -103,7 +103,7 @@ class EvaluationCore:
             except Exception as e:
                 self.logger.error(f"加载内置指标配置失败: {str(e)}")
                 self.metrics_config = {}
-        
+
         # 加载自定义指标配置
         if self.custom_config_path and Path(self.custom_config_path).exists():
             try:
@@ -113,28 +113,28 @@ class EvaluationCore:
             except Exception as e:
                 self.logger.error(f"加载自定义指标配置失败: {str(e)}")
                 self.custom_metrics_config = {}
-        
+
         # 合并配置
         self.merged_config = self._merge_configs(self.metrics_config, self.custom_metrics_config)
 
     def _merge_configs(self, base_config: Dict, custom_config: Dict) -> Dict:
         """
         合并内置指标和自定义指标配置
-        
+
         策略:
         1. 如果自定义指标与内置指标有相同的一级指标,则合并其下的二级指标
         2. 如果自定义指标与内置指标有相同的二级指标,则合并其下的三级指标
         3. 如果是全新的指标,则直接添加
         """
         merged = base_config.copy()
-        
+
         for level1_key, level1_value in custom_config.items():
             # 跳过非指标配置项(如vehicle等)
             if not isinstance(level1_value, dict) or 'name' not in level1_value:
                 if level1_key not in merged:
                     merged[level1_key] = level1_value
                 continue
-                
+
             if level1_key not in merged:
                 # 全新的一级指标
                 merged[level1_key] = level1_value
@@ -143,7 +143,7 @@ class EvaluationCore:
                 for level2_key, level2_value in level1_value.items():
                     if level2_key == 'name' or level2_key == 'priority':
                         continue
-                        
+
                     if isinstance(level2_value, dict):
                         if level2_key not in merged[level1_key]:
                             # 新的二级指标
@@ -153,24 +153,24 @@ class EvaluationCore:
                             for level3_key, level3_value in level2_value.items():
                                 if level3_key == 'name' or level3_key == 'priority':
                                     continue
-                                    
+
                                 if isinstance(level3_value, dict):
                                     if level3_key not in merged[level1_key][level2_key]:
                                         # 新的三级指标
                                         merged[level1_key][level2_key][level3_key] = level3_value
-        
+
         return merged
 
     def _load_custom_metrics(self) -> None:
         """加载自定义指标脚本"""
         if not self.custom_metrics_path or not Path(self.custom_metrics_path).exists():
             return
-            
+
         custom_metrics_dir = Path(self.custom_metrics_path)
         if not custom_metrics_dir.is_dir():
             self.logger.warning(f"自定义指标路径不是目录: {custom_metrics_dir}")
             return
-            
+
         # 遍历自定义指标脚本目录
         for file_path in custom_metrics_dir.glob("*.py"):
             if file_path.name.startswith("metric_") and file_path.name.endswith(".py"):
@@ -178,39 +178,40 @@ class EvaluationCore:
                     # 解析脚本名称,获取指标层级信息
                     parts = file_path.stem[7:].split('_')  # 去掉'metric_'前缀
                     if len(parts) < 3:
-                        self.logger.warning(f"自定义指标脚本 {file_path.name} 命名不符合规范,应为 metric_<level1>_<level2>_<level3>.py")
+                        self.logger.warning(
+                            f"自定义指标脚本 {file_path.name} 命名不符合规范,应为 metric_<level1>_<level2>_<level3>.py")
                         continue
-                    
+
                     level1, level2, level3 = parts[0], parts[1], parts[2]
-                    
+
                     # 检查指标是否在配置中
                     if not self._check_metric_in_config(level1, level2, level3, self.custom_metrics_config):
                         self.logger.warning(f"自定义指标 {level1}.{level2}.{level3} 在配置中不存在,跳过加载")
                         continue
-                    
+
                     # 加载脚本模块
                     module_name = f"custom_metric_{level1}_{level2}_{level3}"
                     spec = importlib.util.spec_from_file_location(module_name, file_path)
                     module = importlib.util.module_from_spec(spec)
                     spec.loader.exec_module(module)
-                    
+
                     # 检查模块是否包含必要的函数
                     if not hasattr(module, 'evaluate'):
                         self.logger.warning(f"自定义指标脚本 {file_path.name} 缺少 evaluate 函数")
                         continue
-                    
+
                     # 存储模块引用
                     key = f"{level1}.{level2}.{level3}"
                     self.custom_metrics_modules[key] = module
                     self.logger.info(f"成功加载自定义指标脚本: {file_path.name}")
-                    
+
                 except Exception as e:
                     self.logger.error(f"加载自定义指标脚本 {file_path.name} 失败: {str(e)}")
 
     def _check_metric_in_config(self, level1: str, level2: str, level3: str, config: Dict) -> bool:
         """检查指标是否在配置中存在"""
         try:
-            return (level1 in config and 
+            return (level1 in config and
                     isinstance(config[level1], dict) and
                     level2 in config[level1] and
                     isinstance(config[level1][level2], dict) and
@@ -223,15 +224,15 @@ class EvaluationCore:
         """并行化评估引擎(动态线程池)"""
         # 存储所有评估结果
         results = {}
-        
+
         # 1. 先评估内置指标
         self._evaluate_built_in_metrics(data, results)
-        
+
         # 2. 再评估自定义指标并合并结果
         self._evaluate_and_merge_custom_metrics(data, results)
-        
+
         return results
-    
+
     def _evaluate_built_in_metrics(self, data: Any, results: Dict[str, Any]) -> None:
         """评估内置指标"""
         # 关键修改点1:线程数=模块数
@@ -265,12 +266,12 @@ class EvaluationCore:
                         "message": str(e),
                         "timestamp": datetime.now().isoformat(),
                     }
-    
+
     def _evaluate_and_merge_custom_metrics(self, data: Any, results: Dict[str, Any]) -> None:
         """评估自定义指标并合并结果"""
         if not self.custom_metrics_modules:
             return
-            
+
         # 按一级指标分组自定义指标
         grouped_metrics = {}
         for metric_key in self.custom_metrics_modules:
@@ -278,13 +279,13 @@ class EvaluationCore:
             if level1 not in grouped_metrics:
                 grouped_metrics[level1] = []
             grouped_metrics[level1].append(metric_key)
-        
+
         # 处理每个一级指标组
         for level1, metric_keys in grouped_metrics.items():
             # 检查是否为内置一级指标
             is_built_in = level1 in self.metrics_config and 'name' in self.metrics_config[level1]
             level1_name = self.merged_config[level1].get('name', level1) if level1 in self.merged_config else level1
-            
+
             # 如果是内置一级指标,将结果合并到已有结果中
             if is_built_in and level1_name in results:
                 for metric_key in metric_keys:
@@ -293,52 +294,52 @@ class EvaluationCore:
                 # 如果是新的一级指标,创建新的结果结构
                 if level1_name not in results:
                     results[level1_name] = {}
-                
+
                 # 评估该一级指标下的所有自定义指标
                 for metric_key in metric_keys:
                     self._evaluate_and_merge_single_metric(data, results, metric_key, level1_name)
-    
-    def _evaluate_and_merge_single_metric(self, data: Any, results: Dict[str, Any], metric_key: str, level1_name: str) -> None:
+
+    def _evaluate_and_merge_single_metric(self, data: Any, results: Dict[str, Any], metric_key: str,
+                                          level1_name: str) -> None:
         """评估单个自定义指标并合并结果"""
         try:
             level1, level2, level3 = metric_key.split('.')
             module = self.custom_metrics_modules[metric_key]
-            
+
             # 获取指标配置
             metric_config = self.custom_metrics_config[level1][level2][level3]
-            
+
             # 获取指标名称
             level2_name = self.custom_metrics_config[level1][level2].get('name', level2)
             level3_name = metric_config.get('name', level3)
-            
+
             # 确保结果字典结构存在
             if level2_name not in results[level1_name]:
                 results[level1_name][level2_name] = {}
-            
+
             # 调用自定义指标评测函数
             metric_result = module.evaluate(data)
             from modules.lib.score import Score
             evaluator = Score(self.merged_config, level1_name)
-            
+
             result = evaluator.evaluate(metric_result)
-           
+
             results.update(result)
-            
-            
+
             self.logger.info(f"评测自定义指标: {level1_name}.{level2_name}.{level3_name}")
-            
+
         except Exception as e:
             self.logger.error(f"评测自定义指标 {metric_key} 失败: {str(e)}")
-            
+
             # 尝试添加错误信息到结果中
             try:
                 level1, level2, level3 = metric_key.split('.')
                 level2_name = self.custom_metrics_config[level1][level2].get('name', level2)
                 level3_name = self.custom_metrics_config[level1][level2][level3].get('name', level3)
-                
+
                 if level2_name not in results[level1_name]:
                     results[level1_name][level2_name] = {}
-                    
+
                 results[level1_name][level2_name][level3_name] = {
                     "status": "error",
                     "message": str(e),
@@ -348,7 +349,7 @@ class EvaluationCore:
                 pass
 
     def _run_module(
-        self, module_class: Any, data: Any, module_name: str
+            self, module_class: Any, data: Any, module_name: str
     ) -> Dict[str, Any]:
         """执行单个评估模块(带熔断机制)"""
         try:
@@ -359,26 +360,25 @@ class EvaluationCore:
             return {module_name: {"error": str(e)}}
 
 
-
-
 class EvaluationPipeline:
     """评估流水线控制器"""
 
-    def __init__(self, configPath: str, logPath: str, dataPath: str, resultPath: str, customMetricsPath: Optional[str] = None, customConfigPath: Optional[str] = None):
+    def __init__(self, configPath: str, logPath: str, dataPath: str, resultPath: str,
+                 customMetricsPath: Optional[str] = None, customConfigPath: Optional[str] = None):
         self.configPath = Path(configPath)
         self.custom_config_path = Path(customConfigPath) if customConfigPath else None
         self.data_path = Path(dataPath)
         self.report_path = Path(resultPath)
         self.custom_metrics_path = Path(customMetricsPath) if customMetricsPath else None
-        
+
         # 创建评估引擎实例,传入所有必要参数
         self.engine = EvaluationCore(
-            logPath, 
-            configPath=str(self.configPath), 
+            logPath,
+            configPath=str(self.configPath),
             customConfigPath=str(self.custom_config_path) if self.custom_config_path else None,
             customMetricsPath=str(self.custom_metrics_path) if self.custom_metrics_path else None
         )
-        
+
         self.data_processor = self._load_data_processor()
 
     def _load_data_processor(self) -> Any:
@@ -435,39 +435,39 @@ def main():
     parser.add_argument(
         "--logPath",
         type=str,
-        default=r"D:\Cicv\招远\zhaoyuan0410\logs\test.log",
+        default=r"D:\Cicv\招远\zhaoyuan\test.log",
         help="日志文件存储路径",
     )
     parser.add_argument(
         "--dataPath",
         type=str,
-        default=r"D:\Cicv\招远\V2V_CSAE53-2020_ForwardCollision_LST_02-03",
+        default=r"D:\Cicv\招远\V2V_CSAE53-2020_ForwardCollision_LST_01-02_new",
         help="预处理后的输入数据目录",
     )
     parser.add_argument(
         "--configPath",
         type=str,
-        default=r"D:\Cicv\招远\zhaoyuan0410\config\metrics_config.yaml",
+        default=r"D:\Cicv\招远\zhaoyuan\zhaoyuan\config\all_metrics_config.yaml",
         help="评估指标配置文件路径",
     )
     parser.add_argument(
         "--reportPath",
         type=str,
-        default=r"D:\Cicv\招远\zhaoyuan0410\result",
+        default=r"D:\Cicv\招远\zhaoyuan\zhaoyuan\result",
         help="评估报告输出目录",
     )
     # 新增自定义指标路径参数(可选)
     parser.add_argument(
         "--customMetricsPath",
         type=str,
-        default=r"D:\Cicv\招远\zhaoyuan0410\custom_metrics",
+        default=r"D:\Cicv\招远\zhaoyuan\zhaoyuan\custom_metrics",
         help="自定义指标脚本目录(可选)",
     )
     # 新增自定义指标路径参数(可选)
     parser.add_argument(
         "--customConfigPath",
         type=str,
-        default=r"D:\Cicv\招远\zhaoyuan0410\test\custom_metrics_config.yaml",
+        default=r"D:\Cicv\招远\zhaoyuan\zhaoyuan\test\custom_metrics_config.yaml",
         help="自定义指标脚本目录(可选)",
     )
     args = parser.parse_args()
@@ -483,7 +483,7 @@ def main():
         if "error" in result:
             sys.exit(1)
 
-        print(f"评估完成,耗时: {time.perf_counter()-start_time:.2f}s")
+        print(f"评估完成,耗时: {time.perf_counter() - start_time:.2f}s")
         print(f"报告路径: {pipeline.report_path}")
     except KeyboardInterrupt:
         print("\n用户中断操作")

Неке датотеке нису приказане због велике количине промена