pgvil.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. from pathlib import Path
  2. from typing import Dict, Any, Optional
  3. import pandas as pd
  4. import numpy as np
  5. from pyproj import Proj
  6. from dataclasses import dataclass, field
  7. from typing import Dict, Optional
  8. from pathlib import Path
  9. import pandas as pd
  10. # from core.error_handler import ErrorHandler
  11. # from core.config_manager import get_config
  12. import sys
  13. import csv
  14. import os
  15. import zipfile
  16. import argparse
  17. from genpy import Message
  18. import shutil
  19. import tempfile
  20. import pandas as pd
  21. import subprocess
  22. import pandas as pd
  23. import numpy as np
  24. @dataclass
  25. class Config:
  26. """PGVIL处理器配置类"""
  27. zip_path: Path
  28. output_path: Path
  29. engine_path: Optional[Path] = None
  30. map_path: Optional[Path] = None
  31. utm_zone: int = 51 # Example UTM zone
  32. x_offset: float = 0.0
  33. y_offset: float = 0.0
  34. def __post_init__(self):
  35. # Use output_path directly as output_dir to avoid nested directories
  36. self.output_dir = self.output_path
  37. self.output_dir.mkdir(parents=True, exist_ok=True)
  38. def run_pgvil_engine(config: Config):
  39. """Runs the external C++ preprocessing engine."""
  40. if not config.engine_path or not config.map_path:
  41. print("C++ engine path or map path not configured. Skipping C++ engine execution.")
  42. return True # Return True assuming it's optional or handled elsewhere
  43. engine_cmd = [
  44. str(config.engine_path),
  45. str(config.map_path),
  46. str(config.output_dir),
  47. str(config.x_offset),
  48. str(config.y_offset)
  49. ]
  50. print(f"run_pgvil_engine: x={config.x_offset}, y={config.y_offset}")
  51. print(f"--- Running C++ Preprocessing Engine ---")
  52. print(f"Command: {' '.join(engine_cmd)}")
  53. try:
  54. result = subprocess.run(
  55. engine_cmd,
  56. check=True, # Raise exception on non-zero exit code
  57. capture_output=True, # Capture stdout/stderr
  58. text=True, # Decode output as text
  59. # cwd=config.engine_path.parent # Run from the engine's directory? Or script's? Adjust if needed.
  60. )
  61. print("C++ Engine Output:")
  62. print(result.stdout)
  63. if result.stderr:
  64. print("C++ Engine Error Output:")
  65. print(result.stderr)
  66. print("--- C++ Engine Finished Successfully ---")
  67. return True
  68. except FileNotFoundError:
  69. print(f"Error: C++ engine executable not found at {config.engine_path}.")
  70. return False
  71. except subprocess.CalledProcessError as e:
  72. print(f"Error: C++ engine failed with exit code {e.returncode}.")
  73. print("C++ Engine Output (stdout):")
  74. print(e.stdout)
  75. print("C++ Engine Output (stderr):")
  76. print(e.stderr)
  77. return False
  78. except Exception as e:
  79. print(f"An unexpected error occurred while running the C++ engine: {e}")
  80. return False
  81. def remove_conflicting_columns(df_object, df_csv_info):
  82. """
  83. 找到连个表中除(simTime, simFrame, or playerId) 都存在的列,删掉df_csv_info中对应的重复列
  84. """
  85. renamed = {}
  86. conflicting_columns = set(df_object.columns) & set(df_csv_info.columns)
  87. for col in conflicting_columns:
  88. # if col not in ["simTime", "simFrame", "playerId"]:
  89. if col not in ["simFrame", "playerId"]:
  90. del df_csv_info[col]
  91. return df_csv_info
  92. def align_simtime_by_simframe(df):
  93. # 创建一个映射,将simFrame映射到其对应的simTime代表值
  94. sim_frame_to_time_map = df.groupby('simFrame')['simTime'].first().to_dict()
  95. frames_sorted = sorted(sim_frame_to_time_map.keys())
  96. times_sorted = [sim_frame_to_time_map[f] for f in frames_sorted]
  97. times_head = times_sorted[:100]
  98. if len(times_head) > 2:
  99. diffs = np.diff(times_head)
  100. diffs_rounded = np.round(diffs, 3)
  101. values, counts = np.unique(diffs_rounded, return_counts=True)
  102. mode_dt = values[np.argmax(counts)]
  103. new_frame_to_time_map = {
  104. frame: round(times_sorted[0] + mode_dt * i, 3)
  105. for i, frame in enumerate(frames_sorted)
  106. }
  107. else:
  108. new_frame_to_time_map = sim_frame_to_time_map
  109. # 使用映射来更新DataFrame中的simTime值
  110. df['simTime'] = df['simFrame'].map(new_frame_to_time_map)
  111. # 检查simFrame列是否为空或包含非整数类型的数据
  112. if df['simFrame'].empty or not df['simFrame'].apply(lambda x: isinstance(x, (int, np.integer))).all():
  113. return df
  114. # 识别缺失的simFrame
  115. all_frames = np.arange(df['simFrame'].min(), df['simFrame'].max() + 1)
  116. missing_frames = set(all_frames) - set(df['simFrame'])
  117. new_rows = []
  118. # 填补缺失的simFrame
  119. for missing_frame in missing_frames:
  120. prev_frame = df[df['simFrame'] < missing_frame]['simFrame'].max()
  121. next_frame = df[df['simFrame'] > missing_frame]['simFrame'].min()
  122. if prev_frame is not None and next_frame is not None:
  123. prev_row = df[df['simFrame'] == prev_frame].iloc[0]
  124. next_row = df[df['simFrame'] == next_frame].iloc[0]
  125. # 计算平均值并创建新行
  126. new_row = prev_row.copy()
  127. new_row['simFrame'] = missing_frame
  128. for col in df.columns:
  129. if col not in ['simTime', 'simFrame']:
  130. new_row[col] = (prev_row[col] + next_row[col]) / 2
  131. # 更新simTime值
  132. new_row['simTime'] = new_frame_to_time_map.get(missing_frame, np.nan)
  133. # 将新行添加到DataFrame中
  134. new_rows.append(new_row)
  135. if new_rows:
  136. df = pd.concat([df, pd.DataFrame(new_rows)], ignore_index=True)
  137. return df.sort_values(by='simFrame').reset_index(drop=True)
  138. def mergecopy_by_simtime(merged_df, external_df, ignore_cols, prefix=None, ):
  139. """
  140. 将external_df中所有的字段,基于nearest simTime,批量合并到merged_df中simtime相同的所有行中
  141. """
  142. useful_cols = [col for col in external_df.columns if col not in ignore_cols]
  143. for col in useful_cols:
  144. col_name = f"{prefix}_{col}" if prefix else col
  145. mapping = external_df.set_index('nearest_simTime')[col].to_dict()
  146. merged_df[col_name] = merged_df['simTime'].map(mapping)
  147. return merged_df
  148. def read_csv_with_filled_columns(file_path):
  149. try:
  150. # 确保 file_path 是字符串类型且有效
  151. if not isinstance(file_path, str):
  152. raise ValueError("提供的文件路径无效")
  153. if not os.path.exists(file_path):
  154. raise FileNotFoundError(f"文件 {file_path} 不存在")
  155. # 使用 on_bad_lines='skip' 跳过格式错误的行
  156. df = pd.read_csv(file_path, on_bad_lines='skip') # 跳过格式错误的行
  157. # 强制填充缺失的列为 NaN,确保列数一致
  158. if not df.empty: # 确保 df 不为空
  159. df.fillna(np.nan, inplace=True) # 用 NaN 填充所有空值
  160. return df
  161. except Exception as e:
  162. print(f"读取 CSV 文件 {file_path} 时发生错误: {str(e)}")
  163. return pd.DataFrame() # 返回空的 DataFrame 以便继续处理
  164. def convert_heading(posH_rad):
  165. # 将弧度转换为角度
  166. angle_deg = np.degrees(posH_rad)
  167. # 逆时针东为0 => 顺时针北为0,相当于 new_angle = (90 - angle_deg) % 360
  168. heading_deg = (90 - angle_deg) % 360
  169. return round(heading_deg, 3)
  170. # def find_closest_time(sim_time, sim_time_to_index, tolerance=0.01):
  171. # # 找到最接近的时间点,并且该时间点的差异小于 tolerance
  172. # closest_time = min(sim_time_to_index.keys(), key=lambda y: abs(y - sim_time) if abs(y - sim_time) < tolerance else float('inf'))
  173. # return closest_time
  174. def find_closest_time(sim_time, sim_time_to_index, tolerance=0.04):
  175. # 计算所有 simTime 的差值
  176. diffs = {k: abs(k - sim_time) for k in sim_time_to_index.keys()}
  177. # Step 1: 优先在容差范围内找
  178. within_tolerance = {k: v for k, v in diffs.items() if v <= tolerance}
  179. if within_tolerance:
  180. return min(within_tolerance, key=within_tolerance.get)
  181. # Step 2: 容忍失败,强制返回最近值
  182. return min(diffs, key=diffs.get)
  183. def convert_numeric_columns(df):
  184. numeric_cols = df.select_dtypes(include=['number']).columns
  185. # 强制保留为 int 类型的列,其余为float
  186. int_columns = ["simFrame", "playerId", "type", "stateMask", "ctrlId", "ifwarning"]
  187. for col in numeric_cols:
  188. if col in int_columns and col in df.columns:
  189. df[col] = df[col].astype(int)
  190. else:
  191. df[col] = df[col].astype(float)
  192. return df
  193. def safe_convert_numeric(df, name):
  194. if df is None or df.empty:
  195. return df
  196. return convert_numeric_columns(df)
  197. def safe_align_simtime(df, name):
  198. if df is None or df.empty:
  199. return df
  200. return align_simtime_by_simframe(df)
  201. class PGVILProcessor:
  202. """PGVIL数据处理器,实现PGVIL特有的处理逻辑"""
  203. def __init__(self, config: Config):
  204. self.config = config
  205. def process_zip(self) -> Path:
  206. """处理输入ZIP文件,并返回输出目录路径
  207. zip_path
  208. output_dir
  209. """
  210. print(f"Processing ZIP: {self.config.zip_path}")
  211. zip_path = self.config.zip_path
  212. output_dir = Path(self.config.output_dir) # 将目录路径转换为Path对象
  213. # 创建以 ZIP 名称为子目录的提取目录
  214. zip_name = Path(zip_path).stem
  215. output_dir.mkdir(parents=True, exist_ok=True)
  216. # 提取HMIdata和RDBdata中的CSV文件
  217. with zipfile.ZipFile(zip_path, 'r') as zip_ref:
  218. for name in zip_ref.namelist():
  219. if ('HMIdata/' in name or 'RDBdata/' in name) and name.endswith('.csv'):
  220. # 原 zip 内的子路径最后一部分为文件名
  221. filename = os.path.basename(name)
  222. src = zip_ref.open(name)
  223. dst_path = output_dir / filename
  224. # print(f"提取 {name} 到 {dst_path}")
  225. with open(dst_path, 'wb') as dst_file:
  226. shutil.copyfileobj(src, dst_file)
  227. # 更新 config 中的输出目录为刚才的子目录
  228. self.config.output_dir = output_dir
  229. return output_dir
  230. def merge_csv_files(self):
  231. x_offset = self.config.x_offset
  232. y_offset = self.config.y_offset
  233. data_path = self.config.output_dir
  234. # 定义CSV文件路径
  235. try:
  236. obj_state_path = os.path.join(data_path, "ObjState.csv")
  237. ego_map_path = os.path.join(data_path, "EgoMap.csv")
  238. lane_map_path = os.path.join(data_path, "LaneMap.csv")
  239. laneINfo_path = os.path.join(data_path, "LaneInfo.csv")
  240. roadPos_path = os.path.join(data_path, "RoadPos.csv")
  241. vehicleystems_path = os.path.join(data_path, "VehicleSystems.csv")
  242. trafficlight_path = os.path.join(data_path, "TrafficLight.csv")
  243. function_path = os.path.join(data_path, "Function.csv")
  244. except FileNotFoundError:
  245. raise Exception("File not found")
  246. print("777777:")
  247. df_object = read_csv_with_filled_columns(obj_state_path)
  248. df_map_info = read_csv_with_filled_columns(ego_map_path)
  249. df_lane_map = read_csv_with_filled_columns(lane_map_path)
  250. df_laneINfo = read_csv_with_filled_columns(laneINfo_path)
  251. df_roadPos = read_csv_with_filled_columns(roadPos_path)
  252. df_vehicleystems = read_csv_with_filled_columns(vehicleystems_path)
  253. df_trafficlight = read_csv_with_filled_columns(trafficlight_path)
  254. df_function = None
  255. if os.path.exists(function_path):
  256. df_function = read_csv_with_filled_columns(function_path)
  257. # 对df_object中的posX和posY应用偏置
  258. if df_object is not None and not df_object.empty:
  259. df_object['posX'] += x_offset
  260. df_object['posY'] += y_offset
  261. # 对齐simTime和simFrame
  262. df_object = safe_align_simtime(df_object, "df_object")
  263. df_map_info = safe_align_simtime(df_map_info, "df_map_info")
  264. df_lane_map = safe_align_simtime(df_lane_map, "df_lane_map")
  265. df_laneINfo = safe_align_simtime(df_laneINfo, "df_laneINfo")
  266. df_roadPos = safe_align_simtime(df_roadPos, "df_roadPos")
  267. df_vehicleystems = safe_align_simtime(df_vehicleystems, "df_vehicleystems")
  268. df_trafficlight = safe_align_simtime(df_trafficlight, "df_trafficlight")
  269. print("0000000<<<<<<<<<<<<<<<<<<<<<")
  270. df_object = safe_convert_numeric(df_object, "df_object")
  271. df_map_info = safe_convert_numeric(df_map_info, "df_map_info")
  272. df_lane_map = safe_convert_numeric(df_lane_map, "df_lane_map")
  273. df_laneINfo = safe_convert_numeric(df_laneINfo, "df_laneINfo")
  274. df_roadPos = safe_convert_numeric(df_roadPos, "df_roadPos")
  275. df_vehicleystems = safe_convert_numeric(df_vehicleystems, "df_vehicleystems")
  276. df_trafficlight = safe_convert_numeric(df_trafficlight, "df_trafficlight")
  277. print("1111111<<<<<<<<<<<<<<<<<<<<<")
  278. if df_function is not None:
  279. df_function = safe_convert_numeric(df_function, "df_function")
  280. # 使用simTime, simFrame, playerId合并ObjState和df_roadPos
  281. del_roadPos = remove_conflicting_columns(df_object, df_roadPos)
  282. if df_object is not None and not df_object.empty and df_roadPos is not None and not df_roadPos.empty:
  283. merged_df = df_object.merge(df_roadPos, on=["simFrame", "playerId"], how="inner")
  284. # merged_df = pd.merge(df_object, del_roadPos, on=["simTime", "simFrame", "playerId"], how="left").drop_duplicates()
  285. # 创建一个映射,存储 df_object 中每个 simTime 值及其对应的行索引
  286. sim_time_to_index = {row['simTime']: idx for idx, row in merged_df.iterrows()}
  287. ego_df = merged_df[merged_df["playerId"] == 1].copy() # 拆成ego和other
  288. other_df = merged_df[merged_df["playerId"] != 1].copy()
  289. print("444444<<<<<<<<<<<<<<<<<<<<<")
  290. # ego merge del_trafficlight
  291. if df_trafficlight is not None and not df_trafficlight.empty:
  292. df_trafficlight = df_trafficlight[df_trafficlight["ctrlId"] == 3][
  293. ["simTime", "simFrame", "stateMask", "ctrlId"]].copy()
  294. df_trafficlight = df_trafficlight.drop_duplicates(subset=["simTime", "simFrame", "ctrlId"]).reset_index(
  295. drop=True)
  296. if df_trafficlight.empty:
  297. ego_df["stateMask"] = np.nan
  298. ego_df["ctrlId"] = np.nan
  299. else:
  300. ego_df = pd.merge(ego_df, df_trafficlight, on=["simTime", "simFrame"], how="left")
  301. else:
  302. ego_df["stateMask"] = np.nan
  303. ego_df["ctrlId"] = np.nan
  304. merged_df = pd.concat([ego_df, other_df], ignore_index=True)
  305. print("33333333<<<<<<<<<<<<<<<<<<<<<")
  306. if df_laneINfo is not None and not df_laneINfo.empty:
  307. del_laneINfo = remove_conflicting_columns(merged_df, df_laneINfo)
  308. # merged_df = pd.merge(merged_df, del_laneINfo, on=["simTime", "simFrame", "playerId"], how="left").drop_duplicates()
  309. merged_df = pd.merge(merged_df, del_laneINfo, on=["simFrame", "playerId"], how="left").drop_duplicates()
  310. if df_map_info is not None and not df_map_info.empty:
  311. del_ego_map = remove_conflicting_columns(merged_df, df_map_info)
  312. # merged_df = pd.merge(merged_df, del_ego_map, on=["simTime", "simFrame", "playerId"], how="left")
  313. merged_df = pd.merge(merged_df, del_ego_map, on=["simFrame", "playerId"], how="left")
  314. if df_lane_map is not None and not df_lane_map.empty:
  315. del_lane_map = remove_conflicting_columns(merged_df, df_lane_map)
  316. # merged_df = pd.merge(merged_df, del_lane_map, on=["simTime", "simFrame", "playerId"], how="left").drop_duplicates()
  317. merged_df = pd.merge(merged_df, del_lane_map, on=["simFrame", "playerId"], how="left").drop_duplicates()
  318. if df_vehicleystems is not None and not df_vehicleystems.empty:
  319. del_vehicleystems = remove_conflicting_columns(merged_df, df_vehicleystems)
  320. # merged_df = pd.merge(merged_df, del_vehicleystems, on=["simTime", "simFrame", "playerId"], how="left").drop_duplicates()
  321. merged_df = pd.merge(merged_df, del_vehicleystems, on=["simFrame", "playerId"],
  322. how="left").drop_duplicates()
  323. if df_function is not None and not df_function.empty:
  324. tolerance = 0.01
  325. df_function = df_function.sort_values(by='simTime').reset_index(drop=True) # 按simTime列排序
  326. # 找到 function.csv 中每个 simTime 值在 df_object 中的最近时间点
  327. df_function['nearest_simTime'] = df_function['simTime'].apply(
  328. lambda x: find_closest_time(x, sim_time_to_index, tolerance))
  329. df_function['nearest_index'] = df_function['nearest_simTime'].map(sim_time_to_index)
  330. # 确保df_function中的nearest_index为整数类型,且去掉NaN值
  331. df_function_renamed = df_function.rename(
  332. columns={'simTime': 'function_simTime'}) # 重命名 df_function 中的 simTime 列
  333. df_function_valid = df_function_renamed.dropna(subset=['nearest_index']).copy()
  334. df_function_valid['nearest_index'] = df_function_valid['nearest_index'].astype(int)
  335. ignore_cols = ['function_simTime', 'nearest_simTime', 'nearest_index']
  336. merged_df = mergecopy_by_simtime(merged_df, df_function_valid, ignore_cols)
  337. """
  338. def check_matching(df_function, sim_time_to_index, tolerance=0.01):
  339. #检查 function.csv 中的所有行是否都成功匹配
  340. # 计算每个 simTime 对应的 nearest_simTime
  341. df_function['nearest_simTime'] = df_function['simTime'].apply(lambda x: find_closest_time(x, sim_time_to_index, tolerance))
  342. # 检查是否有没有匹配到的行
  343. unmatched_rows = df_function[df_function['nearest_simTime'].isna()]
  344. if not unmatched_rows.empty:
  345. print(f"没有匹配上的行: {unmatched_rows}")
  346. else:
  347. print("所有行都成功匹配!")
  348. # 统计匹配上了的行数和没有匹配上的行数
  349. total_rows = len(df_function)
  350. matched_rows = len(df_function) - len(unmatched_rows)
  351. print(f"总行数: {total_rows}, 匹配上的行数: {matched_rows}, 没有匹配上的行数: {len(unmatched_rows)}")
  352. return unmatched_rows
  353. # 调用检查函数
  354. unmatched_rows = check_matching(df_function, sim_time_to_index, tolerance=0.01)
  355. # 获取最后一行的 simTime
  356. last_row_simtime = df_function.iloc[-1]['simTime']
  357. print(f"最后一行的 simTime: {last_row_simtime}")
  358. # 获取最后一行的 nearest_simTime
  359. last_row_nearest_simtime = df_function.iloc[-1]['nearest_simTime']
  360. print(f"最后一行的 nearest_simTime: {last_row_nearest_simtime}")
  361. """
  362. columns_to_convert = ['posH']
  363. for col in columns_to_convert:
  364. if col in merged_df.columns:
  365. merged_df[col] = merged_df[col].apply(convert_heading)
  366. # 将弧度/秒转换为度/秒
  367. rad_to_deg = 180 / np.pi
  368. for col in ['speedH', 'accelH']:
  369. if col in merged_df.columns:
  370. merged_df[col] = merged_df[col] * rad_to_deg
  371. if 'posP' in merged_df.columns:
  372. merged_df.rename(columns={'posP': 'pitch_rate'}, inplace=True)
  373. merged_df['pitch_rate'] = merged_df['pitch_rate'].apply(convert_heading)
  374. if 'posR' in merged_df.columns:
  375. merged_df.rename(columns={'posR': 'roll_rate'}, inplace=True)
  376. merged_df['roll_rate'] = merged_df['roll_rate'].apply(convert_heading)
  377. # 先使用 infer_objects 来确保类型一致
  378. merged_df = merged_df.infer_objects()
  379. merged_df.fillna(np.nan, inplace=True) # 确保空值填充为 NaN
  380. merged_df = merged_df.sort_values(by=["simTime", "simFrame", "playerId"]).reset_index(drop=True)
  381. merged_csv_path = Path(data_path) / "merged_ObjState.csv"
  382. # merged_df.to_csv(merged_csv_path, index=False,na_rep="NaN")
  383. merged_df.to_csv(merged_csv_path, index=False)
  384. return merged_csv_path