Browse Source

修改功能性指标中读取二级指标名称的代码逻辑

XGJ_zhaoyuan 1 week ago
parent
commit
e556a1496c
1 changed files with 37 additions and 15 deletions
  1. 37 15
      modules/metric/function.py

+ 37 - 15
modules/metric/function.py

@@ -33,6 +33,28 @@ scenario_sign_dict = {"LeftTurnAssist": 206, "HazardousLocationW": 207, "RedLigh
                       "CoorperativeIntersectionPassing": 225, "GreenLightOptimalSpeedAdvisory": 234,
                       "ForwardCollision": 212}
 
+# 寻找二级指标的名称
+def find_nested_name(data):
+    """
+    查找字典中嵌套的name结构。
+
+    :param data: 要搜索的字典
+    :return: 找到的第一个嵌套name结构的值,如果没有找到则返回None
+    """
+    if isinstance(data, dict):
+        for key, value in data.items():
+            if isinstance(value, dict) and 'name' in value:
+                return value['name']
+            # 递归查找嵌套字典
+            result = find_nested_name(value)
+            if result is not None:
+                return result
+    elif isinstance(data, list):
+        for item in data:
+            result = find_nested_name(item)
+            if result is not None:
+                return result
+    return None
 
 def calculate_distance_PGVIL(ego_pos: np.ndarray, obj_pos: np.ndarray) -> np.ndarray:
     """向量化距离计算"""
@@ -69,7 +91,7 @@ def get_first_warning(data_processed) -> Optional[pd.DataFrame]:
     ego_df = data_processed.ego_data
     obj_df = data_processed.object_df
 
-    scenario_name = data_processed.function_config["function"]["scenario"]["name"]
+    scenario_name = find_nested_name(data_processed.function_config["function"])
     correctwarning = scenario_sign_dict.get(scenario_name)
 
     if correctwarning is None:
@@ -91,8 +113,8 @@ def get_first_warning(data_processed) -> Optional[pd.DataFrame]:
 # ----------------------
 def latestWarningDistance_LST(data) -> dict:
     """预警距离计算流水线"""
-    scenario_name = data.function_config["function"]["scenario"]["name"]
-    value = data.function_config["function"]["scenario"]["latestWarningDistance_LST"]["max"]
+    scenario_name = find_nested_name(data.function_config["function"])
+    value = data.function_config["function"][scenario_name]["latestWarningDistance_LST"]["max"]
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
     warning_dist = calculate_distance(ego_df, correctwarning)
@@ -104,8 +126,8 @@ def latestWarningDistance_LST(data) -> dict:
 
 def earliestWarningDistance_LST(data) -> dict:
     """预警距离计算流水线"""
-    scenario_name = data.function_config["function"]["scenario"]["name"]
-    value = data.function_config["function"]["scenario"]["earliestWarningDistance_LST"]["max"]
+    scenario_name = find_nested_name(data.function_config["function"])
+    value = data.function_config["function"][scenario_name]["earliestWarningDistance_LST"]["max"]
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
     warning_dist = calculate_distance(ego_df, correctwarning)
@@ -117,8 +139,8 @@ def earliestWarningDistance_LST(data) -> dict:
 
 def latestWarningDistance_TTC_LST(data) -> dict:
     """TTC计算流水线"""
-    scenario_name = data.function_config["function"]["scenario"]["name"]
-    value = data.function_config["function"]["scenario"]["latestWarningDistance_TTC_LST"]["max"]
+    scenario_name = find_nested_name(data.function_config["function"])
+    value = data.function_config["function"][scenario_name]["latestWarningDistance_TTC_LST"]["max"]
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
     warning_dist = calculate_distance(ego_df, correctwarning)
@@ -143,8 +165,8 @@ def latestWarningDistance_TTC_LST(data) -> dict:
 
 def earliestWarningDistance_TTC_LST(data) -> dict:
     """TTC计算流水线"""
-    scenario_name = data.function_config["function"]["scenario"]["name"]
-    value = data.function_config["function"]["scenario"]["earliestWarningDistance_TTC_LST"]["max"]
+    scenario_name = find_nested_name(data.function_config["function"])
+    value = data.function_config["function"][scenario_name]["earliestWarningDistance_TTC_LST"]["max"]
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
     warning_dist = calculate_distance(ego_df, correctwarning)
@@ -164,7 +186,7 @@ def earliestWarningDistance_TTC_LST(data) -> dict:
 
 
 def warningDelayTime_LST(data):
-    scenario_name = data.function_config["function"]["scenario"]["name"]
+    scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
     HMI_warning_rows = ego_df[(ego_df['ifwarning'] == correctwarning)]['simTime'].tolist()
@@ -181,7 +203,7 @@ def warningDelayTime_LST(data):
 
 
 def warningDelayTimeofReachDecel_LST(data):
-    scenario_name = data.function_config["function"]["scenario"]["name"]
+    scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
     ego_speed_simtime = ego_df[ego_df['accel'] <= -4]['simTime'].tolist()  # 单位m/s^2
@@ -197,7 +219,7 @@ def warningDelayTimeofReachDecel_LST(data):
 
 
 def rightWarningSignal_LST(data):
-    scenario_name = data.function_config["function"]["scenario"]["name"]
+    scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
     if ego_df['ifwarning'].empty:
@@ -211,7 +233,7 @@ def rightWarningSignal_LST(data):
 
 
 def ifCrossingRedLight_LST(data):
-    scenario_name = data.function_config["function"]["scenario"]["name"]
+    scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
     redlight_simtime = ego_df[
@@ -224,7 +246,7 @@ def ifCrossingRedLight_LST(data):
 
 
 def ifStopgreenWaveSpeedGuidance_LST(data):
-    scenario_name = data.function_config["function"]["scenario"]["name"]
+    scenario_name = find_nested_name(data.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
     ego_df = data.ego_data
     greenlight_simtime = \
@@ -239,7 +261,7 @@ def rightWarningSignal_PGVIL(data_processed) -> dict:
     """判断是否发出正确预警信号"""
 
     ego_df = data_processed.ego_data
-    scenario_name = data_processed.function_config["function"]["scenario"]["name"]
+    scenario_name = find_nested_name(data_processed.function_config["function"])
     correctwarning = scenario_sign_dict[scenario_name]
 
     if correctwarning is None: