123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- import json
- from typing import Dict, Any
- from modules.lib.log_manager import LogManager
- from modules.config import config
- class Score:
- def __init__(self, yaml_config ):
- self.logger = LogManager().get_logger() # 获取全局日志实例
- self.calculated_metrics = None
- self.config = yaml_config
- self.module_config = None
- self.module_name = None
- self.t_threshold = None
- self.process_config(self.config)
- self.level_3_merics = self._extract_level_3_metrics(self.module_config)
- self.result = {}
-
- def process_config(self, config_dict):
- t_threshold = config_dict.get("T_threshold")
- if t_threshold is None:
- raise ValueError("配置中缺少 T_threshold 键")
- module_keys = [key for key in config_dict if (key != "T_threshold")]
- if len(module_keys) != 1:
- raise ValueError("配置字典应包含且仅包含一个模块配置键")
-
- module_name = module_keys[0]
- module_config = config_dict[module_name]
- # print(f'模块名称:{module_name}')
- # print(f'模块配置:{module_config}')
- # print(f'T_threshold:{t_threshold}')
- # 实际业务逻辑(示例:存储到对象属性)
- self.module_name = module_name
- self.module_config = module_config
- self.t_threshold = t_threshold
- self.logger.info(f'模块名称:{module_name}')
- self.logger.info(f'模块配置:{module_config}')
- self.logger.info(f'T_threshold:{t_threshold}')
- def _extract_level_3_metrics(self, d):
- name = []
- for key, value in d.items():
- if isinstance(value, dict): # 如果值是字典,继续遍历
- self._extract_level_3_metrics(value)
- elif key == 'name': # 找到name键时,将值添加到列表
- name.append(value)
- return name
-
- def is_within_range(self, value, min_val, max_val):
- return min_val <= value <= max_val
-
- def evaluate_level_3(self, metrics):
- result3 = {}
- name = metrics.get('name')
- priority = metrics.get('priority')
- max_val = metrics.get('max')
- min_val = metrics.get('min')
-
- self.level_3_merics.append(name)
- metric_value = self.calculated_metrics.get(name)
- result3[name] = {
- 'result': True,
- 'priority': priority
- }
- if metric_value is None:
- return result3
-
- if not self.is_within_range(metric_value, min_val, max_val) and priority == 0:
- result3[name]['result'] = False
- elif not self.is_within_range(metric_value, min_val, max_val) and priority == 1:
- result3[name]['priority_1_count'] += 1
-
- # Count priority 1 failures and override result if more than 3
-
- priority_1_metrics = [v for v in result3.values() if v['priority'] == 1 and not v['result']]
- if len([m for m in priority_1_metrics if not m['result']]) > 3:
- result3[name]['result'] = False
-
- return result3
-
- def evaluate_level_2(self, metrics):
- result2 = {}
- name = metrics.get('name')
- priority = metrics.get('priority')
- result2[name] = {}
-
- for metric, sub_metrics in metrics.items():
- if metric not in ['name', 'priority']:
- result2[name].update(self.evaluate_level_3(sub_metrics))
-
- # Aggregate results for level 2 config.T0 config.T1 config.T2
- priority_0_count = sum(1 for v in result2[name].values() if v['priority'] == 0 and not v['result'])
- priority_1_count = sum(1 for v in result2[name].values() if v['priority'] == 1 and not v['result'])
- priority_2_count = sum(1 for v in result2[name].values() if v['priority'] == 2 and not v['result'])
- if priority_0_count > self.t_threshold['T0_threshold']:
- result2[name]['result'] = False
-
- elif priority_1_count > self.t_threshold['T1_threshold']:
- for metric in result2[name].values():
- metric['result'] = False
- elif priority_2_count > self.t_threshold['T2_threshold']:
- for metric in result2[name].values():
- metric['result'] = False
- else:
- result2[name]['result'] = True # Default to True unless overridden
- result2[name]['priority'] = priority
- result2[name]['priority_0_count'] = priority_0_count
- result2[name]['priority_1_count'] = priority_1_count
- result2[name]['priority_2_count'] = priority_2_count
-
- return result2
-
- def evaluate_level_1(self):
- name = self.module_config.get('name')
- priority = self.module_config.get('priority')
- result1 = {}
- result1[name] = {}
- for metric, metrics in self.module_config.items():
- if metric not in ['name', 'priority']:
- result1[name].update(self.evaluate_level_2(metrics))
-
- # Aggregate results for level 2 config.T0 config.T1 config.T2
- priority_0_count = sum(1 for v in result1[name].values() if v['priority'] == 0 and not v['result'])
- priority_1_count = sum(1 for v in result1[name].values() if v['priority'] == 1 and not v['result'])
- priority_2_count = sum(1 for v in result1[name].values() if v['priority'] == 2 and not v['result'])
- if priority_0_count > self.t_threshold['T0_threshold']:
- result1[name]['result'] = False
-
- elif priority_1_count > self.t_threshold['T1_threshold']:
- for metric in result1[name].values():
- metric['result'] = False
- elif priority_2_count > self.t_threshold['T2_threshold']:
- for metric in result1[name].values():
- metric['result'] = False
- else:
- result1[name]['result'] = True # Default to True unless overridden
- result1[name]['priority'] = priority
- result1[name]['priority_0_count'] = priority_0_count
- result1[name]['priority_1_count'] = priority_1_count
- result1[name]['priority_2_count'] = priority_2_count
- return result1
-
- def evaluate(self, calculated_metrics):
- self.calculated_metrics = calculated_metrics
- self.result = self.evaluate_level_1()
- return self.result
- def evaluate_single_case(self, case_name, priority, json_dict):
- name = case_name
- result = {}
- result[name] = {}
- # print(json_dict)
- # Aggregate results for level 2 config.T0 config.T1 config.T2
- priority_0_count = sum(1 for v in json_dict.values() if v['priority'] == 0 and not v['result'])
- priority_1_count = sum(1 for v in json_dict.values() if v['priority'] == 1 and not v['result'])
- priority_2_count = sum(1 for v in json_dict.values() if v['priority'] == 2 and not v['result'])
- if priority_0_count > config.T0:
- result[name]['result'] = False
-
- elif priority_1_count > config.T1:
- for metric in result[name].values():
- metric['result'] = False
- elif priority_2_count > config.T2:
- for metric in result[name].values():
- metric['result'] = False
- else:
- result[name]['result'] = True # Default to True unless overridden
- result[name]['priority'] = priority
- result[name]['priority_0_count'] = priority_0_count
- result[name]['priority_1_count'] = priority_1_count
- result[name]['priority_2_count'] = priority_2_count
- result[case_name].update(json_dict)
-
- return result
- import yaml
- def load_thresholds(config_path: str) -> Dict[str, int]:
- """从YAML配置文件加载阈值参数"""
- with open(config_path, 'r') as f:
- config = yaml.safe_load(f)
- return {
- "T0": config['T_threshold']['T0_threshold'],
- "T1": config['T_threshold']['T1_threshold'],
- "T2": config['T_threshold']['T2_threshold']
- }
- def get_overall_result(report: Dict[str, Any], config_path: str) -> Dict[str, Any]:
- """
- 处理评测报告并添加总体结果字段
-
- 参数:
- report: 原始评测报告字典
- config_path: YAML配置文件路径
-
- 返回:
- 添加了 overall_result 的处理后报告
- """
- # 加载阈值参数
- thresholds = load_thresholds(config_path)
-
- # 初始化计数器
- counters = {'p0': 0, 'p1': 0, 'p2': 0}
-
- # 目标分类
- target_categories = ['function', 'safety', 'comfort', 'traffic', 'efficient']
-
- # 直接统计每个维度的结果
- for category in target_categories:
- if category in report:
- # 如果该维度的结果为False,根据其priority增加对应计数
- if not report[category].get('result', True):
- priority = report[category].get('priority')
- if priority == 0:
- counters['p0'] += 1
- elif priority == 1:
- counters['p1'] += 1
- elif priority == 2:
- counters['p2'] += 1
-
- # 阈值判断逻辑
- thresholds_exceeded = (
- counters['p0'] > thresholds['T0'],
- counters['p1'] > thresholds['T1'],
- counters['p2'] > thresholds['T2']
- )
-
- # 生成处理后的报告
- processed_report = report.copy()
- processed_report['overall_result'] = not any(thresholds_exceeded)
-
- # 添加统计信息
- processed_report['threshold_checks'] = {
- 'T0_threshold': thresholds['T0'],
- 'T1_threshold': thresholds['T1'],
- 'T2_threshold': thresholds['T2'],
- 'actual_counts': counters
- }
-
- return processed_report
-
-
- def main():
- pass
-
-
- if __name__ == '__main__':
- main()
|