lst.py 89 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880
  1. import zipfile
  2. import sqlite3
  3. import csv
  4. import tempfile
  5. from pathlib import Path
  6. from typing import List, Dict, Tuple, Optional, Any, NamedTuple
  7. import cantools
  8. import os
  9. import subprocess
  10. import numpy as np
  11. import pandas as pd
  12. from collections import Counter
  13. # from datetime import datetime
  14. import argparse
  15. import sys
  16. from pyproj import Proj
  17. from scipy.spatial import cKDTree
  18. import shutil
  19. import json
  20. from dataclasses import dataclass, field
  21. # --- Constants ---
  22. PLAYER_ID_EGO = int(1)
  23. PLAYER_ID_OBJ = int(2)
  24. PLAYER_ID_PEDESTRIAN = int(5)
  25. DEFAULT_TYPE = int(1)
  26. PEDESTRIAN_TYPE = int(5)
  27. OUTPUT_CSV_OBJSTATE = "ObjState.csv"
  28. OUTPUT_CSV_TEMP_OBJSTATE = "ObjState_temp_intermediate.csv" # Should be eliminated
  29. OUTPUT_CSV_EGOSTATE = "EgoState.csv" # Not used in final merge? Check logic if needed.
  30. OUTPUT_CSV_MERGED = "merged_ObjState.csv"
  31. OUTPUT_CSV_OBU = "OBUdata.csv"
  32. OUTPUT_CSV_LANEMAP = "LaneMap.csv"
  33. OUTPUT_CSV_EGOMAP = "EgoMap.csv"
  34. MERGED_CSV_EGOMAP = "merged_egomap.csv"
  35. OUTPUT_CSV_FUNCTION = "Function.csv"
  36. ROADMARK_CSV = "RoadMark.csv"
  37. HD_LANE_CSV = "hd_lane.csv"
  38. HD_LINK_CSV = "hd_link.csv"
  39. # --- Configuration Class ---
  40. @dataclass
  41. class Config:
  42. """Holds configuration paths and settings."""
  43. zip_path: Path
  44. output_path: Path
  45. json_path: Optional[Path] # Make json_path optional
  46. dbc_path: Optional[Path] = None
  47. engine_path: Optional[Path] = None
  48. map_path: Optional[Path] = None
  49. utm_zone: int = 51 # Example UTM zone
  50. x_offset: float = 0.0
  51. y_offset: float = 0.0
  52. # Derived paths
  53. output_dir: Path = field(init=False)
  54. def __post_init__(self):
  55. # Use output_path directly as output_dir to avoid nested directories
  56. self.output_dir = self.output_path
  57. self.output_dir.mkdir(parents=True, exist_ok=True)
  58. # --- Zip/CSV Processing ---
  59. class ZipCSVProcessor:
  60. """Processes DB files within a ZIP archive to generate CSV data."""
  61. # Define column mappings more clearly
  62. EGO_COLS_NEW = [
  63. "simTime", "simFrame", "playerId", "v", "speedX", "speedY", "speedZ",
  64. "posH", "pitch", "roll", "roll_rate", "pitch_rate", "speedH", "posX", "posY", "accelX", "accelY", "accelZ",
  65. "travelDist", "composite_v", "relative_dist", "x_relative_dist", "y_relative_dist", "type" # Added type
  66. ]
  67. OBJ_COLS_OLD_SUFFIXED = [
  68. "v_obj", "speedX_obj", "speedY_obj", "speedZ_obj", "posH_obj", "pitch_obj", "roll_obj", "roll_rate_obj",
  69. "pitch_rate_obj", "speedH_obj",
  70. "posX_obj", "posY_obj", "accelX_obj", "accelY_obj", "accelZ_obj", "travelDist_obj"
  71. ]
  72. OBJ_COLS_MAPPING = {old: new for old, new in
  73. zip(OBJ_COLS_OLD_SUFFIXED, EGO_COLS_NEW[3:19])} # Map suffixed cols to standard names
  74. def __init__(self, config: Config):
  75. self.config = config
  76. self.dbc = self._load_dbc(config.dbc_path)
  77. self.projection = Proj(proj='utm', zone=config.utm_zone, ellps='WGS84', preserve_units='m')
  78. self._init_table_config()
  79. self._init_keyword_mapping()
  80. def _load_dbc(self, dbc_path: Optional[Path]) -> Optional[cantools.db.Database]:
  81. dbc_path = Path(dbc_path)
  82. if not dbc_path or not dbc_path.exists():
  83. print("DBC path not provided or file not found.")
  84. return None
  85. try:
  86. return cantools.db.load_file(dbc_path)
  87. except Exception as e:
  88. print(f"DBC loading failed: {e}")
  89. return None
  90. def _init_table_config(self):
  91. """Initializes configurations for different table types."""
  92. self.table_config = {
  93. "gnss_table": self._get_gnss_config(),
  94. "can_table": self._get_can_config()
  95. }
  96. def _get_gnss_config(self):
  97. # Keep relevant columns, adjust mapping as needed
  98. return {
  99. "output_columns": self.EGO_COLS_NEW, # Use the standard ego columns + type
  100. "mapping": { # Map output columns to source DB columns/signals
  101. "simTime": ("second", "usecond"),
  102. "simFrame": "ID",
  103. "v": "speed",
  104. "speedY": "y_speed",
  105. "speedX": "x_speed",
  106. "speedZ": "z_speed",
  107. "posH": "yaw",
  108. "pitch": "tilt",
  109. "roll": "roll",
  110. "speedH": "yaw_rate",
  111. "posX": "latitude_dd", # Source before projection
  112. "posY": "longitude_dd", # Source before projection
  113. "accelX": "x_acceleration",
  114. "accelY": "y_acceleration",
  115. "accelZ": "z_acceleration",
  116. "travelDist": "total_distance",
  117. # composite_v/relative_dist might not be direct fields in GNSS, handle later if needed
  118. "composite_v": "speed", # Placeholder, adjust if needed
  119. "relative_dist": "distance", # Placeholder, likely not in GNSS data
  120. "x_relative_dist": "x_distance",
  121. "y_relative_dist": "y_distance",
  122. "type": None # Will be set later
  123. },
  124. "db_columns": ["ID", "second", "usecond", "speed", "y_speed", "x_speed", "z_speed", "z_acceleration",
  125. "yaw", "tilt", "roll", "yaw_rate", "latitude_dd", "longitude_dd", "total_distance",
  126. "x_acceleration", "y_acceleration", "total_distance", "distance", "x_distance", "y_distance"]
  127. # Actual cols to SELECT
  128. }
  129. def _get_can_config(self):
  130. # Define columns needed from DB/CAN signals for both EGO and OBJ
  131. return {
  132. "mapping": { # Map unified output columns to CAN signals or direct fields
  133. # EGO mappings (VUT = Vehicle Under Test)
  134. "v": "VUT_Speed_mps",
  135. "speedX": "VUT_Speed_long_mps",
  136. "speedY": "VUT_Speed_lat_mps",
  137. "speedZ": "VUT_Speed_z_mps",
  138. "speedH": "VUT_Yaw_Rate",
  139. "posX": "VUT_GPS_Latitude", # Source before projection
  140. "posY": "VUT_GPS_Longitude", # Source before projection
  141. "posH": "VUT_Heading",
  142. "pitch": "VUT_Pitch",
  143. "roll": "VUT_Roll",
  144. "pitch_rate": None,
  145. "roll_rate": None,
  146. "accelX": "VUT_Acc_X",
  147. "accelY": "VUT_Acc_Y2",
  148. "accelZ": "VUT_Acc_Z",
  149. # OBJ mappings (UFO = Unidentified Flying Object / Other Vehicle)
  150. "v_obj": "Speed_mps",
  151. "speedX_obj": "UFO_Speed_long_mps",
  152. "speedY_obj": "UFO_Speed_lat_mps",
  153. "speedZ_obj": "UFO_Speed_z_mps",
  154. "speedH_obj": "Yaw_Rate",
  155. "posX_obj": "GPS_Latitude", # Source before projection
  156. "posY_obj": "GPS_Longitude", # Source before projection
  157. "posH_obj": "Heading",
  158. "pitch_obj": None,
  159. "roll_obj": None,
  160. "pitch_rate_obj": None,
  161. "roll_rate_obj": None,
  162. "accelX_obj": "Acc_X",
  163. "accelY_obj": "Acc_Y",
  164. "accelZ_obj": "Acc_Z",
  165. # Relative Mappings
  166. "composite_v": "VUT_Rel_speed_long_mps",
  167. "relative_dist": "VUT_Dist_MRP_Abs",
  168. "x_relative_dist": "VUT_Dist_MRP_long",
  169. "y_relative_dist": "VUT_Dist_MRP_lat",
  170. # travelDist often calculated, not direct CAN signal
  171. "travelDist": None, # Placeholder
  172. "travelDist_obj": None # Placeholder
  173. },
  174. "db_columns": ["ID", "second", "usecond", "timestamp", "canid", "len", "frame"] # Core DB columns
  175. }
  176. def _init_keyword_mapping(self):
  177. """Maps keywords in filenames to table configurations and output CSV names."""
  178. self.keyword_mapping = {
  179. "gnss": ("gnss_table", OUTPUT_CSV_OBJSTATE),
  180. # GNSS likely represents ego, writing to ObjState first? Revisit logic if needed.
  181. "can4": ("can_table", OUTPUT_CSV_OBJSTATE), # Process CAN data into the combined ObjState file
  182. }
  183. def process_zip(self) -> None:
  184. """Extracts and processes DB files from the configured ZIP path."""
  185. print(f"Processing ZIP: {self.config.zip_path}")
  186. output_dir = self.config.output_dir # Already created in Config
  187. try:
  188. with zipfile.ZipFile(self.config.zip_path, "r") as zip_ref:
  189. db_files_to_process = []
  190. for file_info in zip_ref.infolist():
  191. # Check if it's a DB file in the CANdata directory
  192. if 'CANdata/' in file_info.filename and file_info.filename.endswith('.db'):
  193. # Check if the filename contains any of the keywords
  194. match = self._match_keyword(file_info.filename)
  195. if match:
  196. table_type, csv_name = match
  197. db_files_to_process.append((file_info, table_type, csv_name))
  198. if not db_files_to_process:
  199. print("No relevant DB files found in CANdata/ matching keywords.")
  200. return
  201. # Process matched DB files
  202. with tempfile.TemporaryDirectory() as tmp_dir_str:
  203. tmp_dir = Path(tmp_dir_str)
  204. for file_info, table_type, csv_name in db_files_to_process:
  205. print(f"Processing DB: {file_info.filename} for table type {table_type}")
  206. extracted_path = tmp_dir / Path(file_info.filename).name
  207. try:
  208. # Extract the specific DB file
  209. with zip_ref.open(file_info.filename) as source, open(extracted_path, "wb") as target:
  210. shutil.copyfileobj(source, target)
  211. # Process the extracted DB file
  212. self._process_db_file(file_info.filename, extracted_path, output_dir, table_type, csv_name)
  213. except (sqlite3.Error, pd.errors.EmptyDataError, FileNotFoundError, KeyError) as e:
  214. print(f"Error processing DB file {file_info.filename}: {e}")
  215. except Exception as e:
  216. print(f"Unexpected error processing DB file {file_info.filename}: {e}")
  217. finally:
  218. if extracted_path.exists():
  219. extracted_path.unlink() # Clean up extracted file
  220. except zipfile.BadZipFile:
  221. print(f"Error: Bad ZIP file: {self.config.zip_path}")
  222. except FileNotFoundError:
  223. print(f"Error: ZIP file not found: {self.config.zip_path}")
  224. except Exception as e:
  225. print(f"An error occurred during ZIP processing: {e}")
  226. def _match_keyword(self, filename: str) -> Optional[Tuple[str, str]]:
  227. """Finds the first matching keyword configuration for a filename."""
  228. for keyword, (table_type, csv_name) in self.keyword_mapping.items():
  229. if keyword in filename:
  230. return table_type, csv_name
  231. return None
  232. def _process_db_file(
  233. self, filename: str, db_path: Path, output_dir: Path, table_type: str, csv_name: str
  234. ) -> None:
  235. """Connects to SQLite DB and processes the specified table type."""
  236. output_csv_path = output_dir / csv_name
  237. try:
  238. # Use URI for read-only connection
  239. conn_str = f"file:{db_path}?mode=ro"
  240. with sqlite3.connect(conn_str, uri=True) as conn:
  241. cursor = conn.cursor()
  242. if not self._check_table_exists(cursor, table_type):
  243. print(f"Table '{table_type}' does not exist in {db_path.name}. Skipping.")
  244. return
  245. if self._check_table_empty(cursor, table_type):
  246. print(f"Table '{table_type}' in {db_path.name} is empty. Skipping.")
  247. return
  248. print(f"Exporting data from table '{table_type}' to {output_csv_path}")
  249. if table_type == "can_table":
  250. self._process_can_table_optimized(filename, cursor, output_csv_path)
  251. elif table_type == "gnss_table":
  252. # Pass output_path directly, avoid intermediate steps
  253. self._process_gnss_table(cursor, output_csv_path)
  254. else:
  255. print(f"Warning: No specific processor for table type '{table_type}'. Skipping.")
  256. except sqlite3.OperationalError as e:
  257. print(f"Database operational error for {db_path.name}: {e}. Check file integrity/permissions.")
  258. except sqlite3.DatabaseError as e:
  259. print(f"Database error connecting to {db_path.name}: {e}")
  260. except Exception as e:
  261. print(f"Unexpected error processing DB {db_path.name}: {e}")
  262. def _check_table_exists(self, cursor, table_name: str) -> bool:
  263. """Checks if a table exists in the database."""
  264. try:
  265. cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (table_name,))
  266. return cursor.fetchone() is not None
  267. except sqlite3.Error as e:
  268. print(f"Error checking existence of table {table_name}: {e}")
  269. return False # Assume not exists on error
  270. def _check_table_empty(self, cursor, table_name: str) -> bool:
  271. """Checks if a table is empty."""
  272. try:
  273. cursor.execute(f"SELECT COUNT(*) FROM {table_name}") # Use COUNT(*) for efficiency
  274. count = cursor.fetchone()[0]
  275. return count == 0
  276. except sqlite3.Error as e:
  277. # If error occurs (e.g., table doesn't exist after check - race condition?), treat as problematic/empty
  278. print(f"Error checking if table {table_name} is empty: {e}")
  279. return True
  280. def _process_gnss_table(self, cursor, output_path: Path) -> None:
  281. """Processes gnss_table data and writes directly to CSV."""
  282. config = self.table_config["gnss_table"]
  283. db_columns = config["db_columns"]
  284. output_columns = config["output_columns"]
  285. mapping = config["mapping"]
  286. try:
  287. cursor.execute(f"SELECT {', '.join(db_columns)} FROM gnss_table")
  288. rows = cursor.fetchall()
  289. if not rows:
  290. print("No data found in gnss_table.")
  291. return
  292. processed_data = []
  293. for row in rows:
  294. row_dict = dict(zip(db_columns, row))
  295. record = {}
  296. # Calculate simTime
  297. record["simTime"] = round(row_dict.get("second", 0) + row_dict.get("usecond", 0) / 1e6, 2)
  298. # Map other columns
  299. for out_col in output_columns:
  300. if out_col == "simTime": continue # Already handled
  301. if out_col == "playerId":
  302. record[out_col] = PLAYER_ID_EGO # Assuming GNSS is ego
  303. continue
  304. if out_col == "type":
  305. record[out_col] = DEFAULT_TYPE
  306. continue
  307. source_info = mapping.get(out_col)
  308. if source_info is None:
  309. record[out_col] = 0.0 # Or np.nan if preferred
  310. elif isinstance(source_info, tuple):
  311. # This case was only for simTime, handled above
  312. record[out_col] = 0.0
  313. else: # Direct mapping from db_columns
  314. raw_value = row_dict.get(source_info)
  315. if raw_value is not None:
  316. # Handle projection for position columns
  317. if out_col == "posX":
  318. # Assuming source_info = "latitude_dd"
  319. lat = row_dict.get(mapping["posX"])
  320. lon = row_dict.get(mapping["posY"])
  321. if lat is not None and lon is not None:
  322. proj_x, _ = self.projection(lon, lat)
  323. record[out_col] = round(proj_x, 6)
  324. else:
  325. record[out_col] = 0.0
  326. elif out_col == "posY":
  327. # Assuming source_info = "longitude_dd"
  328. lat = row_dict.get(mapping["posX"])
  329. lon = row_dict.get(mapping["posY"])
  330. if lat is not None and lon is not None:
  331. _, proj_y = self.projection(lon, lat)
  332. record[out_col] = round(proj_y, 6)
  333. else:
  334. record[out_col] = 0.0
  335. elif out_col in ["composite_v", "relative_dist"]:
  336. # Handle these based on source if available, else default
  337. record[out_col] = round(float(raw_value), 3) if source_info else 0.0
  338. else:
  339. # General case: round numeric values
  340. try:
  341. record[out_col] = round(float(raw_value), 3)
  342. except (ValueError, TypeError):
  343. record[out_col] = raw_value # Keep as is if not numeric
  344. else:
  345. record[out_col] = 0.0 # Default for missing source data
  346. processed_data.append(record)
  347. if processed_data:
  348. df_final = pd.DataFrame(processed_data)[output_columns].iloc[::4].reset_index(
  349. drop=True) # Ensure column order
  350. # df_final.to_csv("/home/output/V2I_CSAE53-2020_LeftTurnAssist_LST_01/ObjState_old.csv", index=False, encoding="utf-8")
  351. # print(df_final)
  352. # df_final['speedY'] = -df_final['speedY']
  353. # df_final['accelY'] = -df_final['accelY']
  354. # df_final['speedZ'] = -df_final['speedZ']
  355. # df_final['accelZ'] = -df_final['accelZ']
  356. df_final['simFrame'] = np.arange(1, len(df_final) + 1)
  357. df_final["pitch_rate"] = df_final["pitch"].diff() / df_final["simTime"].diff()
  358. df_final["roll_rate"] = df_final["roll"].diff() / df_final["simTime"].diff()
  359. # print("df_final[\"posH\"] is", df_final["posH"])
  360. df_final["posH"] = (90 - df_final["posH"])
  361. stopcar_flag = self.is_valid_interval(df_final)
  362. # print("stopcar_flag is", stopcar_flag)
  363. if stopcar_flag:
  364. first_gt_1 = df_final['v'].gt(1).idxmax()
  365. last_gt_1 = df_final['v'].gt(0.15)[::-1].idxmax()
  366. result_df = df_final.loc[first_gt_1:last_gt_1].copy()
  367. result_df.to_csv(output_path, index=False, encoding="utf-8")
  368. else:
  369. df_final.to_csv(output_path, index=False, encoding="utf-8")
  370. print(f"Successfully wrote GNSS data to {output_path}")
  371. else:
  372. print("No processable records found in gnss_table.")
  373. except sqlite3.Error as e:
  374. print(f"SQL error during GNSS processing: {e}")
  375. except Exception as e:
  376. print(f"Unexpected error during GNSS processing: {e}")
  377. def _process_can_table_optimized(self, filename, cursor, output_path: Path) -> None:
  378. """Processes CAN data directly into the final merged DataFrame format."""
  379. config = self.table_config["can_table"]
  380. db_columns = config["db_columns"]
  381. mapping = config["mapping"]
  382. try:
  383. cursor.execute(f"SELECT {', '.join(db_columns)} FROM can_table")
  384. rows = cursor.fetchall()
  385. if not rows:
  386. print("No data found in can_table.")
  387. return
  388. all_records = []
  389. for row in rows:
  390. row_dict = dict(zip(db_columns, row))
  391. # Decode CAN frame if DBC is available
  392. decoded_signals = self._decode_can_frame(row_dict)
  393. # Create a unified record combining DB fields and decoded signals
  394. record = self._create_unified_can_record(row_dict, decoded_signals, mapping)
  395. if record: # Only add if parsing was successful
  396. all_records.append(record)
  397. if not all_records:
  398. print("No CAN records could be successfully processed.")
  399. return
  400. # Convert raw records to DataFrame for easier manipulation
  401. df_raw = pd.DataFrame(all_records)
  402. # Separate EGO and OBJ data based on available columns
  403. df_ego = self._extract_vehicle_data(df_raw, PLAYER_ID_EGO)
  404. # if 'pedestrian' in filename.lower():
  405. # df_obj = self._extract_vehicle_data(df_raw, PLAYER_ID_PEDESTRIAN)
  406. # else:
  407. df_obj = self._extract_vehicle_data(df_raw, PLAYER_ID_OBJ)
  408. # Project coordinates
  409. df_ego = self._project_coordinates(df_ego, 'posX', 'posY')
  410. # df_ego = self._project_coordinates(df_ego, 'posX', 'posY', 'speedX', 'speedY', 'speedZ', 'accelX', 'accelY', 'accelZ', 'posH', 'pitch', 'roll')
  411. print("df_ego is", df_ego.columns)
  412. df_obj = self._project_coordinates(df_obj, 'posX', 'posY') # Use same column names after extraction
  413. # df_obj = self._project_coordinates(df_obj, 'posX', 'posY', 'speedX', 'speedY', 'speedZ', 'accelX', 'accelY', 'accelZ', 'posH', 'pitch', 'roll') # Use same column names after extraction
  414. # Add calculated/default columns
  415. df_ego['type'] = DEFAULT_TYPE
  416. if 'pedestrian' in filename.lower():
  417. df_obj['type'] = PEDESTRIAN_TYPE
  418. else:
  419. df_obj['type'] = DEFAULT_TYPE
  420. # Note: travelDist is often calculated later or not available directly
  421. # Ensure both have the same columns before merging
  422. final_columns = self.EGO_COLS_NEW # Target columns
  423. df_ego = df_ego.reindex(columns=final_columns).iloc[::4]
  424. df_obj = df_obj.reindex(columns=final_columns).iloc[::4]
  425. # Reindex simFrame of ego and obj
  426. df_ego['simFrame'] = np.arange(1, len(df_ego) + 1)
  427. df_obj['simFrame'] = np.arange(1, len(df_obj) + 1)
  428. # Merge EGO and OBJ dataframes
  429. df_merged = pd.concat([df_ego, df_obj], ignore_index=True)
  430. # Sort and clean up
  431. df_merged.sort_values(by=["simTime", "simFrame", "playerId"], inplace=True)
  432. df_merged.fillna(0, inplace=True)
  433. df_merged.reset_index(drop=True, inplace=True)
  434. # Fill potential NaNs introduced by reindexing or missing data
  435. # Choose appropriate fill strategy (e.g., 0, forward fill, or leave as NaN)
  436. # df_merged.fillna(0.0, inplace=True) # Example: fill with 0.0
  437. stopcar_flag = self.is_valid_interval(df_merged)
  438. print("stopcar_flag is", stopcar_flag)
  439. if stopcar_flag:
  440. print("筛选非静止车辆数据!")
  441. first_gt_01 = df_merged[df_merged['playerId'] == 1]['v'].gt(1).idxmax()
  442. last_gt_01 = df_merged[df_merged['playerId'] == 1]['v'].gt(0.15)[::-1].idxmax()
  443. result_df = df_merged.loc[first_gt_01:last_gt_01 - 1].copy()
  444. # Save the final merged DataFrame
  445. result_df.to_csv(output_path, index=False, encoding="utf-8")
  446. else:
  447. df_merged.to_csv(output_path, index=False, encoding="utf-8")
  448. print(f"Successfully processed CAN data and wrote merged output to {output_path}")
  449. except sqlite3.Error as e:
  450. print(f"SQL error during CAN processing: {e}")
  451. except KeyError as e:
  452. print(f"Key error during CAN processing - mapping issue? Missing key: {e}")
  453. except Exception as e:
  454. print(f"Unexpected error during CAN processing: {e}")
  455. import traceback
  456. traceback.print_exc() # Print detailed traceback for debugging
  457. def _decode_can_frame(self, row_dict: Dict) -> Dict[str, Any]:
  458. """Decodes CAN frame using DBC file if available."""
  459. decoded_signals = {}
  460. if self.dbc and 'canid' in row_dict and 'frame' in row_dict and 'len' in row_dict:
  461. can_id = row_dict['canid']
  462. frame_bytes = bytes(row_dict['frame'][:row_dict['len']]) # Ensure correct length
  463. try:
  464. message_def = self.dbc.get_message_by_frame_id(can_id)
  465. decoded_signals = message_def.decode(frame_bytes, decode_choices=False,
  466. allow_truncated=True) # Allow truncated
  467. except KeyError:
  468. # Optional: print(f"Warning: CAN ID 0x{can_id:X} not found in DBC.")
  469. pass # Ignore unknown IDs silently
  470. except ValueError as e:
  471. print(
  472. f"Warning: Decoding ValueError for CAN ID 0x{can_id:X} (length {row_dict['len']}, data: {frame_bytes.hex()}): {e}")
  473. except Exception as e:
  474. print(f"Warning: Error decoding CAN ID 0x{can_id:X}: {e}")
  475. return decoded_signals
  476. def _create_unified_can_record(self, row_dict: Dict, decoded_signals: Dict, mapping: Dict) -> Optional[
  477. Dict[str, Any]]:
  478. """Creates a single record combining DB fields and decoded signals based on mapping."""
  479. record = {}
  480. try:
  481. # Handle time and frame ID first
  482. record["simTime"] = round(row_dict.get("second", 0) + row_dict.get("usecond", 0) / 1e6, 2)
  483. record["simFrame"] = row_dict.get("ID")
  484. record["canid"] = f"0x{row_dict.get('canid'):X}" # Store CAN ID if needed
  485. # Populate record using the mapping config
  486. for target_col, source_info in mapping.items():
  487. if target_col in ["simTime", "simFrame", "canid"]: continue # Already handled
  488. if isinstance(source_info, tuple): continue # Should only be time
  489. # source_info is now the signal name (or None)
  490. signal_name = source_info
  491. if signal_name and signal_name in decoded_signals:
  492. # Value from decoded CAN signal
  493. raw_value = decoded_signals[signal_name]
  494. try:
  495. # Apply scaling/offset if needed (cantools handles this)
  496. # Round appropriately, especially for floats
  497. if isinstance(raw_value, (int, float)):
  498. # Be cautious with lat/lon precision before projection
  499. if "Latitude" in target_col or "Longitude" in target_col:
  500. record[target_col] = float(raw_value) # Keep precision for projection
  501. else:
  502. record[target_col] = round(float(raw_value), 6)
  503. else:
  504. record[target_col] = raw_value # Keep non-numeric as is (e.g., enums)
  505. except (ValueError, TypeError):
  506. record[target_col] = raw_value # Assign raw value if conversion fails
  507. # If signal not found or source_info is None, leave it empty for now
  508. # Will be filled later or during DataFrame processing
  509. return record
  510. except Exception as e:
  511. print(f"Error creating unified record for row {row_dict.get('ID')}: {e}")
  512. return None
  513. def _extract_vehicle_data(self, df_raw: pd.DataFrame, player_id: int) -> pd.DataFrame:
  514. """Extracts and renames columns for a specific vehicle (EGO or OBJ)."""
  515. df_vehicle = pd.DataFrame()
  516. # df_vehicle["simTime"] = df_raw["simTime"].drop_duplicates().sort_values().reset_index(drop=True)
  517. # df_vehicle["simFrame"] = np.arange(1, len(df_vehicle) + 1)
  518. # df_vehicle["playerId"] = int(player_id)
  519. df_vehicle_temps_ego = pd.DataFrame()
  520. df_vehicle_temps_obj = pd.DataFrame()
  521. if player_id == PLAYER_ID_EGO:
  522. # Select EGO columns (not ending in _obj) + relative columns
  523. ego_cols = {target: source for target, source in self.table_config['can_table']['mapping'].items()
  524. if source and not isinstance(source, tuple) and not target.endswith('_obj')}
  525. print("ego_cols is", ego_cols)
  526. rename_map = {}
  527. select_cols_raw = []
  528. for target_col, source_info in ego_cols.items():
  529. if source_info: # Mapped signal/field name in df_raw
  530. select_cols_raw.append(target_col) # Column names in df_raw are already target names
  531. rename_map[target_col] = target_col # No rename needed here
  532. # Include relative speed and distance for ego frame
  533. relative_cols = ["composite_v", "relative_dist"]
  534. select_cols_raw.extend(relative_cols)
  535. for col in relative_cols:
  536. rename_map[col] = col
  537. # Select and rename
  538. df_vehicle_temp = df_raw[list(set(select_cols_raw) & set(df_raw.columns))] # Select available columns
  539. for col in df_vehicle_temp.columns:
  540. df_vehicle_temps_ego[col] = df_vehicle_temp[col].dropna().reset_index(drop=True)
  541. df_vehicle = pd.concat([df_vehicle, df_vehicle_temps_ego], axis=1)
  542. elif (player_id == PLAYER_ID_OBJ) or (player_id == PLAYER_ID_PEDESTRIAN):
  543. # Select OBJ columns (ending in _obj)
  544. obj_cols = {target: source for target, source in self.table_config['can_table']['mapping'].items()
  545. if source and not isinstance(source, tuple) and target.endswith('_obj')}
  546. rename_map = {}
  547. select_cols_raw = []
  548. for target_col, source_info in obj_cols.items():
  549. if source_info:
  550. select_cols_raw.append(target_col) # Original _obj column name
  551. # Map from VUT_XXX_obj -> VUT_XXX
  552. rename_map[target_col] = self.OBJ_COLS_MAPPING.get(target_col,
  553. target_col) # Rename to standard name
  554. # Select and rename
  555. df_vehicle_temp = df_raw[list(set(select_cols_raw) & set(df_raw.columns))] # Select available columns
  556. df_vehicle_temp.rename(columns=rename_map, inplace=True)
  557. for col in df_vehicle_temp.columns:
  558. df_vehicle_temps_obj[col] = df_vehicle_temp[col].dropna().reset_index(drop=True)
  559. df_vehicle = pd.concat([df_vehicle, df_vehicle_temps_obj], axis=1)
  560. # Copy relative speed/distance from ego calculation (assuming it's relative *to* ego)
  561. if "composite_v" in df_raw.columns:
  562. df_vehicle["composite_v"] = df_raw["composite_v"].dropna().reset_index(drop=True)
  563. if "relative_dist" in df_raw.columns:
  564. df_vehicle["relative_dist"] = df_raw["relative_dist"].dropna().reset_index(drop=True)
  565. # Drop rows where essential position data might be missing after selection/renaming
  566. # Adjust required columns as necessary
  567. # required_pos = ['posX', 'posY', 'posH']
  568. # df_vehicle.dropna(subset=[col for col in required_pos if col in df_vehicle.columns], inplace=True)
  569. try:
  570. df_vehicle["simTime"] = np.round(
  571. np.linspace(df_raw["simTime"].tolist()[0], df_raw["simTime"].tolist()[0] + 0.01 * (len(df_vehicle)),
  572. len(df_vehicle)), 2)
  573. df_vehicle["simFrame"] = np.arange(1, len(df_vehicle) + 1)
  574. df_vehicle["playerId"] = int(player_id)
  575. df_vehicle['playerId'] = pd.to_numeric(df_vehicle['playerId']).astype(int)
  576. df_vehicle["pitch_rate"] = df_vehicle["pitch"].diff() / df_vehicle["simTime"].diff()
  577. df_vehicle["roll_rate"] = df_vehicle["roll"].diff() / df_vehicle["simTime"].diff()
  578. # print("df_vehicle is", df_vehicle)
  579. except ValueError as ve:
  580. print(f"{ve}")
  581. except TypeError as te:
  582. print(f"{te}")
  583. except Exception as Ee:
  584. print(f"{Ee}")
  585. return df_vehicle
  586. def _project_coordinates(self, df: pd.DataFrame, lat_col: str, lon_col: str) -> pd.DataFrame:
  587. """Applies UTM projection to latitude and longitude columns."""
  588. if lat_col in df.columns and lon_col in df.columns:
  589. # Ensure data is numeric and handle potential errors/missing values
  590. lat = pd.to_numeric(df[lat_col], errors='coerce')
  591. lon = pd.to_numeric(df[lon_col], errors='coerce')
  592. # speedX = pd.to_numeric(df[speedX_col], errors='coerce')
  593. # speedY = pd.to_numeric(df[speedY_col], errors='coerce')
  594. # speedZ = pd.to_numeric(0, errors='coerce')
  595. # accelX = pd.to_numeric(df[accelX_col], errors='coerce')
  596. # accelY = pd.to_numeric(df[accelY_col], errors='coerce')
  597. # accelZ = pd.to_numeric(0, errors='coerce')
  598. # posh = pd.to_numeric(df[posh_col], errors='coerce')
  599. # pitch = pd.to_numeric(df[pitch_col], errors='coerce')
  600. # roll = pd.to_numeric(df[roll_col], errors='coerce')
  601. valid_coords = lat.notna() & lon.notna()
  602. if valid_coords.any():
  603. x, y = self.projection(lon[valid_coords].values, lat[valid_coords].values)
  604. # Update DataFrame, assign NaN where original coords were invalid
  605. df.loc[valid_coords, lat_col] = np.round(x, 6) # Overwrite latitude col with X
  606. df.loc[valid_coords, lon_col] = np.round(y, 6) # Overwrite longitude col with Y
  607. df.loc[~valid_coords, [lat_col, lon_col]] = np.nan # Set invalid coords to NaN
  608. else:
  609. # No valid coordinates found, set columns to NaN or handle as needed
  610. df[lat_col] = np.nan
  611. df[lon_col] = np.nan
  612. # Rename columns AFTER projection for clarity
  613. df.rename(columns={lat_col: 'posX', lon_col: 'posY'}, inplace=True)
  614. else:
  615. # Ensure columns exist even if projection didn't happen
  616. if 'posX' not in df.columns: df['posX'] = np.nan
  617. if 'posY' not in df.columns: df['posY'] = np.nan
  618. print(f"Warning: Latitude ('{lat_col}') or Longitude ('{lon_col}') columns not found for projection.")
  619. return df
  620. def is_valid_interval(self, df, threshold=1, window_sec=0.1):
  621. '''
  622. 检查区间[start_idx, end_idx]是否满足前后window_sec秒的速度均 < threshold
  623. '''
  624. # 获取时间范围
  625. print("获取时间范围...")
  626. start_time = df['simTime'].tolist()[0]
  627. end_time = df['simTime'].tolist()[-1]
  628. # 前5秒数据
  629. # print("前5秒数据...")
  630. mask_before = (df['simTime'] >= start_time) & \
  631. (df['simTime'] < start_time + window_sec)
  632. # 后5秒数据
  633. mask_after = (df['simTime'] < end_time) & \
  634. (df['simTime'] >= end_time - window_sec)
  635. # 判断前后是否均 < threshold
  636. return (df.loc[mask_before, 'v'].max() < threshold) and \
  637. (df.loc[mask_after, 'v'].max() < threshold)
  638. # --- Polynomial Fitting (Largely unchanged, minor cleanup) ---
  639. class PolynomialCurvatureFitting:
  640. """Calculates curvature and its derivative using polynomial fitting."""
  641. def __init__(self, lane_map_path: Path, degree: int = 3):
  642. self.lane_map_path = Path(lane_map_path)
  643. self.degree = degree
  644. self.data = self._load_data()
  645. if self.data is not None:
  646. self.points = self.data[["centerLine_x", "centerLine_y"]].values
  647. self.x_data, self.y_data = self.points[:, 0], self.points[:, 1]
  648. else:
  649. self.points = np.empty((0, 2))
  650. self.x_data, self.y_data = np.array([]), np.array([])
  651. def _load_data(self) -> Optional[pd.DataFrame]:
  652. """Loads lane map data safely."""
  653. if not self.lane_map_path.exists() or self.lane_map_path.stat().st_size == 0:
  654. print(f"Warning: LaneMap file not found or empty: {self.lane_map_path}")
  655. return None
  656. try:
  657. return pd.read_csv(self.lane_map_path)
  658. except pd.errors.EmptyDataError:
  659. print(f"Warning: LaneMap file is empty: {self.lane_map_path}")
  660. return None
  661. except Exception as e:
  662. print(f"Error reading LaneMap file {self.lane_map_path}: {e}")
  663. return None
  664. def curvature(self, coefficients: np.ndarray, x: float) -> float:
  665. """Computes curvature of the polynomial at x."""
  666. if len(coefficients) < 3: # Need at least degree 2 for curvature
  667. return 0.0
  668. first_deriv_coeffs = np.polyder(coefficients)
  669. second_deriv_coeffs = np.polyder(first_deriv_coeffs)
  670. dy_dx = np.polyval(first_deriv_coeffs, x)
  671. d2y_dx2 = np.polyval(second_deriv_coeffs, x)
  672. denominator = (1 + dy_dx ** 2) ** 1.5
  673. return np.abs(d2y_dx2) / denominator if denominator != 0 else np.inf
  674. def curvature_derivative(self, coefficients: np.ndarray, x: float) -> float:
  675. """Computes the derivative of curvature with respect to x."""
  676. if len(coefficients) < 4: # Need at least degree 3 for derivative of curvature
  677. return 0.0
  678. first_deriv_coeffs = np.polyder(coefficients)
  679. second_deriv_coeffs = np.polyder(first_deriv_coeffs)
  680. third_deriv_coeffs = np.polyder(second_deriv_coeffs)
  681. dy_dx = np.polyval(first_deriv_coeffs, x)
  682. d2y_dx2 = np.polyval(second_deriv_coeffs, x)
  683. d3y_dx3 = np.polyval(third_deriv_coeffs, x)
  684. denominator = (1 + dy_dx ** 2) ** 2.5 # Note the power is 2.5 or 5/2
  685. if denominator == 0:
  686. return np.inf
  687. numerator = d3y_dx3 * (1 + dy_dx ** 2) - 3 * dy_dx * d2y_dx2 * d2y_dx2 # Corrected term order? Verify formula
  688. # Standard formula: (d3y_dx3*(1 + dy_dx**2) - 3*dy_dx*(d2y_dx2**2)) / ((1 + dy_dx**2)**(5/2)) * sign(d2y_dx2)
  689. # Let's stick to the provided calculation logic but ensure denominator is correct
  690. # The provided formula in the original code seems to be for dk/ds (arc length), not dk/dx.
  691. # Re-implementing dk/dx based on standard calculus:
  692. term1 = d3y_dx3 * (1 + dy_dx ** 2) ** (3 / 2)
  693. term2 = d2y_dx2 * (3 / 2) * (1 + dy_dx ** 2) ** (1 / 2) * (2 * dy_dx * d2y_dx2) # Chain rule
  694. numerator_dk_dx = term1 - term2
  695. denominator_dk_dx = (1 + dy_dx ** 2) ** 3
  696. if denominator_dk_dx == 0:
  697. return np.inf
  698. # Take absolute value or not? Original didn't. Let's omit abs() for derivative.
  699. return numerator_dk_dx / denominator_dk_dx
  700. # dk_dx = (d3y_dx3 * (1 + dy_dx ** 2) - 3 * dy_dx * d2y_dx2 ** 2) / (
  701. # (1 + dy_dx ** 2) ** (5/2) # Original had power 3 ?? Double check this formula source
  702. # ) * np.sign(d2y_dx2) # Need sign of curvature
  703. # return dk_dx
  704. def polynomial_fit(
  705. self, x_window: np.ndarray, y_window: np.ndarray
  706. ) -> Tuple[Optional[np.ndarray], Optional[np.poly1d]]:
  707. """Performs polynomial fitting, handling potential rank warnings."""
  708. if len(x_window) <= self.degree:
  709. print(f"Warning: Window size {len(x_window)} is <= degree {self.degree}. Cannot fit.")
  710. return None, None
  711. try:
  712. # Use warnings context manager if needed, but RankWarning often indicates insufficient data variability
  713. # with warnings.catch_warnings():
  714. # warnings.filterwarnings('error', category=np.RankWarning) # Or ignore
  715. coefficients = np.polyfit(x_window, y_window, self.degree)
  716. return coefficients, np.poly1d(coefficients)
  717. except np.RankWarning:
  718. print(f"Warning: Rank deficient fitting for window. Check data variability.")
  719. # Attempt lower degree fit? Or return None? For now, return None.
  720. # try:
  721. # coefficients = np.polyfit(x_window, y_window, len(x_window) - 1)
  722. # return coefficients, np.poly1d(coefficients)
  723. # except:
  724. return None, None
  725. except Exception as e:
  726. print(f"Error during polynomial fit: {e}")
  727. return None, None
  728. def find_best_window(self, point: Tuple[float, float], window_size: int) -> Optional[int]:
  729. """Finds the start index of the window whose center is closest to the point."""
  730. if len(self.x_data) < window_size:
  731. print("Warning: Not enough data points for the specified window size.")
  732. return None
  733. x_point, y_point = point
  734. min_dist_sq = np.inf
  735. best_start_index = -1
  736. # Calculate window centers more efficiently
  737. # Use rolling mean if window_size is large, otherwise simple loop is fine
  738. num_windows = len(self.x_data) - window_size + 1
  739. if num_windows <= 0: return None
  740. for start in range(num_windows):
  741. x_center = np.mean(self.x_data[start: start + window_size])
  742. y_center = np.mean(self.y_data[start: start + window_size])
  743. dist_sq = (x_point - x_center) ** 2 + (y_point - y_center) ** 2
  744. if dist_sq < min_dist_sq:
  745. min_dist_sq = dist_sq
  746. best_start_index = start
  747. return best_start_index if best_start_index != -1 else None
  748. def find_projection(
  749. self,
  750. x_target: float,
  751. y_target: float,
  752. polynomial: np.poly1d,
  753. x_range: Tuple[float, float],
  754. search_points: int = 100, # Number of points instead of step size
  755. ) -> Optional[Tuple[float, float, float]]:
  756. """Finds the approximate closest point on the polynomial within the x_range."""
  757. if x_range[1] <= x_range[0]: return None # Invalid range
  758. x_values = np.linspace(x_range[0], x_range[1], search_points)
  759. y_values = polynomial(x_values)
  760. distances_sq = (x_target - x_values) ** 2 + (y_target - y_values) ** 2
  761. if len(distances_sq) == 0: return None
  762. min_idx = np.argmin(distances_sq)
  763. min_distance = np.sqrt(distances_sq[min_idx])
  764. return x_values[min_idx], y_values[min_idx], min_distance
  765. def fit_and_project(
  766. self, points: np.ndarray, window_size: int
  767. ) -> List[Dict[str, Any]]:
  768. """Fits polynomial and calculates curvature for each point in the input array."""
  769. if self.data is None or len(self.x_data) < window_size:
  770. print("Insufficient LaneMap data for fitting.")
  771. # Return default values for all points
  772. return [
  773. {
  774. "projection": (np.nan, np.nan),
  775. "lightMask": 0,
  776. "curvHor": np.nan,
  777. "curvHorDot": np.nan,
  778. "laneOffset": np.nan,
  779. }
  780. ] * len(points)
  781. results = []
  782. if points.ndim != 2 or points.shape[1] != 2:
  783. raise ValueError("Input points must be a 2D numpy array with shape (n, 2).")
  784. for x_target, y_target in points:
  785. result = { # Default result
  786. "projection": (np.nan, np.nan),
  787. "lightMask": 0,
  788. "curvHor": np.nan,
  789. "curvHorDot": np.nan,
  790. "laneOffset": np.nan,
  791. }
  792. best_start = self.find_best_window((x_target, y_target), window_size)
  793. if best_start is None:
  794. results.append(result)
  795. continue
  796. x_window = self.x_data[best_start: best_start + window_size]
  797. y_window = self.y_data[best_start: best_start + window_size]
  798. coefficients, polynomial = self.polynomial_fit(x_window, y_window)
  799. if coefficients is None or polynomial is None:
  800. results.append(result)
  801. continue
  802. x_min, x_max = np.min(x_window), np.max(x_window)
  803. projection_result = self.find_projection(
  804. x_target, y_target, polynomial, (x_min, x_max)
  805. )
  806. if projection_result is None:
  807. results.append(result)
  808. continue
  809. proj_x, proj_y, min_distance = projection_result
  810. curv_hor = self.curvature(coefficients, proj_x)
  811. curv_hor_dot = self.curvature_derivative(coefficients, proj_x)
  812. result = {
  813. "projection": (round(proj_x, 6), round(proj_y, 6)),
  814. "lightMask": 0,
  815. "curvHor": round(curv_hor, 6),
  816. "curvHorDot": round(curv_hor_dot, 6),
  817. "laneOffset": round(min_distance, 6),
  818. }
  819. results.append(result)
  820. return results
  821. # --- Data Quality Analyzer (Optimized) ---
  822. class DataQualityAnalyzer:
  823. """Analyzes data quality metrics, focusing on frame loss."""
  824. def __init__(self, df: Optional[pd.DataFrame] = None):
  825. self.df = df if df is not None and not df.empty else pd.DataFrame() # Ensure df is DataFrame
  826. def analyze_frame_loss(self) -> Dict[str, Any]:
  827. """Analyzes frame loss characteristics."""
  828. metrics = {
  829. "total_frames_data": 0,
  830. "unique_frames_count": 0,
  831. "min_frame": np.nan,
  832. "max_frame": np.nan,
  833. "expected_frames": 0,
  834. "dropped_frames_count": 0,
  835. "loss_rate": np.nan,
  836. "max_consecutive_loss": 0,
  837. "max_loss_start_frame": np.nan,
  838. "max_loss_end_frame": np.nan,
  839. "loss_intervals_distribution": {},
  840. "valid": False, # Indicate if analysis was possible
  841. "message": ""
  842. }
  843. if self.df.empty or 'simFrame' not in self.df.columns:
  844. metrics["message"] = "DataFrame is empty or 'simFrame' column is missing."
  845. return metrics
  846. # Drop rows with NaN simFrame and ensure integer type
  847. frames_series = self.df['simFrame'].dropna().astype(int)
  848. metrics["total_frames_data"] = len(frames_series)
  849. if frames_series.empty:
  850. metrics["message"] = "No valid 'simFrame' data found after dropping NaN."
  851. return metrics
  852. unique_frames = sorted(frames_series.unique())
  853. metrics["unique_frames_count"] = len(unique_frames)
  854. if metrics["unique_frames_count"] < 2:
  855. metrics["message"] = "Less than two unique frames; cannot analyze loss."
  856. metrics["valid"] = True # Data exists, just not enough to analyze loss
  857. if metrics["unique_frames_count"] == 1:
  858. metrics["min_frame"] = unique_frames[0]
  859. metrics["max_frame"] = unique_frames[0]
  860. metrics["expected_frames"] = 1
  861. return metrics
  862. metrics["min_frame"] = unique_frames[0]
  863. metrics["max_frame"] = unique_frames[-1]
  864. metrics["expected_frames"] = metrics["max_frame"] - metrics["min_frame"] + 1
  865. # Calculate differences between consecutive unique frames
  866. frame_diffs = np.diff(unique_frames)
  867. # Gaps are where diff > 1. The number of lost frames in a gap is diff - 1.
  868. gaps = frame_diffs[frame_diffs > 1]
  869. lost_frames_in_gaps = gaps - 1
  870. metrics["dropped_frames_count"] = int(lost_frames_in_gaps.sum())
  871. if metrics["expected_frames"] > 0:
  872. metrics["loss_rate"] = round(metrics["dropped_frames_count"] / metrics["expected_frames"], 4)
  873. else:
  874. metrics["loss_rate"] = 0.0 # Avoid division by zero if min_frame == max_frame (already handled)
  875. if len(lost_frames_in_gaps) > 0:
  876. metrics["max_consecutive_loss"] = int(lost_frames_in_gaps.max())
  877. # Find where the max loss occurred
  878. max_loss_indices = np.where(frame_diffs == metrics["max_consecutive_loss"] + 1)[0]
  879. # Get the first occurrence start/end frames
  880. max_loss_idx = max_loss_indices[0]
  881. metrics["max_loss_start_frame"] = unique_frames[max_loss_idx]
  882. metrics["max_loss_end_frame"] = unique_frames[max_loss_idx + 1]
  883. # Count distribution of loss interval lengths
  884. loss_counts = Counter(lost_frames_in_gaps)
  885. metrics["loss_intervals_distribution"] = {int(k): int(v) for k, v in loss_counts.items()}
  886. else:
  887. metrics["max_consecutive_loss"] = 0
  888. metrics["valid"] = True
  889. metrics["message"] = "Frame loss analysis complete."
  890. return metrics
  891. def get_all_csv_files(path: Path) -> List[Path]:
  892. """Gets all CSV files in path, excluding specific ones."""
  893. excluded_files = {OUTPUT_CSV_LANEMAP, ROADMARK_CSV}
  894. return [
  895. file_path
  896. for file_path in path.rglob("*.csv") # Recursive search
  897. if file_path.is_file() and file_path.name not in excluded_files
  898. ]
  899. def run_frame_loss_analysis_on_folder(path: Path) -> Dict[str, Dict[str, Any]]:
  900. """Runs frame loss analysis on all relevant CSV files in a folder."""
  901. analysis_results = {}
  902. csv_files = get_all_csv_files(path)
  903. if not csv_files:
  904. print(f"No relevant CSV files found in {path}")
  905. return analysis_results
  906. for file_path in csv_files:
  907. file_name = file_path.name
  908. if file_name in {OUTPUT_CSV_FUNCTION, OUTPUT_CSV_OBU}: # Skip specific files if needed
  909. print(f"Skipping frame analysis for: {file_name}")
  910. continue
  911. print(f"Analyzing frame loss for: {file_name}")
  912. if file_path.stat().st_size == 0:
  913. print(f"File {file_name} is empty. Skipping analysis.")
  914. analysis_results[file_name] = {"valid": False, "message": "File is empty."}
  915. continue
  916. try:
  917. # Read only necessary column if possible, handle errors
  918. df = pd.read_csv(file_path, usecols=['simFrame'], index_col=False,
  919. on_bad_lines='warn') # 'warn' or 'skip'
  920. analyzer = DataQualityAnalyzer(df)
  921. metrics = analyzer.analyze_frame_loss()
  922. analysis_results[file_name] = metrics
  923. # Optionally print a summary here
  924. if metrics["valid"]:
  925. print(f" Loss Rate: {metrics.get('loss_rate', np.nan) * 100:.2f}%, "
  926. f"Dropped: {metrics.get('dropped_frames_count', 'N/A')}, "
  927. f"Max Gap: {metrics.get('max_consecutive_loss', 'N/A')}")
  928. else:
  929. print(f" Analysis failed: {metrics.get('message')}")
  930. except pd.errors.EmptyDataError:
  931. print(f"File {file_name} contains no data after reading.")
  932. analysis_results[file_name] = {"valid": False, "message": "Empty data after read."}
  933. except ValueError as ve: # Handle case where simFrame might not be present
  934. print(f"ValueError processing file {file_name}: {ve}. Is 'simFrame' column present?")
  935. analysis_results[file_name] = {"valid": False, "message": f"ValueError: {ve}"}
  936. except Exception as e:
  937. print(f"Unexpected error processing file {file_name}: {e}")
  938. analysis_results[file_name] = {"valid": False, "message": f"Unexpected error: {e}"}
  939. return analysis_results
  940. def data_precheck(output_dir: Path, max_allowed_loss_rate: float = 0.20) -> bool:
  941. """Checks data quality, focusing on frame loss rate."""
  942. print(f"--- Running Data Quality Precheck on: {output_dir} ---")
  943. if not output_dir.exists() or not output_dir.is_dir():
  944. print(f"Error: Output directory does not exist: {output_dir}")
  945. return False
  946. try:
  947. frame_loss_results = run_frame_loss_analysis_on_folder(output_dir)
  948. except Exception as e:
  949. print(f"Critical error during frame loss analysis: {e}")
  950. return False # Treat critical error as failure
  951. if not frame_loss_results:
  952. print("Warning: No files were analyzed for frame loss.")
  953. # Decide if this is a failure or just a warning. Let's treat it as OK for now.
  954. return True
  955. all_checks_passed = True
  956. for file_name, metrics in frame_loss_results.items():
  957. if metrics.get("valid", False):
  958. loss_rate = metrics.get("loss_rate", np.nan)
  959. if pd.isna(loss_rate):
  960. print(f" {file_name}: Loss rate could not be calculated.")
  961. # Decide if NaN loss rate is acceptable.
  962. elif loss_rate > max_allowed_loss_rate:
  963. print(
  964. f" FAIL: {file_name} - Frame loss rate ({loss_rate * 100:.2f}%) exceeds threshold ({max_allowed_loss_rate * 100:.1f}%).")
  965. all_checks_passed = False
  966. else:
  967. print(f" PASS: {file_name} - Frame loss rate ({loss_rate * 100:.2f}%) is acceptable.")
  968. else:
  969. print(
  970. f" WARN: {file_name} - Frame loss analysis could not be completed ({metrics.get('message', 'Unknown reason')}).")
  971. # Decide if inability to analyze is a failure. Let's allow it for now.
  972. print(f"--- Data Quality Precheck {'PASSED' if all_checks_passed else 'FAILED'} ---")
  973. return all_checks_passed
  974. # --- Final Preprocessing Step ---
  975. class FinalDataProcessor:
  976. """Merges processed CSVs, adds curvature, and handles traffic lights."""
  977. def __init__(self, config: Config):
  978. self.config = config
  979. self.output_dir = config.output_dir
  980. def _axis_to_ENU(self, speedX, speedY, speedZ, accelX, accelY, accelZ, posH, pitch, roll):
  981. posh_ENU = posH % 360
  982. posh_ENU = posh_ENU * np.pi / 180
  983. pitch = pitch * np.pi / 180
  984. roll = roll * np.pi / 180
  985. east_speedX, north_speedY, north_speedZ = [], [], []
  986. east_accelX, north_accelY, north_accelZ = [], [], []
  987. for i in range(len(posH)):
  988. sy = np.sin(posh_ENU[i])
  989. cy = np.cos(posh_ENU[i])
  990. cp = np.cos(pitch[i])
  991. sp = np.sin(pitch[i])
  992. cr = np.cos(roll[i])
  993. sr = np.sin(roll[i])
  994. trametrix = np.array([[sy * cp, sy * sp * sr - cy * cr, sy * sp * cr + cy * sr],
  995. [cy * cp, cy * sp * sr + sy * cr, cy * sp * cr - sy * sr], [-sp, cp * sr, cp * cr]])
  996. # trametrix = np.array([[sy, cy], [-cy, sy]])
  997. east_speedX_i, north_speedY_i, north_speedZ_i = np.linalg.pinv(trametrix) @ np.array(
  998. [speedX[i], speedY[i], speedZ[i]])
  999. # east_speedX_i, north_speedY_i = np.linalg.pinv(trametrix) @ np.array([speedX[i], speedY[i]])
  1000. east_accelX_i, north_accelY_i, north_accelZ_i = np.linalg.pinv(trametrix) @ np.array(
  1001. [accelX[i], accelY[i], accelZ[i]])
  1002. # east_accelX_i, north_accelY_i = np.linalg.pinv(trametrix) @ np.array([accelX[i], accelY[i]])
  1003. east_speedX.append(east_speedX_i)
  1004. north_speedY.append(north_speedY_i)
  1005. north_speedZ.append(north_speedZ_i)
  1006. east_accelX.append(east_accelX_i)
  1007. north_accelY.append(north_accelY_i)
  1008. north_accelZ.append(north_accelZ_i)
  1009. return east_speedX, north_speedY, speedZ, east_accelX, north_accelY, accelZ
  1010. # return east_speedX, north_speedY, east_accelX, north_accelY
  1011. def process(self) -> bool:
  1012. """执行最终数据合并和处理步骤。"""
  1013. print("--- Starting Final Data Processing ---")
  1014. try:
  1015. # 1. Load main object state data
  1016. obj_state_path = self.output_dir / OUTPUT_CSV_OBJSTATE
  1017. lane_map_path = self.output_dir / OUTPUT_CSV_LANEMAP
  1018. if not obj_state_path.exists():
  1019. print(f"Error: Required input file not found: {obj_state_path}")
  1020. return False
  1021. # 处理交通灯数据并保存
  1022. df_traffic = self._process_trafficlight_data()
  1023. if not df_traffic.empty:
  1024. traffic_csv_path = self.output_dir / "Traffic.csv"
  1025. df_traffic.to_csv(traffic_csv_path, index=False, float_format='%.6f')
  1026. print(f"Successfully created traffic light data file: {traffic_csv_path}")
  1027. # Load and process data
  1028. df_object = pd.read_csv(obj_state_path, dtype={"simTime": float}, low_memory=False)
  1029. # 坐标转换
  1030. speedX = df_object['speedX']
  1031. speedY = df_object['speedY']
  1032. speedZ = df_object['speedZ']
  1033. accelX = df_object['accelX']
  1034. accelY = df_object['accelY']
  1035. accelZ = df_object['accelZ']
  1036. posH = df_object['posH']
  1037. pitch = df_object['pitch']
  1038. roll = df_object['roll']
  1039. east_speedX, north_speedY, north_speedZ, east_accelX, north_accelY, north_accelZ = self._axis_to_ENU(speedX,
  1040. speedY,
  1041. speedZ,
  1042. accelX,
  1043. accelY,
  1044. accelZ,
  1045. posH,
  1046. pitch,
  1047. roll)
  1048. # east_speedX, north_speedY, east_accelX, north_accelY = self._axis_to_ENU(speedX, speedY, speedZ, accelX, accelY, accelZ, posH, pitch, roll)
  1049. df_object['speedX'] = east_speedX
  1050. df_object['speedY'] = north_speedY
  1051. df_object['speedZ'] = north_speedZ
  1052. df_object['accelX'] = east_accelX
  1053. df_object['accelY'] = north_accelY
  1054. df_object['accelZ'] = north_accelZ
  1055. df_ego = df_object[df_object["playerId"] == 1]
  1056. points = df_ego[["posX", "posY"]].values
  1057. window_size = 4
  1058. fitting_instance = PolynomialCurvatureFitting(lane_map_path)
  1059. result_list = fitting_instance.fit_and_project(points, window_size)
  1060. curvHor_values = [result["curvHor"] for result in result_list]
  1061. curvature_change_value = [result["curvHorDot"] for result in result_list]
  1062. min_distance = [result["laneOffset"] for result in result_list]
  1063. indices = df_object[df_object["playerId"] == 1].index
  1064. if len(indices) == len(curvHor_values):
  1065. df_object.loc[indices, "lightMask"] = 0
  1066. df_object.loc[indices, "curvHor"] = curvHor_values
  1067. df_object.loc[indices, "curvHorDot"] = curvature_change_value
  1068. df_object.loc[indices, "laneOffset"] = min_distance
  1069. else:
  1070. print("计算值的长度与 playerId == 1 的行数不匹配!")
  1071. # Process and merge data
  1072. df_merged = self._merge_optional_data(df_object)
  1073. # df_merged[['speedH', 'accelX']] = -df_merged[['speedH', 'accelX']]
  1074. # Save final merged file directly to output directory
  1075. merged_csv_path = self.output_dir / OUTPUT_CSV_MERGED
  1076. print(f'merged_csv_path:{merged_csv_path}')
  1077. df_merged.to_csv(merged_csv_path, index=False, float_format='%.6f')
  1078. print(f"Successfully created final merged file: {merged_csv_path}")
  1079. # Clean up intermediate files
  1080. # if obj_state_path.exists():
  1081. # obj_state_path.unlink()
  1082. print("--- Final Data Processing Finished ---")
  1083. return True
  1084. except Exception as e:
  1085. print(f"An unexpected error occurred during final data processing: {e}")
  1086. import traceback
  1087. traceback.print_exc()
  1088. return False
  1089. def _merge_optional_data(self, df_object: pd.DataFrame) -> pd.DataFrame:
  1090. """加载和合并可选数据"""
  1091. df_merged = df_object.copy()
  1092. # 检查并删除重复列的函数
  1093. def clean_duplicate_columns(df):
  1094. # 查找带有 _x 或 _y 后缀的列
  1095. duplicate_cols = []
  1096. base_cols = {}
  1097. # 打印清理前的列名
  1098. print(f"清理重复列前的列名: {df.columns.tolist()}")
  1099. for col in df.columns:
  1100. if col.endswith('_x') or col.endswith('_y'):
  1101. base_name = col[:-2] # 去掉后缀
  1102. if base_name not in base_cols:
  1103. base_cols[base_name] = []
  1104. base_cols[base_name].append(col)
  1105. # 对于每组重复列,检查数据是否相同,如果相同则只保留一个
  1106. for base_name, cols in base_cols.items():
  1107. if len(cols) > 1:
  1108. # 检查这些列的数据是否相同
  1109. is_identical = True
  1110. first_col = cols[0]
  1111. for col in cols[1:]:
  1112. if not df[first_col].equals(df[col]):
  1113. is_identical = False
  1114. break
  1115. if is_identical:
  1116. # 数据相同,保留第一列并重命名为基本名称
  1117. df = df.rename(columns={first_col: base_name})
  1118. # 删除其他重复列
  1119. for col in cols[1:]:
  1120. duplicate_cols.append(col)
  1121. print(f"列 {cols} 数据相同,保留为 {base_name}")
  1122. else:
  1123. print(f"列 {cols} 数据不同,保留所有列")
  1124. # 如果是 simTime 相关列,确保保留一个
  1125. if base_name == 'simTime' and 'simTime' not in df.columns:
  1126. df = df.rename(columns={cols[0]: 'simTime'})
  1127. print(f"将 {cols[0]} 重命名为 simTime")
  1128. # 删除其他 simTime 相关列
  1129. for col in cols[1:]:
  1130. duplicate_cols.append(col)
  1131. # 删除重复列
  1132. if duplicate_cols:
  1133. # 确保不会删除 simTime 列
  1134. if 'simTime' not in df.columns and any(col.startswith('simTime_') for col in duplicate_cols):
  1135. # 找到一个 simTime 相关列保留
  1136. for col in duplicate_cols[:]:
  1137. if col.startswith('simTime_'):
  1138. df = df.rename(columns={col: 'simTime'})
  1139. duplicate_cols.remove(col)
  1140. print(f"将 {col} 重命名为 simTime")
  1141. break
  1142. df = df.drop(columns=duplicate_cols)
  1143. print(f"删除了重复列: {duplicate_cols}")
  1144. # 打印清理后的列名
  1145. print(f"清理重复列后的列名: {df.columns.tolist()}")
  1146. return df
  1147. # --- 合并 EgoMap ---
  1148. egomap_path = self.output_dir / OUTPUT_CSV_EGOMAP
  1149. merged_egomap_path = self.output_dir / MERGED_CSV_EGOMAP
  1150. if egomap_path.exists() and egomap_path.stat().st_size > 0:
  1151. try:
  1152. df_ego = pd.read_csv(egomap_path, dtype={"simTime": float})
  1153. ego_column = ['posX', 'posY', 'posH']
  1154. ego_new_column = ['posX_map', 'posY_map', 'posH_map']
  1155. df_ego = df_ego.rename(columns=dict(zip(ego_column, ego_new_column)))
  1156. # 删除 simFrame 列,因为使用主数据的 simFrame
  1157. if 'simFrame' in df_ego.columns:
  1158. df_ego = df_ego.drop(columns=['simFrame'])
  1159. # 打印合并前的列名
  1160. print(f"合并 EgoMap 前 df_merged 的列: {df_merged.columns.tolist()}")
  1161. print(f"df_ego 的列: {df_ego.columns.tolist()}")
  1162. # 按时间和ID排序
  1163. df_ego.sort_values(['simTime', 'playerId'], inplace=True)
  1164. df_merged.sort_values(['simTime', 'playerId'], inplace=True)
  1165. # 使用 merge_asof 进行就近合并,不包括 simFrame
  1166. # df_merged = pd.merge_asof(
  1167. # df_merged,
  1168. # df_ego,
  1169. # on='simTime',
  1170. # by='playerId',
  1171. # direction='nearest',
  1172. # tolerance=0.01 # 10ms tolerance
  1173. # )
  1174. df_merged = pd.merge(
  1175. df_merged,
  1176. df_ego,
  1177. how='left',
  1178. on='simTime'
  1179. )
  1180. # 打印合并后的列名
  1181. print(f"合并 EgoMap 后 df_merged 的列: {df_merged.columns.tolist()}")
  1182. # 确保 simTime 列存在
  1183. if 'simTime' not in df_merged.columns:
  1184. if 'simTime_x' in df_merged.columns:
  1185. df_merged.rename(columns={'simTime_x': 'simTime'}, inplace=True)
  1186. print("将 simTime_x 重命名为 simTime")
  1187. else:
  1188. print("警告: 合并 EgoMap 后找不到 simTime 列!")
  1189. df_merged = df_merged.drop(columns=['posX_map', 'posY_map', 'posH_map', 'stateMask'])
  1190. df_merged.to_csv(merged_egomap_path, index=False, float_format='%.6f')
  1191. print("EgoMap data merged.")
  1192. except Exception as e:
  1193. print(f"Warning: Could not merge EgoMap data from {egomap_path}: {e}")
  1194. import traceback
  1195. traceback.print_exc()
  1196. # 先处理可能的列名重复问题
  1197. df_merged = clean_duplicate_columns(df_merged)
  1198. # --- 合并hd_lane.csv,hd_road.csv ---
  1199. current_file_path = os.path.abspath(__file__)
  1200. # --- 合并 Traffic ---
  1201. traffic_path = self.output_dir / "Traffic.csv"
  1202. if traffic_path.exists() and traffic_path.stat().st_size > 0:
  1203. try:
  1204. df_traffic = pd.read_csv(traffic_path, dtype={"simTime": float}, low_memory=False).drop_duplicates()
  1205. # 删除 simFrame 列
  1206. if 'simFrame' in df_traffic.columns:
  1207. df_traffic = df_traffic.drop(columns=['simFrame'])
  1208. # 根据车辆航向角确定行驶方向并筛选对应的红绿灯
  1209. def get_direction_from_heading(heading):
  1210. # 将角度归一化到 -180 到 180 度范围
  1211. heading = heading % 360
  1212. if heading > 180:
  1213. heading -= 360
  1214. # 确定方向:北(N)、东(E)、南(S)、西(W)
  1215. if -45 <= heading <= 45: # 北向
  1216. return 'N'
  1217. elif 45 < heading <= 135: # 东向
  1218. return 'E'
  1219. elif -135 <= heading < -45: # 西向
  1220. return 'W'
  1221. else: # 南向 (135 < heading <= 180 或 -180 <= heading < -135)
  1222. return 'S'
  1223. # 检查posH列是否存在,如果不存在但posH_x存在,则使用posH_x
  1224. heading_col = 'posH'
  1225. if heading_col not in df_merged.columns:
  1226. if 'posH_x' in df_merged.columns:
  1227. heading_col = 'posH_x'
  1228. print(f"使用 {heading_col} 替代 posH")
  1229. else:
  1230. print(f"警告: 找不到航向角列 posH 或 posH_x")
  1231. return df_merged
  1232. # 添加方向列
  1233. df_merged['vehicle_direction'] = df_merged[heading_col].apply(get_direction_from_heading)
  1234. # 创建 phaseId 到方向的映射
  1235. phase_to_direction = {
  1236. 1: 'S', # 南直行
  1237. 2: 'W', # 西直行
  1238. 3: 'N', # 北直行
  1239. 4: 'E', # 东直行
  1240. 5: 'S', # 南行人
  1241. 6: 'W', # 西行人
  1242. 7: 'S', # 南左转
  1243. 8: 'W', # 西左转
  1244. 9: 'N', # 北左转
  1245. 10: 'E', # 东左转
  1246. 11: 'N', # 北行人
  1247. 12: 'E', # 东行人
  1248. 13: 'S', # 南右转
  1249. 14: 'W', # 西右转
  1250. 15: 'N', # 北右转
  1251. 16: 'E' # 东右转
  1252. }
  1253. # 创建 trafficlight_id 到方向的映射
  1254. trafficlight_to_direction = {
  1255. # 南向北方向的红绿灯
  1256. # 48100017: 'S',
  1257. # 48100038: 'S',
  1258. # 48100043: 'S',
  1259. # 48100030: 'S',
  1260. 48100017: 'N',
  1261. 48100038: 'N',
  1262. 48100043: 'N',
  1263. 48100030: 'N',
  1264. # 西向东方向的红绿灯
  1265. # 48100021: 'W',
  1266. # 48100039: 'W',
  1267. 48100021: 'E',
  1268. 48100039: 'E',
  1269. # 东向西方向的红绿灯
  1270. # 48100041: 'E',
  1271. # 48100019: 'E',
  1272. 48100041: 'W',
  1273. 48100019: 'W',
  1274. # 北向南方向的红绿灯
  1275. # 48100033: 'N',
  1276. # 48100018: 'N',
  1277. # 48100022: 'N'
  1278. 48100033: 'S',
  1279. 48100018: 'S',
  1280. 48100022: 'S'
  1281. }
  1282. # 添加时间列用于合并
  1283. df_traffic['time'] = df_traffic['simTime'].round(2).astype(float)
  1284. # 检查 df_merged 中是否有 simTime 列
  1285. if 'simTime' not in df_merged.columns:
  1286. print("警告: 合并 Traffic 前 df_merged 中找不到 simTime 列!")
  1287. # 尝试查找 simTime_x 或其他可能的列
  1288. if 'simTime_x' in df_merged.columns:
  1289. df_merged.rename(columns={'simTime_x': 'simTime'}, inplace=True)
  1290. print("将 simTime_x 重命名为 simTime")
  1291. else:
  1292. print("严重错误: 无法找到任何 simTime 相关列,无法继续合并!")
  1293. return df_merged
  1294. df_merged['time'] = df_merged['simTime'].round(2).astype(float)
  1295. tree = cKDTree(df_traffic[['simTime']])
  1296. # 查询df1中每个时间戳的最近邻
  1297. distances, indices = tree.query(df_merged[['simTime']], k=1)
  1298. # df_merged['time1'] = df_merged['simTime'].round(0).astype(float)
  1299. # df_traffic1 = df_traffic.rename(columns={'simTime' 'simTime1'})
  1300. # df_traffic['time1'] = df_traffic['time'].round(0).astype(float)
  1301. # 合并 Traffic 数据
  1302. df_merged['matched_time'] = df_traffic.iloc[indices.flatten()]['simTime'].values
  1303. # 合并DataFrame
  1304. df_merged = pd.merge(df_merged, df_traffic, left_on='matched_time', right_on='simTime', how='left')
  1305. # df_merged = pd.merge(df_merged, df_traffic, on=["time1"], how="left")
  1306. # df_merged = df_merged.drop(columns = ['time1'])
  1307. # 再次处理可能的列名重复问题
  1308. df_merged = clean_duplicate_columns(df_merged)
  1309. df_merged = df_merged.drop(columns=['time_x', 'time_y', 'matched_time'])
  1310. # 检查trafficlight_id列是否存在
  1311. trafficlight_col = 'trafficlight_id'
  1312. if trafficlight_col not in df_merged.columns:
  1313. if 'trafficlight_id_x' in df_merged.columns:
  1314. trafficlight_col = 'trafficlight_id_x'
  1315. print(f"使用 {trafficlight_col} 替代 trafficlight_id")
  1316. else:
  1317. print(f"警告: 找不到红绿灯ID列 trafficlight_id 或 trafficlight_id_x")
  1318. # 筛选与车辆行驶方向相关的红绿灯
  1319. def filter_relevant_traffic_light(row):
  1320. if 'phaseId' not in row or pd.isna(row['phaseId']):
  1321. return np.nan
  1322. # 获取 phaseId 对应的方向
  1323. phase_id = int(row['phaseId']) if not pd.isna(row['phaseId']) else None
  1324. if phase_id is None:
  1325. return np.nan
  1326. phase_direction = phase_to_direction.get(phase_id, None)
  1327. # 如果 phaseId 方向与车辆方向匹配
  1328. if phase_direction == row['vehicle_direction']:
  1329. # 查找该方向的所有红绿灯 ID
  1330. relevant_ids = [tid for tid, direction in trafficlight_to_direction.items()
  1331. if direction == phase_direction]
  1332. # 如果 trafficlight_id 在 EgoMap 中且方向匹配
  1333. # if trafficlight_col in row and not pd.isna(row[trafficlight_col]) and row[trafficlight_col] in relevant_ids:
  1334. if trafficlight_col in row:
  1335. if not pd.isna(row[trafficlight_col]):
  1336. if row[trafficlight_col] in relevant_ids:
  1337. return row[trafficlight_col]
  1338. return np.nan
  1339. # 应用筛选函数
  1340. df_merged['filtered_trafficlight_id'] = df_merged.apply(filter_relevant_traffic_light, axis=1)
  1341. # 清理临时列
  1342. # print(f"删除 time 列前 df_merged 的列: {df_merged.columns.tolist()}")
  1343. # df_merged.drop(columns=['time'], inplace=True)
  1344. # print(f"删除 time 列后 df_merged 的列: {df_merged.columns.tolist()}")
  1345. # 确保 simTime 列存在
  1346. if 'simTime' not in df_merged.columns:
  1347. if 'simTime_x' in df_merged.columns:
  1348. df_merged.rename(columns={'simTime_x': 'simTime'}, inplace=True)
  1349. print("将 simTime_x 重命名为 simTime")
  1350. else:
  1351. print("警告: 处理 Traffic 数据后找不到 simTime 列!")
  1352. print("Traffic light data merged and filtered.")
  1353. except Exception as e:
  1354. print(f"Warning: Could not merge Traffic data from {traffic_path}: {e}")
  1355. import traceback
  1356. traceback.print_exc()
  1357. else:
  1358. print("Traffic data not found or empty, skipping merge.")
  1359. # --- Merge Function ---
  1360. function_path = self.output_dir / OUTPUT_CSV_FUNCTION
  1361. if function_path.exists() and function_path.stat().st_size > 0:
  1362. try:
  1363. # 添加调试信息
  1364. print(f"正在读取 Function 数据: {function_path}")
  1365. df_function = pd.read_csv(function_path, low_memory=False).drop_duplicates()
  1366. print(f"Function 数据列名: {df_function.columns.tolist()}")
  1367. # 删除 simFrame 列
  1368. if 'simFrame' in df_function.columns:
  1369. df_function = df_function.drop(columns=['simFrame'])
  1370. # 确保 simTime 列存在并且是浮点型
  1371. if 'simTime' in df_function.columns:
  1372. # 安全地将 simTime 转换为浮点型
  1373. try:
  1374. df_function['simTime'] = pd.to_numeric(df_function['simTime'], errors='coerce')
  1375. df_function = df_function.dropna(subset=['simTime']) # 删除无法转换的行
  1376. df_function['time'] = df_function['simTime'].round(2)
  1377. # 安全地处理 df_merged 的 simTime 列
  1378. if 'simTime' in df_merged.columns:
  1379. print(f"df_merged['simTime'] 的类型: {df_merged['simTime'].dtype}")
  1380. print(f"df_merged['simTime'] 的前5个值: {df_merged['simTime'].head().tolist()}")
  1381. df_merged['time'] = pd.to_numeric(df_merged['simTime'], errors='coerce').round(2)
  1382. # 删除 time 列中的 NaN 值
  1383. nan_count = df_merged['time'].isna().sum()
  1384. if nan_count > 0:
  1385. print(f"警告: 转换后有 {nan_count} 个 NaN 值,将删除这些行")
  1386. df_merged = df_merged.dropna(subset=['time'])
  1387. # 确保两个 DataFrame 的 time 列类型一致
  1388. df_function['time'] = df_function['time'].astype(float)
  1389. df_merged['time'] = df_merged['time'].astype(float)
  1390. common_cols = list(set(df_merged.columns) & set(df_function.columns) - {'time'})
  1391. df_function.drop(columns=common_cols, inplace=True, errors='ignore')
  1392. # 合并数据
  1393. df_merged = pd.merge(df_merged, df_function, on=["time"], how="left")
  1394. df_merged.drop(columns=['time'], inplace=True)
  1395. print("Function 数据合并成功。")
  1396. else:
  1397. print("警告: df_merged 中找不到 'simTime' 列,无法合并 Function 数据。")
  1398. # 打印所有列名以便调试
  1399. print(f"df_merged 的所有列: {df_merged.columns.tolist()}")
  1400. except Exception as e:
  1401. print(f"警告: 处理 Function.csv 中的 simTime 列时出错: {e}")
  1402. import traceback
  1403. traceback.print_exc()
  1404. else:
  1405. print(f"警告: Function.csv 中找不到 'simTime' 列。可用的列: {df_function.columns.tolist()}")
  1406. except Exception as e:
  1407. print(f"警告: 无法合并 Function 数据: {e}")
  1408. import traceback
  1409. traceback.print_exc()
  1410. else:
  1411. print(f"Function 数据文件不存在或为空: {function_path}")
  1412. # --- Merge OBU ---
  1413. obu_path = self.output_dir / OUTPUT_CSV_OBU
  1414. if obu_path.exists() and obu_path.stat().st_size > 0:
  1415. try:
  1416. # 添加调试信息
  1417. print(f"正在读取 OBU 数据: {obu_path}")
  1418. df_obu = pd.read_csv(obu_path, low_memory=False).drop_duplicates()
  1419. print(f"OBU 数据列名: {df_obu.columns.tolist()}")
  1420. # 删除 simFrame 列
  1421. if 'simFrame' in df_obu.columns:
  1422. df_obu = df_obu.drop(columns=['simFrame'])
  1423. # 确保 simTime 列存在并且是浮点型
  1424. if 'simTime' in df_obu.columns:
  1425. # 安全地将 simTime 转换为浮点型
  1426. try:
  1427. df_obu['simTime'] = pd.to_numeric(df_obu['simTime'], errors='coerce')
  1428. df_obu = df_obu.dropna(subset=['simTime']) # 删除无法转换的行
  1429. df_obu['time'] = df_obu['simTime'].round(2)
  1430. # 安全地处理 df_merged 的 simTime 列
  1431. if 'simTime' in df_merged.columns:
  1432. print(f"合并 OBU 前 df_merged['simTime'] 的类型: {df_merged['simTime'].dtype}")
  1433. print(f"合并 OBU 前 df_merged['simTime'] 的前5个值: {df_merged['simTime'].head().tolist()}")
  1434. df_merged['time'] = pd.to_numeric(df_merged['simTime'], errors='coerce').round(2)
  1435. # 删除 time 列中的 NaN 值
  1436. nan_count = df_merged['time'].isna().sum()
  1437. if nan_count > 0:
  1438. print(f"警告: 转换后有 {nan_count} 个 NaN 值,将删除这些行")
  1439. df_merged = df_merged.dropna(subset=['time'])
  1440. # 确保两个 DataFrame 的 time 列类型一致
  1441. df_obu['time'] = df_obu['time'].astype(float)
  1442. df_merged['time'] = df_merged['time'].astype(float)
  1443. common_cols = list(set(df_merged.columns) & set(df_obu.columns) - {'time'})
  1444. df_obu.drop(columns=common_cols, inplace=True, errors='ignore')
  1445. # 合并数据
  1446. df_merged = pd.merge(df_merged, df_obu, on=["time"], how="left")
  1447. df_merged.drop(columns=['time'], inplace=True)
  1448. print("OBU 数据合并成功。")
  1449. else:
  1450. print("警告: df_merged 中找不到 'simTime' 列,无法合并 OBU 数据。")
  1451. # 打印所有列名以便调试
  1452. print(f"df_merged 的所有列: {df_merged.columns.tolist()}")
  1453. except Exception as e:
  1454. print(f"警告: 处理 OBUdata.csv 中的 simTime 列时出错: {e}")
  1455. import traceback
  1456. traceback.print_exc()
  1457. else:
  1458. print(f"警告: OBUdata.csv 中找不到 'simTime' 列。可用的列: {df_obu.columns.tolist()}")
  1459. except Exception as e:
  1460. print(f"警告: 无法合并 OBU 数据: {e}")
  1461. import traceback
  1462. traceback.print_exc()
  1463. else:
  1464. print(f"OBU 数据文件不存在或为空: {obu_path}")
  1465. # 在所有合并完成后,再次清理重复列
  1466. df_merged = clean_duplicate_columns(df_merged)
  1467. return df_merged
  1468. def _process_trafficlight_data(self) -> pd.DataFrame:
  1469. """Processes traffic light JSON data if available."""
  1470. # Check if json_path is provided and exists
  1471. if not self.config.json_path:
  1472. print("No traffic light JSON file provided. Skipping traffic light processing.")
  1473. return pd.DataFrame()
  1474. if not self.config.json_path.exists():
  1475. print("Traffic light JSON file not found. Skipping traffic light processing.")
  1476. return pd.DataFrame()
  1477. print(f"Processing traffic light data from: {self.config.json_path}")
  1478. valid_trafficlights = []
  1479. try:
  1480. with open(self.config.json_path, 'r', encoding='utf-8') as f:
  1481. # Read the whole file, assuming it's a JSON array or JSON objects per line
  1482. try:
  1483. # Attempt to read as a single JSON array
  1484. raw_data = json.load(f)
  1485. if not isinstance(raw_data, list):
  1486. raw_data = [raw_data] # Handle case of single JSON object
  1487. except json.JSONDecodeError:
  1488. # If fails, assume JSON objects per line
  1489. f.seek(0) # Reset file pointer
  1490. raw_data = [json.loads(line) for line in f if line.strip()]
  1491. for entry in raw_data:
  1492. # Normalize entry if it's a string containing JSON
  1493. if isinstance(entry, str):
  1494. try:
  1495. entry = json.loads(entry)
  1496. except json.JSONDecodeError:
  1497. print(f"Warning: Skipping invalid JSON string in traffic light data: {entry[:100]}...")
  1498. continue
  1499. # Safely extract data using .get()
  1500. intersections = entry.get('intersections', [])
  1501. if not isinstance(intersections, list): continue # Skip if not a list
  1502. for intersection in intersections:
  1503. if not isinstance(intersection, dict): continue
  1504. timestamp_ms = intersection.get('intersectionTimestamp', 0)
  1505. sim_time = round(int(timestamp_ms) / 1000, 2) # Convert ms to s and round
  1506. phases = intersection.get('phases', [])
  1507. if not isinstance(phases, list): continue
  1508. for phase in phases:
  1509. if not isinstance(phase, dict): continue
  1510. phase_id = phase.get('phaseId', 0)
  1511. phase_states = phase.get('phaseStates', [])
  1512. if not isinstance(phase_states, list): continue
  1513. for phase_state in phase_states:
  1514. if not isinstance(phase_state, dict): continue
  1515. # Check for startTime == 0 as per original logic
  1516. if phase_state.get('startTime') == 0:
  1517. light_state = phase_state.get('light', 0) # Extract light state
  1518. data = {
  1519. 'simTime': sim_time,
  1520. 'phaseId': phase_id,
  1521. 'stateMask': light_state,
  1522. # Add playerId for merging - assume applies to ego
  1523. 'playerId': PLAYER_ID_EGO
  1524. }
  1525. valid_trafficlights.append(data)
  1526. if not valid_trafficlights:
  1527. print("No valid traffic light states (with startTime=0) found in JSON.")
  1528. return pd.DataFrame()
  1529. df_trafficlights = pd.DataFrame(valid_trafficlights)
  1530. # Drop duplicates based on relevant fields
  1531. df_trafficlights.drop_duplicates(subset=['simTime', 'playerId', 'phaseId', 'stateMask'], keep='first',
  1532. inplace=True)
  1533. print(f"Processed {len(df_trafficlights)} unique traffic light state entries.")
  1534. # 按时间升序排序 - 修复倒序问题
  1535. df_trafficlights = df_trafficlights.sort_values('simTime', ascending=True)
  1536. # 添加调试信息
  1537. print(f"交通灯数据时间范围: {df_trafficlights['simTime'].min()} 到 {df_trafficlights['simTime'].max()}")
  1538. print(f"交通灯数据前5行时间: {df_trafficlights['simTime'].head().tolist()}")
  1539. print(f"交通灯数据后5行时间: {df_trafficlights['simTime'].tail().tolist()}")
  1540. return df_trafficlights
  1541. except json.JSONDecodeError as e:
  1542. print(f"Error decoding traffic light JSON file {self.config.json_path}: {e}")
  1543. return pd.DataFrame()
  1544. except Exception as e:
  1545. print(f"Unexpected error processing traffic light data: {e}")
  1546. return pd.DataFrame()
  1547. # --- Rosbag Processing ---
  1548. class RosbagProcessor:
  1549. """Extracts data from HMIdata files within a ZIP archive."""
  1550. def __init__(self, config: Config):
  1551. self.config = config
  1552. self.output_dir = config.output_dir
  1553. def process_zip_for_rosbags(self) -> None:
  1554. """Finds, extracts, and processes rosbags from the ZIP file."""
  1555. print(f"--- Processing HMIdata in {self.config.zip_path} ---")
  1556. with tempfile.TemporaryDirectory() as tmp_dir_str:
  1557. try:
  1558. with zipfile.ZipFile(self.config.zip_path, 'r') as zip_ref:
  1559. for member in zip_ref.infolist():
  1560. # Extract HMIdata CSV files directly to output
  1561. if 'HMIdata/' in member.filename and member.filename.endswith('.csv'):
  1562. try:
  1563. target_path = self.output_dir / Path(member.filename).name
  1564. with zip_ref.open(member) as source, open(target_path, "wb") as target:
  1565. shutil.copyfileobj(source, target)
  1566. print(f"Extracted HMI data: {target_path.name}")
  1567. except Exception as e:
  1568. print(f"Error extracting HMI data {member.filename}: {e}")
  1569. except zipfile.BadZipFile:
  1570. print(f"Error: Bad ZIP file provided: {self.config.zip_path}")
  1571. return
  1572. except FileNotFoundError:
  1573. print(f"Error: ZIP file not found: {self.config.zip_path}")
  1574. return
  1575. print("--- HMIdata Processing Finished ---")
  1576. # --- Utility Functions ---
  1577. def get_base_path() -> Path:
  1578. """Gets the base path of the script or executable."""
  1579. if getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS'):
  1580. # Running in a PyInstaller bundle
  1581. return Path(sys._MEIPASS)
  1582. else:
  1583. # Running as a normal script
  1584. return Path(__file__).parent.resolve()
  1585. def run_cpp_engine(config: Config):
  1586. """Runs the external C++ preprocessing engine."""
  1587. if not config.engine_path or not config.map_path:
  1588. print("C++ engine path or map path not configured. Skipping C++ engine execution.")
  1589. return True # Return True assuming it's optional or handled elsewhere
  1590. engine_cmd = [
  1591. str(config.engine_path),
  1592. str(config.map_path),
  1593. str(config.output_dir),
  1594. str(config.x_offset),
  1595. str(config.y_offset)
  1596. ]
  1597. print(f"--- Running C++ Preprocessing Engine ---")
  1598. print(f"Command: {' '.join(engine_cmd)}")
  1599. try:
  1600. result = subprocess.run(
  1601. engine_cmd,
  1602. check=True, # Raise exception on non-zero exit code
  1603. capture_output=True, # Capture stdout/stderr
  1604. text=True, # Decode output as text
  1605. # cwd=config.engine_path.parent # Run from the engine's directory? Or script's? Adjust if needed.
  1606. )
  1607. print("C++ Engine Output:")
  1608. print(result.stdout)
  1609. if result.stderr:
  1610. print("C++ Engine Error Output:")
  1611. print(result.stderr)
  1612. print("--- C++ Engine Finished Successfully ---")
  1613. return True
  1614. except FileNotFoundError:
  1615. print(f"Error: C++ engine executable not found at {config.engine_path}.")
  1616. return False
  1617. except subprocess.CalledProcessError as e:
  1618. print(f"Error: C++ engine failed with exit code {e.returncode}.")
  1619. print("C++ Engine Output (stdout):")
  1620. print(e.stdout)
  1621. print("C++ Engine Output (stderr):")
  1622. print(e.stderr)
  1623. return False
  1624. except Exception as e:
  1625. print(f"An unexpected error occurred while running the C++ engine: {e}")
  1626. return False
  1627. if __name__ == "__main__":
  1628. pass