score.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import json
  2. from modules.lib.log_manager import LogManager
  3. class Score:
  4. def __init__(self, yaml_config, module_name: str ):
  5. self.logger = LogManager().get_logger() # 获取全局日志实例
  6. self.calculated_metrics = None
  7. self.config = yaml_config
  8. self.module_config = None
  9. self.module_name = module_name
  10. self.t_threshold = None
  11. self.process_config(self.config)
  12. self.level_3_merics = self._extract_level_3_metrics(self.module_config)
  13. self.result = {}
  14. def process_config(self, config_dict):
  15. t_threshold = config_dict.get("T_threshold")
  16. if t_threshold is None:
  17. raise ValueError("配置中缺少 T_threshold 键")
  18. module_keys = [key for key in config_dict if key != "T_threshold"]
  19. # if len(module_keys) != 1:
  20. # raise ValueError("配置字典应包含且仅包含一个模块配置键")
  21. # module_name = module_keys[0]
  22. module_config = config_dict[self.module_name]
  23. # print(f'模块名称:{module_name}')
  24. # print(f'模块配置:{module_config}')
  25. # print(f'T_threshold:{t_threshold}')
  26. # 实际业务逻辑(示例:存储到对象属性)
  27. # self.module_name = module_name
  28. self.module_config = module_config
  29. self.t_threshold = t_threshold
  30. self.logger.info(f'模块名称:{self.module_name}')
  31. self.logger.info(f'模块配置:{self.module_config}')
  32. self.logger.info(f'T_threshold: {t_threshold}')
  33. def _extract_level_3_metrics(self, d):
  34. name = []
  35. for key, value in d.items():
  36. if isinstance(value, dict): # 如果值是字典,继续遍历
  37. self._extract_level_3_metrics(value)
  38. elif key == 'name': # 找到name键时,将值添加到列表
  39. name.append(value)
  40. return name
  41. def is_within_range(self, value, min_val, max_val):
  42. return min_val <= value <= max_val
  43. def evaluate_level_3(self, metrics):
  44. result3 = {}
  45. name = metrics.get('name')
  46. priority = metrics.get('priority')
  47. max_val = metrics.get('max')
  48. min_val = metrics.get('min')
  49. self.level_3_merics.append(name)
  50. print(f'name: {name}')
  51. print(f'self.calculated_metrics: {self.calculated_metrics}')
  52. metric_value = self.calculated_metrics.get(name)
  53. print(f'metric_value: {metric_value}')
  54. result3[name] = {
  55. 'result': True,
  56. 'priority': priority
  57. }
  58. if metric_value is None:
  59. return result3
  60. if not self.is_within_range(metric_value, min_val, max_val) and priority == 0:
  61. result3[name]['result'] = False
  62. elif not self.is_within_range(metric_value, min_val, max_val) and priority == 1:
  63. result3[name]['priority_1_count'] += 1
  64. # Count priority 1 failures and override result if more than 3
  65. priority_1_metrics = [v for v in result3.values() if v['priority'] == 1 and not v['result']]
  66. if len([m for m in priority_1_metrics if not m['result']]) > 3:
  67. result3[name]['result'] = False
  68. return result3
  69. def evaluate_level_2(self, metrics):
  70. result2 = {}
  71. name = metrics.get('name')
  72. priority = metrics.get('priority')
  73. result2[name] = {}
  74. for metric, sub_metrics in metrics.items():
  75. if metric not in ['name', 'priority']:
  76. result2[name].update(self.evaluate_level_3(sub_metrics))
  77. # Aggregate results for level 2 config.T0 config.T1 config.T2
  78. priority_0_count = sum(1 for v in result2[name].values() if v['priority'] == 0 and not v['result'])
  79. priority_1_count = sum(1 for v in result2[name].values() if v['priority'] == 1 and not v['result'])
  80. priority_2_count = sum(1 for v in result2[name].values() if v['priority'] == 2 and not v['result'])
  81. if priority_0_count > self.t_threshold['T0_threshold']:
  82. result2[name]['result'] = False
  83. elif priority_1_count > self.t_threshold['T1_threshold']:
  84. for metric in result2[name].values():
  85. metric['result'] = False
  86. elif priority_2_count > self.t_threshold['T2_threshold']:
  87. for metric in result2[name].values():
  88. metric['result'] = False
  89. else:
  90. result2[name]['result'] = True # Default to True unless overridden
  91. result2[name]['priority'] = priority
  92. result2[name]['priority_0_count'] = priority_0_count
  93. result2[name]['priority_1_count'] = priority_1_count
  94. result2[name]['priority_2_count'] = priority_2_count
  95. return result2
  96. def evaluate_level_1(self):
  97. name = self.module_config.get('name')
  98. priority = self.module_config.get('priority')
  99. result1 = {}
  100. result1[name] = {}
  101. for metric, metrics in self.module_config.items():
  102. if metric not in ['name', 'priority']:
  103. result1[name].update(self.evaluate_level_2(metrics))
  104. # Aggregate results for level 2 config.T0 config.T1 config.T2
  105. priority_0_count = sum(1 for v in result1[name].values() if v['priority'] == 0 and not v['result'])
  106. priority_1_count = sum(1 for v in result1[name].values() if v['priority'] == 1 and not v['result'])
  107. priority_2_count = sum(1 for v in result1[name].values() if v['priority'] == 2 and not v['result'])
  108. if priority_0_count > self.t_threshold['T0_threshold']:
  109. result1[name]['result'] = False
  110. elif priority_1_count > self.t_threshold['T1_threshold']:
  111. for metric in result1[name].values():
  112. metric['result'] = False
  113. elif priority_2_count > self.t_threshold['T2_threshold']:
  114. for metric in result1[name].values():
  115. metric['result'] = False
  116. else:
  117. result1[name]['result'] = True # Default to True unless overridden
  118. result1[name]['priority'] = priority
  119. result1[name]['priority_0_count'] = priority_0_count
  120. result1[name]['priority_1_count'] = priority_1_count
  121. result1[name]['priority_2_count'] = priority_2_count
  122. return result1
  123. def evaluate(self, calculated_metrics):
  124. self.calculated_metrics = calculated_metrics
  125. self.result = self.evaluate_level_1()
  126. return self.result
  127. def evaluate_single_case(self, case_name, priority, json_dict):
  128. name = case_name
  129. result = {}
  130. result[name] = {}
  131. # print(json_dict)
  132. # Aggregate results for level 2 config.T0 config.T1 config.T2
  133. priority_0_count = sum(1 for v in json_dict.values() if v['priority'] == 0 and not v['result'])
  134. priority_1_count = sum(1 for v in json_dict.values() if v['priority'] == 1 and not v['result'])
  135. priority_2_count = sum(1 for v in json_dict.values() if v['priority'] == 2 and not v['result'])
  136. if priority_0_count > config.T0:
  137. result[name]['result'] = False
  138. elif priority_1_count > config.T1:
  139. for metric in result[name].values():
  140. metric['result'] = False
  141. elif priority_2_count > config.T2:
  142. for metric in result[name].values():
  143. metric['result'] = False
  144. else:
  145. result[name]['result'] = True # Default to True unless overridden
  146. result[name]['priority'] = priority
  147. result[name]['priority_0_count'] = priority_0_count
  148. result[name]['priority_1_count'] = priority_1_count
  149. result[name]['priority_2_count'] = priority_2_count
  150. result[case_name].update(json_dict)
  151. return result
  152. def evaluate_single_case_back(case_name, priority, json_dict):
  153. """对单个案例进行评估"""
  154. result = {case_name: {}}
  155. priority_counts = {priority: sum(1 for v in json_dict.values() if v['priority'] == priority and not v['result'])
  156. for priority in [0, 1, 2]}
  157. if priority_counts[0] > config.T0:
  158. result[case_name]['result'] = False
  159. elif priority_counts[1] > config.T1:
  160. for metric in result[case_name].values():
  161. metric['result'] = False
  162. elif priority_counts[2] > config.T2:
  163. for metric in result[case_name].values():
  164. metric['result'] = False
  165. else:
  166. result[case_name]['result'] = True
  167. result[case_name].update(priority_counts)
  168. result[case_name].update(json_dict) # 合并原始数据
  169. return result
  170. def main():
  171. # config_path = r'/home/kevin/kevin/zhaoyuan/evaluate_zhaoyuan/models/safety/config.yaml'
  172. # config_path1 = r'/home/kevin/kevin/zhaoyuan/evaluate_zhaoyuan/models/safety/config.json'
  173. # calculated_metrics = {
  174. # 'TTC': 1.0,
  175. # 'MTTC': 1.0,
  176. # 'THW': 1.0,
  177. # 'LonSD': 50.0,
  178. # 'LatSD': 3.0,
  179. # 'DRAC': 3.0,
  180. # 'BTN': -1000.0,
  181. # 'STN': 0.5,
  182. # 'collisionRisk': 5.0,
  183. # 'collisionSeverity': 2.0,
  184. # }
  185. # # evaluator = Score(config_path, calculated_metrics)
  186. # evaluator = Score(config_path)
  187. # result = evaluator.evaluate(calculated_metrics)
  188. # with open(config_path1, 'w') as json_file:
  189. # json.dump(result, json_file, indent=4) # `indent` 参数用于美化输出
  190. config_path = r'/home/kevin/kevin/zhaoyuan/zhaoyuan/models/caseMetric/single_config.yaml'
  191. config_path1 = r'/home/kevin/kevin/zhaoyuan/zhaoyuan/result/data_zhaoyuan/data_zhaoyuan_single_report.json'
  192. # evaluator = Score(config_path, calculated_metrics)
  193. with open(config_path1, 'r') as file:
  194. data = json.load(file)
  195. result = evaluate_single_case("case1", 0, data)
  196. print(result)
  197. if __name__ == '__main__':
  198. main()