XGJ_zhaoyuan пре 4 дана
родитељ
комит
a2a9b42fbf
2 измењених фајлова са 19 додато и 13 уклоњено
  1. 11 10
      modules/lib/chart_generator.py
  2. 8 3
      modules/metric/function.py

+ 11 - 10
modules/lib/chart_generator.py

@@ -255,12 +255,12 @@ def generate_earliest_warning_distance_pgvil_chart(function_calculator, output_d
         warning_dist = getattr(function_calculator, 'warning_dist', None)
         warning_time = getattr(function_calculator, 'warning_time', None)
 
-        if warning_dist.empty:
+        if len(warning_dist) == 0:
             logger.warning(f"Cannot generate {"earliestWarningDistance_LST"} chart: empty data")
             return None
 
         # Calculate metric value
-        metric_value = float(warning_dist.iloc[0]) if len(warning_dist) >= 0.0 else max_threshold
+        metric_value = float(warning_dist[0]) if len(warning_dist) >= 0.0 else max_threshold
 
         # Save CSV data
         csv_filename = os.path.join(output_dir, f"earliestWarningDistance_PGVIL_data.csv")
@@ -268,7 +268,7 @@ def generate_earliest_warning_distance_pgvil_chart(function_calculator, output_d
             'simTime': warning_time,
             'warning_distance': warning_dist,
             'min_threshold': min_threshold,
-            'max_threshold': max_threshold,
+            'max_threshold': max_threshold
         })
         df_csv.to_csv(csv_filename, index=False)
         logger.info(f"earliestWarningDistance_PGVIL data saved to: {csv_filename}")
@@ -348,7 +348,7 @@ def generate_latest_warning_ttc_pgvil_chart(function_calculator, output_dir: str
         warning_time = getattr(function_calculator, 'warning_time', None)
         ttc = getattr(function_calculator, 'ttc', None)
 
-        if warning_dist.empty:
+        if len(warning_dist) == 0:
             logger.warning("Cannot generate TTC warning chart: empty data")
             return None
 
@@ -357,7 +357,7 @@ def generate_latest_warning_ttc_pgvil_chart(function_calculator, output_dir: str
         # 保存 CSV 数据
         csv_filename = os.path.join(output_dir, f"latestwarningdistance_ttc_pgvil_data.csv")
         df_csv = pd.DataFrame({
-            'simTime': ego_df['simTime'],
+            'simTime': warning_time,
             'warning_distance': warning_dist,
             'warning_speed': warning_speed,
             'ttc': ttc,
@@ -650,17 +650,17 @@ def generate_latest_warning_distance_pgvil_chart(function_calculator, output_dir
         warning_dist = getattr(function_calculator, 'warning_dist', None)
         warning_time = getattr(function_calculator, 'warning_time', None)
 
-        if warning_dist.empty:
+        if len(warning_dist) == 0:
             logger.warning(f"Cannot generate latestWarningDistance_PGVIL chart: empty data")
             return None
 
         # Calculate metric value
-        metric_value = float(warning_dist.iloc[-1]) if len(warning_dist) > 0 else max_threshold
+        metric_value = float(warning_dist[-1]) if len(warning_dist) > 0 else max_threshold
 
         # Save CSV data
         csv_filename = os.path.join(output_dir, f"latestWarningDistance_PGVIL_data.csv")
         df_csv = pd.DataFrame({
-            'simTime': ego_df['simTime'],
+            'simTime': warning_time,
             'warning_distance': warning_dist,
             'min_threshold': min_threshold,
             'max_threshold': max_threshold
@@ -837,16 +837,17 @@ def generate_earliest_warning_distance_ttc_pgvil_chart(function_calculator, outp
         min_threshold = thresholds["min"]
 
         # Get calculated warning distance and speed
-        warning_dist = getattr(function_calculator, 'correctwarning', None)
+        warning_dist = getattr(function_calculator, 'warning_dist', None)
         warning_speed = getattr(function_calculator, 'warning_speed', None)
         ttc = getattr(function_calculator, 'ttc', None)
+        warning_time = getattr(function_calculator, 'warning_time', None)
 
         # Calculate metric value
         metric_value = float(ttc[0]) if len(ttc) > 0 else max_threshold
         # Save CSV data
         csv_filename = os.path.join(output_dir, f"{metric_name.lower()}_data.csv")
         df_csv = pd.DataFrame({
-            'simTime': ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]['simTime'],
+            'simTime': warning_time,
             'warning_distance': warning_dist,
             'warning_speed': warning_speed,
             'ttc': ttc,

+ 8 - 3
modules/metric/function.py

@@ -604,10 +604,11 @@ def latestWarningDistance_PGVIL(data_processed) -> dict:
 
     # 将计算结果保存到data对象中,供图表生成使用
     data_processed.warning_dist = distances
+    data_processed.warning_time = ego['simTime'].tolist()
     if distances.size == 0:
         print("没有找到数据!")
         return {"latestWarningDistance_PGVIL": 15}  # 或返回其他默认值,如0.0
-
+    generate_function_chart_data(data_processed, 'latestWarningDistance_PGVIL')
     return {"latestWarningDistance_PGVIL": float(np.min(distances))}
 
 
@@ -632,6 +633,7 @@ def latestWarningDistance_TTC_PGVIL(data_processed) -> dict:
 
     data_processed.warning_dist = distances
     data_processed.warning_speed = rel_speeds
+    data_processed.warning_time = ego['simTime'].tolist()
 
     with np.errstate(divide="ignore", invalid="ignore"):
         ttc = np.where(rel_speeds != 0, distances / rel_speeds, np.inf)
@@ -639,7 +641,7 @@ def latestWarningDistance_TTC_PGVIL(data_processed) -> dict:
         print("没有找到数据!")
         return {"latestWarningDistance_TTC_PGVIL": 2}  # 或返回其他默认值,如0.0
     data_processed.ttc = ttc
-
+    generate_function_chart_data(data_processed, 'latestWarningDistance_TTC_PGVIL')
     return {"latestWarningDistance_TTC_PGVIL": float(np.nanmin(ttc))}
 
 
@@ -658,10 +660,11 @@ def earliestWarningDistance_PGVIL(data_processed) -> dict:
     )
     # 将计算结果保存到data对象中,供图表生成使用
     data_processed.warning_dist = distances
+    data_processed.warning_time = ego['simTime'].tolist()
     if distances.size == 0:
         print("没有找到数据!")
         return {"earliestWarningDistance_PGVIL": 15}  # 或返回其他默认值,如0.0
-
+    generate_function_chart_data(data_processed, 'earliestWarningDistance_PGVIL')
     return {"earliestWarningDistance": float(np.min(distances))}
 
 
@@ -687,6 +690,7 @@ def earliestWarningDistance_TTC_PGVIL(data_processed) -> dict:
 
     data_processed.warning_dist = distances
     data_processed.warning_speed = rel_speeds
+    data_processed.warning_time = ego['simTime'].tolist()
 
     with np.errstate(divide="ignore", invalid="ignore"):
         ttc = np.where(rel_speeds != 0, distances / rel_speeds, np.inf)
@@ -694,6 +698,7 @@ def earliestWarningDistance_TTC_PGVIL(data_processed) -> dict:
         print("没有找到数据!")
         return {"earliestWarningDistance_TTC_PGVIL": 2}  # 或返回其他默认值,如0.0
     data_processed.ttc = ttc
+    generate_function_chart_data(data_processed, 'earliestWarningDistance_TTC_PGVIL')
     return {"earliestWarningDistance_TTC_PGVIL": float(np.nanmin(ttc))}