import zipfile import sqlite3 import csv import tempfile from pathlib import Path from typing import List, Dict, Tuple, Optional, Any, NamedTuple import cantools import os import subprocess import numpy as np import pandas as pd from collections import Counter from datetime import datetime import argparse import sys from pyproj import Proj from bagpy.bagreader import bagreader import shutil import json from dataclasses import dataclass, field # --- Constants --- PLAYER_ID_EGO = int(1) PLAYER_ID_OBJ = int(2) DEFAULT_TYPE = int(1) OUTPUT_CSV_OBJSTATE = "ObjState.csv" OUTPUT_CSV_TEMP_OBJSTATE = "ObjState_temp_intermediate.csv" # Should be eliminated OUTPUT_CSV_EGOSTATE = "EgoState.csv" # Not used in final merge? Check logic if needed. OUTPUT_CSV_MERGED = "merged_ObjState.csv" OUTPUT_CSV_OBU = "OBUdata.csv" OUTPUT_CSV_LANEMAP = "LaneMap.csv" OUTPUT_CSV_EGOMAP = "EgoMap.csv" OUTPUT_CSV_FUNCTION = "Function.csv" ROADMARK_CSV = "RoadMark.csv" # --- Configuration Class --- @dataclass class Config: """Holds configuration paths and settings.""" zip_path: Path output_path: Path json_path: Optional[Path] # Make json_path optional dbc_path: Optional[Path] = None engine_path: Optional[Path] = None map_path: Optional[Path] = None utm_zone: int = 51 # Example UTM zone x_offset: float = 0.0 y_offset: float = 0.0 # Derived paths output_dir: Path = field(init=False) def __post_init__(self): # Use output_path directly as output_dir to avoid nested directories self.output_dir = self.output_path self.output_dir.mkdir(parents=True, exist_ok=True) # --- Zip/CSV Processing --- class ZipCSVProcessor: """Processes DB files within a ZIP archive to generate CSV data.""" # Define column mappings more clearly EGO_COLS_NEW = [ "simTime", "simFrame", "playerId", "v", "speedX", "speedY", "posH", "speedH", "posX", "posY", "accelX", "accelY", "travelDist", "composite_v", "relative_dist", "type" # Added type ] OBJ_COLS_OLD_SUFFIXED = [ "v_obj", "speedX_obj", "speedY_obj", "posH_obj", "speedH_obj", "posX_obj", "posY_obj", "accelX_obj", "accelY_obj", "travelDist_obj" ] OBJ_COLS_MAPPING = {old: new for old, new in zip(OBJ_COLS_OLD_SUFFIXED, EGO_COLS_NEW[3:13])} # Map suffixed cols to standard names def __init__(self, config: Config): self.config = config self.dbc = self._load_dbc(config.dbc_path) self.projection = Proj(proj='utm', zone=config.utm_zone, ellps='WGS84', preserve_units='m') self._init_table_config() self._init_keyword_mapping() def _load_dbc(self, dbc_path: Optional[Path]) -> Optional[cantools.db.Database]: if not dbc_path or not dbc_path.exists(): print("DBC path not provided or file not found.") return None try: return cantools.db.load_file(dbc_path) except Exception as e: print(f"DBC loading failed: {e}") return None def _init_table_config(self): """Initializes configurations for different table types.""" self.table_config = { "gnss_table": self._get_gnss_config(), "can_table": self._get_can_config() } def _get_gnss_config(self): # Keep relevant columns, adjust mapping as needed return { "output_columns": self.EGO_COLS_NEW, # Use the standard ego columns + type "mapping": { # Map output columns to source DB columns/signals "simTime": ("second", "usecond"), "simFrame": "ID", "v": "speed", "speedY": "y_speed", "speedX": "x_speed", "posH": "yaw", "speedH": "yaw_rate", "posX": "latitude_dd", # Source before projection "posY": "longitude_dd", # Source before projection "accelX": "x_acceleration", "accelY": "y_acceleration", "travelDist": "total_distance", # composite_v/relative_dist might not be direct fields in GNSS, handle later if needed "composite_v": "speed", # Placeholder, adjust if needed "relative_dist": None, # Placeholder, likely not in GNSS data "type": None # Will be set later }, "db_columns": ["ID", "second", "usecond", "speed", "y_speed", "x_speed", "yaw", "yaw_rate", "latitude_dd", "longitude_dd", "x_acceleration", "y_acceleration", "total_distance"] # Actual cols to SELECT } def _get_can_config(self): # Define columns needed from DB/CAN signals for both EGO and OBJ return { "mapping": { # Map unified output columns to CAN signals or direct fields # EGO mappings (VUT = Vehicle Under Test) "v": "VUT_Speed_mps", "speedX": "VUT_Speed_x_mps", "speedY": "VUT_Speed_y_mps", "speedH": "VUT_Yaw_Rate", "posX": "VUT_GPS_Latitude", # Source before projection "posY": "VUT_GPS_Longitude", # Source before projection "posH": "VUT_Heading", "accelX": "VUT_Acc_X", "accelY": "VUT_Acc_Y", # OBJ mappings (UFO = Unidentified Flying Object / Other Vehicle) "v_obj": "Speed_mps", "speedX_obj": "UFO_Speed_x_mps", "speedY_obj": "UFO_Speed_y_mps", "speedH_obj": "Yaw_Rate", "posX_obj": "GPS_Latitude", # Source before projection "posY_obj": "GPS_Longitude", # Source before projection "posH_obj": "Heading", "accelX_obj": "Acc_X", "accelY_obj": "Acc_Y", # Relative Mappings "composite_v": "VUT_Rel_speed_long_mps", "relative_dist": "VUT_Dist_MRP_Abs", # travelDist often calculated, not direct CAN signal "travelDist": None, # Placeholder "travelDist_obj": None # Placeholder }, "db_columns": ["ID", "second", "usecond", "timestamp", "canid", "len", "frame"] # Core DB columns } def _init_keyword_mapping(self): """Maps keywords in filenames to table configurations and output CSV names.""" self.keyword_mapping = { "gnss": ("gnss_table", OUTPUT_CSV_OBJSTATE), # GNSS likely represents ego, writing to ObjState first? Revisit logic if needed. "can2": ("can_table", OUTPUT_CSV_OBJSTATE), # Process CAN data into the combined ObjState file } def process_zip(self) -> None: """Extracts and processes DB files from the configured ZIP path.""" print(f"Processing ZIP: {self.config.zip_path}") output_dir = self.config.output_dir # Already created in Config try: with zipfile.ZipFile(self.config.zip_path, "r") as zip_ref: db_files_to_process = [] for file_info in zip_ref.infolist(): # Check if it's a DB file in the CANdata directory if 'CANdata/' in file_info.filename and file_info.filename.endswith('.db'): # Check if the filename contains any of the keywords match = self._match_keyword(file_info.filename) if match: table_type, csv_name = match db_files_to_process.append((file_info, table_type, csv_name)) if not db_files_to_process: print("No relevant DB files found in CANdata/ matching keywords.") return # Process matched DB files with tempfile.TemporaryDirectory() as tmp_dir_str: tmp_dir = Path(tmp_dir_str) for file_info, table_type, csv_name in db_files_to_process: print(f"Processing DB: {file_info.filename} for table type {table_type}") extracted_path = tmp_dir / Path(file_info.filename).name try: # Extract the specific DB file with zip_ref.open(file_info.filename) as source, open(extracted_path, "wb") as target: shutil.copyfileobj(source, target) # Process the extracted DB file self._process_db_file(extracted_path, output_dir, table_type, csv_name) except (sqlite3.Error, pd.errors.EmptyDataError, FileNotFoundError, KeyError) as e: print(f"Error processing DB file {file_info.filename}: {e}") except Exception as e: print(f"Unexpected error processing DB file {file_info.filename}: {e}") finally: if extracted_path.exists(): extracted_path.unlink() # Clean up extracted file except zipfile.BadZipFile: print(f"Error: Bad ZIP file: {self.config.zip_path}") except FileNotFoundError: print(f"Error: ZIP file not found: {self.config.zip_path}") except Exception as e: print(f"An error occurred during ZIP processing: {e}") def _match_keyword(self, filename: str) -> Optional[Tuple[str, str]]: """Finds the first matching keyword configuration for a filename.""" for keyword, (table_type, csv_name) in self.keyword_mapping.items(): if keyword in filename: return table_type, csv_name return None def _process_db_file( self, db_path: Path, output_dir: Path, table_type: str, csv_name: str ) -> None: """Connects to SQLite DB and processes the specified table type.""" output_csv_path = output_dir / csv_name try: # Use URI for read-only connection conn_str = f"file:{db_path}?mode=ro" with sqlite3.connect(conn_str, uri=True) as conn: cursor = conn.cursor() if not self._check_table_exists(cursor, table_type): print(f"Table '{table_type}' does not exist in {db_path.name}. Skipping.") return if self._check_table_empty(cursor, table_type): print(f"Table '{table_type}' in {db_path.name} is empty. Skipping.") return print(f"Exporting data from table '{table_type}' to {output_csv_path}") if table_type == "can_table": self._process_can_table_optimized(cursor, output_csv_path) elif table_type == "gnss_table": # Pass output_path directly, avoid intermediate steps self._process_gnss_table(cursor, output_csv_path) else: print(f"Warning: No specific processor for table type '{table_type}'. Skipping.") except sqlite3.OperationalError as e: print(f"Database operational error for {db_path.name}: {e}. Check file integrity/permissions.") except sqlite3.DatabaseError as e: print(f"Database error connecting to {db_path.name}: {e}") except Exception as e: print(f"Unexpected error processing DB {db_path.name}: {e}") def _check_table_exists(self, cursor, table_name: str) -> bool: """Checks if a table exists in the database.""" try: cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (table_name,)) return cursor.fetchone() is not None except sqlite3.Error as e: print(f"Error checking existence of table {table_name}: {e}") return False # Assume not exists on error def _check_table_empty(self, cursor, table_name: str) -> bool: """Checks if a table is empty.""" try: cursor.execute(f"SELECT COUNT(*) FROM {table_name}") # Use COUNT(*) for efficiency count = cursor.fetchone()[0] return count == 0 except sqlite3.Error as e: # If error occurs (e.g., table doesn't exist after check - race condition?), treat as problematic/empty print(f"Error checking if table {table_name} is empty: {e}") return True def _process_gnss_table(self, cursor, output_path: Path) -> None: """Processes gnss_table data and writes directly to CSV.""" config = self.table_config["gnss_table"] db_columns = config["db_columns"] output_columns = config["output_columns"] mapping = config["mapping"] try: cursor.execute(f"SELECT {', '.join(db_columns)} FROM gnss_table") rows = cursor.fetchall() if not rows: print("No data found in gnss_table.") return processed_data = [] for row in rows: row_dict = dict(zip(db_columns, row)) record = {} # Calculate simTime record["simTime"] = round(row_dict.get("second", 0) + row_dict.get("usecond", 0) / 1e6, 2) # Map other columns for out_col in output_columns: if out_col == "simTime": continue # Already handled if out_col == "playerId": record[out_col] = PLAYER_ID_EGO # Assuming GNSS is ego continue if out_col == "type": record[out_col] = DEFAULT_TYPE continue source_info = mapping.get(out_col) if source_info is None: record[out_col] = 0.0 # Or np.nan if preferred elif isinstance(source_info, tuple): # This case was only for simTime, handled above record[out_col] = 0.0 else: # Direct mapping from db_columns raw_value = row_dict.get(source_info) if raw_value is not None: # Handle projection for position columns if out_col == "posX": # Assuming source_info = "latitude_dd" lat = row_dict.get(mapping["posX"]) lon = row_dict.get(mapping["posY"]) if lat is not None and lon is not None: proj_x, _ = self.projection(lon, lat) record[out_col] = round(proj_x, 6) else: record[out_col] = 0.0 elif out_col == "posY": # Assuming source_info = "longitude_dd" lat = row_dict.get(mapping["posX"]) lon = row_dict.get(mapping["posY"]) if lat is not None and lon is not None: _, proj_y = self.projection(lon, lat) record[out_col] = round(proj_y, 6) else: record[out_col] = 0.0 elif out_col in ["composite_v", "relative_dist"]: # Handle these based on source if available, else default record[out_col] = round(float(raw_value), 3) if source_info else 0.0 else: # General case: round numeric values try: record[out_col] = round(float(raw_value), 3) except (ValueError, TypeError): record[out_col] = raw_value # Keep as is if not numeric else: record[out_col] = 0.0 # Default for missing source data processed_data.append(record) if processed_data: df_final = pd.DataFrame(processed_data)[output_columns].iloc[::4].reset_index(drop=True) # Ensure column order df_final['simFrame'] = np.arange(1, len(df_final) + 1) df_final.to_csv(output_path, index=False, encoding="utf-8") print(f"Successfully wrote GNSS data to {output_path}") else: print("No processable records found in gnss_table.") except sqlite3.Error as e: print(f"SQL error during GNSS processing: {e}") except Exception as e: print(f"Unexpected error during GNSS processing: {e}") def _process_can_table_optimized(self, cursor, output_path: Path) -> None: """Processes CAN data directly into the final merged DataFrame format.""" config = self.table_config["can_table"] db_columns = config["db_columns"] mapping = config["mapping"] try: cursor.execute(f"SELECT {', '.join(db_columns)} FROM can_table") rows = cursor.fetchall() if not rows: print("No data found in can_table.") return all_records = [] for row in rows: row_dict = dict(zip(db_columns, row)) # Decode CAN frame if DBC is available decoded_signals = self._decode_can_frame(row_dict) # Create a unified record combining DB fields and decoded signals record = self._create_unified_can_record(row_dict, decoded_signals, mapping) if record: # Only add if parsing was successful all_records.append(record) if not all_records: print("No CAN records could be successfully processed.") return # Convert raw records to DataFrame for easier manipulation df_raw = pd.DataFrame(all_records) # Separate EGO and OBJ data based on available columns df_ego = self._extract_vehicle_data(df_raw, PLAYER_ID_EGO) df_obj = self._extract_vehicle_data(df_raw, PLAYER_ID_OBJ) # Project coordinates df_ego = self._project_coordinates(df_ego, 'posX', 'posY') df_obj = self._project_coordinates(df_obj, 'posX', 'posY') # Use same column names after extraction # Add calculated/default columns df_ego['type'] = DEFAULT_TYPE df_obj['type'] = DEFAULT_TYPE # Note: travelDist is often calculated later or not available directly # Ensure both have the same columns before merging final_columns = self.EGO_COLS_NEW # Target columns df_ego = df_ego.reindex(columns=final_columns).iloc[::4] df_obj = df_obj.reindex(columns=final_columns).iloc[::4] # Reindex simFrame of ego and obj df_ego['simFrame'] = np.arange(1, len(df_ego)+1) df_obj['simFrame'] = np.arange(1, len(df_obj)+1) # Merge EGO and OBJ dataframes df_merged = pd.concat([df_ego, df_obj], ignore_index=True) # Sort and clean up df_merged.sort_values(by=["simTime", "simFrame", "playerId"], inplace=True) df_merged.reset_index(drop=True, inplace=True) # Fill potential NaNs introduced by reindexing or missing data # Choose appropriate fill strategy (e.g., 0, forward fill, or leave as NaN) # df_merged.fillna(0.0, inplace=True) # Example: fill with 0.0 # Save the final merged DataFrame df_merged.to_csv(output_path, index=False, encoding="utf-8") print(f"Successfully processed CAN data and wrote merged output to {output_path}") except sqlite3.Error as e: print(f"SQL error during CAN processing: {e}") except KeyError as e: print(f"Key error during CAN processing - mapping issue? Missing key: {e}") except Exception as e: print(f"Unexpected error during CAN processing: {e}") import traceback traceback.print_exc() # Print detailed traceback for debugging def _decode_can_frame(self, row_dict: Dict) -> Dict[str, Any]: """Decodes CAN frame using DBC file if available.""" decoded_signals = {} if self.dbc and 'canid' in row_dict and 'frame' in row_dict and 'len' in row_dict: can_id = row_dict['canid'] frame_bytes = bytes(row_dict['frame'][:row_dict['len']]) # Ensure correct length try: message_def = self.dbc.get_message_by_frame_id(can_id) decoded_signals = message_def.decode(frame_bytes, decode_choices=False, allow_truncated=True) # Allow truncated except KeyError: # Optional: print(f"Warning: CAN ID 0x{can_id:X} not found in DBC.") pass # Ignore unknown IDs silently except ValueError as e: print( f"Warning: Decoding ValueError for CAN ID 0x{can_id:X} (length {row_dict['len']}, data: {frame_bytes.hex()}): {e}") except Exception as e: print(f"Warning: Error decoding CAN ID 0x{can_id:X}: {e}") return decoded_signals def _create_unified_can_record(self, row_dict: Dict, decoded_signals: Dict, mapping: Dict) -> Optional[ Dict[str, Any]]: """Creates a single record combining DB fields and decoded signals based on mapping.""" record = {} try: # Handle time and frame ID first record["simTime"] = round(row_dict.get("second", 0) + row_dict.get("usecond", 0) / 1e6, 2) record["simFrame"] = row_dict.get("ID") record["canid"] = f"0x{row_dict.get('canid'):X}" # Store CAN ID if needed # Populate record using the mapping config for target_col, source_info in mapping.items(): if target_col in ["simTime", "simFrame", "canid"]: continue # Already handled if isinstance(source_info, tuple): continue # Should only be time # source_info is now the signal name (or None) signal_name = source_info if signal_name and signal_name in decoded_signals: # Value from decoded CAN signal raw_value = decoded_signals[signal_name] try: # Apply scaling/offset if needed (cantools handles this) # Round appropriately, especially for floats if isinstance(raw_value, (int, float)): # Be cautious with lat/lon precision before projection if "Latitude" in target_col or "Longitude" in target_col: record[target_col] = float(raw_value) # Keep precision for projection else: record[target_col] = round(float(raw_value), 6) else: record[target_col] = raw_value # Keep non-numeric as is (e.g., enums) except (ValueError, TypeError): record[target_col] = raw_value # Assign raw value if conversion fails # If signal not found or source_info is None, leave it empty for now # Will be filled later or during DataFrame processing return record except Exception as e: print(f"Error creating unified record for row {row_dict.get('ID')}: {e}") return None def _extract_vehicle_data(self, df_raw: pd.DataFrame, player_id: int) -> pd.DataFrame: """Extracts and renames columns for a specific vehicle (EGO or OBJ).""" df_vehicle = pd.DataFrame() # df_vehicle["simTime"] = df_raw["simTime"].drop_duplicates().sort_values().reset_index(drop=True) # df_vehicle["simFrame"] = np.arange(1, len(df_vehicle) + 1) # df_vehicle["playerId"] = int(player_id) df_vehicle_temps_ego = pd.DataFrame() df_vehicle_temps_obj = pd.DataFrame() if player_id == PLAYER_ID_EGO: # Select EGO columns (not ending in _obj) + relative columns ego_cols = {target: source for target, source in self.table_config['can_table']['mapping'].items() if source and not isinstance(source, tuple) and not target.endswith('_obj')} rename_map = {} select_cols_raw = [] for target_col, source_info in ego_cols.items(): if source_info: # Mapped signal/field name in df_raw select_cols_raw.append(target_col) # Column names in df_raw are already target names rename_map[target_col] = target_col # No rename needed here # Include relative speed and distance for ego frame relative_cols = ["composite_v", "relative_dist"] select_cols_raw.extend(relative_cols) for col in relative_cols: rename_map[col] = col # Select and rename df_vehicle_temp = df_raw[list(set(select_cols_raw) & set(df_raw.columns))] # Select available columns for col in df_vehicle_temp.columns: df_vehicle_temps_ego[col] = df_vehicle_temp[col].dropna().reset_index(drop=True) df_vehicle = pd.concat([df_vehicle, df_vehicle_temps_ego], axis=1) elif player_id == PLAYER_ID_OBJ: # Select OBJ columns (ending in _obj) obj_cols = {target: source for target, source in self.table_config['can_table']['mapping'].items() if source and not isinstance(source, tuple) and target.endswith('_obj')} rename_map = {} select_cols_raw = [] for target_col, source_info in obj_cols.items(): if source_info: select_cols_raw.append(target_col) # Original _obj column name # Map from VUT_XXX_obj -> VUT_XXX rename_map[target_col] = self.OBJ_COLS_MAPPING.get(target_col, target_col) # Rename to standard name # Select and rename df_vehicle_temp = df_raw[list(set(select_cols_raw) & set(df_raw.columns))] # Select available columns df_vehicle_temp.rename(columns=rename_map, inplace=True) for col in df_vehicle_temp.columns: df_vehicle_temps_obj[col] = df_vehicle_temp[col].dropna().reset_index(drop=True) df_vehicle = pd.concat([df_vehicle, df_vehicle_temps_obj], axis=1) # Copy relative speed/distance from ego calculation (assuming it's relative *to* ego) if "composite_v" in df_raw.columns: df_vehicle["composite_v"] = df_raw["composite_v"].dropna().reset_index(drop=True) if "relative_dist" in df_raw.columns: df_vehicle["relative_dist"] = df_raw["relative_dist"].dropna().reset_index(drop=True) # Drop rows where essential position data might be missing after selection/renaming # Adjust required columns as necessary # required_pos = ['posX', 'posY', 'posH'] # df_vehicle.dropna(subset=[col for col in required_pos if col in df_vehicle.columns], inplace=True) try: df_vehicle["simTime"] = np.round(np.linspace(df_raw["simTime"].tolist()[0]+28800, df_raw["simTime"].tolist()[0]+28800 + 0.01*(len(df_vehicle)), len(df_vehicle)), 2) df_vehicle["simFrame"] = np.arange(1, len(df_vehicle) + 1) df_vehicle["playerId"] = int(player_id) df_vehicle['playerId'] = pd.to_numeric(df_vehicle['playerId']).astype(int) except ValueError as ve: print(f"{ve}") except TypeError as te: print(f"{te}") except Exception as Ee: print(f"{Ee}") return df_vehicle def _project_coordinates(self, df: pd.DataFrame, lat_col: str, lon_col: str) -> pd.DataFrame: """Applies UTM projection to latitude and longitude columns.""" if lat_col in df.columns and lon_col in df.columns: # Ensure data is numeric and handle potential errors/missing values lat = pd.to_numeric(df[lat_col], errors='coerce') lon = pd.to_numeric(df[lon_col], errors='coerce') valid_coords = lat.notna() & lon.notna() if valid_coords.any(): x, y = self.projection(lon[valid_coords].values, lat[valid_coords].values) # Update DataFrame, assign NaN where original coords were invalid df.loc[valid_coords, lat_col] = np.round(x, 6) # Overwrite latitude col with X df.loc[valid_coords, lon_col] = np.round(y, 6) # Overwrite longitude col with Y df.loc[~valid_coords, [lat_col, lon_col]] = np.nan # Set invalid coords to NaN else: # No valid coordinates found, set columns to NaN or handle as needed df[lat_col] = np.nan df[lon_col] = np.nan # Rename columns AFTER projection for clarity df.rename(columns={lat_col: 'posX', lon_col: 'posY'}, inplace=True) else: # Ensure columns exist even if projection didn't happen if 'posX' not in df.columns: df['posX'] = np.nan if 'posY' not in df.columns: df['posY'] = np.nan print(f"Warning: Latitude ('{lat_col}') or Longitude ('{lon_col}') columns not found for projection.") return df # --- Polynomial Fitting (Largely unchanged, minor cleanup) --- class PolynomialCurvatureFitting: """Calculates curvature and its derivative using polynomial fitting.""" def __init__(self, lane_map_path: Path, degree: int = 3): self.lane_map_path = lane_map_path self.degree = degree self.data = self._load_data() if self.data is not None: self.points = self.data[["centerLine_x", "centerLine_y"]].values self.x_data, self.y_data = self.points[:, 0], self.points[:, 1] else: self.points = np.empty((0, 2)) self.x_data, self.y_data = np.array([]), np.array([]) def _load_data(self) -> Optional[pd.DataFrame]: """Loads lane map data safely.""" if not self.lane_map_path.exists() or self.lane_map_path.stat().st_size == 0: print(f"Warning: LaneMap file not found or empty: {self.lane_map_path}") return None try: return pd.read_csv(self.lane_map_path) except pd.errors.EmptyDataError: print(f"Warning: LaneMap file is empty: {self.lane_map_path}") return None except Exception as e: print(f"Error reading LaneMap file {self.lane_map_path}: {e}") return None def curvature(self, coefficients: np.ndarray, x: float) -> float: """Computes curvature of the polynomial at x.""" if len(coefficients) < 3: # Need at least degree 2 for curvature return 0.0 first_deriv_coeffs = np.polyder(coefficients) second_deriv_coeffs = np.polyder(first_deriv_coeffs) dy_dx = np.polyval(first_deriv_coeffs, x) d2y_dx2 = np.polyval(second_deriv_coeffs, x) denominator = (1 + dy_dx ** 2) ** 1.5 return np.abs(d2y_dx2) / denominator if denominator != 0 else np.inf def curvature_derivative(self, coefficients: np.ndarray, x: float) -> float: """Computes the derivative of curvature with respect to x.""" if len(coefficients) < 4: # Need at least degree 3 for derivative of curvature return 0.0 first_deriv_coeffs = np.polyder(coefficients) second_deriv_coeffs = np.polyder(first_deriv_coeffs) third_deriv_coeffs = np.polyder(second_deriv_coeffs) dy_dx = np.polyval(first_deriv_coeffs, x) d2y_dx2 = np.polyval(second_deriv_coeffs, x) d3y_dx3 = np.polyval(third_deriv_coeffs, x) denominator = (1 + dy_dx ** 2) ** 2.5 # Note the power is 2.5 or 5/2 if denominator == 0: return np.inf numerator = d3y_dx3 * (1 + dy_dx ** 2) - 3 * dy_dx * d2y_dx2 * d2y_dx2 # Corrected term order? Verify formula # Standard formula: (d3y_dx3*(1 + dy_dx**2) - 3*dy_dx*(d2y_dx2**2)) / ((1 + dy_dx**2)**(5/2)) * sign(d2y_dx2) # Let's stick to the provided calculation logic but ensure denominator is correct # The provided formula in the original code seems to be for dk/ds (arc length), not dk/dx. # Re-implementing dk/dx based on standard calculus: term1 = d3y_dx3 * (1 + dy_dx ** 2) ** (3 / 2) term2 = d2y_dx2 * (3 / 2) * (1 + dy_dx ** 2) ** (1 / 2) * (2 * dy_dx * d2y_dx2) # Chain rule numerator_dk_dx = term1 - term2 denominator_dk_dx = (1 + dy_dx ** 2) ** 3 if denominator_dk_dx == 0: return np.inf # Take absolute value or not? Original didn't. Let's omit abs() for derivative. return numerator_dk_dx / denominator_dk_dx # dk_dx = (d3y_dx3 * (1 + dy_dx ** 2) - 3 * dy_dx * d2y_dx2 ** 2) / ( # (1 + dy_dx ** 2) ** (5/2) # Original had power 3 ?? Double check this formula source # ) * np.sign(d2y_dx2) # Need sign of curvature # return dk_dx def polynomial_fit( self, x_window: np.ndarray, y_window: np.ndarray ) -> Tuple[Optional[np.ndarray], Optional[np.poly1d]]: """Performs polynomial fitting, handling potential rank warnings.""" if len(x_window) <= self.degree: print(f"Warning: Window size {len(x_window)} is <= degree {self.degree}. Cannot fit.") return None, None try: # Use warnings context manager if needed, but RankWarning often indicates insufficient data variability # with warnings.catch_warnings(): # warnings.filterwarnings('error', category=np.RankWarning) # Or ignore coefficients = np.polyfit(x_window, y_window, self.degree) return coefficients, np.poly1d(coefficients) except np.RankWarning: print(f"Warning: Rank deficient fitting for window. Check data variability.") # Attempt lower degree fit? Or return None? For now, return None. # try: # coefficients = np.polyfit(x_window, y_window, len(x_window) - 1) # return coefficients, np.poly1d(coefficients) # except: return None, None except Exception as e: print(f"Error during polynomial fit: {e}") return None, None def find_best_window(self, point: Tuple[float, float], window_size: int) -> Optional[int]: """Finds the start index of the window whose center is closest to the point.""" if len(self.x_data) < window_size: print("Warning: Not enough data points for the specified window size.") return None x_point, y_point = point min_dist_sq = np.inf best_start_index = -1 # Calculate window centers more efficiently # Use rolling mean if window_size is large, otherwise simple loop is fine num_windows = len(self.x_data) - window_size + 1 if num_windows <= 0: return None for start in range(num_windows): x_center = np.mean(self.x_data[start: start + window_size]) y_center = np.mean(self.y_data[start: start + window_size]) dist_sq = (x_point - x_center) ** 2 + (y_point - y_center) ** 2 if dist_sq < min_dist_sq: min_dist_sq = dist_sq best_start_index = start return best_start_index if best_start_index != -1 else None def find_projection( self, x_target: float, y_target: float, polynomial: np.poly1d, x_range: Tuple[float, float], search_points: int = 100, # Number of points instead of step size ) -> Optional[Tuple[float, float, float]]: """Finds the approximate closest point on the polynomial within the x_range.""" if x_range[1] <= x_range[0]: return None # Invalid range x_values = np.linspace(x_range[0], x_range[1], search_points) y_values = polynomial(x_values) distances_sq = (x_target - x_values) ** 2 + (y_target - y_values) ** 2 if len(distances_sq) == 0: return None min_idx = np.argmin(distances_sq) min_distance = np.sqrt(distances_sq[min_idx]) return x_values[min_idx], y_values[min_idx], min_distance def fit_and_project( self, points: np.ndarray, window_size: int ) -> List[Dict[str, Any]]: """Fits polynomial and calculates curvature for each point in the input array.""" if self.data is None or len(self.x_data) < window_size: print("Insufficient LaneMap data for fitting.") # Return default values for all points return [ { "projection": (np.nan, np.nan), "curvHor": np.nan, "curvHorDot": np.nan, "laneOffset": np.nan, } ] * len(points) results = [] if points.ndim != 2 or points.shape[1] != 2: raise ValueError("Input points must be a 2D numpy array with shape (n, 2).") for x_target, y_target in points: result = { # Default result "projection": (np.nan, np.nan), "curvHor": np.nan, "curvHorDot": np.nan, "laneOffset": np.nan, } best_start = self.find_best_window((x_target, y_target), window_size) if best_start is None: results.append(result) continue x_window = self.x_data[best_start: best_start + window_size] y_window = self.y_data[best_start: best_start + window_size] coefficients, polynomial = self.polynomial_fit(x_window, y_window) if coefficients is None or polynomial is None: results.append(result) continue x_min, x_max = np.min(x_window), np.max(x_window) projection_result = self.find_projection( x_target, y_target, polynomial, (x_min, x_max) ) if projection_result is None: results.append(result) continue proj_x, proj_y, min_distance = projection_result curv_hor = self.curvature(coefficients, proj_x) curv_hor_dot = self.curvature_derivative(coefficients, proj_x) result = { "projection": (round(proj_x, 6), round(proj_y, 6)), "curvHor": round(curv_hor, 6), "curvHorDot": round(curv_hor_dot, 6), "laneOffset": round(min_distance, 6), } results.append(result) return results # --- Data Quality Analyzer (Optimized) --- class DataQualityAnalyzer: """Analyzes data quality metrics, focusing on frame loss.""" def __init__(self, df: Optional[pd.DataFrame] = None): self.df = df if df is not None and not df.empty else pd.DataFrame() # Ensure df is DataFrame def analyze_frame_loss(self) -> Dict[str, Any]: """Analyzes frame loss characteristics.""" metrics = { "total_frames_data": 0, "unique_frames_count": 0, "min_frame": np.nan, "max_frame": np.nan, "expected_frames": 0, "dropped_frames_count": 0, "loss_rate": np.nan, "max_consecutive_loss": 0, "max_loss_start_frame": np.nan, "max_loss_end_frame": np.nan, "loss_intervals_distribution": {}, "valid": False, # Indicate if analysis was possible "message": "" } if self.df.empty or 'simFrame' not in self.df.columns: metrics["message"] = "DataFrame is empty or 'simFrame' column is missing." return metrics # Drop rows with NaN simFrame and ensure integer type frames_series = self.df['simFrame'].dropna().astype(int) metrics["total_frames_data"] = len(frames_series) if frames_series.empty: metrics["message"] = "No valid 'simFrame' data found after dropping NaN." return metrics unique_frames = sorted(frames_series.unique()) metrics["unique_frames_count"] = len(unique_frames) if metrics["unique_frames_count"] < 2: metrics["message"] = "Less than two unique frames; cannot analyze loss." metrics["valid"] = True # Data exists, just not enough to analyze loss if metrics["unique_frames_count"] == 1: metrics["min_frame"] = unique_frames[0] metrics["max_frame"] = unique_frames[0] metrics["expected_frames"] = 1 return metrics metrics["min_frame"] = unique_frames[0] metrics["max_frame"] = unique_frames[-1] metrics["expected_frames"] = metrics["max_frame"] - metrics["min_frame"] + 1 # Calculate differences between consecutive unique frames frame_diffs = np.diff(unique_frames) # Gaps are where diff > 1. The number of lost frames in a gap is diff - 1. gaps = frame_diffs[frame_diffs > 1] lost_frames_in_gaps = gaps - 1 metrics["dropped_frames_count"] = int(lost_frames_in_gaps.sum()) if metrics["expected_frames"] > 0: metrics["loss_rate"] = round(metrics["dropped_frames_count"] / metrics["expected_frames"], 4) else: metrics["loss_rate"] = 0.0 # Avoid division by zero if min_frame == max_frame (already handled) if len(lost_frames_in_gaps) > 0: metrics["max_consecutive_loss"] = int(lost_frames_in_gaps.max()) # Find where the max loss occurred max_loss_indices = np.where(frame_diffs == metrics["max_consecutive_loss"] + 1)[0] # Get the first occurrence start/end frames max_loss_idx = max_loss_indices[0] metrics["max_loss_start_frame"] = unique_frames[max_loss_idx] metrics["max_loss_end_frame"] = unique_frames[max_loss_idx + 1] # Count distribution of loss interval lengths loss_counts = Counter(lost_frames_in_gaps) metrics["loss_intervals_distribution"] = {int(k): int(v) for k, v in loss_counts.items()} else: metrics["max_consecutive_loss"] = 0 metrics["valid"] = True metrics["message"] = "Frame loss analysis complete." return metrics def get_all_csv_files(path: Path) -> List[Path]: """Gets all CSV files in path, excluding specific ones.""" excluded_files = {OUTPUT_CSV_LANEMAP, ROADMARK_CSV} return [ file_path for file_path in path.rglob("*.csv") # Recursive search if file_path.is_file() and file_path.name not in excluded_files ] def run_frame_loss_analysis_on_folder(path: Path) -> Dict[str, Dict[str, Any]]: """Runs frame loss analysis on all relevant CSV files in a folder.""" analysis_results = {} csv_files = get_all_csv_files(path) if not csv_files: print(f"No relevant CSV files found in {path}") return analysis_results for file_path in csv_files: file_name = file_path.name if file_name in {OUTPUT_CSV_FUNCTION, OUTPUT_CSV_OBU}: # Skip specific files if needed print(f"Skipping frame analysis for: {file_name}") continue print(f"Analyzing frame loss for: {file_name}") if file_path.stat().st_size == 0: print(f"File {file_name} is empty. Skipping analysis.") analysis_results[file_name] = {"valid": False, "message": "File is empty."} continue try: # Read only necessary column if possible, handle errors df = pd.read_csv(file_path, usecols=['simFrame'], index_col=False, on_bad_lines='warn') # 'warn' or 'skip' analyzer = DataQualityAnalyzer(df) metrics = analyzer.analyze_frame_loss() analysis_results[file_name] = metrics # Optionally print a summary here if metrics["valid"]: print(f" Loss Rate: {metrics.get('loss_rate', np.nan) * 100:.2f}%, " f"Dropped: {metrics.get('dropped_frames_count', 'N/A')}, " f"Max Gap: {metrics.get('max_consecutive_loss', 'N/A')}") else: print(f" Analysis failed: {metrics.get('message')}") except pd.errors.EmptyDataError: print(f"File {file_name} contains no data after reading.") analysis_results[file_name] = {"valid": False, "message": "Empty data after read."} except ValueError as ve: # Handle case where simFrame might not be present print(f"ValueError processing file {file_name}: {ve}. Is 'simFrame' column present?") analysis_results[file_name] = {"valid": False, "message": f"ValueError: {ve}"} except Exception as e: print(f"Unexpected error processing file {file_name}: {e}") analysis_results[file_name] = {"valid": False, "message": f"Unexpected error: {e}"} return analysis_results def data_precheck(output_dir: Path, max_allowed_loss_rate: float = 0.20) -> bool: """Checks data quality, focusing on frame loss rate.""" print(f"--- Running Data Quality Precheck on: {output_dir} ---") if not output_dir.exists() or not output_dir.is_dir(): print(f"Error: Output directory does not exist: {output_dir}") return False try: frame_loss_results = run_frame_loss_analysis_on_folder(output_dir) except Exception as e: print(f"Critical error during frame loss analysis: {e}") return False # Treat critical error as failure if not frame_loss_results: print("Warning: No files were analyzed for frame loss.") # Decide if this is a failure or just a warning. Let's treat it as OK for now. return True all_checks_passed = True for file_name, metrics in frame_loss_results.items(): if metrics.get("valid", False): loss_rate = metrics.get("loss_rate", np.nan) if pd.isna(loss_rate): print(f" {file_name}: Loss rate could not be calculated.") # Decide if NaN loss rate is acceptable. elif loss_rate > max_allowed_loss_rate: print( f" FAIL: {file_name} - Frame loss rate ({loss_rate * 100:.2f}%) exceeds threshold ({max_allowed_loss_rate * 100:.1f}%).") all_checks_passed = False else: print(f" PASS: {file_name} - Frame loss rate ({loss_rate * 100:.2f}%) is acceptable.") else: print( f" WARN: {file_name} - Frame loss analysis could not be completed ({metrics.get('message', 'Unknown reason')}).") # Decide if inability to analyze is a failure. Let's allow it for now. print(f"--- Data Quality Precheck {'PASSED' if all_checks_passed else 'FAILED'} ---") return all_checks_passed # --- Final Preprocessing Step --- class FinalDataProcessor: """Merges processed CSVs, adds curvature, and handles traffic lights.""" def __init__(self, config: Config): self.config = config self.output_dir = config.output_dir def process(self) -> bool: """执行最终数据合并和处理步骤。""" print("--- Starting Final Data Processing ---") try: # 1. Load main object state data obj_state_path = self.output_dir / OUTPUT_CSV_OBJSTATE lane_map_path = self.output_dir / OUTPUT_CSV_LANEMAP if not obj_state_path.exists(): print(f"Error: Required input file not found: {obj_state_path}") return False # 处理交通灯数据并保存 df_traffic = self._process_trafficlight_data() if not df_traffic.empty: traffic_csv_path = self.output_dir / "Traffic.csv" df_traffic.to_csv(traffic_csv_path, index=False, float_format='%.6f') print(f"Successfully created traffic light data file: {traffic_csv_path}") # Load and process data df_object = pd.read_csv(obj_state_path, dtype={"simTime": float}, low_memory=False) # Process and merge data df_merged = self._merge_optional_data(df_object) # Save final merged file directly to output directory merged_csv_path = self.output_dir / OUTPUT_CSV_MERGED print(f'merged_csv_path:{merged_csv_path}') df_merged.to_csv(merged_csv_path, index=False, float_format='%.6f') print(f"Successfully created final merged file: {merged_csv_path}") # Clean up intermediate files # if obj_state_path.exists(): # obj_state_path.unlink() print("--- Final Data Processing Finished ---") return True except Exception as e: print(f"An unexpected error occurred during final data processing: {e}") import traceback traceback.print_exc() return False def _merge_optional_data(self, df_object: pd.DataFrame) -> pd.DataFrame: """加载和合并可选数据""" df_merged = df_object.copy() # 检查并删除重复列的函数 def clean_duplicate_columns(df): # 查找带有 _x 或 _y 后缀的列 duplicate_cols = [] base_cols = {} # 打印清理前的列名 print(f"清理重复列前的列名: {df.columns.tolist()}") for col in df.columns: if col.endswith('_x') or col.endswith('_y'): base_name = col[:-2] # 去掉后缀 if base_name not in base_cols: base_cols[base_name] = [] base_cols[base_name].append(col) # 对于每组重复列,检查数据是否相同,如果相同则只保留一个 for base_name, cols in base_cols.items(): if len(cols) > 1: # 检查这些列的数据是否相同 is_identical = True first_col = cols[0] for col in cols[1:]: if not df[first_col].equals(df[col]): is_identical = False break if is_identical: # 数据相同,保留第一列并重命名为基本名称 df = df.rename(columns={first_col: base_name}) # 删除其他重复列 for col in cols[1:]: duplicate_cols.append(col) print(f"列 {cols} 数据相同,保留为 {base_name}") else: print(f"列 {cols} 数据不同,保留所有列") # 如果是 simTime 相关列,确保保留一个 if base_name == 'simTime' and 'simTime' not in df.columns: df = df.rename(columns={cols[0]: 'simTime'}) print(f"将 {cols[0]} 重命名为 simTime") # 删除其他 simTime 相关列 for col in cols[1:]: duplicate_cols.append(col) # 删除重复列 if duplicate_cols: # 确保不会删除 simTime 列 if 'simTime' not in df.columns and any(col.startswith('simTime_') for col in duplicate_cols): # 找到一个 simTime 相关列保留 for col in duplicate_cols[:]: if col.startswith('simTime_'): df = df.rename(columns={col: 'simTime'}) duplicate_cols.remove(col) print(f"将 {col} 重命名为 simTime") break df = df.drop(columns=duplicate_cols) print(f"删除了重复列: {duplicate_cols}") # 打印清理后的列名 print(f"清理重复列后的列名: {df.columns.tolist()}") return df # --- 合并 EgoMap --- egomap_path = self.output_dir / OUTPUT_CSV_EGOMAP if egomap_path.exists() and egomap_path.stat().st_size > 0: try: df_ego = pd.read_csv(egomap_path, dtype={"simTime": float}) # 删除 simFrame 列,因为使用主数据的 simFrame if 'simFrame' in df_ego.columns: df_ego = df_ego.drop(columns=['simFrame']) # 打印合并前的列名 print(f"合并 EgoMap 前 df_merged 的列: {df_merged.columns.tolist()}") print(f"df_ego 的列: {df_ego.columns.tolist()}") # 按时间和ID排序 df_ego.sort_values(['simTime', 'playerId'], inplace=True) df_merged.sort_values(['simTime', 'playerId'], inplace=True) # 使用 merge_asof 进行就近合并,不包括 simFrame df_merged = pd.merge_asof( df_merged, df_ego, on='simTime', by='playerId', direction='nearest', tolerance=0.01 # 10ms tolerance ) # 打印合并后的列名 print(f"合并 EgoMap 后 df_merged 的列: {df_merged.columns.tolist()}") # 确保 simTime 列存在 if 'simTime' not in df_merged.columns: if 'simTime_x' in df_merged.columns: df_merged.rename(columns={'simTime_x': 'simTime'}, inplace=True) print("将 simTime_x 重命名为 simTime") else: print("警告: 合并 EgoMap 后找不到 simTime 列!") print("EgoMap data merged.") except Exception as e: print(f"Warning: Could not merge EgoMap data from {egomap_path}: {e}") import traceback traceback.print_exc() # 先处理可能的列名重复问题 df_merged = clean_duplicate_columns(df_merged) # --- 合并 Traffic --- traffic_path = self.output_dir / "Traffic.csv" if traffic_path.exists() and traffic_path.stat().st_size > 0: try: df_traffic = pd.read_csv(traffic_path, dtype={"simTime": float}, low_memory=False).drop_duplicates() # 删除 simFrame 列 if 'simFrame' in df_traffic.columns: df_traffic = df_traffic.drop(columns=['simFrame']) # 根据车辆航向角确定行驶方向并筛选对应的红绿灯 def get_direction_from_heading(heading): # 将角度归一化到 -180 到 180 度范围 heading = heading % 360 if heading > 180: heading -= 360 # 确定方向:北(N)、东(E)、南(S)、西(W) if -45 <= heading <= 45: # 北向 return 'N' elif 45 < heading <= 135: # 东向 return 'E' elif -135 <= heading < -45: # 西向 return 'W' else: # 南向 (135 < heading <= 180 或 -180 <= heading < -135) return 'S' # 检查posH列是否存在,如果不存在但posH_x存在,则使用posH_x heading_col = 'posH' if heading_col not in df_merged.columns: if 'posH_x' in df_merged.columns: heading_col = 'posH_x' print(f"使用 {heading_col} 替代 posH") else: print(f"警告: 找不到航向角列 posH 或 posH_x") return df_merged # 添加方向列 df_merged['vehicle_direction'] = df_merged[heading_col].apply(get_direction_from_heading) # 创建 phaseId 到方向的映射 phase_to_direction = { 1: 'S', # 南直行 2: 'W', # 西直行 3: 'N', # 北直行 4: 'E', # 东直行 5: 'S', # 南行人 6: 'W', # 西行人 7: 'S', # 南左转 8: 'W', # 西左转 9: 'N', # 北左转 10: 'E', # 东左转 11: 'N', # 北行人 12: 'E', # 东行人 13: 'S', # 南右转 14: 'W', # 西右转 15: 'N', # 北右转 16: 'E' # 东右转 } # 创建 trafficlight_id 到方向的映射 trafficlight_to_direction = { # 南向北方向的红绿灯 48100017: 'S', 48100038: 'S', 48100043: 'S', 48100030: 'S', # 西向东方向的红绿灯 48100021: 'W', 48100039: 'W', # 东向西方向的红绿灯 48100041: 'E', 48100019: 'E', # 北向南方向的红绿灯 48100033: 'N', 48100018: 'N', 48100022: 'N' } # 添加时间列用于合并 df_traffic['time'] = df_traffic['simTime'].round(2).astype(float) # 检查 df_merged 中是否有 simTime 列 if 'simTime' not in df_merged.columns: print("警告: 合并 Traffic 前 df_merged 中找不到 simTime 列!") # 尝试查找 simTime_x 或其他可能的列 if 'simTime_x' in df_merged.columns: df_merged.rename(columns={'simTime_x': 'simTime'}, inplace=True) print("将 simTime_x 重命名为 simTime") else: print("严重错误: 无法找到任何 simTime 相关列,无法继续合并!") return df_merged df_merged['time'] = df_merged['simTime'].round(2).astype(float) # 合并 Traffic 数据 df_merged = pd.merge(df_merged, df_traffic, on=["time"], how="left") # 再次处理可能的列名重复问题 df_merged = clean_duplicate_columns(df_merged) # 检查trafficlight_id列是否存在 trafficlight_col = 'trafficlight_id' if trafficlight_col not in df_merged.columns: if 'trafficlight_id_x' in df_merged.columns: trafficlight_col = 'trafficlight_id_x' print(f"使用 {trafficlight_col} 替代 trafficlight_id") else: print(f"警告: 找不到红绿灯ID列 trafficlight_id 或 trafficlight_id_x") # 筛选与车辆行驶方向相关的红绿灯 def filter_relevant_traffic_light(row): if 'phaseId' not in row or pd.isna(row['phaseId']): return np.nan # 获取 phaseId 对应的方向 phase_id = int(row['phaseId']) if not pd.isna(row['phaseId']) else None if phase_id is None: return np.nan phase_direction = phase_to_direction.get(phase_id, None) # 如果 phaseId 方向与车辆方向匹配 if phase_direction == row['vehicle_direction']: # 查找该方向的所有红绿灯 ID relevant_ids = [tid for tid, direction in trafficlight_to_direction.items() if direction == phase_direction] # 如果 trafficlight_id 在 EgoMap 中且方向匹配 if trafficlight_col in row and not pd.isna(row[trafficlight_col]) and row[trafficlight_col] in relevant_ids: return row[trafficlight_col] return np.nan # 应用筛选函数 df_merged['filtered_trafficlight_id'] = df_merged.apply(filter_relevant_traffic_light, axis=1) # 清理临时列 print(f"删除 time 列前 df_merged 的列: {df_merged.columns.tolist()}") df_merged.drop(columns=['time'], inplace=True) print(f"删除 time 列后 df_merged 的列: {df_merged.columns.tolist()}") # 确保 simTime 列存在 if 'simTime' not in df_merged.columns: if 'simTime_x' in df_merged.columns: df_merged.rename(columns={'simTime_x': 'simTime'}, inplace=True) print("将 simTime_x 重命名为 simTime") else: print("警告: 处理 Traffic 数据后找不到 simTime 列!") print("Traffic light data merged and filtered.") except Exception as e: print(f"Warning: Could not merge Traffic data from {traffic_path}: {e}") import traceback traceback.print_exc() else: print("Traffic data not found or empty, skipping merge.") # --- Merge Function --- function_path = self.output_dir / OUTPUT_CSV_FUNCTION if function_path.exists() and function_path.stat().st_size > 0: try: # 添加调试信息 print(f"正在读取 Function 数据: {function_path}") df_function = pd.read_csv(function_path, low_memory=False).drop_duplicates() print(f"Function 数据列名: {df_function.columns.tolist()}") # 删除 simFrame 列 if 'simFrame' in df_function.columns: df_function = df_function.drop(columns=['simFrame']) # 确保 simTime 列存在并且是浮点型 if 'simTime' in df_function.columns: # 安全地将 simTime 转换为浮点型 try: df_function['simTime'] = pd.to_numeric(df_function['simTime'], errors='coerce') df_function = df_function.dropna(subset=['simTime']) # 删除无法转换的行 df_function['time'] = df_function['simTime'].round(2) # 安全地处理 df_merged 的 simTime 列 if 'simTime' in df_merged.columns: print(f"df_merged['simTime'] 的类型: {df_merged['simTime'].dtype}") print(f"df_merged['simTime'] 的前5个值: {df_merged['simTime'].head().tolist()}") df_merged['time'] = pd.to_numeric(df_merged['simTime'], errors='coerce').round(2) # 删除 time 列中的 NaN 值 nan_count = df_merged['time'].isna().sum() if nan_count > 0: print(f"警告: 转换后有 {nan_count} 个 NaN 值,将删除这些行") df_merged = df_merged.dropna(subset=['time']) # 确保两个 DataFrame 的 time 列类型一致 df_function['time'] = df_function['time'].astype(float) df_merged['time'] = df_merged['time'].astype(float) common_cols = list(set(df_merged.columns) & set(df_function.columns) - {'time'}) df_function.drop(columns=common_cols, inplace=True, errors='ignore') # 合并数据 df_merged = pd.merge(df_merged, df_function, on=["time"], how="left") df_merged.drop(columns=['time'], inplace=True) print("Function 数据合并成功。") else: print("警告: df_merged 中找不到 'simTime' 列,无法合并 Function 数据。") # 打印所有列名以便调试 print(f"df_merged 的所有列: {df_merged.columns.tolist()}") except Exception as e: print(f"警告: 处理 Function.csv 中的 simTime 列时出错: {e}") import traceback traceback.print_exc() else: print(f"警告: Function.csv 中找不到 'simTime' 列。可用的列: {df_function.columns.tolist()}") except Exception as e: print(f"警告: 无法合并 Function 数据: {e}") import traceback traceback.print_exc() else: print(f"Function 数据文件不存在或为空: {function_path}") # --- Merge OBU --- obu_path = self.output_dir / OUTPUT_CSV_OBU if obu_path.exists() and obu_path.stat().st_size > 0: try: # 添加调试信息 print(f"正在读取 OBU 数据: {obu_path}") df_obu = pd.read_csv(obu_path, low_memory=False).drop_duplicates() print(f"OBU 数据列名: {df_obu.columns.tolist()}") # 删除 simFrame 列 if 'simFrame' in df_obu.columns: df_obu = df_obu.drop(columns=['simFrame']) # 确保 simTime 列存在并且是浮点型 if 'simTime' in df_obu.columns: # 安全地将 simTime 转换为浮点型 try: df_obu['simTime'] = pd.to_numeric(df_obu['simTime'], errors='coerce') df_obu = df_obu.dropna(subset=['simTime']) # 删除无法转换的行 df_obu['time'] = df_obu['simTime'].round(2) # 安全地处理 df_merged 的 simTime 列 if 'simTime' in df_merged.columns: print(f"合并 OBU 前 df_merged['simTime'] 的类型: {df_merged['simTime'].dtype}") print(f"合并 OBU 前 df_merged['simTime'] 的前5个值: {df_merged['simTime'].head().tolist()}") df_merged['time'] = pd.to_numeric(df_merged['simTime'], errors='coerce').round(2) # 删除 time 列中的 NaN 值 nan_count = df_merged['time'].isna().sum() if nan_count > 0: print(f"警告: 转换后有 {nan_count} 个 NaN 值,将删除这些行") df_merged = df_merged.dropna(subset=['time']) # 确保两个 DataFrame 的 time 列类型一致 df_obu['time'] = df_obu['time'].astype(float) df_merged['time'] = df_merged['time'].astype(float) common_cols = list(set(df_merged.columns) & set(df_obu.columns) - {'time'}) df_obu.drop(columns=common_cols, inplace=True, errors='ignore') # 合并数据 df_merged = pd.merge(df_merged, df_obu, on=["time"], how="left") df_merged.drop(columns=['time'], inplace=True) print("OBU 数据合并成功。") else: print("警告: df_merged 中找不到 'simTime' 列,无法合并 OBU 数据。") # 打印所有列名以便调试 print(f"df_merged 的所有列: {df_merged.columns.tolist()}") except Exception as e: print(f"警告: 处理 OBUdata.csv 中的 simTime 列时出错: {e}") import traceback traceback.print_exc() else: print(f"警告: OBUdata.csv 中找不到 'simTime' 列。可用的列: {df_obu.columns.tolist()}") except Exception as e: print(f"警告: 无法合并 OBU 数据: {e}") import traceback traceback.print_exc() else: print(f"OBU 数据文件不存在或为空: {obu_path}") # 在所有合并完成后,再次清理重复列 df_merged = clean_duplicate_columns(df_merged) return df_merged def _process_trafficlight_data(self) -> pd.DataFrame: """Processes traffic light JSON data if available.""" # Check if json_path is provided and exists if not self.config.json_path: print("No traffic light JSON file provided. Skipping traffic light processing.") return pd.DataFrame() if not self.config.json_path.exists(): print("Traffic light JSON file not found. Skipping traffic light processing.") return pd.DataFrame() print(f"Processing traffic light data from: {self.config.json_path}") valid_trafficlights = [] try: with open(self.config.json_path, 'r', encoding='utf-8') as f: # Read the whole file, assuming it's a JSON array or JSON objects per line try: # Attempt to read as a single JSON array raw_data = json.load(f) if not isinstance(raw_data, list): raw_data = [raw_data] # Handle case of single JSON object except json.JSONDecodeError: # If fails, assume JSON objects per line f.seek(0) # Reset file pointer raw_data = [json.loads(line) for line in f if line.strip()] for entry in raw_data: # Normalize entry if it's a string containing JSON if isinstance(entry, str): try: entry = json.loads(entry) except json.JSONDecodeError: print(f"Warning: Skipping invalid JSON string in traffic light data: {entry[:100]}...") continue # Safely extract data using .get() intersections = entry.get('intersections', []) if not isinstance(intersections, list): continue # Skip if not a list for intersection in intersections: if not isinstance(intersection, dict): continue timestamp_ms = intersection.get('intersectionTimestamp', 0) sim_time = round(int(timestamp_ms) / 1000, 2) # Convert ms to s and round phases = intersection.get('phases', []) if not isinstance(phases, list): continue for phase in phases: if not isinstance(phase, dict): continue phase_id = phase.get('phaseId', 0) phase_states = phase.get('phaseStates', []) if not isinstance(phase_states, list): continue for phase_state in phase_states: if not isinstance(phase_state, dict): continue # Check for startTime == 0 as per original logic if phase_state.get('startTime') == 0: light_state = phase_state.get('light', 0) # Extract light state data = { 'simTime': sim_time, 'phaseId': phase_id, 'stateMask': light_state, # Add playerId for merging - assume applies to ego 'playerId': PLAYER_ID_EGO } valid_trafficlights.append(data) if not valid_trafficlights: print("No valid traffic light states (with startTime=0) found in JSON.") return pd.DataFrame() df_trafficlights = pd.DataFrame(valid_trafficlights) # Drop duplicates based on relevant fields df_trafficlights.drop_duplicates(subset=['simTime', 'playerId', 'phaseId', 'stateMask'], keep='first', inplace=True) print(f"Processed {len(df_trafficlights)} unique traffic light state entries.") # 按时间升序排序 - 修复倒序问题 df_trafficlights = df_trafficlights.sort_values('simTime', ascending=True) # 添加调试信息 print(f"交通灯数据时间范围: {df_trafficlights['simTime'].min()} 到 {df_trafficlights['simTime'].max()}") print(f"交通灯数据前5行时间: {df_trafficlights['simTime'].head().tolist()}") print(f"交通灯数据后5行时间: {df_trafficlights['simTime'].tail().tolist()}") return df_trafficlights except json.JSONDecodeError as e: print(f"Error decoding traffic light JSON file {self.config.json_path}: {e}") return pd.DataFrame() except Exception as e: print(f"Unexpected error processing traffic light data: {e}") return pd.DataFrame() # --- Rosbag Processing --- class RosbagProcessor: """Extracts data from Rosbag files within a ZIP archive.""" # Mapping from filename parts to rostopics ROSTOPIC_MAP = { ('V2I', 'HazardousLocationW'): "/HazardousLocationWarning", ('V2C', 'OtherVehicleRedLightViolationW'): "/c2v/GoThroughRadLight", ('V2I', 'LeftTurnAssist'): "/LeftTurnAssistant", ('V2V', 'LeftTurnAssist'): "/V2VLeftTurnAssistant", ('V2I', 'RedLightViolationW'): "/SignalViolationWarning", ('V2C', 'AbnormalVehicleW'): "/c2v/AbnormalVehicleWarnning", ('V2C', 'SignalLightReminder'): "/c2v/TrafficLightInfo", ('V2C', 'VulnerableRoadUserCollisionW'): "/c2v/VulnerableObject", ('V2C', 'EmergencyVehiclesPriority'): "/c2v/EmergencyVehiclesPriority", ('V2C', 'LitterW'): "/c2v/RoadSpillageWarning", ('V2V', 'ForwardCollision'): "/V2VForwardCollisionWarning", ('V2C', 'VisibilityW'): "/c2v/VisibilityWarinning", ('V2V', 'EmergencyBrakeW'): "/V2VEmergencyBrakeWarning", ('V2I', 'GreenLightOptimalSpeedAdvisory'): "/GreenLightOptimalSpeedAdvisory", # Check exact topic name ('V2C', 'DynamicSpeedLimitingInformation'): "/c2v/DynamicSpeedLimit", ('V2C', 'TrafficJamW'): "/c2v/TrafficJam", ('V2C', 'DrivingLaneRecommendation'): "/c2v/LaneGuidance", ('V2C', 'RampMerge'): "/c2v/RampMerging", ('V2I', 'CooperativeIntersectionPassing'): "/CooperativeIntersectionPassing", ('V2I', 'IntersectionCollisionW'): "/IntersectionCollisionWarning", ('V2V', 'IntersectionCollisionW'): "/V2VIntersectionCollisionWarning", ('V2V', 'BlindSpotW'): "/V2VBlindSpotWarning", ('V2I', 'SpeedLimitW'): "/SpeedLimit", ('V2I', 'VulnerableRoadUserCollisionW'): "/VulnerableRoadUserCollisionWarning", ('V2I', 'CooperativeLaneChange'): "/CooperativeLaneChange", ('V2V', 'CooperativeLaneChange'): "/V2VCooperativeLaneChange", ('V2I', 'CooperativeVehicleMerge'): "/CooperativeVehicleMerge", ('V2V', 'AbnormalVehicleW'): "/V2VAbnormalVehicleWarning", ('V2V', 'ControlLossW'): "/V2VVehicleLossControlWarning", ('V2V', 'EmergencyVehicleW'): '/V2VEmergencyVehicleWarning', ('V2I', 'InVehicleSignage'): "/InVehicleSign", ('V2V', 'DoNotPassW'): "/V2VDoNotPassWarning", ('V2I', 'TrafficJamW'): "/TrafficJamWarning", # Add more mappings as needed } def __init__(self, config: Config): self.config = config self.output_dir = config.output_dir def _get_target_rostopic(self, zip_filename: str) -> Optional[str]: """Determines the target rostopic based on keywords in the filename.""" for (kw1, kw2), topic in self.ROSTOPIC_MAP.items(): if kw1 in zip_filename and kw2 in zip_filename: print(f"Identified target topic '{topic}' for {zip_filename}") return topic print(f"Warning: No specific rostopic mapping found for {zip_filename}.") return None def process_zip_for_rosbags(self) -> None: """Finds, extracts, and processes rosbags from the ZIP file.""" print(f"--- Processing Rosbags in {self.config.zip_path} ---") target_rostopic = self._get_target_rostopic(self.config.zip_path.stem) if not target_rostopic: print("Skipping Rosbag processing as no target topic was identified.") with tempfile.TemporaryDirectory() as tmp_dir_str: tmp_dir = Path(tmp_dir_str) bag_files_extracted = [] try: with zipfile.ZipFile(self.config.zip_path, 'r') as zip_ref: for member in zip_ref.infolist(): # Extract Rosbag files if 'Rosbag/' in member.filename and member.filename.endswith('.bag'): try: extracted_path = Path(zip_ref.extract(member, path=tmp_dir)) bag_files_extracted.append(extracted_path) print(f"Extracted Rosbag: {extracted_path.name}") except Exception as e: print(f"Error extracting Rosbag {member.filename}: {e}") # Extract HMIdata CSV files directly to output elif 'HMIdata/' in member.filename and member.filename.endswith('.csv'): try: target_path = self.output_dir / Path(member.filename).name with zip_ref.open(member) as source, open(target_path, "wb") as target: shutil.copyfileobj(source, target) print(f"Extracted HMI data: {target_path.name}") except Exception as e: print(f"Error extracting HMI data {member.filename}: {e}") except zipfile.BadZipFile: print(f"Error: Bad ZIP file provided: {self.config.zip_path}") return except FileNotFoundError: print(f"Error: ZIP file not found: {self.config.zip_path}") return if not bag_files_extracted: print("No Rosbag files found in the archive.") # Attempt extraction of HMI/RDB anyway if needed (already done above) return # Process extracted bag files for bag_path in bag_files_extracted: print(f"Processing bag file: {bag_path.name}") self._convert_bag_topic_to_csv(bag_path, target_rostopic) print("--- Rosbag Processing Finished ---") def _convert_bag_topic_to_csv(self, bag_file_path: Path, target_topic: str) -> None: """Converts a specific topic from a single bag file to CSV.""" output_csv_path = self.output_dir / OUTPUT_CSV_OBU # Standard name for OBU data try: # Check if bagpy can handle Path object, else convert to str bag_reader = bagreader(str(bag_file_path), verbose=False) # Check if topic exists available_topics = bag_reader.topic_table['Topics'].tolist() if hasattr(bag_reader, 'topic_table') and bag_reader.topic_table is not None else [] if target_topic not in available_topics: print(f"Target topic '{target_topic}' not found in {bag_file_path.name}. Available: {available_topics}") # Clean up temporary bagpy-generated files if possible df = pd.DataFrame(columns=['simTime', 'event_Type']) if hasattr(bag_reader, 'data_folder') and Path(bag_reader.data_folder).exists(): shutil.rmtree(bag_reader.data_folder, ignore_errors=True) else: # Extract message data to a temporary CSV created by bagpy temp_csv_path_str = bag_reader.message_by_topic(target_topic) temp_csv_path = Path(temp_csv_path_str) if not temp_csv_path.exists() or temp_csv_path.stat().st_size == 0: print( f"Warning: Bagpy generated an empty or non-existent CSV for topic '{target_topic}' from {bag_file_path.name}.") return # Skip if empty # Read the temporary CSV, process, and save to final location df = pd.read_csv(temp_csv_path) if df.empty: print(f"Warning: Bagpy CSV for topic '{target_topic}' is empty after reading.") return # Clean columns: Drop 'Time', rename '*timestamp' -> 'simTime' if 'Time' in df.columns: df.drop(columns=['Time'], inplace=True) rename_dict = {} for col in df.columns: if col.endswith('.timestamp'): # More specific match rename_dict[col] = 'simTime' elif col.endswith('event_type'): # As per original code rename_dict[col] = 'event_Type' # Add other renames if necessary df.rename(columns=rename_dict, inplace=True) # Ensure simTime is float and rounded (optional, do if needed for merging) if 'simTime' in df.columns: df['simTime'] = pd.to_numeric(df['simTime'], errors='coerce').round(2) # Example rounding # Save processed data df.to_csv(output_csv_path, index=False, float_format='%.6f') print(f"Saved processed OBU data to: {output_csv_path}") except ValueError as ve: # Catch potential Bagpy internal errors if topic doesn't contain messages print( f"ValueError processing bag {bag_file_path.name} (Topic: {target_topic}): {ve}. Topic might be empty.") except ImportError as ie: print( f"ImportError during bag processing: {ie}. Ensure all ROS dependencies are installed if needed by bagpy.") except Exception as e: print(f"Error processing bag file {bag_file_path.name} (Topic: {target_topic}): {e}") import traceback traceback.print_exc() # More details on unexpected errors finally: # Clean up temporary files/folders created by bagpy if 'temp_csv_path' in locals() and temp_csv_path.exists(): try: temp_csv_path.unlink() # Delete the specific CSV except OSError as ose: print(f"Warning: Could not delete bagpy temp csv {temp_csv_path}: {ose}") if 'bag_reader' in locals() and hasattr(bag_reader, 'data_folder'): bagpy_folder = Path(bag_reader.data_folder) if bagpy_folder.exists() and bagpy_folder.is_dir(): try: shutil.rmtree(bagpy_folder, ignore_errors=True) # Delete the folder bagpy made except OSError as ose: print(f"Warning: Could not delete bagpy temp folder {bagpy_folder}: {ose}") # --- Utility Functions --- def get_base_path() -> Path: """Gets the base path of the script or executable.""" if getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS'): # Running in a PyInstaller bundle return Path(sys._MEIPASS) else: # Running as a normal script return Path(__file__).parent.resolve() def run_cpp_engine(config: Config): """Runs the external C++ preprocessing engine.""" if not config.engine_path or not config.map_path: print("C++ engine path or map path not configured. Skipping C++ engine execution.") return True # Return True assuming it's optional or handled elsewhere engine_cmd = [ str(config.engine_path), str(config.map_path), str(config.output_dir), str(config.x_offset), str(config.y_offset) ] print(f"--- Running C++ Preprocessing Engine ---") print(f"Command: {' '.join(engine_cmd)}") try: result = subprocess.run( engine_cmd, check=True, # Raise exception on non-zero exit code capture_output=True, # Capture stdout/stderr text=True, # Decode output as text cwd=config.engine_path.parent # Run from the engine's directory? Or script's? Adjust if needed. ) print("C++ Engine Output:") print(result.stdout) if result.stderr: print("C++ Engine Error Output:") print(result.stderr) print("--- C++ Engine Finished Successfully ---") return True except FileNotFoundError: print(f"Error: C++ engine executable not found at {config.engine_path}.") return False except subprocess.CalledProcessError as e: print(f"Error: C++ engine failed with exit code {e.returncode}.") print("C++ Engine Output (stdout):") print(e.stdout) print("C++ Engine Output (stderr):") print(e.stderr) return False except Exception as e: print(f"An unexpected error occurred while running the C++ engine: {e}") return False if __name__ == "__main__": pass