data_process.py 11 KB


  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. ##################################################################
  4. #
  5. # Copyright (c) 2023 CICV, Inc. All Rights Reserved
  6. #
  7. ##################################################################
  8. """
  9. @Authors: yangzihao(yangzihao@china-icv.cn)
  10. @Data: 2023/11/27
  11. @Last Modified: 2023/11/27
  12. @Summary: Csv data process functions
  13. """
  14. import os
  15. import numpy as np
  16. import pandas as pd
  17. from common import cal_velocity
  18. from data_info import CsvData
  19. import matplotlib.pyplot as plt
  20. class DataProcess(object):
  21. """
  22. The data process class. It is a template to get evaluation raw data and process the raw data.
  23. Attributes:
  24. """
  25. def __init__(self, data_path, config):
  26. self.data_path = data_path
  27. self.casePath = data_path
  28. # config info
  29. self.config = config
  30. # self.safe_config = config.config['safe']
  31. # self.function_config = config.config['function']
  32. # self.compliance_config = config.config['compliance']
  33. self.comfort_config = config.config['comfort']
  34. self.efficient_config = config.config['efficient']
  35. # data process
  36. self.ego_df = pd.DataFrame()
  37. self.object_df = pd.DataFrame()
  38. # self.driver_ctrl_df = pd.DataFrame()
  39. # self.road_mark_df = pd.DataFrame()
  40. # self.road_pos_df = pd.DataFrame()
  41. # self.traffic_light_df = pd.DataFrame()
  42. # self.traffic_signal_df = pd.DataFrame()
  43. # self.status_df = pd.DataFrame()
  44. self.obj_data = {}
  45. self.ego_data = {}
  46. self.obj_id_list = {}
  47. self.car_info = {}
  48. self.report_info = {}
  49. self.driver_ctrl_data = {}
  50. self._process()
  51. def _process(self):
  52. # self._merge_csv()
  53. self._read_csv()
  54. self._draw_track()
  55. # self._signal_mapping()
  56. # self.car_info = self._get_car_info(self.object_df)
  57. # self._compact_data()
  58. # self._abnormal_detect()
  59. # self._status_map(self.object_df)
  60. self._object_df_process()
  61. self.report_info = self._get_report_info(self.obj_data[1])
  62. self.driver_ctrl_data = self._get_driver_ctrl_data(self.ego_df)
  63. def _read_csv(self):
  64. """
  65. Read csv files to dataframe.
  66. Args:
  67. data_path: A str of the path of csv files
  68. Returns:
  69. No returns.
  70. """
  71. # self.object_df = pd.read_csv(os.path.join(self.data_path, 'merged_ObjState.csv'))
  72. # self.object_df = pd.read_csv(os.path.join(self.data_path, 'ObjState.csv'))
  73. self.ego_df = pd.read_csv(os.path.join(self.data_path, 'EgoState_pji.csv'))
  74. self.ego_df['playerId'] = 1
  75. # self.driver_ctrl_df = pd.read_csv(os.path.join(self.data_path, 'DriverCtrl.csv'))
  76. # self.road_mark_df = pd.read_csv(os.path.join(self.data_path, 'RoadMark.csv'))
  77. # self.road_pos_df = pd.read_csv(os.path.join(self.data_path, 'RoadPos.csv'))
  78. # self.traffic_light_df = pd.read_csv(os.path.join(self.data_path, 'TrafficLight.csv'))
  79. # self.traffic_signal_df = pd.read_csv(os.path.join(self.data_path, 'TrafficSign.csv'))
  80. # self.lane_info_df = pd.read_csv(os.path.join(self.data_path, 'LaneInfo.csv')).drop_duplicates()
  81. # self.status_df = pd.read_csv(os.path.join(self.data_path, 'VehicleState.csv'))
  82. # self.vehicle_sys_df = pd.read_csv(
  83. # os.path.join(self.data_path, 'VehicleSystems.csv')).drop_duplicates() # 车灯信息
  84. def _draw_track(self):
  85. """
  86. """
  87. df = self.ego_df.copy()
  88. plt.scatter(df['posX'], df['posY'], c=df['simTime'], s=0.1)
  89. plt.axis('equal')
  90. # 添加坐标轴标签和标题
  91. plt.xlabel('posX')
  92. plt.ylabel('posY')
  93. plt.title('Trajectory')
  94. # 显示图形
  95. plt.savefig(os.path.join(self.casePath, "./track.png"))
  96. plt.close()
  97. def _signal_mapping(self):
  98. pass
  99. # singal mapping
  100. # signal_json = r'./signal.json'
  101. # signal_dict = json2dict(signal_json)
  102. # df_objectstate = signal_name_map(df_objectstate, signal_dict, 'objectState')
  103. # df_roadmark = signal_name_map(df_roadmark, signal_dict, 'roadMark')
  104. # df_roadpos = signal_name_map(df_roadpos, signal_dict, 'roadPos')
  105. # df_trafficlight = signal_name_map(df_trafficlight, signal_dict, 'trafficLight')
  106. # df_trafficsignal = signal_name_map(df_trafficsignal, signal_dict, 'trafficSignal')
  107. # df_drivectrl = signal_name_map(df_drivectrl, signal_dict, 'driverCtrl')
  108. # df_laneinfo = signal_name_map(df_laneinfo, signal_dict, 'laneInfo')
  109. # df_status = signal_name_map(df_status, signal_dict, 'statusMachine')
  110. # df_vehiclesys = signal_name_map(df_vehiclesys, signal_dict, 'vehicleSys')
  111. def _get_car_info(self, df):
  112. """
  113. Args:
  114. df:
  115. Returns:
  116. """
  117. first_row = df[df['playerId'] == 1].iloc[0].to_dict()
  118. length = first_row['dimX']
  119. width = first_row['dimY']
  120. height = first_row['dimZ']
  121. offset = first_row['offX']
  122. car_info = {
  123. "length": length,
  124. "width": width,
  125. "height": height,
  126. "offset": offset
  127. }
  128. return car_info
  129. def _compact_data(self):
  130. """
  131. Extra necessary data from dataframes.
  132. Returns:
  133. """
  134. self.object_df = self.object_df[CsvData.OBJECT_INFO].copy()
  135. def _abnormal_detect(self): # head and tail detect
  136. """
  137. Detect the head of the csv whether begin with 0 or not.
  138. Returns:
  139. A dataframe, which 'time' column begin with 0.
  140. """
  141. pass
  142. def _mileage_cal(self, df):
  143. """
  144. Calculate mileage of given df.
  145. Args:
  146. df: A dataframe of driving data.
  147. Returns:
  148. mileage: A float of the mileage(meter) of the driving data.
  149. """
  150. travelDist = df['traveledDist'].values.tolist()
  151. travelDist = [x for x in travelDist if not np.isnan(x)]
  152. # mile_start = df['travelDist'].iloc[0]
  153. mile_start = travelDist[0]
  154. mile_end = travelDist[-1]
  155. mileage = mile_end - mile_start
  156. return mileage
  157. def _duration_cal(self, df):
  158. """
  159. Calculate duration of given df.
  160. Args:
  161. df: A dataframe of driving data.
  162. Returns:
  163. duration: A float of the duration(second) of the driving data.
  164. """
  165. time_start = df['simTime'].iloc[0]
  166. time_end = df['simTime'].iloc[-1]
  167. duration = time_end - time_start
  168. return duration
  169. def _get_report_info(self, df):
  170. """
  171. Get report infomation from dataframe.
  172. Args:
  173. df: A dataframe of driving data.
  174. Returns:
  175. report_info: A dict of report infomation.
  176. """
  177. mileage = self._mileage_cal(df)
  178. duration = self._duration_cal(df)
  179. report_info = {
  180. "mileage": mileage,
  181. "duration": duration
  182. }
  183. return report_info
  184. # def _status_mapping(self, df):
  185. # df['ACC_status'] = df['ACC_status'].apply(lambda x: acc_status_mapping(x))
  186. # df['LKA_status'] = df['LKA_status'].apply(lambda x: lka_status_mapping(x))
  187. # df['LDW_status'] = df['LDW_status'].apply(lambda x: ldw_status_mapping(x))
  188. def _object_df_process(self):
  189. """
  190. Process the data of object dataframe. Save the data groupby object_ID.
  191. Returns:
  192. No returns.
  193. """
  194. data = self.ego_df.copy()
  195. data.rename(
  196. columns={"speedY": "lat_v", "speedX": "lon_v", "accelY": "lat_acc", "accelX": "lon_acc", "dimZ": "speedH"},
  197. inplace=True)
  198. # calculate common parameters
  199. # data['lat_v'] = data['speedX'] * np.sin(data['posH']) * -1 + data['speedY'] * np.cos(data['posH'])
  200. # data['lon_v'] = data['speedX'] * np.cos(data['posH']) + data['speedY'] * np.sin(data['posH'])
  201. data['v'] = data.apply(lambda row: cal_velocity(row['lat_v'], row['lon_v']), axis=1)
  202. data['time_diff'] = data['simTime'].diff()
  203. data['avg_speed'] = (data['v'] + data['v'].shift()) / 2 # 计算每个时间间隔的平均速度
  204. data['distance_increment'] = data['avg_speed'] * data['time_diff'] # 计算每个时间间隔的距离增量
  205. # 计算当前里程
  206. data['traveledDist'] = data['distance_increment'].cumsum()
  207. data['traveledDist'] = data['traveledDist'].fillna(0)
  208. # calculate acceleraton components
  209. # data['lat_acc'] = data['accelX'] * np.sin(data['posH']) * -1 + data['accelY'] * np.cos(data['posH'])
  210. # data['lon_acc'] = data['accelX'] * np.cos(data['posH']) + data['accelY'] * np.sin(data['posH'])
  211. data['accel'] = data.apply(lambda row: cal_velocity(row['lat_acc'], row['lon_acc']), axis=1)
  212. # data.rename(columns={"yawrate_roc": "accelH"}, inplace=True)
  213. self.object_df = data.copy()
  214. # calculate respective parameters
  215. for obj_id, obj_data in data.groupby("playerId"):
  216. self.obj_data[obj_id] = obj_data
  217. self.obj_data[obj_id]['lat_acc_diff'] = self.obj_data[obj_id]['lat_acc'].diff()
  218. self.obj_data[obj_id]['lon_acc_diff'] = self.obj_data[obj_id]['lon_acc'].diff()
  219. self.obj_data[obj_id]['speedH_diff'] = self.obj_data[obj_id]['speedH'].diff()
  220. self.obj_data[obj_id]['time_diff'] = self.obj_data[obj_id]['simTime'].diff()
  221. # self.obj_data['avg_speed'] = (self.obj_data['v'] + self.obj_data['v'].shift()) / 2 # 计算每个时间间隔的平均速度
  222. # self.obj_data['distance_increment'] = self.obj_data['avg_speed'] * self.obj_data['time_diff'] / 3.6 # 计算每个时间间隔的距离增量
  223. #
  224. # # 计算当前里程
  225. # self.obj_data['travelDist'] = self.obj_data['distance_increment'].cumsum()
  226. # self.obj_data['travelDist'] = self.obj_data['travelDist'].fillna(0)
  227. self.obj_data[obj_id]['lat_acc_roc'] = self.obj_data[obj_id]['lat_acc_diff'] / self.obj_data[obj_id][
  228. 'time_diff']
  229. self.obj_data[obj_id]['lon_acc_roc'] = self.obj_data[obj_id]['lon_acc_diff'] / self.obj_data[obj_id][
  230. 'time_diff']
  231. self.obj_data[obj_id]['accelH'] = self.obj_data[obj_id]['speedH_diff'] / self.obj_data[obj_id][
  232. 'time_diff']
  233. # get object id list
  234. self.obj_id_list = list(self.obj_data.keys())
  235. self.ego_data = self.obj_data[1]
  236. def _get_driver_ctrl_data(self, df):
  237. """
  238. Process and get drive ctrl information. Such as brake pedal, throttle pedal and steering wheel.
  239. Args:
  240. df: A dataframe of driver ctrl data.
  241. Returns:
  242. driver_ctrl_data: A dict of driver ctrl info.
  243. """
  244. time_list = df['simTime'].values.tolist()
  245. frame_list = df['simFrame'].values.tolist()
  246. # df['brakePedal'] = df['brakePedal'] * 100
  247. # brakePedal_list = df['brakePedal'].values.tolist()
  248. # df['throttlePedal'] = df['throttlePedal'] * 100
  249. # throttlePedal_list = df['throttlePedal'].values.tolist()
  250. # steeringWheel_list = df['steeringWheel'].values.tolist()
  251. driver_ctrl_data = {
  252. "time_list": time_list,
  253. "frame_list": frame_list,
  254. # "brakePedal_list": brakePedal_list,
  255. # "throttlePedal_list": throttlePedal_list,
  256. # "steeringWheel_list": steeringWheel_list
  257. }
  258. return driver_ctrl_data