compliance.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. ##################################################################
  4. #
  5. # Copyright (c) 2023 CICV, Inc. All Rights Reserved
  6. #
  7. ##################################################################
  8. """
  9. @Authors: xieguijin(xieguijin@china-icv.cn), yangzihao(yangzihao@china-icv.cn)
  10. @Data: 2023/08/21
  11. @Last Modified: 2023/08/21
  12. @Summary: Compliance metrics
  13. """
  14. import sys
  15. sys.path.append('../common')
  16. sys.path.append('../modules')
  17. sys.path.append('../results')
  18. import numpy as np
  19. import pandas as pd
  20. from common import score_grade, string_concatenate, replace_key_with_value, score_over_100
  21. from scipy.spatial.distance import euclidean
  22. class traffic_rule(object):
  23. """
  24. Class for achieving compliance metrics for autonomous driving.
  25. Attributes:
  26. droadMark_df: Roadmark data, stored in dataframe format.
  27. """
  28. def __init__(self, data_processed, scoreModel):
  29. self.scoreModel = scoreModel
  30. self.roadMark_df = data_processed.road_mark_df
  31. self.trafficLight_df = data_processed.traffic_light_df
  32. self.trafficSignal_df = data_processed.traffic_signal_df
  33. self.objState_df = data_processed.object_df
  34. self.ego_df = self.objState_df[(self.objState_df.playerId == 1) & (self.objState_df.type == 1)]
  35. self.violation_df = pd.DataFrame(columns=['start_time', 'end_time', 'start_frame', 'end_frame', 'violation'])
  36. def _filter_groups_by_frame_period(self, grouped_violations):
  37. """
  38. Filter groups by a minimum continuous frame period.
  39. """
  40. CONTINUOUS_FRAME_PERIOD = 13
  41. return [g for g in grouped_violations if len(g[0]) >= CONTINUOUS_FRAME_PERIOD]
  42. def _extract_violation_times(self, filtered_groups):
  43. """
  44. Create a dataframe with start and end times for each violation group.
  45. """
  46. return [[g[0][0], g[0][-1]] for g in filtered_groups]
  47. def _filter_solid_lines(self, Dimy):
  48. """
  49. Filter solid lines within the player's lateral distance.
  50. """
  51. dist_line = self.roadMark_df[self.roadMark_df["type"] == 1]
  52. dist_line = dist_line.reset_index()
  53. return dist_line[abs(dist_line["lateralDist"].values) <= Dimy]
  54. def _group_violations(self, dist_press):
  55. """
  56. Group violations by continuous frames.
  57. """
  58. t_list = dist_press['simTime'].values.tolist()
  59. f_list = dist_press['simFrame'].values.tolist()
  60. group_time = []
  61. group_frame = []
  62. sub_group_time = []
  63. sub_group_frame = []
  64. for i in range(len(f_list)):
  65. if not sub_group_time or t_list[i] - t_list[i - 1] <= 1:
  66. sub_group_time.append(t_list[i])
  67. sub_group_frame.append(f_list[i])
  68. else:
  69. group_time.append(sub_group_time)
  70. group_frame.append(sub_group_frame)
  71. sub_group_time = [t_list[i]]
  72. sub_group_frame = [f_list[i]]
  73. group_time.append(sub_group_time)
  74. group_frame.append(sub_group_frame)
  75. return list(zip(group_time, group_frame))
  76. def get_solid_line_violations(self):
  77. """
  78. Detect violations of pressing solid lines and return a dictionary with violation details.
  79. """
  80. # Extract relevant data
  81. Dimy = self.objState_df[self.objState_df["playerId"] == 1]["dimY"][0] / 2
  82. dist_press = self._filter_solid_lines(Dimy)
  83. grouped_violations = self._group_violations(dist_press)# Group violations by continuous frames
  84. filtered_groups = self._filter_groups_by_frame_period(grouped_violations)# Filter groups by minimum frame period
  85. # Calculate violation count and create violation dataframe
  86. press_line_count = len(filtered_groups)
  87. press_line_time = self._extract_violation_times(filtered_groups)
  88. if press_line_time:
  89. time_df = pd.DataFrame(press_line_time, columns=['start_time', 'end_time'])
  90. time_df['violation'] = '压实线'
  91. # Update violation dataframe
  92. self.violation_df = pd.concat([self.violation_df, press_line_time], ignore_index=True)
  93. # Create and return violation dictionary
  94. warning_count = 0
  95. press_line_dict = {
  96. 'metric': 'pressSolidLine',
  97. 'weight': 3,
  98. 'illegal_count': press_line_count,
  99. 'penalty_points': press_line_count * 3,
  100. 'penalty_money': press_line_count * 200,
  101. 'warning_count': warning_count,
  102. 'penalty_law': '《中华人民共和国道路交通安全法》第八十二条:机动车在高速公路上行驶,不得有下列行为:(三)骑、轧车行道分界线或者在路肩上行驶。'
  103. }
  104. return press_line_dict
  105. def normalize_angle(self, angle):
  106. """Normalize angle to the range [0, 360)."""
  107. difference = angle
  108. while difference >= 360:
  109. difference -= 360
  110. return difference
  111. def is_red_light(self, simTime, cycleTime, duration_start, duration_end):
  112. """Check if the current time corresponds to a red light phase."""
  113. divisor = simTime / cycleTime
  114. decimal_part = divisor - int(divisor)
  115. return duration_start <= decimal_part < duration_end
  116. def process_traffic_light(self, trafficLight_id):
  117. """Process a single traffic light and detect run red light events."""
  118. trafficLight_position = self.trafficSignal_df[self.trafficSignal_df["playerId"] == trafficLight_id].iloc[:1, :]
  119. if trafficLight_position.empty:
  120. return
  121. trafficLight_position_x = trafficLight_position['posX'].values[0]
  122. trafficLight_position_y = trafficLight_position['posY'].values[0]
  123. trafficLight_position_heading = trafficLight_position['posH'].values[0]
  124. trafficLight_character = self.trafficLight_df[self.trafficLight_df.id == trafficLight_id]
  125. cycleTime = trafficLight_character["cycleTime"].values[0]
  126. noPhases = trafficLight_character["noPhases"].values[0]
  127. # Calculate distances and headings
  128. self.ego_df["traffic_light_distance_absolute"] = self.ego_df[['posX', 'posY']].apply(
  129. lambda x: euclidean((trafficLight_position_x, trafficLight_position_y), (x['posX'], x['posY'])), axis=1)
  130. self.ego_df["traffic_light_h_diff"] = self.ego_df.apply(
  131. lambda x: abs(x['posH'] - trafficLight_position_heading) * 57.3, axis=1)
  132. self.ego_df["traffic_light_h_diff"] = self.ego_df["traffic_light_h_diff"].apply(self.normalize_angle)
  133. # Filter ego vehicles near the traffic light with the correct heading
  134. mask_traffic_light = ((self.ego_df['traffic_light_h_diff'] <= 210) & (
  135. self.ego_df['traffic_light_h_diff'] >= 150)) | (self.ego_df['traffic_light_h_diff'] <= 30) | (
  136. self.ego_df['traffic_light_h_diff'] >= 330)
  137. ego_near_light = self.ego_df[(self.ego_df.traffic_light_distance_absolute <= 10) & mask_traffic_light]
  138. if ego_near_light.empty:
  139. return
  140. # Check for red light violations
  141. ego_near_light["flag_red_traffic_light"] = 0
  142. type_list = trafficLight_character['violation'][:noPhases]
  143. duration = trafficLight_character['duration'][:noPhases]
  144. duration_correct = [0] * noPhases
  145. for number in range(noPhases):
  146. duration_correct[number] = sum(duration[:number + 1])
  147. type_current = type_list.values[number]
  148. if type_current == 1: # Red light phase
  149. if number == 0:
  150. duration_start = 0
  151. else:
  152. duration_start = duration_correct[number - 1]
  153. duration_end = duration_correct[number]
  154. ego_near_light["flag_red_traffic_light"] = ego_near_light.apply(
  155. lambda x: self.is_red_light(x['simTime'], cycleTime, duration_start, duration_end), axis=1)
  156. # Collect run red light events
  157. run_red_light_df = ego_near_light[ego_near_light['flag_red_traffic_light'] == 1]
  158. self.collect_run_red_light_events(run_red_light_df)
  159. def collect_run_red_light_events(self, run_red_light_df):
  160. grouped_events = self._group_violations(run_red_light_df)
  161. filtered_events = self._filter_groups_by_frame_period(grouped_events)
  162. violation_times = self._extract_violation_times(filtered_events)
  163. if violation_times:
  164. time_df = pd.DataFrame(violation_times, columns=['start_time', 'end_time'])
  165. time_df['violation'] = '闯红灯'
  166. self.violation_df = pd.concat([self.violation_df, time_df], ignore_index=True)
  167. def run_red_light_detection(self):
  168. """Main function to detect run red light events."""
  169. trafficLight_id_list = set(self.trafficLight_df["id"].tolist())
  170. run_red_light_count = 0
  171. for trafficLight_id in trafficLight_id_list:
  172. self.process_traffic_light(trafficLight_id)
  173. # 闯红灯次数统计(这里可以根据需要修改统计逻辑)
  174. if 'flag_red_traffic_light' in self.ego_df.columns and self.ego_df['flag_red_traffic_light'].any() == 1:
  175. run_red_light_count += 1
  176. run_red_light_dict = {
  177. 'metric': 'runRedLight',
  178. 'weight': 6,
  179. 'illegal_count': run_red_light_count,
  180. 'penalty_points': run_red_light_count * 6,
  181. 'penalty_money': run_red_light_count * 200,
  182. 'warning_count': 0,
  183. 'penalty_law': '《中华人民共和国道路交通安全法实施条例》第四十条:(二)红色叉形灯或者箭头灯亮时,禁止本车道车辆通行。'
  184. }
  185. return run_red_light_dict
  186. def _find_speed_violations(self):
  187. DimX = self.objState_df[self.objState_df["playerId"] == 1]["dimY"][0] / 2
  188. data_ego = self.objState_df[self.objState_df["playerId"] == 1]
  189. speed_limit_sign = self.trafficSignal_df[self.trafficSignal_df["type"] == 274]
  190. same_df_rate = pd.merge(speed_limit_sign, data_ego, on=['simTime', 'simFrame'], how='inner').reset_index()
  191. speed_df = same_df_rate[(abs(same_df_rate["posX_x"] - same_df_rate["posX_y"]) <= 7) & (
  192. abs(same_df_rate["posY_x"] - same_df_rate["posY_y"]) <= DimX)]
  193. speed_df["speed"] = np.sqrt(speed_df["speedX"] ** 2 + speed_df["speedY"] ** 2) * 3.6
  194. list_sign = speed_df[speed_df["speed"] > speed_df["value"]]
  195. return list_sign, speed_df
  196. def _calculate_overspeed_statistics(self, speed_df, list_sign):
  197. index_sign = list_sign.index.to_list()
  198. speed_df["flag_press"] = speed_df["simFrame"].apply(lambda x: 1 if x in list_sign["simFrame"] else 0)
  199. speed_df["diff_press"] = speed_df["flag_press"].diff()
  200. index_list = []
  201. subindex_list = []
  202. for i in range(len(index_sign)):
  203. if not subindex_list or index_sign[i] - index_sign[i - 1] == 1:
  204. subindex_list.append(index_sign[i])
  205. else:
  206. index_list.append(subindex_list)
  207. subindex_list = [index_sign[i]]
  208. index_list.append(subindex_list)
  209. overspeed_count_0_to_10 = 0
  210. overspeed_count_10_to_20 = 0
  211. overspeed_count_20_to_50 = 0
  212. overspeed_count_50_to_ = 0
  213. if index_list[0]:
  214. for i in range(len(index_list)):
  215. left = index_list[i][0]
  216. right = index_list[i][-1]
  217. df_tmp = speed_df.loc[left:right + 1]
  218. max_ratio = ((df_tmp["speed"] - df_tmp["value"]) / df_tmp["value"]).max()
  219. if 0 <= max_ratio < 0.1:
  220. overspeed_count_0_to_10 += 1
  221. elif 0.1 <= max_ratio < 0.2:
  222. overspeed_count_10_to_20 += 1
  223. elif 0.2 <= max_ratio < 0.5:
  224. overspeed_count_20_to_50 += 1
  225. elif max_ratio >= 0.5:
  226. overspeed_count_50_to_ += 1
  227. return (
  228. self._create_overspeed_dict(overspeed_count_0_to_10, 'overspeed10', 0, 0),
  229. self._create_overspeed_dict(overspeed_count_10_to_20, 'overspeed10_20', 0, 200),
  230. self._create_overspeed_dict(overspeed_count_20_to_50, 'overspeed20_50', 6, 200),
  231. self._create_overspeed_dict(overspeed_count_50_to_, 'overspeed50', 12, 2000)
  232. )
  233. def _create_overspeed_dict(self, count, metric, penalty_points, penalty_money):
  234. return {
  235. 'metric': metric,
  236. 'weight': None,
  237. 'illegal_count': count,
  238. 'penalty_points': count * penalty_points,
  239. 'penalty_money': count * penalty_money,
  240. 'warning_count': count if penalty_points == 0 else 0,
  241. 'penalty_law': '《中华人民共和国道路交通安全法》第四十二条:机动车上道路行驶,不得超过限速标志标明的最高时速。'
  242. }
  243. def overspeed(self):
  244. list_sign, speed_df = self._find_speed_violations()
  245. grouped_events = self._group_violations(list_sign)
  246. filtered_events = self._filter_groups_by_frame_period(grouped_events)
  247. violation_times = self._extract_violation_times(filtered_events)
  248. if violation_times:
  249. time_df = pd.DataFrame([d[0] for d in violation_times], columns=['start_time', 'end_time'])
  250. time_df['violation'] = '超速'
  251. self.violation_df = pd.concat([self.violation_df, time_df], ignore_index=True)
  252. return self._calculate_overspeed_statistics(speed_df, list_sign)
  253. class Compliance(object):
  254. def __init__(self, data_processed, custom_data, scoreModel):
  255. self.eval_data = pd.DataFrame()
  256. self.penalty_points = 0
  257. self.config = data_processed.config
  258. compliance_config = self.config.config['compliance']
  259. self.compliance_config = compliance_config
  260. self.weight_dict = compliance_config['weight']
  261. self.metric_list = compliance_config['metric']
  262. self.type_list = compliance_config['type']
  263. print("self.type_list is", self.type_list)
  264. self.weight_custom = compliance_config['weightCustom']
  265. self.name_dict = compliance_config['name']
  266. self.metric_dict = compliance_config['typeMetricDict']
  267. self.type_name_dict = compliance_config['typeName']
  268. self.weight = compliance_config['weightDimension']
  269. self.weight_type_dict = compliance_config['typeWeight']
  270. self.weight_type_list = compliance_config['typeWeightList']
  271. self.type_illegal_count_dict = {}
  272. self.traffic_rule = traffic_rule(data_processed, scoreModel)
  273. self.violation_df = self.traffic_rule.violation_df
  274. def score_cal_penalty_points(self, penalty_points):
  275. if penalty_points == 0:
  276. score = 100
  277. elif penalty_points >= 12:
  278. score = 0
  279. else:
  280. score = (12 - penalty_points) / 12 * 60
  281. return score
  282. def time_splice(self, start_time, end_time):
  283. str_time = f"[{start_time}s, {end_time}s]"
  284. return str_time
  285. def weight_type_cal(self):
  286. # penalty_list = [1, 3, 6, 9, 12]
  287. penalty_list = [1, 3, 6, 12]
  288. sum_penalty = sum(penalty_list)
  289. weight_type_list = [round(x / sum_penalty, 2) for x in penalty_list]
  290. return weight_type_list
  291. def compliance_statistic(self):
  292. # metric analysis
  293. press_line_dict = self.traffic_rule.get_solid_line_violations()
  294. run_red_light_dict = self.traffic_rule.run_red_light_detection()
  295. overspeed_0_to_10_dict, overspeed_10_to_20_dict, overspeed_20_to_50_dict, overspeed_50_dict = self.traffic_rule.overspeed()
  296. df_list = []
  297. if "overspeed10" in self.metric_list:
  298. df_list.append(overspeed_0_to_10_dict)
  299. if "overspeed10_20" in self.metric_list:
  300. df_list.append(overspeed_10_to_20_dict)
  301. if "pressSolidLine" in self.metric_list:
  302. df_list.append(press_line_dict)
  303. if "runRedLight" in self.metric_list:
  304. df_list.append(run_red_light_dict)
  305. if "overspeed20_50" in self.metric_list:
  306. df_list.append(overspeed_20_to_50_dict)
  307. if "overspeed50" in self.metric_list:
  308. df_list.append(overspeed_50_dict)
  309. # generate dataframe and dicts
  310. compliance_df = pd.DataFrame(df_list)
  311. return compliance_df
  312. def prepare_data(self):
  313. self.compliance_df = self.compliance_statistic()
  314. self.illegal_count = int(self.compliance_df['illegal_count'].sum())
  315. self.metric_penalty_points_dict = self.compliance_df.set_index('metric').to_dict()['penalty_points']
  316. self.metric_illegal_count_dict = self.compliance_df.set_index('metric').to_dict()['illegal_count']
  317. self.metric_penalty_money_dict = self.compliance_df.set_index('metric').to_dict()['penalty_money']
  318. self.metric_warning_count_dict = self.compliance_df.set_index('metric').to_dict()['warning_count']
  319. self.metric_penalty_law_dict = self.compliance_df.set_index('metric').to_dict()['penalty_law']
  320. # 初始化数据字典
  321. self.illegal_count = int(self.compliance_df['illegal_count'].sum())
  322. self.metric_penalty_points_dict = self.compliance_df.set_index('metric').to_dict()['penalty_points']
  323. self.metric_illegal_count_dict = self.compliance_df.set_index('metric').to_dict()['illegal_count']
  324. self.metric_penalty_money_dict = self.compliance_df.set_index('metric').to_dict()['penalty_money']
  325. self.metric_warning_count_dict = self.compliance_df.set_index('metric').to_dict()['warning_count']
  326. self.metric_penalty_law_dict = self.compliance_df.set_index('metric').to_dict()['penalty_law']
  327. def calculate_deduct_scores(self):
  328. score_type_dict = {}
  329. deduct_functions = {
  330. "deduct1": self.calculate_deduct(1),
  331. "deduct3": self.calculate_deduct(3),
  332. "deduct6": self.calculate_deduct(6),
  333. "deduct9": self.calculate_deduct(9),
  334. "deduct12": self.calculate_deduct(12),
  335. }
  336. for deduct_type in self.type_list:
  337. print("deduct_type is", deduct_type)
  338. if deduct_type in deduct_functions:
  339. penalty_points, illegal_count = deduct_functions[deduct_type]
  340. score_type_dict[deduct_type] = self.score_cal_penalty_points(penalty_points)
  341. self.type_illegal_count_dict[deduct_type] = illegal_count
  342. return score_type_dict
  343. def calculate_deduct(self, num):
  344. deduct_df = self.compliance_df[(self.compliance_df['weight'].isna()) | (self.compliance_df['weight'] == num)]
  345. return deduct_df['penalty_points'].sum(), deduct_df['illegal_count'].sum()
  346. def calculate_weights(self):
  347. weight_dict = {
  348. "overspeed10": 0.5,
  349. "overspeed10_20": 0.5,
  350. "pressSolidLine": 1.0,
  351. "runRedLight": 0.5,
  352. "overspeed20_50": 0.5,
  353. "overspeed50": 1.0
  354. }
  355. self.weight_type_list = [weight_dict.get(metric, 0.5) for metric in
  356. self.type_list] # 假设 type_list 中的每个元素都在 weight_dict 的键中,或者默认为 0.5
  357. self.weight_type_dict = {key: value for key, value in zip(self.type_list, self.weight_type_list)}
  358. self.weight_dict = weight_dict
  359. def calculate_compliance_score(self, score_type_dict):
  360. if not self.weight_custom: # 客观赋权
  361. self.calculate_weights()
  362. penalty_points_threshold = 12 # 假设的扣分阈值
  363. if hasattr(self, 'penalty_points') and self.penalty_points >= penalty_points_threshold:
  364. score_compliance = 0
  365. elif sum(score_type_dict.values()) / len(score_type_dict) == 100:
  366. score_compliance = 100
  367. else:
  368. score_type_tmp = [80 if x == 100 else x for key, x in score_type_dict.items()]
  369. score_compliance = np.dot(self.weight_type_list, score_type_tmp)
  370. return round(score_compliance, 2)
  371. def output_results(self, score_compliance, score_type_dict):
  372. print("\n[合规性表现及得分情况]")
  373. print(f"合规性得分为:{score_compliance:.2f}分。")
  374. print(f"合规性各分组得分为:{score_type_dict}。")
  375. print(f"合规性各分组权重为:{self.weight_type_list}。")
  376. def compliance_score(self):
  377. self.prepare_data()
  378. score_type_dict = self.calculate_deduct_scores()
  379. score_compliance = self.calculate_compliance_score(score_type_dict)
  380. self.output_results(score_compliance, score_type_dict)
  381. return score_compliance, score_type_dict
  382. def _get_weight_distribution(self):
  383. # get weight distribution
  384. weight_distribution = {}
  385. weight_distribution["name"] = self.config.dimension_name["compliance"]
  386. for type in self.type_list:
  387. type_weight_indexes_dict = {key: f"{self.name_dict[key]}({value * 100:.2f}%)" for key, value in
  388. self.weight_dict.items() if
  389. key in self.metric_dict[type]}
  390. weight_distribution_type = {
  391. "weight": f"{self.type_name_dict[type]}({self.weight_type_dict[type] * 100:.2f}%)",
  392. "indexes": type_weight_indexes_dict
  393. }
  394. weight_distribution[type] = weight_distribution_type
  395. return weight_distribution
  396. def build_weight_index(self, metric_prefix, weight_key):
  397. if metric_prefix in self.metric_list:
  398. return {
  399. f"{metric_prefix}Weight": f"{metric_prefix.replace('_', ' ').capitalize()}({self.weight_dict[weight_key] * 100:.2f}%)"}
  400. return {}
  401. def compliance_weight_distribution(self):
  402. weight_distribution = {"name": "合规性"}
  403. deduct_types = {
  404. "deduct1": {"metrics": ["overspeed10", "overspeed10_20"], "weight_key": "deduct1"},
  405. "deduct3": {"metrics": ["pressSolidLine"], "weight_key": "deduct3"},
  406. "deduct6": {"metrics": ["runRedLight", "overspeed20_50"], "weight_key": "deduct6"},
  407. "deduct9": {"metrics": [], "weight_key": "deduct9"}, # Assuming no specific metrics for deduct9
  408. "deduct12": {"metrics": ["overspeed50"], "weight_key": "deduct12"},
  409. }
  410. for deduct_type, details in deduct_types.items():
  411. if deduct_type in self.type_list:
  412. indexes_dict = {}
  413. for metric in details["metrics"]:
  414. indexes_dict.update(self.build_weight_index(metric, details["weight_key"]))
  415. weight_distribution[deduct_type] = {
  416. f"{deduct_type.replace('deduct', '')}Weight":
  417. f"{details['weight_key'].replace('deduct', '').replace('_', ' ').capitalize()}"
  418. f"违规({int(details['weight_key'].replace('deduct', ''))}分)({self.weight_type_dict[deduct_type] * 100:.2f} % )",
  419. "indexes": indexes_dict
  420. }
  421. if deduct_type == "deduct9" and "deduct9" in self.type_list:
  422. weight_distribution[deduct_type] = {
  423. "deduct9Weight": f"严重违规(9分)({self.weight_type_dict[deduct_type] * 100:.2f}%)",
  424. "indexes": {}
  425. }
  426. return weight_distribution
  427. def report_statistic(self):
  428. score_compliance, score_type_dict = self.compliance_score()
  429. grade_compliance = score_grade(score_compliance)
  430. score_compliance = int(score_compliance) if int(score_compliance) == score_compliance else score_compliance
  431. score_type = [int(n) if int(n) == n else n for key, n in score_type_dict.items()]
  432. # 获取合规性描述
  433. comp_description1 = self.get_compliance_description(grade_compliance, self.illegal_count)
  434. comp_description2 = self.get_violation_details(self.type_list, self.type_illegal_count_dict,
  435. self.type_name_dict)
  436. weight_distribution = self._get_weight_distribution()
  437. # 获取违规数据表
  438. violations_slices = self.get_violations_table()
  439. # 获取扣分详情
  440. deductPoint_dict = self.get_deduct_points_dict(score_type_dict)
  441. # 返回结果(这里可以根据需要返回具体的数据结构)
  442. return {
  443. "name": "合规性",
  444. "weight": f"{self.weight * 100:.2f}%",
  445. "weightDistribution": weight_distribution,
  446. "score": score_compliance,
  447. "level": grade_compliance,
  448. 'score_type': score_type,
  449. # 'score_metric': score_metric,
  450. 'illegalCount': self.illegal_count,
  451. "description1": comp_description1,
  452. "description2": comp_description2,
  453. "details": deductPoint_dict,
  454. "violations": violations_slices
  455. }
  456. def get_compliance_description(self, grade_compliance, illegal_count):
  457. # 获取合规性描述
  458. if grade_compliance == '优秀':
  459. return '车辆在本轮测试中无违反交通法规行为;'
  460. else:
  461. return f'车辆在本轮测试中共发生{illegal_count}次违反交通法规行为;' \
  462. f'如果等级为一般或较差,需要提高算法在合规性上的表现。'
  463. def get_violation_details(self, type_list, type_illegal_count_dict, type_name_dict):
  464. # 获取违规详情描述
  465. if self.illegal_count == 0:
  466. return "车辆在该用例中无违反交通法规行为,算法表现良好。"
  467. else:
  468. str_illegel_type = ", ".join(
  469. [f"{type_name_dict[type]}行为{count}次" for type, count in type_illegal_count_dict.items() if
  470. count > 0])
  471. return f"车辆在该用例共违反交通法规{self.illegal_count}次。其中{str_illegel_type}。违规行为详情见附录C。"
  472. def get_violations_table(self):
  473. # 获取违规数据表
  474. if not self.violation_df.empty:
  475. self.violation_df['time'] = self.violation_df.apply(
  476. lambda row: self.time_splice(row['start_time'], row['end_time']), axis=1)
  477. df_violations = self.violation_df[['time', 'violation']]
  478. return df_violations.to_dict('records')
  479. else:
  480. return []
  481. def get_deduct_points_dict(self, score_type_dict):
  482. # 获取扣分详情字典
  483. deductPoint_dict = {}
  484. for type in self.type_list:
  485. type_dict = {
  486. "name": self.type_name_dict[type],
  487. "score": score_type_dict[type],
  488. "indexes": {}
  489. }
  490. for metric in self.metric_list:
  491. type_dict["indexes"][metric] = {
  492. "name": self.name_dict[metric],
  493. "times": self.metric_illegal_count_dict[metric],
  494. "deductPoints": self.metric_penalty_points_dict[metric],
  495. "fine": self.metric_penalty_money_dict[metric],
  496. "basis": self.metric_penalty_law_dict[metric]
  497. }
  498. deductPoint_dict[type] = type_dict
  499. for deduct_type, metrics in [
  500. ("deduct1", ["overspeed10", "overspeed10_20"]),
  501. ("deduct3", ["pressSolidLine"]),
  502. ("deduct6", ["runRedLight", "overspeed20_50"]),
  503. ("deduct9", ["xx"]), # 注意:这里xx是示例,实际应替换为具体的违规类型
  504. ("deduct12", ["overspeed50"])
  505. ]:
  506. if deduct_type in self.type_list:
  507. deduct_indexes = {}
  508. for metric in metrics:
  509. if metric in self.metric_list:
  510. deduct_indexes[metric] = {
  511. "name": self.name_dict[metric],
  512. "times": self.metric_illegal_count_dict[metric],
  513. "deductPoints": self.metric_penalty_points_dict[metric],
  514. "fine": self.metric_penalty_money_dict[metric],
  515. "basis": self.metric_penalty_law_dict[metric]
  516. }
  517. # 对于deduct9,特殊处理其score值
  518. score = score_type_dict.get(deduct_type, 100) if deduct_type != "deduct9" else 100
  519. deductPoint_dict[deduct_type] = {
  520. "name": f"{deduct_type.replace('deduct', '').strip()}违规({int(deduct_type.replace('deduct', ''))}分)",
  521. "score": score,
  522. "indexes": deduct_indexes
  523. }
  524. return deductPoint_dict
  525. def get_eval_data(self):
  526. df = self.eval_data
  527. return df