data_process.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. ##################################################################
  4. #
  5. # Copyright (c) 2024 CICV, Inc. All Rights Reserved
  6. #
  7. ##################################################################
  8. """
  9. @Authors: zhanghaiwen(zhanghaiwen@china-icv.cn)
  10. @Data: 2024/10/17
  11. @Last Modified: 2024/10/17
  12. @Summary: Evaluation functions
  13. """
  14. import os
  15. import numpy as np
  16. import pandas as pd
  17. import yaml
  18. from modules.lib.log_manager import LogManager
  19. # from lib import log # 确保这个路径是正确的,或者调整它
  20. # logger = None # 初始化为 None
  21. class DataPreprocessing:
  22. def __init__(self, data_path, config_path):
  23. # Initialize paths and data containers
  24. # self.logger = log.get_logger()
  25. self.data_path = data_path
  26. self.case_name = os.path.basename(os.path.dirname(data_path))
  27. self.config_path = config_path
  28. # Initialize DataFrames
  29. self.object_df = pd.DataFrame()
  30. self.driver_ctrl_df = pd.DataFrame()
  31. self.vehicle_sys_df = pd.DataFrame()
  32. self.ego_data_df = pd.DataFrame()
  33. # Environment data
  34. self.lane_info_df = pd.DataFrame()
  35. self.road_mark_df = pd.DataFrame()
  36. self.road_pos_df = pd.DataFrame()
  37. self.traffic_light_df = pd.DataFrame()
  38. self.traffic_signal_df = pd.DataFrame()
  39. self.vehicle_config = {}
  40. self.safety_config = {}
  41. self.comfort_config = {}
  42. self.efficient_config = {}
  43. self.function_config = {}
  44. self.traffic_config = {}
  45. # Initialize data for later processing
  46. self.obj_data = {}
  47. self.ego_data = {}
  48. self.obj_id_list = []
  49. # Data quality level
  50. self.data_quality_level = 15
  51. # Process mode and prepare report information
  52. self._process_mode()
  53. self._get_yaml_config()
  54. self.report_info = self._get_report_info(self.obj_data.get(1, pd.DataFrame()))
  55. def _process_mode(self):
  56. """Handle different processing modes."""
  57. self._real_process_object_df()
  58. def _get_yaml_config(self):
  59. with open(self.config_path, 'r') as f:
  60. full_config = yaml.safe_load(f)
  61. modules = ["vehicle", "T_threshold", "safety", "comfort", "efficient", "function", "traffic"]
  62. # 1. 初始化 vehicle_config(不涉及 T_threshold 合并)
  63. self.vehicle_config = full_config[modules[0]]
  64. # 2. 定义 T_threshold_config(封装为字典)
  65. T_threshold_config = {"T_threshold": full_config[modules[1]]}
  66. # 3. 统一处理需要合并 T_threshold 的模块
  67. # 3.1 safety_config
  68. self.safety_config = {"safety": full_config[modules[2]]}
  69. self.safety_config.update(T_threshold_config)
  70. # 3.2 comfort_config
  71. self.comfort_config = {"comfort": full_config[modules[3]]}
  72. self.comfort_config.update(T_threshold_config)
  73. # 3.3 efficient_config
  74. self.efficient_config = {"efficient": full_config[modules[4]]}
  75. self.efficient_config.update(T_threshold_config)
  76. # 3.4 function_config
  77. self.function_config = {"function": full_config[modules[5]]}
  78. self.function_config.update(T_threshold_config)
  79. # 3.5 traffic_config
  80. self.traffic_config = {"traffic": full_config[modules[6]]}
  81. self.traffic_config.update(T_threshold_config)
  82. @staticmethod
  83. def cal_velocity(lat_v, lon_v):
  84. """Calculate resultant velocity from lateral and longitudinal components."""
  85. return np.sqrt(lat_v**2 + lon_v**2)
  86. def _real_process_object_df(self):
  87. """Process the object DataFrame."""
  88. try:
  89. # 读取 CSV 文件
  90. merged_csv_path = os.path.join(self.data_path, "merged_ObjState.csv")
  91. self.object_df = pd.read_csv(
  92. merged_csv_path, dtype={"simTime": float}
  93. ).drop_duplicates(subset=["simTime", "simFrame", "playerId"])
  94. data = self.object_df.copy()
  95. # Calculate common parameters
  96. data["lat_v"] = data["speedY"] * 1
  97. data["lon_v"] = data["speedX"] * 1
  98. data["v"] = data.apply(
  99. lambda row: self.cal_velocity(row["lat_v"], row["lon_v"]), axis=1
  100. )
  101. data["v"] = data["v"] # km/h
  102. # Calculate acceleration components
  103. data["lat_acc"] = data["accelY"] * 1
  104. data["lon_acc"] = data["accelX"] * 1
  105. data["accel"] = data.apply(
  106. lambda row: self.cal_velocity(row["lat_acc"], row["lon_acc"]), axis=1
  107. )
  108. # Drop rows with missing 'type' and reset index
  109. data = data.dropna(subset=["type"])
  110. data.reset_index(drop=True, inplace=True)
  111. self.object_df = data.copy()
  112. # Calculate respective parameters for each object
  113. for obj_id, obj_data in data.groupby("playerId"):
  114. self.obj_data[obj_id] = self._calculate_object_parameters(obj_data)
  115. # Get object id list
  116. EGO_PLAYER_ID = 1
  117. self.obj_id_list = list(self.obj_data.keys())
  118. self.ego_data = self.obj_data[EGO_PLAYER_ID]
  119. except Exception as e:
  120. # self.logger.error(f"Error processing object DataFrame: {e}")
  121. raise
  122. def _calculate_object_parameters(self, obj_data):
  123. """Calculate additional parameters for a single object."""
  124. obj_data = obj_data.copy()
  125. obj_data["time_diff"] = obj_data["simTime"].diff()
  126. obj_data["lat_acc_diff"] = obj_data["lat_acc"].diff()
  127. obj_data["lon_acc_diff"] = obj_data["lon_acc"].diff()
  128. obj_data["yawrate_diff"] = obj_data["speedH"].diff()
  129. obj_data["lat_acc_roc"] = (
  130. obj_data["lat_acc_diff"] / obj_data["time_diff"]
  131. ).replace([np.inf, -np.inf], [9999, -9999])
  132. obj_data["lon_acc_roc"] = (
  133. obj_data["lon_acc_diff"] / obj_data["time_diff"]
  134. ).replace([np.inf, -np.inf], [9999, -9999])
  135. obj_data["accelH"] = (
  136. obj_data["yawrate_diff"] / obj_data["time_diff"]
  137. ).replace([np.inf, -np.inf], [9999, -9999])
  138. return obj_data
  139. def _get_driver_ctrl_data(self, df):
  140. """
  141. Process and get driver control information.
  142. Args:
  143. df: A DataFrame containing driver control data.
  144. Returns:
  145. A dictionary of driver control info.
  146. """
  147. driver_ctrl_data = {
  148. "time_list": df["simTime"].round(2).tolist(),
  149. "frame_list": df["simFrame"].tolist(),
  150. "brakePedal_list": (
  151. (df["brakePedal"] * 100).tolist()
  152. if df["brakePedal"].max() < 1
  153. else df["brakePedal"].tolist()
  154. ),
  155. "throttlePedal_list": (
  156. (df["throttlePedal"] * 100).tolist()
  157. if df["throttlePedal"].max() < 1
  158. else df["throttlePedal"].tolist()
  159. ),
  160. "steeringWheel_list": df["steeringWheel"].tolist(),
  161. }
  162. return driver_ctrl_data
  163. def _get_report_info(self, df):
  164. """Extract report information from the DataFrame."""
  165. mileage = self._mileage_cal(df)
  166. duration = self._duration_cal(df)
  167. return {"mileage": mileage, "duration": duration}
  168. def _mileage_cal(self, df):
  169. """Calculate mileage based on the driving data."""
  170. if df["travelDist"].nunique() == 1:
  171. df["time_diff"] = df["simTime"].diff().fillna(0)
  172. df["avg_speed"] = (df["v"] + df["v"].shift()).fillna(0) / 2
  173. df["distance_increment"] = df["avg_speed"] * df["time_diff"] / 3.6
  174. df["travelDist"] = df["distance_increment"].cumsum().fillna(0)
  175. mileage = round(df["travelDist"].iloc[-1] - df["travelDist"].iloc[0], 2)
  176. return mileage
  177. return 0.0 # Return 0 if travelDist is not valid
  178. def _duration_cal(self, df):
  179. """Calculate duration of the driving data."""
  180. return df["simTime"].iloc[-1] - df["simTime"].iloc[0]