score.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. import json
  2. from typing import Dict, Any
  3. from modules.lib.log_manager import LogManager
  4. from modules.config import config
  5. class Score:
  6. def __init__(self, yaml_config ):
  7. self.logger = LogManager().get_logger() # 获取全局日志实例
  8. self.calculated_metrics = None
  9. self.config = yaml_config
  10. self.module_config = None
  11. self.module_name = None
  12. self.t_threshold = None
  13. self.process_config(self.config)
  14. self.level_3_merics = self._extract_level_3_metrics(self.module_config)
  15. self.result = {}
  16. def process_config(self, config_dict):
  17. t_threshold = config_dict.get("T_threshold")
  18. if t_threshold is None:
  19. raise ValueError("配置中缺少 T_threshold 键")
  20. module_keys = [key for key in config_dict if (key != "T_threshold")]
  21. if len(module_keys) != 1:
  22. raise ValueError("配置字典应包含且仅包含一个模块配置键")
  23. module_name = module_keys[0]
  24. module_config = config_dict[module_name]
  25. # print(f'模块名称:{module_name}')
  26. # print(f'模块配置:{module_config}')
  27. # print(f'T_threshold:{t_threshold}')
  28. # 实际业务逻辑(示例:存储到对象属性)
  29. self.module_name = module_name
  30. self.module_config = module_config
  31. self.t_threshold = t_threshold
  32. self.logger.info(f'模块名称:{module_name}')
  33. self.logger.info(f'模块配置:{module_config}')
  34. self.logger.info(f'T_threshold:{t_threshold}')
  35. def _extract_level_3_metrics(self, d):
  36. name = []
  37. for key, value in d.items():
  38. if isinstance(value, dict): # 如果值是字典,继续遍历
  39. self._extract_level_3_metrics(value)
  40. elif key == 'name': # 找到name键时,将值添加到列表
  41. name.append(value)
  42. return name
  43. def is_within_range(self, value, min_val, max_val):
  44. return min_val <= value <= max_val
  45. def evaluate_level_3(self, metrics):
  46. result3 = {}
  47. name = metrics.get('name')
  48. priority = metrics.get('priority')
  49. max_val = metrics.get('max')
  50. min_val = metrics.get('min')
  51. self.level_3_merics.append(name)
  52. metric_value = self.calculated_metrics.get(name)
  53. result3[name] = {
  54. 'result': True,
  55. 'priority': priority
  56. }
  57. if metric_value is None:
  58. return result3
  59. if not self.is_within_range(metric_value, min_val, max_val) and priority == 0:
  60. result3[name]['result'] = False
  61. elif not self.is_within_range(metric_value, min_val, max_val) and priority == 1:
  62. result3[name]['priority_1_count'] += 1
  63. # Count priority 1 failures and override result if more than 3
  64. priority_1_metrics = [v for v in result3.values() if v['priority'] == 1 and not v['result']]
  65. if len([m for m in priority_1_metrics if not m['result']]) > 3:
  66. result3[name]['result'] = False
  67. return result3
  68. def evaluate_level_2(self, metrics):
  69. result2 = {}
  70. name = metrics.get('name')
  71. priority = metrics.get('priority')
  72. result2[name] = {}
  73. for metric, sub_metrics in metrics.items():
  74. if metric not in ['name', 'priority']:
  75. result2[name].update(self.evaluate_level_3(sub_metrics))
  76. # Aggregate results for level 2 config.T0 config.T1 config.T2
  77. priority_0_count = sum(1 for v in result2[name].values() if v['priority'] == 0 and not v['result'])
  78. priority_1_count = sum(1 for v in result2[name].values() if v['priority'] == 1 and not v['result'])
  79. priority_2_count = sum(1 for v in result2[name].values() if v['priority'] == 2 and not v['result'])
  80. if priority_0_count > self.t_threshold['T0_threshold']:
  81. result2[name]['result'] = False
  82. elif priority_1_count > self.t_threshold['T1_threshold']:
  83. for metric in result2[name].values():
  84. metric['result'] = False
  85. elif priority_2_count > self.t_threshold['T2_threshold']:
  86. for metric in result2[name].values():
  87. metric['result'] = False
  88. else:
  89. result2[name]['result'] = True # Default to True unless overridden
  90. result2[name]['priority'] = priority
  91. result2[name]['priority_0_count'] = priority_0_count
  92. result2[name]['priority_1_count'] = priority_1_count
  93. result2[name]['priority_2_count'] = priority_2_count
  94. return result2
  95. def evaluate_level_1(self):
  96. name = self.module_config.get('name')
  97. priority = self.module_config.get('priority')
  98. result1 = {}
  99. result1[name] = {}
  100. for metric, metrics in self.module_config.items():
  101. if metric not in ['name', 'priority']:
  102. result1[name].update(self.evaluate_level_2(metrics))
  103. # Aggregate results for level 2 config.T0 config.T1 config.T2
  104. priority_0_count = sum(1 for v in result1[name].values() if v['priority'] == 0 and not v['result'])
  105. priority_1_count = sum(1 for v in result1[name].values() if v['priority'] == 1 and not v['result'])
  106. priority_2_count = sum(1 for v in result1[name].values() if v['priority'] == 2 and not v['result'])
  107. if priority_0_count > self.t_threshold['T0_threshold']:
  108. result1[name]['result'] = False
  109. elif priority_1_count > self.t_threshold['T1_threshold']:
  110. for metric in result1[name].values():
  111. metric['result'] = False
  112. elif priority_2_count > self.t_threshold['T2_threshold']:
  113. for metric in result1[name].values():
  114. metric['result'] = False
  115. else:
  116. result1[name]['result'] = True # Default to True unless overridden
  117. result1[name]['priority'] = priority
  118. result1[name]['priority_0_count'] = priority_0_count
  119. result1[name]['priority_1_count'] = priority_1_count
  120. result1[name]['priority_2_count'] = priority_2_count
  121. return result1
  122. def evaluate(self, calculated_metrics):
  123. self.calculated_metrics = calculated_metrics
  124. self.result = self.evaluate_level_1()
  125. return self.result
  126. def evaluate_single_case(self, case_name, priority, json_dict):
  127. name = case_name
  128. result = {}
  129. result[name] = {}
  130. # print(json_dict)
  131. # Aggregate results for level 2 config.T0 config.T1 config.T2
  132. priority_0_count = sum(1 for v in json_dict.values() if v['priority'] == 0 and not v['result'])
  133. priority_1_count = sum(1 for v in json_dict.values() if v['priority'] == 1 and not v['result'])
  134. priority_2_count = sum(1 for v in json_dict.values() if v['priority'] == 2 and not v['result'])
  135. if priority_0_count > config.T0:
  136. result[name]['result'] = False
  137. elif priority_1_count > config.T1:
  138. for metric in result[name].values():
  139. metric['result'] = False
  140. elif priority_2_count > config.T2:
  141. for metric in result[name].values():
  142. metric['result'] = False
  143. else:
  144. result[name]['result'] = True # Default to True unless overridden
  145. result[name]['priority'] = priority
  146. result[name]['priority_0_count'] = priority_0_count
  147. result[name]['priority_1_count'] = priority_1_count
  148. result[name]['priority_2_count'] = priority_2_count
  149. result[case_name].update(json_dict)
  150. return result
  151. import yaml
  152. def load_thresholds(config_path: str) -> Dict[str, int]:
  153. """从YAML配置文件加载阈值参数"""
  154. with open(config_path, 'r') as f:
  155. config = yaml.safe_load(f)
  156. return {
  157. "T0": config['T_threshold']['T0_threshold'],
  158. "T1": config['T_threshold']['T1_threshold'],
  159. "T2": config['T_threshold']['T2_threshold']
  160. }
  161. def get_overall_result(report: Dict[str, Any], config_path: str) -> Dict[str, Any]:
  162. """
  163. 处理评测报告并添加总体结果字段
  164. 参数:
  165. report: 原始评测报告字典
  166. config_path: YAML配置文件路径
  167. 返回:
  168. 添加了 overall_result 的处理后报告
  169. """
  170. # 加载阈值参数
  171. thresholds = load_thresholds(config_path)
  172. # 初始化计数器
  173. counters = {'p0': 0, 'p1': 0, 'p2': 0}
  174. # 目标分类
  175. target_categories = ['function', 'safety', 'comfort', 'traffic', 'efficient']
  176. # 直接统计每个维度的结果
  177. for category in target_categories:
  178. if category in report:
  179. # 如果该维度的结果为False,根据其priority增加对应计数
  180. if not report[category].get('result', True):
  181. priority = report[category].get('priority')
  182. if priority == 0:
  183. counters['p0'] += 1
  184. elif priority == 1:
  185. counters['p1'] += 1
  186. elif priority == 2:
  187. counters['p2'] += 1
  188. # 阈值判断逻辑
  189. thresholds_exceeded = (
  190. counters['p0'] > thresholds['T0'],
  191. counters['p1'] > thresholds['T1'],
  192. counters['p2'] > thresholds['T2']
  193. )
  194. # 生成处理后的报告
  195. processed_report = report.copy()
  196. processed_report['overall_result'] = not any(thresholds_exceeded)
  197. # 添加统计信息
  198. processed_report['threshold_checks'] = {
  199. 'T0_threshold': thresholds['T0'],
  200. 'T1_threshold': thresholds['T1'],
  201. 'T2_threshold': thresholds['T2'],
  202. 'actual_counts': counters
  203. }
  204. return processed_report
  205. def main():
  206. pass
  207. if __name__ == '__main__':
  208. main()