evaluator_enhanced.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734
  1. #!/usr/bin/env python3
  2. # evaluator_enhanced.py
  3. import sys
  4. import warnings
  5. import time
  6. import importlib
  7. import importlib.util
  8. import yaml
  9. from pathlib import Path
  10. import argparse
  11. from concurrent.futures import ThreadPoolExecutor
  12. from functools import lru_cache
  13. from typing import Dict, Any, List, Optional, Type, Tuple, Callable, Union
  14. from datetime import datetime
  15. import logging
  16. import traceback
  17. import json
  18. import inspect
  19. # 常量定义
  20. DEFAULT_WORKERS = 4
  21. CUSTOM_METRIC_PREFIX = "metric_"
  22. CUSTOM_METRIC_FILE_PATTERN = "*.py"
  23. # 安全设置根目录路径
  24. if hasattr(sys, "_MEIPASS"):
  25. _ROOT_PATH = Path(sys._MEIPASS)
  26. else:
  27. _ROOT_PATH = Path(__file__).resolve().parent.parent
  28. sys.path.insert(0, str(_ROOT_PATH))
  29. class ConfigManager:
  30. """配置管理组件"""
  31. def __init__(self, logger: logging.Logger):
  32. self.logger = logger
  33. self.base_config: Dict[str, Any] = {}
  34. self.custom_config: Dict[str, Any] = {}
  35. self.merged_config: Dict[str, Any] = {}
  36. self._config_cache = {}
  37. def split_configs(self, all_metrics_path: Path, builtin_metrics_path: Path, custom_metrics_path: Path) -> None:
  38. """从all_metrics_config.yaml拆分成内置和自定义配置"""
  39. # 检查是否已经存在提取的配置文件,如果存在则跳过拆分过程
  40. extracted_builtin_path = builtin_metrics_path.parent / f"{builtin_metrics_path.stem}_extracted{builtin_metrics_path.suffix}"
  41. if extracted_builtin_path.exists() and custom_metrics_path.exists():
  42. self.logger.info(f"使用已存在的拆分配置文件: {extracted_builtin_path}")
  43. return
  44. try:
  45. # 使用缓存加载配置文件,避免重复读取
  46. all_metrics_dict = self._safe_load_config(all_metrics_path)
  47. builtin_metrics_dict = self._safe_load_config(builtin_metrics_path)
  48. # 递归提取内置和自定义指标
  49. extracted_builtin_metrics, custom_metrics_dict = self._split_metrics_recursive(
  50. all_metrics_dict, builtin_metrics_dict
  51. )
  52. # 保存提取的内置指标到新文件
  53. with open(extracted_builtin_path, 'w', encoding='utf-8') as f:
  54. yaml.dump(extracted_builtin_metrics, f, allow_unicode=True, sort_keys=False, indent=2)
  55. self.logger.info(f"拆分配置: 提取的内置指标已保存到 {extracted_builtin_path}")
  56. if custom_metrics_dict:
  57. with open(custom_metrics_path, 'w', encoding='utf-8') as f:
  58. yaml.dump(custom_metrics_dict, f, allow_unicode=True, sort_keys=False, indent=2)
  59. self.logger.info(f"拆分配置: 自定义指标已保存到 {custom_metrics_path}")
  60. except Exception as err:
  61. self.logger.error(f"拆分配置失败: {str(err)}")
  62. raise
  63. def _split_metrics_recursive(self, all_dict: Dict, builtin_dict: Dict) -> Tuple[Dict, Dict]:
  64. """递归拆分内置和自定义指标配置"""
  65. extracted_builtin = {}
  66. custom_metrics = {}
  67. for key, value in all_dict.items():
  68. if key in builtin_dict:
  69. # 如果是字典类型,继续递归
  70. if isinstance(value, dict) and isinstance(builtin_dict[key], dict):
  71. sub_builtin, sub_custom = self._split_metrics_recursive(value, builtin_dict[key])
  72. if sub_builtin:
  73. extracted_builtin[key] = sub_builtin
  74. if sub_custom:
  75. custom_metrics[key] = sub_custom
  76. else:
  77. # 如果不是字典类型,直接复制
  78. extracted_builtin[key] = value
  79. else:
  80. # 如果键不在内置配置中,归类为自定义指标
  81. custom_metrics[key] = value
  82. return extracted_builtin, custom_metrics
  83. def load_configs(self, all_config_path: Optional[Path], builtin_metrics_path: Optional[Path], custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
  84. """加载并合并配置"""
  85. # 如果已经加载过配置,直接返回缓存的结果
  86. cache_key = f"{all_config_path}_{builtin_metrics_path}_{custom_metrics_path}"
  87. if cache_key in self._config_cache:
  88. self.logger.info("使用缓存的配置数据")
  89. return self._config_cache[cache_key]
  90. # 自动拆分配置
  91. extracted_builtin_path = None
  92. if all_config_path and all_config_path.exists():
  93. # 生成提取的内置指标配置文件路径
  94. extracted_builtin_path = builtin_metrics_path.parent / f"{builtin_metrics_path.stem}_extracted{builtin_metrics_path.suffix}"
  95. self.split_configs(all_config_path, builtin_metrics_path, custom_metrics_path)
  96. # 优先使用提取的内置指标配置
  97. if extracted_builtin_path and extracted_builtin_path.exists():
  98. self.base_config = self._safe_load_config(extracted_builtin_path)
  99. else:
  100. self.base_config = self._safe_load_config(builtin_metrics_path) if builtin_metrics_path else {}
  101. self.custom_config = self._safe_load_config(custom_metrics_path) if custom_metrics_path else {}
  102. if all_config_path and all_config_path.exists():
  103. self.merged_config = self._safe_load_config(all_config_path)
  104. # 缓存配置结果
  105. self._config_cache[cache_key] = self.merged_config
  106. return self.merged_config
  107. return {}
  108. @lru_cache(maxsize=16)
  109. def _safe_load_config(self, config_path: Path) -> Dict[str, Any]:
  110. """安全加载YAML配置,使用lru_cache减少重复读取"""
  111. try:
  112. if not config_path or not config_path.exists():
  113. self.logger.warning(f"Config file not found: {config_path}")
  114. return {}
  115. with config_path.open('r', encoding='utf-8') as f:
  116. config_dict = yaml.safe_load(f) or {}
  117. self.logger.info(f"Loaded config: {config_path}")
  118. return config_dict
  119. except Exception as err:
  120. self.logger.error(f"Failed to load config {config_path}: {str(err)}")
  121. return {}
  122. def get_config(self) -> Dict[str, Any]:
  123. return self.merged_config
  124. def get_base_config(self) -> Dict[str, Any]:
  125. return self.base_config
  126. def get_custom_config(self) -> Dict[str, Any]:
  127. return self.custom_config
  128. class MetricLoader:
  129. """指标加载器组件"""
  130. def __init__(self, logger: logging.Logger, config_manager: ConfigManager):
  131. self.logger = logger
  132. self.config_manager = config_manager
  133. self.metric_modules: Dict[str, Type] = {}
  134. self.custom_metric_modules: Dict[str, Any] = {}
  135. def load_builtin_metrics(self) -> Dict[str, Type]:
  136. """加载内置指标模块"""
  137. module_mapping = {
  138. "safety": ("modules.metric.safety", "SafeManager"),
  139. "comfort": ("modules.metric.comfort", "ComfortManager"),
  140. "traffic": ("modules.metric.traffic", "TrafficManager"),
  141. "efficient": ("modules.metric.efficient", "EfficientManager"),
  142. "function": ("modules.metric.function", "FunctionManager"),
  143. }
  144. self.metric_modules = {
  145. name: self._load_module(*info)
  146. for name, info in module_mapping.items()
  147. }
  148. self.logger.info(f"Loaded builtin metrics: {', '.join(self.metric_modules.keys())}")
  149. return self.metric_modules
  150. @lru_cache(maxsize=32)
  151. def _load_module(self, module_path: str, class_name: str) -> Type:
  152. """动态加载Python模块"""
  153. try:
  154. module = __import__(module_path, fromlist=[class_name])
  155. return getattr(module, class_name)
  156. except (ImportError, AttributeError) as e:
  157. self.logger.error(f"Failed to load module: {module_path}.{class_name} - {str(e)}")
  158. raise
  159. def load_custom_metrics(self, custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
  160. """加载自定义指标模块"""
  161. if not custom_metrics_path or not custom_metrics_path.is_dir():
  162. self.logger.info("No custom metrics path or path not exists")
  163. return {}
  164. # 检查是否有新的自定义指标文件
  165. current_files = set(f.name for f in custom_metrics_path.glob(CUSTOM_METRIC_FILE_PATTERN)
  166. if f.name.startswith(CUSTOM_METRIC_PREFIX))
  167. loaded_files = set(self.custom_metric_modules.keys())
  168. # 如果没有新文件且已有加载的模块,直接返回
  169. if self.custom_metric_modules and not (current_files - loaded_files):
  170. self.logger.info(f"No new custom metrics to load, using {len(self.custom_metric_modules)} cached modules")
  171. return self.custom_metric_modules
  172. loaded_count = 0
  173. for py_file in custom_metrics_path.glob(CUSTOM_METRIC_FILE_PATTERN):
  174. if py_file.name.startswith(CUSTOM_METRIC_PREFIX):
  175. if self._process_custom_metric_file(py_file):
  176. loaded_count += 1
  177. self.logger.info(f"Loaded {loaded_count} custom metric modules")
  178. return self.custom_metric_modules
  179. def _process_custom_metric_file(self, file_path: Path) -> bool:
  180. """处理单个自定义指标文件"""
  181. try:
  182. metric_key = self._validate_metric_file(file_path)
  183. module_name = f"custom_metric_{file_path.stem}"
  184. spec = importlib.util.spec_from_file_location(module_name, file_path)
  185. module = importlib.util.module_from_spec(spec)
  186. spec.loader.exec_module(module)
  187. from modules.lib.metric_registry import BaseMetric
  188. metric_class = None
  189. for name, obj in inspect.getmembers(module):
  190. if (inspect.isclass(obj) and
  191. issubclass(obj, BaseMetric) and
  192. obj != BaseMetric):
  193. metric_class = obj
  194. break
  195. if metric_class:
  196. self.custom_metric_modules[metric_key] = {
  197. 'type': 'class',
  198. 'module': module,
  199. 'class': metric_class
  200. }
  201. self.logger.info(f"Loaded class-based custom metric: {metric_key}")
  202. elif hasattr(module, 'evaluate'):
  203. self.custom_metric_modules[metric_key] = {
  204. 'type': 'function',
  205. 'module': module
  206. }
  207. self.logger.info(f"Loaded function-based custom metric: {metric_key}")
  208. else:
  209. raise AttributeError(f"Missing evaluate() function or BaseMetric subclass: {file_path.name}")
  210. return True
  211. except ValueError as e:
  212. self.logger.warning(str(e))
  213. return False
  214. except Exception as e:
  215. self.logger.error(f"Failed to load custom metric {file_path}: {str(e)}")
  216. return False
  217. def _validate_metric_file(self, file_path: Path) -> str:
  218. """验证自定义指标文件命名规范"""
  219. stem = file_path.stem[len(CUSTOM_METRIC_PREFIX):]
  220. parts = stem.split('_')
  221. if len(parts) < 3:
  222. raise ValueError(f"Invalid custom metric filename: {file_path.name}, should be metric_<level1>_<level2>_<level3>.py")
  223. level1, level2, level3 = parts[:3]
  224. if not self._is_metric_configured(level1, level2, level3):
  225. raise ValueError(f"Unconfigured metric: {level1}.{level2}.{level3}")
  226. return f"{level1}.{level2}.{level3}"
  227. def _is_metric_configured(self, level1: str, level2: str, level3: str) -> bool:
  228. """检查指标是否在配置中注册"""
  229. custom_config = self.config_manager.get_custom_config()
  230. try:
  231. return (level1 in custom_config and
  232. isinstance(custom_config[level1], dict) and
  233. level2 in custom_config[level1] and
  234. isinstance(custom_config[level1][level2], dict) and
  235. level3 in custom_config[level1][level2] and
  236. isinstance(custom_config[level1][level2][level3], dict))
  237. except Exception:
  238. return False
  239. def get_builtin_metrics(self) -> Dict[str, Type]:
  240. return self.metric_modules
  241. def get_custom_metrics(self) -> Dict[str, Any]:
  242. return self.custom_metric_modules
  243. class EvaluationEngine:
  244. """评估引擎组件"""
  245. def __init__(self, logger: logging.Logger, config_manager: ConfigManager, metric_loader: MetricLoader):
  246. self.logger = logger
  247. self.config_manager = config_manager
  248. self.metric_loader = metric_loader
  249. def evaluate(self, data: Any) -> Dict[str, Any]:
  250. """执行评估流程"""
  251. raw_results = self._collect_builtin_metrics(data)
  252. custom_results = self._collect_custom_metrics(data)
  253. return self._process_merged_results(raw_results, custom_results)
  254. def _collect_builtin_metrics(self, data: Any) -> Dict[str, Any]:
  255. """收集内置指标结果"""
  256. metric_modules = self.metric_loader.get_builtin_metrics()
  257. raw_results: Dict[str, Any] = {}
  258. # 获取配置中实际存在的指标
  259. config = self.config_manager.get_config()
  260. available_metrics = {
  261. metric_name for metric_name in metric_modules.keys()
  262. if metric_name in config and isinstance(config[metric_name], dict)
  263. }
  264. # 只处理配置中存在的指标
  265. filtered_modules = {
  266. name: module for name, module in metric_modules.items()
  267. if name in available_metrics
  268. }
  269. # 优化线程池大小,避免创建过多线程
  270. max_workers = min(len(filtered_modules), DEFAULT_WORKERS)
  271. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  272. futures = {
  273. executor.submit(self._run_module, module, data, module_name): module_name
  274. for module_name, module in filtered_modules.items()
  275. }
  276. for future in futures:
  277. module_name = futures[future]
  278. try:
  279. result = future.result()
  280. raw_results[module_name] = result[module_name]
  281. except Exception as e:
  282. self.logger.error(
  283. f"{module_name} evaluation failed: {str(e)}",
  284. exc_info=True,
  285. )
  286. raw_results[module_name] = {
  287. "status": "error",
  288. "message": str(e),
  289. "timestamp": datetime.now().isoformat(),
  290. }
  291. return raw_results
  292. def _collect_custom_metrics(self, data: Any) -> Dict[str, Dict]:
  293. """收集自定义指标结果"""
  294. custom_metrics = self.metric_loader.get_custom_metrics()
  295. if not custom_metrics:
  296. return {}
  297. custom_results = {}
  298. # 使用线程池并行处理自定义指标
  299. max_workers = min(len(custom_metrics), DEFAULT_WORKERS)
  300. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  301. futures = {}
  302. # 提交所有自定义指标任务
  303. for metric_key, metric_info in custom_metrics.items():
  304. futures[executor.submit(self._run_custom_metric, metric_key, metric_info, data)] = metric_key
  305. # 收集结果
  306. for future in futures:
  307. metric_key = futures[future]
  308. try:
  309. level1, result = future.result()
  310. if level1:
  311. custom_results[level1] = result
  312. except Exception as e:
  313. self.logger.error(f"Custom metric {metric_key} execution failed: {str(e)}")
  314. return custom_results
  315. def _run_custom_metric(self, metric_key: str, metric_info: Dict, data: Any) -> Tuple[str, Dict]:
  316. """执行单个自定义指标"""
  317. try:
  318. level1, level2, level3 = metric_key.split('.')
  319. if metric_info['type'] == 'class':
  320. metric_class = metric_info['class']
  321. metric_instance = metric_class(data)
  322. metric_result = metric_instance.calculate()
  323. else:
  324. module = metric_info['module']
  325. metric_result = module.evaluate(data)
  326. self.logger.info(f"Calculated custom metric: {level1}.{level2}.{level3}")
  327. return level1, metric_result
  328. except Exception as e:
  329. self.logger.error(f"Custom metric {metric_key} failed: {str(e)}")
  330. try:
  331. level1 = metric_key.split('.')[0]
  332. return level1, {
  333. "status": "error",
  334. "message": str(e),
  335. "timestamp": datetime.now().isoformat(),
  336. }
  337. except Exception:
  338. return "", {}
  339. def _process_merged_results(self, raw_results: Dict, custom_results: Dict) -> Dict:
  340. """处理合并后的评估结果"""
  341. from modules.lib.score import Score
  342. final_results = {}
  343. merged_config = self.config_manager.get_config()
  344. for level1, level1_data in raw_results.items():
  345. if level1 in custom_results:
  346. level1_data.update(custom_results[level1])
  347. try:
  348. evaluator = Score(merged_config, level1)
  349. final_results.update(evaluator.evaluate(level1_data))
  350. except Exception as e:
  351. final_results[level1] = self._format_error(e)
  352. for level1, level1_data in custom_results.items():
  353. if level1 not in raw_results:
  354. try:
  355. evaluator = Score(merged_config, level1)
  356. final_results.update(evaluator.evaluate(level1_data))
  357. except Exception as e:
  358. final_results[level1] = self._format_error(e)
  359. return final_results
  360. def _format_error(self, e: Exception) -> Dict:
  361. return {
  362. "status": "error",
  363. "message": str(e),
  364. "timestamp": datetime.now().isoformat()
  365. }
  366. def _run_module(self, module_class: Any, data: Any, module_name: str) -> Dict[str, Any]:
  367. """执行单个评估模块"""
  368. try:
  369. instance = module_class(data)
  370. return {module_name: instance.report_statistic()}
  371. except Exception as e:
  372. self.logger.error(f"{module_name} execution error: {str(e)}", exc_info=True)
  373. return {module_name: {"error": str(e)}}
  374. class LoggingManager:
  375. """日志管理组件"""
  376. def __init__(self, log_path: Path):
  377. self.log_path = log_path
  378. self.logger = self._init_logger()
  379. def _init_logger(self) -> logging.Logger:
  380. """初始化日志系统"""
  381. try:
  382. from modules.lib.log_manager import LogManager
  383. log_manager = LogManager(self.log_path)
  384. return log_manager.get_logger()
  385. except (ImportError, PermissionError, IOError) as e:
  386. logger = logging.getLogger("evaluator")
  387. logger.setLevel(logging.INFO)
  388. console_handler = logging.StreamHandler()
  389. console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
  390. logger.addHandler(console_handler)
  391. logger.warning(f"Failed to init standard logger: {str(e)}, using fallback logger")
  392. return logger
  393. def get_logger(self) -> logging.Logger:
  394. return self.logger
  395. class DataProcessor:
  396. """数据处理组件"""
  397. def __init__(self, logger: logging.Logger, data_path: Path, config_path: Optional[Path] = None):
  398. self.logger = logger
  399. self.data_path = data_path
  400. self.config_path = config_path
  401. self.case_name = self.data_path.name
  402. self._processor = None
  403. @property
  404. def processor(self) -> Any:
  405. """懒加载数据处理器,只在首次访问时创建"""
  406. if self._processor is None:
  407. self._processor = self._load_processor()
  408. return self._processor
  409. def _load_processor(self) -> Any:
  410. """加载数据处理器"""
  411. try:
  412. start_time = time.perf_counter()
  413. from modules.lib import data_process
  414. processor = data_process.DataPreprocessing(self.data_path, self.config_path)
  415. elapsed_time = time.perf_counter() - start_time
  416. self.logger.info(f"Data processor loaded in {elapsed_time:.2f}s")
  417. return processor
  418. except ImportError as e:
  419. self.logger.error(f"Failed to load data processor: {str(e)}")
  420. raise RuntimeError(f"Failed to load data processor: {str(e)}") from e
  421. def validate(self) -> None:
  422. """验证数据路径"""
  423. if not self.data_path.exists():
  424. raise FileNotFoundError(f"Data path not exists: {self.data_path}")
  425. if not self.data_path.is_dir():
  426. raise NotADirectoryError(f"Invalid data directory: {self.data_path}")
  427. class EvaluationPipeline:
  428. """评估流水线控制器"""
  429. def __init__(self, all_config_path: str, base_config_path: str, log_path: str, data_path: str, report_path: str,
  430. custom_metrics_path: Optional[str] = None, custom_config_path: Optional[str] = None):
  431. # 路径初始化
  432. self.all_config_path = Path(all_config_path) if all_config_path else None
  433. self.base_config_path = Path(base_config_path) if base_config_path else None
  434. self.custom_config_path = Path(custom_config_path) if custom_config_path else None
  435. self.data_path = Path(data_path)
  436. self.report_path = Path(report_path)
  437. self.custom_metrics_path = Path(custom_metrics_path) if custom_metrics_path else None
  438. # 日志
  439. self.logging_manager = LoggingManager(Path(log_path))
  440. self.logger = self.logging_manager.get_logger()
  441. # 配置
  442. self.config_manager = ConfigManager(self.logger)
  443. self.config = self.config_manager.load_configs(
  444. self.all_config_path, self.base_config_path, self.custom_config_path
  445. )
  446. # 指标加载
  447. self.metric_loader = MetricLoader(self.logger, self.config_manager)
  448. self.metric_loader.load_builtin_metrics()
  449. self.metric_loader.load_custom_metrics(self.custom_metrics_path)
  450. # 数据处理
  451. self.data_processor = DataProcessor(self.logger, self.data_path, self.all_config_path)
  452. self.evaluation_engine = EvaluationEngine(self.logger, self.config_manager, self.metric_loader)
  453. def execute(self) -> Dict[str, Any]:
  454. """执行评估流水线"""
  455. try:
  456. # 只在首次运行时验证数据路径
  457. self.data_processor.validate()
  458. self.logger.info(f"Start evaluation: {self.data_path.name}")
  459. start_time = time.perf_counter()
  460. # 性能分析日志
  461. config_start = time.perf_counter()
  462. results = self.evaluation_engine.evaluate(self.data_processor.processor)
  463. eval_time = time.perf_counter() - config_start
  464. # 生成报告
  465. report_start = time.perf_counter()
  466. report = self._generate_report(self.data_processor.case_name, results)
  467. report_time = time.perf_counter() - report_start
  468. # 总耗时
  469. elapsed_time = time.perf_counter() - start_time
  470. self.logger.info(f"Evaluation completed, time: {elapsed_time:.2f}s (评估: {eval_time:.2f}s, 报告: {report_time:.2f}s)")
  471. return report
  472. except Exception as e:
  473. self.logger.critical(f"Evaluation failed: {str(e)}", exc_info=True)
  474. return {"error": str(e), "traceback": traceback.format_exc()}
  475. def _add_overall_result(self, report: Dict[str, Any]) -> Dict[str, Any]:
  476. """处理评测报告并添加总体结果字段"""
  477. # 加载阈值参数
  478. thresholds = {
  479. "T0": self.config['T_threshold']['T0_threshold'],
  480. "T1": self.config['T_threshold']['T1_threshold'],
  481. "T2": self.config['T_threshold']['T2_threshold']
  482. }
  483. # 初始化计数器
  484. counters = {'p0': 0, 'p1': 0, 'p2': 0}
  485. # 优化:一次性收集所有失败的指标
  486. failed_categories = [
  487. (category, category_data.get('priority'))
  488. for category, category_data in report.items()
  489. if isinstance(category_data, dict) and category != "metadata" and not category_data.get('result', True)
  490. ]
  491. # 计数
  492. for _, priority in failed_categories:
  493. if priority == 0:
  494. counters['p0'] += 1
  495. elif priority == 1:
  496. counters['p1'] += 1
  497. elif priority == 2:
  498. counters['p2'] += 1
  499. # 阈值判断逻辑
  500. overall_result = not (
  501. counters['p0'] > thresholds['T0'] or
  502. counters['p1'] > thresholds['T1'] or
  503. counters['p2'] > thresholds['T2']
  504. )
  505. # 生成处理后的报告
  506. processed_report = report.copy()
  507. processed_report['overall_result'] = overall_result
  508. # 添加统计信息
  509. processed_report['threshold_checks'] = {
  510. 'T0_threshold': thresholds['T0'],
  511. 'T1_threshold': thresholds['T1'],
  512. 'T2_threshold': thresholds['T2'],
  513. 'actual_counts': counters
  514. }
  515. self.logger.info(f"Added overall result: {overall_result}")
  516. return processed_report
  517. def _generate_report(self, case_name: str, results: Dict[str, Any]) -> Dict[str, Any]:
  518. """生成评估报告"""
  519. from modules.lib.common import dict2json
  520. self.report_path.mkdir(parents=True, exist_ok=True)
  521. results["metadata"] = {
  522. "case_name": case_name,
  523. "timestamp": datetime.now().isoformat(),
  524. "version": "1.0",
  525. }
  526. # 添加总体结果评估
  527. results = self._add_overall_result(results)
  528. report_file = self.report_path / f"{case_name}_report.json"
  529. dict2json(results, report_file)
  530. self.logger.info(f"Report generated: {report_file}")
  531. return results
  532. def main():
  533. """命令行入口"""
  534. parser = argparse.ArgumentParser(
  535. description="Autonomous Driving Evaluation System V3.1",
  536. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  537. )
  538. # 必要参数
  539. parser.add_argument(
  540. "--dataPath",
  541. type=str,
  542. default=r"D:\Kevin\zhaoyuan\data\V2V_CSAE53-2020_ForwardCollision_LST_01-02",
  543. help="Input data directory",
  544. )
  545. # 配置参数
  546. config_group = parser.add_argument_group('Configuration')
  547. config_group.add_argument(
  548. "--allConfigPath",
  549. type=str,
  550. default=r"config/all_metrics_config.yaml",
  551. help="Full metrics config file path (built-in + custom)",
  552. )
  553. config_group.add_argument(
  554. "--baseConfigPath",
  555. type=str,
  556. default=r"config/builtin_metrics_config.yaml",
  557. help="Built-in metrics config file path",
  558. )
  559. config_group.add_argument(
  560. "--customConfigPath",
  561. type=str,
  562. default=r"config/custom_metrics_config.yaml",
  563. help="Custom metrics config path (optional)",
  564. )
  565. # 输出参数
  566. output_group = parser.add_argument_group('Output')
  567. output_group.add_argument(
  568. "--logPath",
  569. type=str,
  570. default="test.log",
  571. help="Log file path",
  572. )
  573. output_group.add_argument(
  574. "--reportPath",
  575. type=str,
  576. default="reports",
  577. help="Output report directory",
  578. )
  579. # 扩展参数
  580. ext_group = parser.add_argument_group('Extensions')
  581. ext_group.add_argument(
  582. "--customMetricsPath",
  583. type=str,
  584. default="custom_metrics",
  585. help="Custom metrics scripts directory (optional)",
  586. )
  587. args = parser.parse_args()
  588. try:
  589. pipeline = EvaluationPipeline(
  590. all_config_path=args.allConfigPath,
  591. base_config_path=args.baseConfigPath,
  592. log_path=args.logPath,
  593. data_path=args.dataPath,
  594. report_path=args.reportPath,
  595. custom_metrics_path=args.customMetricsPath,
  596. custom_config_path=args.customConfigPath
  597. )
  598. start_time = time.perf_counter()
  599. result = pipeline.execute()
  600. elapsed_time = time.perf_counter() - start_time
  601. if "error" in result:
  602. print(f"Evaluation failed: {result['error']}")
  603. sys.exit(1)
  604. print(f"Evaluation completed, total time: {elapsed_time:.2f}s")
  605. print(f"Report path: {pipeline.report_path}")
  606. except KeyboardInterrupt:
  607. print("\nUser interrupted")
  608. sys.exit(130)
  609. except Exception as e:
  610. print(f"Execution error: {str(e)}")
  611. traceback.print_exc()
  612. sys.exit(1)
  613. if __name__ == "__main__":
  614. warnings.filterwarnings("ignore")
  615. main()