traffic.py 26 KB


  1. import sys
  2. import os
  3. import math
  4. import numpy as np
  5. import pandas as pd
  6. from pathlib import Path
  7. root_path = Path(__file__).resolve().parent.parent
  8. sys.path.append(str(root_path))
  9. from models.common.score import Score
  10. from config import config
  11. from models.common import log # 确保这个路径是正确的,或者调整它
  12. log_path = config.LOG_PATH
  13. logger = log.get_logger(log_path)
  14. class TrafficViolation:
  15. """交通违规基类"""
  16. def __init__(self, vehicle_id: str):
  17. self.vehicle_id = vehicle_id
  18. self.violation_count = 0 # 违规次数
  19. def get_violation_info(self):
  20. return f"车辆ID: {self.vehicle_id} - 违规次数: {self.violation_count}"
  21. class OvertakingViolation(object):
  22. """超车违规类"""
  23. def __init__(self, df_data):
  24. print("OvertakingViolation-------------------------")
  25. self.traffic_violations_type = "超车违规类"
  26. self.ego_data = df_data.obj_data[1] # Copy to avoid modifying the original DataFrame
  27. self.laneinfo_new_data = df_data.lane_info_new_df
  28. self.roadinfo_data = df_data.road_info_df
  29. self.drivercrtrl_data = df_data.driver_ctrl_df
  30. self.interinfo_data = df_data.inter_info_df
  31. self.crosswalkinfo_data = df_data.cross_walk_df
  32. self.object_data = df_data.object_df
  33. self.overtake_on_right_count = 0
  34. self.overtake_when_turn_around_count = 0
  35. self.overtake_when_passing_car_count = 0
  36. self.overtake_in_forbid_lane_count = 0
  37. self.overtake_in_ramp_count = 0
  38. self.overtake_in_tunnel_count = 0
  39. self.overtake_on_accelerate_lane_count = 0
  40. self.overtake_on_decelerate_lane_count = 0
  41. self.overtake_in_different_senerios_count = 0
  42. def different_road_area_simtime(self, df, threshold = 0.5):
  43. if not df:
  44. return []
  45. simtime_group = []
  46. current_simtime_group = [df[0]]
  47. for i in range(1, len(df)):
  48. if abs(df[i] - df[i-1]) <= threshold:
  49. current_simtime_group.append(df[i])
  50. else:
  51. simtime_group.append(current_simtime_group)
  52. current_simtime_group = [df[i]]
  53. simtime_group.append(current_simtime_group)
  54. return simtime_group
  55. def _is_overtake(self, lane_id, dx, dy, ego_speedx, ego_speedy):
  56. lane_start_id = lane_id[0]
  57. lane_end_id = lane_id[-1]
  58. dx_start = dx[0]
  59. dx_end = dx[-1]
  60. dy_start = dy[0]
  61. dy_end = dy[-1]
  62. ego_start_speedx = ego_speedx[0]
  63. ego_end_speedx = ego_speedx[-1]
  64. ego_start_speedy = ego_speedy[0]
  65. ego_end_speedy = ego_speedy[-1]
  66. if (lane_start_id == lane_end_id
  67. and (dx_start * ego_start_speedx + dy_start * ego_start_speedy >= 0)
  68. and (dx_end * ego_end_speedx + dy_end * ego_end_speedy < 0)):
  69. # 返回读取的帧和下一个起始帧ID(即当前帧之后的第11帧,如果存在的话)
  70. return True
  71. else:
  72. return False
  73. def _is_dxy_of_car(self, df, id):
  74. '''
  75. :param df: objstate.csv and so on
  76. :param id: playerId
  77. :param string_type: posX/Y or speedX/Y and so on
  78. :return: dataframe of dx/y and so on
  79. '''
  80. car_dx = df[df['playerId'] == id]['posX'].values - df[df['playerId'] == 1]['posX'].values
  81. car_dy = df[df['playerId'] == id]['posY'].values - df[df['playerId'] == 1]['posY'].values
  82. return car_dx, car_dy
  83. # 在前车右侧超车、会车时超车、前车掉头时超车
  84. def illegal_overtake_with_car(self, window_width=250):
  85. # 获取csv文件中最短的帧数
  86. frame_id_length = min(len(self.object_data['simFrame']), len(self.laneinfo_new_data['simFrame']),
  87. len(self.drivercrtrl_data['simFrame']))
  88. start_frame_id = self.object_data['simFrame'].iloc[0] # 获取起始点的帧数
  89. while (start_frame_id + window_width) < frame_id_length:
  90. simframe_window = list(np.arange(start_frame_id, start_frame_id + window_width))
  91. # 读取滑动窗口的dataframe数据
  92. obj_data_frames = self.object_data[self.object_data['simFrame'].isin(simframe_window)]
  93. lane_data_frames = self.laneinfo_new_data[self.laneinfo_new_data['simFrame'].isin(simframe_window)]
  94. driver_data_frames = self.drivercrtrl_data[self.drivercrtrl_data['simFrame'].isin(simframe_window)]
  95. # 读取前后的laneId
  96. lane_id = lane_data_frames['lane_id']
  97. # 读取前后方向盘转角steeringWheel
  98. driverctrl_start_state = driver_data_frames['steeringWheel'].iloc[0]
  99. driverctrl_end_state = driver_data_frames['steeringWheel'].iloc[-1]
  100. # 读取车辆前后的位置信息
  101. dx, dy = self._is_dxy_of_car(obj_data_frames, 2)
  102. ego_speedx = obj_data_frames[obj_data_frames['playerId'] == 1]['speedX']
  103. ego_speedy = obj_data_frames[obj_data_frames['playerId'] == 1]['speedY']
  104. obj_speedx = obj_data_frames[obj_data_frames['playerId'] == 2]['speedX']
  105. obj_speedy = obj_data_frames[obj_data_frames['playerId'] == 2]['speedY']
  106. if len(obj_data_frames[obj_data_frames['playerId'] == 3]) > 0:
  107. obj1_start_speedx = obj_data_frames[obj_data_frames['playerId'] == 3]['speedX'].iloc[0]
  108. obj1_start_speedy = obj_data_frames[obj_data_frames['playerId'] == 3]['speedY'].iloc[0]
  109. if ego_speedx.iloc[0] * obj1_start_speedx + ego_speedy.iloc[0] * obj1_start_speedy < 0:
  110. self.overtake_when_passing_car_count += self._is_overtake(lane_id, dx, dy, ego_speedx,
  111. ego_speedy)
  112. start_frame_id += window_width
  113. '''
  114. 如果滑动窗口开始和最后的laneid一致;
  115. 方向盘转角前后方向相反(开始方向盘转角向右后来方向盘转角向左);
  116. 自车和前车的位置发生的交换;
  117. 则认为右超车
  118. '''
  119. if driverctrl_start_state > 0 and driverctrl_end_state < 0:
  120. self.overtake_on_right_count += self._is_overtake(lane_id, dx, dy, ego_speedx, ego_speedy)
  121. start_frame_id += window_width
  122. elif ego_speedx.iloc[0] * obj_speedx.iloc[0] + ego_speedy.iloc[0] * obj_speedy.iloc[0] < 0:
  123. self.overtake_when_turn_around_count += self._is_overtake(lane_id, dx, dy, ego_speedx, ego_speedy)
  124. start_frame_id += window_width
  125. else:
  126. start_frame_id += 1
  127. print(
  128. f"在会车时超车{self.overtake_when_passing_car_count}次, 右侧超车{self.overtake_on_right_count}次, 在前车掉头时超车{self.overtake_when_turn_around_count}次")
  129. # 借道超车场景
  130. def overtake_in_forbid_lane(self):
  131. simTime = self.object_data[self.object_data['playerId'] == 2]['simTime'].tolist()
  132. simtime_devide = self.different_road_area_simtime(simTime)
  133. for simtime in simtime_devide:
  134. lane_overtake = self.laneinfo_new_data[self.laneinfo_new_data['simTime'].isin(simtime)]
  135. try:
  136. lane_type = lane_overtake['lane_type'].tolist()
  137. if 2 in lane_type:
  138. self.overtake_in_forbid_lane_count += 1
  139. except Exception as e:
  140. print("数据缺少lane_type信息")
  141. print(f"在不该占用车道超车{self.overtake_in_forbid_lane_count}次")
  142. # 在匝道超车
  143. def overtake_in_ramp_area(self):
  144. ramp_simtime_list = self.roadinfo_data[(self.roadinfo_data['road_type'] == 19)]['simTime'].tolist()
  145. ramp_simTime_list = self.different_road_area_simtime(ramp_simtime_list)
  146. for ramp_simtime in ramp_simTime_list:
  147. lane_id = self.laneinfo_new_data['lane_id'].tolist()
  148. objstate_in_ramp = self.object_data[self.object_data['simTime'].isin(ramp_simtime)]
  149. dx, dy = self._is_dxy_of_car(objstate_in_ramp, 2)
  150. ego_speedx = objstate_in_ramp[objstate_in_ramp['playerId'] == 1]['speedX'].tolist()
  151. ego_speedy = objstate_in_ramp[objstate_in_ramp['playerId'] == 1]['speedY'].tolist()
  152. if len(lane_id) > 0:
  153. self.overtake_in_ramp_count += self._is_overtake(lane_id, dx, dy, ego_speedx, ego_speedy)
  154. else:
  155. continue
  156. print(f"在匝道超车{self.overtake_in_ramp_count}次")
  157. def overtake_in_tunnel_area(self):
  158. tunnel_simtime_list = self.roadinfo_data[(self.roadinfo_data['road_type'] == 15)]['simTime'].tolist()
  159. tunnel_simTime_list = self.different_road_area_simtime(tunnel_simtime_list)
  160. for tunnel_simtime in tunnel_simTime_list:
  161. lane_id = self.laneinfo_new_data['lane_id'].tolist()
  162. objstate_in_tunnel = self.object_data[self.object_data['simTime'].isin(tunnel_simtime)]
  163. dx, dy = self._is_dxy_of_car(objstate_in_tunnel, 2)
  164. ego_speedx = objstate_in_tunnel[objstate_in_tunnel['playerId'] == 1]['speedX'].tolist()
  165. ego_speedy = objstate_in_tunnel[objstate_in_tunnel['playerId'] == 1]['speedY'].tolist()
  166. if len(lane_id) > 0:
  167. self.overtake_in_tunnel_count += self._is_overtake(lane_id, dx, dy, ego_speedx, ego_speedy)
  168. else:
  169. continue
  170. print(f"在suidao超车{self.overtake_in_tunnel_count}次")
  171. # 加速车道超车
  172. def overtake_on_accelerate_lane(self):
  173. accelerate_simtime_list = \
  174. self.laneinfo_new_data[self.laneinfo_new_data['lane_type'] == 2]['simTime'].tolist()
  175. accelerate_simTime_list = self.different_road_area_simtime(accelerate_simtime_list)
  176. for accelerate_simtime in accelerate_simTime_list:
  177. lane_id = self.laneinfo_new_data['lane_id'].tolist()
  178. objstate_in_accelerate = self.object_data[self.object_data['simTime'].isin(accelerate_simtime)]
  179. dx, dy = self._is_dxy_of_car(objstate_in_accelerate, 2)
  180. ego_speedx = objstate_in_accelerate[objstate_in_accelerate['playerId'] == 1]['speedX'].tolist()
  181. ego_speedy = objstate_in_accelerate[objstate_in_accelerate['playerId'] == 1]['speedY'].tolist()
  182. self.overtake_on_accelerate_lane_count += self._is_overtake(lane_id, dx, dy, ego_speedx, ego_speedy)
  183. print(f"在加速车道超车{self.overtake_on_accelerate_lane_count}次")
  184. # 减速车道超车
  185. def overtake_on_decelerate_lane(self):
  186. decelerate_simtime_list = self.laneinfo_new_data[(self.laneinfo_new_data['lane_type'] == 3)]['simTime'].tolist()
  187. decelerate_simTime_list = self.different_road_area_simtime(decelerate_simtime_list)
  188. for decelerate_simtime in decelerate_simTime_list:
  189. lane_id = self.laneinfo_new_data[self.laneinfo_new_data['playerId'] == 1]['id']
  190. objstate_in_decelerate = self.object_data[self.object_data['simTime'].isin(decelerate_simtime)]
  191. dx, dy = self._is_dxy_of_car(objstate_in_decelerate, 2)
  192. ego_speedx = objstate_in_decelerate[objstate_in_decelerate['playerId'] == 1]['speedX']
  193. ego_speedy = objstate_in_decelerate[objstate_in_decelerate['playerId'] == 1]['speedY']
  194. self.overtake_on_decelerate_lane_count += self._is_overtake(lane_id, dx, dy, ego_speedx, ego_speedy)
  195. print(f"在减速车道超车{self.overtake_on_decelerate_lane_count}次")
  196. # 在交叉路口
  197. def overtake_in_different_senerios(self):
  198. crossroad_simTime = self.interinfo_data[self.interinfo_data['interid'] != 10000][
  199. 'simTime'].tolist() # 判断是路口或者隧道区域
  200. # 筛选在路口或者隧道区域的objectstate、driverctrl、laneinfo数据
  201. crossroad_objstate = self.object_data[self.object_data['simTime'].isin(crossroad_simTime)]
  202. crossroad_laneinfo = self.laneinfo_new_data[self.laneinfo_new_data['simTime'].isin(crossroad_simTime)]
  203. # 读取前后的laneId
  204. lane_id = crossroad_laneinfo['lane_id']
  205. # 读取车辆前后的位置信息
  206. dx, dy = self._is_dxy_of_car(crossroad_objstate, 2)
  207. ego_speedx = crossroad_objstate[crossroad_objstate['playerId'] == 1]['speedX'].tolist()
  208. ego_speedy = crossroad_objstate[crossroad_objstate['playerId'] == 1]['speedY'].tolist()
  209. '''
  210. 如果滑动窗口开始和最后的laneid一致;
  211. 自车和前车的位置发生的交换;
  212. 则认为发生超车
  213. '''
  214. if len(lane_id) > 0:
  215. self.overtake_in_different_senerios_count += self._is_overtake(lane_id, dx, dy, ego_speedx, ego_speedy)
  216. else:
  217. pass
  218. print(f"在路口超车{self.overtake_in_different_senerios_count}次")
  219. def overtake_statistic(self):
  220. self.overtake_in_forbid_lane()
  221. self.overtake_on_decelerate_lane()
  222. self.overtake_on_accelerate_lane()
  223. self.overtake_in_ramp_area()
  224. self.overtake_in_tunnel_area()
  225. self.overtake_in_different_senerios()
  226. self.illegal_overtake_with_car()
  227. self.calculated_value = {
  228. "overtake_on_right": self.overtake_on_right_count,
  229. "overtake_when_turn_around": self.overtake_when_turn_around_count,
  230. "overtake_when_passing_car": self.overtake_when_passing_car_count,
  231. "overtake_in_forbid_lane": self.overtake_in_forbid_lane_count,
  232. "overtake_in_ramp": self.overtake_in_ramp_count,
  233. "overtake_in_tunnel": self.overtake_in_tunnel_count,
  234. "overtake_on_accelerate_lane": self.overtake_on_accelerate_lane_count,
  235. "overtake_on_decelerate_lane": self.overtake_on_decelerate_lane_count,
  236. "overtake_in_different_senerios": self.overtake_in_different_senerios_count
  237. }
  238. return self.calculated_value
  239. class SlowdownViolation(object):
  240. """减速让行违规类"""
  241. def __init__(self, df_data):
  242. print("OvertakingViolation-------------------------")
  243. self.traffic_violations_type = "减速让行违规类"
  244. self.ego_data = df_data.obj_data[1] # Copy to avoid modifying the original DataFrame
  245. self.laneinfo_new_data = df_data.lane_info_new_df
  246. self.roadinfo_data = df_data.road_info_df
  247. self.drivercrtrl_data = df_data.driver_ctrl_df
  248. self.interinfo_data = df_data.inter_info_df
  249. self.crosswalkinfo_data = df_data.cross_walk_df
  250. self.object_data = df_data.object_df
  251. self.slow_down_in_crosswalk_count = 0
  252. self.pedestrian_in_crosswalk_count = 0
  253. def different_road_area_simtime(self, df, threshold = 0.5):
  254. if not df:
  255. return []
  256. simtime_group = []
  257. current_simtime_group = [df[0]]
  258. for i in range(1, len(df)):
  259. if abs(df[i] - df[i-1]) <= threshold:
  260. current_simtime_group.append(df[i])
  261. else:
  262. simtime_group.append(current_simtime_group)
  263. current_simtime_group = [df[i]]
  264. simtime_group.append(current_simtime_group)
  265. return simtime_group
  266. def slow_down_in_crosswalk(self):
  267. crosswalk_simTime = self.crosswalkinfo_data[self.crosswalkinfo_data['crossid'] != 20000][
  268. 'simTime'].tolist() # 判断是路口或者隧道区域
  269. crosswalk_simTime_devide = self.different_road_area_simtime(crosswalk_simTime)
  270. for crosswalk_simtime in crosswalk_simTime_devide:
  271. # 筛选在人行横道区域的crosswalk_objstate数据
  272. crosswalk_objstate = self.object_data[self.object_data['simTime'].isin(crosswalk_simtime)]
  273. ego_speedx = crosswalk_objstate[crosswalk_objstate['playerId'] == 1]['speedX'].tolist()
  274. ego_speedy = crosswalk_objstate[crosswalk_objstate['playerId'] == 1]['speedY'].tolist()
  275. ego_speed = np.sqrt(ego_speedx**2 + ego_speedy**2)
  276. if(max(ego_speed)*3.6 > 15):
  277. self.slow_down_in_crosswalk_count += 1
  278. print(f"在人行横道超车{self.slow_down_in_crosswalk_count}次")
  279. def pedestrian_in_crosswalk(self):
  280. crosswalk_simTime = self.crosswalkinfo_data[self.crosswalkinfo_data['crossid'] != 20000][
  281. 'simTime'].tolist() # 判断是路口或者隧道区域
  282. crosswalk_simTime_devide = self.different_road_area_simtime(crosswalk_simTime)
  283. for crosswalk_simtime in crosswalk_simTime_devide:
  284. crosswalk_objstate = self.object_data[self.object_data['simTime'].isin(crosswalk_simtime)]
  285. if len(crosswalk_objstate[crosswalk_objstate['playerId'] == 5]) > 0:
  286. pedestrian_simtime = crosswalk_objstate[crosswalk_objstate['playerId'] == 5]['simTime']
  287. pedestrian_objstate = crosswalk_objstate[crosswalk_objstate['simTime'].isin(pedestrian_simtime)]
  288. ego_speed = np.sqrt(pedestrian_objstate['speedX']**2 + pedestrian_objstate['speedY']**2)
  289. if ego_speed.any() > 0:
  290. self.pedestrian_in_crosswalk_count += 1
  291. def slowdown_statistic(self):
  292. self.slow_down_in_crosswalk()
  293. self.pedestrian_in_crosswalk()
  294. self.calculated_value = {
  295. "slow_down_in_crosswalk": self.slow_down_in_crosswalk_count,
  296. "pedestrian_in_crosswalk": self.pedestrian_in_crosswalk_count
  297. }
  298. return self.calculated_value
  299. class WrongWayViolation(TrafficViolation):
  300. """逆行违规类"""
  301. def __init__(self, vehicle_id: str, wrong_way_location: str):
  302. super().__init__(vehicle_id)
  303. self.wrong_way_location = wrong_way_location
  304. def get_violation_info(self):
  305. return f"{super().get_violation_info()} - 逆行发生在 {self.wrong_way_location}"
  306. class SpeedingViolation(object):
  307. """超速违规类"""
  308. def __init__(self, df_data):
  309. print("SpeedingViolation-------------------------")
  310. self.traffic_violations_type = "超速违规类"
  311. self.data = df_data.obj_data[1] # Copy to avoid modifying the original DataFrame
  312. # 初始化违规统计
  313. self.violation_counts = {
  314. "urbanExpresswayOrHighwaySpeedOverLimit50": 0,
  315. "urbanExpresswayOrHighwaySpeedOverLimit20to50": 0,
  316. "urbanExpresswayOrHighwaySpeedOverLimit0to20": 0,
  317. "urbanExpresswayOrHighwaySpeedUnderLimit": 0,
  318. "generalRoadSpeedOverLimit50": 0,
  319. "generalRoadSpeedOverLimit20to50": 0,
  320. }
  321. # 处理数据
  322. self.process_violations()
  323. def process_violations(self):
  324. """处理数据帧,检查超速和其他违规行为"""
  325. # 定义速度限制
  326. self.data["speed_limit_max"] = self.data["road_speed_max"]
  327. self.data["speed_limit_min"] = self.data["road_speed_min"]
  328. # 提取有效道路类型
  329. # 提取有效道路类型
  330. urban_expressway_or_highway = {1, 2} # 使用大括号直接创建集合
  331. general_road = {3} # 直接创建包含一个元素的集合
  332. # 违规判定
  333. conditions = [
  334. (
  335. self.data["road_fc"].isin(urban_expressway_or_highway)
  336. & (self.data["v"] > self.data["speed_limit_max"] * 1.5)
  337. ),
  338. (
  339. self.data["road_fc"].isin(urban_expressway_or_highway)
  340. & (self.data["v"] > self.data["speed_limit_max"] * 1.2)
  341. & (self.data["v"] <= self.data["speed_limit_max"] * 1.5)
  342. ),
  343. (
  344. self.data["road_fc"].isin(urban_expressway_or_highway)
  345. & (self.data["v"] > self.data["speed_limit_max"])
  346. & (self.data["v"] <= self.data["speed_limit_max"] * 1.2)
  347. ),
  348. (
  349. self.data["road_fc"].isin(urban_expressway_or_highway)
  350. & (self.data["v"] < self.data["speed_limit_min"])
  351. ),
  352. (
  353. self.data["road_fc"].isin(general_road)
  354. & (self.data["v"] > self.data["speed_limit_max"] * 1.5)
  355. ),
  356. (
  357. self.data["road_fc"].isin(general_road)
  358. & (self.data["v"] > self.data["speed_limit_max"] * 1.2)
  359. & (self.data["v"] <= self.data["speed_limit_max"] * 1.5)
  360. ),
  361. ]
  362. violation_types = [
  363. "urbanExpresswayOrHighwaySpeedOverLimit50",
  364. "urbanExpresswayOrHighwaySpeedOverLimit20to50",
  365. "urbanExpresswayOrHighwaySpeedOverLimit0to20",
  366. "urbanExpresswayOrHighwaySpeedUnderLimit",
  367. "generalRoadSpeedOverLimit50",
  368. "generalRoadSpeedOverLimit20to50",
  369. ]
  370. # 设置违规类型
  371. self.data["violation_type"] = None
  372. for condition, violation_type in zip(conditions, violation_types):
  373. self.data.loc[condition, "violation_type"] = violation_type
  374. # 统计各类违规情况
  375. self.violation_counts = self.data["violation_type"].value_counts().to_dict()
  376. # def get_violation_info(self) -> str:
  377. # return (
  378. # f"超过50%: {self.violation_counts.get('urbanExpresswayOrHighwaySpeedOverLimit50', 0)}, "
  379. # f"超过20%-50%: {self.violation_counts.get('urbanExpresswayOrHighwaySpeedOverLimit20to50', 0)}, "
  380. # f"超过20%以内: {self.violation_counts.get('urbanExpresswayOrHighwaySpeedOverLimit0to20', 0)}, "
  381. # f"低于最低时速: {self.violation_counts.get('urbanExpresswayOrHighwaySpeedUnderLimit', 0)}, "
  382. # f"非高速公路超速50%以上: {self.violation_counts.get('generalRoadSpeedOverLimit50', 0)}, "
  383. # f"非高速公路超速20%-50%: {self.violation_counts.get('generalRoadSpeedOverLimit20to50', 0)}"
  384. # )
  385. def report_statistic(self) -> str:
  386. return (
  387. f"高速或者城市快速路超过50%: {self.violation_counts.get('urbanExpresswayOrHighwaySpeedOverLimit50', 0)},\n"
  388. f"高速或者城市快速路超过20%-50%: {self.violation_counts.get('urbanExpresswayOrHighwaySpeedOverLimit20to50', 0)},\n"
  389. f"高速或者城市快速路超过20%以内: {self.violation_counts.get('urbanExpresswayOrHighwaySpeedOverLimit0to20', 0)},\n"
  390. f"高速或者城市快速路低于最低时速: {self.violation_counts.get('urbanExpresswayOrHighwaySpeedUnderLimit', 0)},\n"
  391. f"非高速公路超速50%以上: {self.violation_counts.get('generalRoadSpeedOverLimit50', 0)},\n"
  392. f"非高速公路超速20%-50%: {self.violation_counts.get('generalRoadSpeedOverLimit20to50', 0)}"
  393. )
  394. class ParkingViolation(TrafficViolation):
  395. """违规停车类"""
  396. def __init__(self, vehicle_id: str, parking_location: str):
  397. super().__init__(vehicle_id)
  398. self.traffic_violations_type = " 违规停车类 "
  399. self.parking_location = parking_location
  400. self.city_road_violation_count = 0
  401. self.city_emergency_violation_count = 0
  402. self.highway_road_violation_count = 0
  403. self.highway_emergency_violation_count = 0
  404. def get_violation_info(self):
  405. return (
  406. f"{super().get_violation_info()} - 违规停车发生在 {self.parking_location}"
  407. )
  408. def city_road_violation(self):
  409. # 驾驶机动车在城市快速路行车道上违法停车的
  410. self.city_road_violation_count += 1
  411. return self.city_road_violation_count, True
  412. def city_emergency_violation(self):
  413. # 驾驶机动车非紧急情况下在城市快速路应急车道上停车的
  414. self.city_emergency_violation_count += 1
  415. return self.city_emergency_violation_count, True
  416. def highway_road_violation(self):
  417. # 驾驶机动车在高速公路行车道上违法停车的
  418. self.highway_road_violation_count += 1
  419. return self.highway_road_violation_count, True
  420. def highway_emergency_violation(self):
  421. # 驾驶机动车非紧急情况下在高速公路应急车道上停车的
  422. self.highway_emergency_violation_count += 1
  423. return self.highway_emergency_violation_count, True
  424. class TrafficLightViolation(TrafficViolation):
  425. """违反交通灯类"""
  426. def __init__(self, vehicle_id: str, violation_type: str, traffic_location: str):
  427. super().__init__(vehicle_id)
  428. self.traffic_violations_type = " 违反交通灯类 "
  429. self.violation_type = violation_type # 可能是 '闯红灯' 等
  430. self.traffic_location = traffic_location
  431. def get_violation_info(self):
  432. return f"{super().get_violation_info()} - {self.violation_type} 发生在 {self.traffic_location}"
  433. class ViolationManager:
  434. """违规管理类,用于管理所有违规行为"""
  435. def __init__(self, data_processed):
  436. self.violations = []
  437. self.data = data_processed.obj_data[1]
  438. self.ego_df = self.data[config.TRIFFIC_INFO].copy()
  439. self.SpeedingViolation = SpeedingViolation
  440. self.ParkingViolation = ParkingViolation
  441. self.TrafficLightViolation = TrafficLightViolation
  442. def add_violation(self, violation: TrafficViolation):
  443. self.violations.append(violation)
  444. def report_violations(self):
  445. for violation in self.violations:
  446. print(violation.get_violation_info())
  447. # 示例使用
  448. if __name__ == "__main__":
  449. manager = ViolationManager()
  450. # 添加超车违规
  451. overtaking_violation = OvertakingViolation("ABC123", "1号公路")
  452. manager.add_violation(overtaking_violation)
  453. # 添加逆行违规
  454. wrong_way_violation = WrongWayViolation("XYZ789", "2号公路")
  455. manager.add_violation(wrong_way_violation)
  456. # 添加超速违规
  457. speeding_violation = SpeedingViolation("LMN456", 85.0, 60.0)
  458. manager.add_violation(speeding_violation)
  459. # 添加违规停车
  460. parking_violation = ParkingViolation("DEF012", "商场停车场")
  461. manager.add_violation(parking_violation)
  462. # 添加违反交通灯行为
  463. traffic_light_violation = TrafficLightViolation("GHI345", "闯红灯", "第三街口")
  464. manager.add_violation(traffic_light_violation)
  465. # 报告所有违规行为
  466. manager.report_violations()