evaluator_enhanced.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624
  1. #!/usr/bin/env python3
  2. # evaluator_enhanced.py
  3. import sys
  4. import warnings
  5. import time
  6. import importlib
  7. import importlib.util
  8. import yaml
  9. from pathlib import Path
  10. import argparse
  11. from concurrent.futures import ThreadPoolExecutor
  12. from functools import lru_cache
  13. from typing import Dict, Any, List, Optional, Type, Tuple, Callable, Union
  14. from datetime import datetime
  15. import logging
  16. import traceback
  17. import json
  18. import inspect
  19. # 常量定义
  20. DEFAULT_WORKERS = 4
  21. CUSTOM_METRIC_PREFIX = "metric_"
  22. CUSTOM_METRIC_FILE_PATTERN = "*.py"
  23. # 安全设置根目录路径
  24. if hasattr(sys, "_MEIPASS"):
  25. _ROOT_PATH = Path(sys._MEIPASS)
  26. else:
  27. _ROOT_PATH = Path(__file__).resolve().parent.parent
  28. sys.path.insert(0, str(_ROOT_PATH))
  29. class ConfigManager:
  30. """配置管理组件"""
  31. def __init__(self, logger: logging.Logger):
  32. self.logger = logger
  33. self.base_config: Dict[str, Any] = {}
  34. self.custom_config: Dict[str, Any] = {}
  35. self.merged_config: Dict[str, Any] = {}
  36. def split_configs(self, all_config_path: Path, base_config_path: Path, custom_config_path: Path) -> None:
  37. """从all_metrics_config.yaml拆分成内置和自定义配置"""
  38. try:
  39. with open(all_config_path, 'r', encoding='utf-8') as f:
  40. all_metrics = yaml.safe_load(f) or {}
  41. with open(base_config_path, 'r', encoding='utf-8') as f:
  42. builtin_metrics = yaml.safe_load(f) or {}
  43. custom_metrics = self._find_custom_metrics(all_metrics, builtin_metrics)
  44. if custom_metrics:
  45. with open(custom_config_path, 'w', encoding='utf-8') as f:
  46. yaml.dump(custom_metrics, f, allow_unicode=True, sort_keys=False, indent=2)
  47. self.logger.info(f"Split configs: custom metrics saved to {custom_config_path}")
  48. except Exception as e:
  49. self.logger.error(f"Failed to split configs: {str(e)}")
  50. raise
  51. def _find_custom_metrics(self, all_metrics, builtin_metrics, current_path=""):
  52. """递归比较找出自定义指标"""
  53. custom_metrics = {}
  54. if isinstance(all_metrics, dict) and isinstance(builtin_metrics, dict):
  55. for key in all_metrics:
  56. if key not in builtin_metrics:
  57. custom_metrics[key] = all_metrics[key]
  58. else:
  59. child_custom = self._find_custom_metrics(
  60. all_metrics[key],
  61. builtin_metrics[key],
  62. f"{current_path}.{key}" if current_path else key
  63. )
  64. if child_custom:
  65. custom_metrics[key] = child_custom
  66. elif all_metrics != builtin_metrics:
  67. return all_metrics
  68. if custom_metrics:
  69. return self._ensure_structure(custom_metrics, all_metrics, current_path)
  70. return None
  71. def _ensure_structure(self, metrics_dict, full_dict, path):
  72. """确保每级包含name和priority"""
  73. if not isinstance(metrics_dict, dict):
  74. return metrics_dict
  75. current = full_dict
  76. for key in path.split('.'):
  77. if key in current:
  78. current = current[key]
  79. else:
  80. break
  81. result = {}
  82. if isinstance(current, dict):
  83. if 'name' in current:
  84. result['name'] = current['name']
  85. if 'priority' in current:
  86. result['priority'] = current['priority']
  87. for key, value in metrics_dict.items():
  88. if key not in ['name', 'priority']:
  89. result[key] = self._ensure_structure(value, full_dict, f"{path}.{key}" if path else key)
  90. return result
  91. def load_configs(self, base_config_path: Optional[Path], custom_config_path: Optional[Path]) -> Dict[str, Any]:
  92. """加载并合并配置"""
  93. # 自动拆分配置
  94. if base_config_path and base_config_path.exists():
  95. all_config_path = base_config_path.parent / "all_metrics_config.yaml"
  96. if all_config_path.exists():
  97. target_custom_path = custom_config_path or (base_config_path.parent / "custom_metrics_config.yaml")
  98. self.split_configs(all_config_path, base_config_path, target_custom_path)
  99. custom_config_path = target_custom_path
  100. self.base_config = self._safe_load_config(base_config_path) if base_config_path else {}
  101. self.custom_config = self._safe_load_config(custom_config_path) if custom_config_path else {}
  102. self.merged_config = self._merge_configs(self.base_config, self.custom_config)
  103. return self.merged_config
  104. def _safe_load_config(self, config_path: Path) -> Dict[str, Any]:
  105. """安全加载YAML配置"""
  106. try:
  107. if not config_path.exists():
  108. self.logger.warning(f"Config file not found: {config_path}")
  109. return {}
  110. with config_path.open('r', encoding='utf-8') as f:
  111. config = yaml.safe_load(f) or {}
  112. self.logger.info(f"Loaded config: {config_path}")
  113. return config
  114. except Exception as e:
  115. self.logger.error(f"Failed to load config {config_path}: {str(e)}")
  116. return {}
  117. def _merge_configs(self, base_config: Dict, custom_config: Dict) -> Dict:
  118. """智能合并配置"""
  119. merged = base_config.copy()
  120. for level1_key, level1_value in custom_config.items():
  121. if not isinstance(level1_value, dict) or 'name' not in level1_value:
  122. if level1_key not in merged:
  123. merged[level1_key] = level1_value
  124. continue
  125. if level1_key not in merged:
  126. merged[level1_key] = level1_value
  127. else:
  128. for level2_key, level2_value in level1_value.items():
  129. if level2_key in ['name', 'priority']:
  130. continue
  131. if isinstance(level2_value, dict):
  132. if level2_key not in merged[level1_key]:
  133. merged[level1_key][level2_key] = level2_value
  134. else:
  135. for level3_key, level3_value in level2_value.items():
  136. if level3_key in ['name', 'priority']:
  137. continue
  138. if isinstance(level3_value, dict):
  139. if level3_key not in merged[level1_key][level2_key]:
  140. merged[level1_key][level2_key][level3_key] = level3_value
  141. return merged
  142. def get_config(self) -> Dict[str, Any]:
  143. return self.merged_config
  144. def get_base_config(self) -> Dict[str, Any]:
  145. return self.base_config
  146. def get_custom_config(self) -> Dict[str, Any]:
  147. return self.custom_config
  148. class MetricLoader:
  149. """指标加载器组件"""
  150. def __init__(self, logger: logging.Logger, config_manager: ConfigManager):
  151. self.logger = logger
  152. self.config_manager = config_manager
  153. self.metric_modules: Dict[str, Type] = {}
  154. self.custom_metric_modules: Dict[str, Any] = {}
  155. def load_builtin_metrics(self) -> Dict[str, Type]:
  156. """加载内置指标模块"""
  157. module_mapping = {
  158. "safety": ("modules.metric.safety", "SafeManager"),
  159. "comfort": ("modules.metric.comfort", "ComfortManager"),
  160. "traffic": ("modules.metric.traffic", "TrafficManager"),
  161. "efficient": ("modules.metric.efficient", "EfficientManager"),
  162. "function": ("modules.metric.function", "FunctionManager"),
  163. }
  164. self.metric_modules = {
  165. name: self._load_module(*info)
  166. for name, info in module_mapping.items()
  167. }
  168. self.logger.info(f"Loaded builtin metrics: {', '.join(self.metric_modules.keys())}")
  169. return self.metric_modules
  170. @lru_cache(maxsize=32)
  171. def _load_module(self, module_path: str, class_name: str) -> Type:
  172. """动态加载Python模块"""
  173. try:
  174. module = __import__(module_path, fromlist=[class_name])
  175. return getattr(module, class_name)
  176. except (ImportError, AttributeError) as e:
  177. self.logger.error(f"Failed to load module: {module_path}.{class_name} - {str(e)}")
  178. raise
  179. def load_custom_metrics(self, custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
  180. """加载自定义指标模块"""
  181. if not custom_metrics_path or not custom_metrics_path.is_dir():
  182. self.logger.info("No custom metrics path or path not exists")
  183. return {}
  184. loaded_count = 0
  185. for py_file in custom_metrics_path.glob(CUSTOM_METRIC_FILE_PATTERN):
  186. if py_file.name.startswith(CUSTOM_METRIC_PREFIX):
  187. if self._process_custom_metric_file(py_file):
  188. loaded_count += 1
  189. self.logger.info(f"Loaded {loaded_count} custom metric modules")
  190. return self.custom_metric_modules
  191. def _process_custom_metric_file(self, file_path: Path) -> bool:
  192. """处理单个自定义指标文件"""
  193. try:
  194. metric_key = self._validate_metric_file(file_path)
  195. module_name = f"custom_metric_{file_path.stem}"
  196. spec = importlib.util.spec_from_file_location(module_name, file_path)
  197. module = importlib.util.module_from_spec(spec)
  198. spec.loader.exec_module(module)
  199. from modules.lib.metric_registry import BaseMetric
  200. metric_class = None
  201. for name, obj in inspect.getmembers(module):
  202. if (inspect.isclass(obj) and
  203. issubclass(obj, BaseMetric) and
  204. obj != BaseMetric):
  205. metric_class = obj
  206. break
  207. if metric_class:
  208. self.custom_metric_modules[metric_key] = {
  209. 'type': 'class',
  210. 'module': module,
  211. 'class': metric_class
  212. }
  213. self.logger.info(f"Loaded class-based custom metric: {metric_key}")
  214. elif hasattr(module, 'evaluate'):
  215. self.custom_metric_modules[metric_key] = {
  216. 'type': 'function',
  217. 'module': module
  218. }
  219. self.logger.info(f"Loaded function-based custom metric: {metric_key}")
  220. else:
  221. raise AttributeError(f"Missing evaluate() function or BaseMetric subclass: {file_path.name}")
  222. return True
  223. except ValueError as e:
  224. self.logger.warning(str(e))
  225. return False
  226. except Exception as e:
  227. self.logger.error(f"Failed to load custom metric {file_path}: {str(e)}")
  228. return False
  229. def _validate_metric_file(self, file_path: Path) -> str:
  230. """验证自定义指标文件命名规范"""
  231. stem = file_path.stem[len(CUSTOM_METRIC_PREFIX):]
  232. parts = stem.split('_')
  233. if len(parts) < 3:
  234. raise ValueError(f"Invalid custom metric filename: {file_path.name}, should be metric_<level1>_<level2>_<level3>.py")
  235. level1, level2, level3 = parts[:3]
  236. if not self._is_metric_configured(level1, level2, level3):
  237. raise ValueError(f"Unconfigured metric: {level1}.{level2}.{level3}")
  238. return f"{level1}.{level2}.{level3}"
  239. def _is_metric_configured(self, level1: str, level2: str, level3: str) -> bool:
  240. """检查指标是否在配置中注册"""
  241. custom_config = self.config_manager.get_custom_config()
  242. try:
  243. return (level1 in custom_config and
  244. isinstance(custom_config[level1], dict) and
  245. level2 in custom_config[level1] and
  246. isinstance(custom_config[level1][level2], dict) and
  247. level3 in custom_config[level1][level2] and
  248. isinstance(custom_config[level1][level2][level3], dict))
  249. except Exception:
  250. return False
  251. def get_builtin_metrics(self) -> Dict[str, Type]:
  252. return self.metric_modules
  253. def get_custom_metrics(self) -> Dict[str, Any]:
  254. return self.custom_metric_modules
  255. class EvaluationEngine:
  256. """评估引擎组件"""
  257. def __init__(self, logger: logging.Logger, config_manager: ConfigManager, metric_loader: MetricLoader):
  258. self.logger = logger
  259. self.config_manager = config_manager
  260. self.metric_loader = metric_loader
  261. def evaluate(self, data: Any) -> Dict[str, Any]:
  262. """执行评估流程"""
  263. raw_results = self._collect_builtin_metrics(data)
  264. custom_results = self._collect_custom_metrics(data)
  265. return self._process_merged_results(raw_results, custom_results)
  266. def _collect_builtin_metrics(self, data: Any) -> Dict[str, Any]:
  267. """收集内置指标结果"""
  268. metric_modules = self.metric_loader.get_builtin_metrics()
  269. raw_results: Dict[str, Any] = {}
  270. with ThreadPoolExecutor(max_workers=len(metric_modules)) as executor:
  271. futures = {
  272. executor.submit(self._run_module, module, data, module_name): module_name
  273. for module_name, module in metric_modules.items()
  274. }
  275. for future in futures:
  276. module_name = futures[future]
  277. try:
  278. result = future.result()
  279. raw_results[module_name] = result[module_name]
  280. except Exception as e:
  281. self.logger.error(
  282. f"{module_name} evaluation failed: {str(e)}",
  283. exc_info=True,
  284. )
  285. raw_results[module_name] = {
  286. "status": "error",
  287. "message": str(e),
  288. "timestamp": datetime.now().isoformat(),
  289. }
  290. return raw_results
  291. def _collect_custom_metrics(self, data: Any) -> Dict[str, Dict]:
  292. """收集自定义指标结果"""
  293. custom_metrics = self.metric_loader.get_custom_metrics()
  294. if not custom_metrics:
  295. return {}
  296. custom_results = {}
  297. for metric_key, metric_info in custom_metrics.items():
  298. try:
  299. level1, level2, level3 = metric_key.split('.')
  300. if metric_info['type'] == 'class':
  301. metric_class = metric_info['class']
  302. metric_instance = metric_class(data)
  303. metric_result = metric_instance.calculate()
  304. else:
  305. module = metric_info['module']
  306. metric_result = module.evaluate(data)
  307. if level1 not in custom_results:
  308. custom_results[level1] = {}
  309. custom_results[level1] = metric_result
  310. self.logger.info(f"Calculated custom metric: {level1}.{level2}.{level3}")
  311. except Exception as e:
  312. self.logger.error(f"Custom metric {metric_key} failed: {str(e)}")
  313. try:
  314. level1, level2, level3 = metric_key.split('.')
  315. if level1 not in custom_results:
  316. custom_results[level1] = {}
  317. custom_results[level1] = {
  318. "status": "error",
  319. "message": str(e),
  320. "timestamp": datetime.now().isoformat(),
  321. }
  322. except Exception:
  323. pass
  324. return custom_results
  325. def _process_merged_results(self, raw_results: Dict, custom_results: Dict) -> Dict:
  326. """处理合并后的评估结果"""
  327. from modules.lib.score import Score
  328. final_results = {}
  329. merged_config = self.config_manager.get_config()
  330. for level1, level1_data in raw_results.items():
  331. if level1 in custom_results:
  332. level1_data.update(custom_results[level1])
  333. try:
  334. evaluator = Score(merged_config, level1)
  335. final_results.update(evaluator.evaluate(level1_data))
  336. except Exception as e:
  337. final_results[level1] = self._format_error(e)
  338. for level1, level1_data in custom_results.items():
  339. if level1 not in raw_results:
  340. try:
  341. evaluator = Score(merged_config, level1)
  342. final_results.update(evaluator.evaluate(level1_data))
  343. except Exception as e:
  344. final_results[level1] = self._format_error(e)
  345. return final_results
  346. def _format_error(self, e: Exception) -> Dict:
  347. return {
  348. "status": "error",
  349. "message": str(e),
  350. "timestamp": datetime.now().isoformat()
  351. }
  352. def _run_module(self, module_class: Any, data: Any, module_name: str) -> Dict[str, Any]:
  353. """执行单个评估模块"""
  354. try:
  355. instance = module_class(data)
  356. return {module_name: instance.report_statistic()}
  357. except Exception as e:
  358. self.logger.error(f"{module_name} execution error: {str(e)}", exc_info=True)
  359. return {module_name: {"error": str(e)}}
  360. class LoggingManager:
  361. """日志管理组件"""
  362. def __init__(self, log_path: Path):
  363. self.log_path = log_path
  364. self.logger = self._init_logger()
  365. def _init_logger(self) -> logging.Logger:
  366. """初始化日志系统"""
  367. try:
  368. from modules.lib.log_manager import LogManager
  369. log_manager = LogManager(self.log_path)
  370. return log_manager.get_logger()
  371. except (ImportError, PermissionError, IOError) as e:
  372. logger = logging.getLogger("evaluator")
  373. logger.setLevel(logging.INFO)
  374. console_handler = logging.StreamHandler()
  375. console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
  376. logger.addHandler(console_handler)
  377. logger.warning(f"Failed to init standard logger: {str(e)}, using fallback logger")
  378. return logger
  379. def get_logger(self) -> logging.Logger:
  380. return self.logger
  381. class DataProcessor:
  382. """数据处理组件"""
  383. def __init__(self, logger: logging.Logger, data_path: Path, config_path: Optional[Path] = None):
  384. self.logger = logger
  385. self.data_path = data_path
  386. self.config_path = config_path
  387. self.processor = self._load_processor()
  388. self.case_name = self.data_path.name
  389. def _load_processor(self) -> Any:
  390. """加载数据处理器"""
  391. try:
  392. from modules.lib import data_process
  393. return data_process.DataPreprocessing(self.data_path, self.config_path)
  394. except ImportError as e:
  395. self.logger.error(f"Failed to load data processor: {str(e)}")
  396. raise RuntimeError(f"Failed to load data processor: {str(e)}") from e
  397. def validate(self) -> None:
  398. """验证数据路径"""
  399. if not self.data_path.exists():
  400. raise FileNotFoundError(f"Data path not exists: {self.data_path}")
  401. if not self.data_path.is_dir():
  402. raise NotADirectoryError(f"Invalid data directory: {self.data_path}")
  403. class EvaluationPipeline:
  404. """评估流水线控制器"""
  405. def __init__(self, config_path: str, log_path: str, data_path: str, report_path: str,
  406. custom_metrics_path: Optional[str] = None, custom_config_path: Optional[str] = None):
  407. # 路径初始化
  408. self.config_path = Path(config_path) if config_path else None
  409. self.custom_config_path = Path(custom_config_path) if custom_config_path else None
  410. self.data_path = Path(data_path)
  411. self.report_path = Path(report_path)
  412. self.custom_metrics_path = Path(custom_metrics_path) if custom_metrics_path else None
  413. # 组件初始化
  414. self.logging_manager = LoggingManager(Path(log_path))
  415. self.logger = self.logging_manager.get_logger()
  416. self.config_manager = ConfigManager(self.logger)
  417. self.config_manager.load_configs(self.config_path, self.custom_config_path)
  418. self.metric_loader = MetricLoader(self.logger, self.config_manager)
  419. self.metric_loader.load_builtin_metrics()
  420. self.metric_loader.load_custom_metrics(self.custom_metrics_path)
  421. self.evaluation_engine = EvaluationEngine(self.logger, self.config_manager, self.metric_loader)
  422. self.data_processor = DataProcessor(self.logger, self.data_path, self.config_path)
  423. def execute(self) -> Dict[str, Any]:
  424. """执行评估流水线"""
  425. try:
  426. self.data_processor.validate()
  427. self.logger.info(f"Start evaluation: {self.data_path.name}")
  428. start_time = time.perf_counter()
  429. results = self.evaluation_engine.evaluate(self.data_processor.processor)
  430. elapsed_time = time.perf_counter() - start_time
  431. self.logger.info(f"Evaluation completed, time: {elapsed_time:.2f}s")
  432. report = self._generate_report(self.data_processor.case_name, results)
  433. return report
  434. except Exception as e:
  435. self.logger.critical(f"Evaluation failed: {str(e)}", exc_info=True)
  436. return {"error": str(e), "traceback": traceback.format_exc()}
  437. def _generate_report(self, case_name: str, results: Dict[str, Any]) -> Dict[str, Any]:
  438. """生成评估报告"""
  439. from modules.lib.common import dict2json
  440. self.report_path.mkdir(parents=True, exist_ok=True)
  441. results["metadata"] = {
  442. "case_name": case_name,
  443. "timestamp": datetime.now().isoformat(),
  444. "version": "3.1.0",
  445. }
  446. report_file = self.report_path / f"{case_name}_report.json"
  447. dict2json(results, report_file)
  448. self.logger.info(f"Report generated: {report_file}")
  449. return results
  450. def main():
  451. """命令行入口"""
  452. parser = argparse.ArgumentParser(
  453. description="Autonomous Driving Evaluation System V3.1",
  454. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  455. )
  456. parser.add_argument(
  457. "--logPath",
  458. type=str,
  459. default="logs/test.log",
  460. help="Log file path",
  461. )
  462. parser.add_argument(
  463. "--dataPath",
  464. type=str,
  465. default=r"D:\Kevin\zhaoyuan\data\V2V_CSAE53-2020_ForwardCollisionW_LST_01-01",
  466. help="Input data directory",
  467. )
  468. parser.add_argument(
  469. "--configPath",
  470. type=str,
  471. default="config/all_metrics_config.yaml",
  472. help="Metrics config file path",
  473. )
  474. parser.add_argument(
  475. "--reportPath",
  476. type=str,
  477. default="reports",
  478. help="Output report directory",
  479. )
  480. parser.add_argument(
  481. "--customMetricsPath",
  482. type=str,
  483. default="custom_metrics",
  484. help="Custom metrics scripts directory (optional)",
  485. )
  486. parser.add_argument(
  487. "--customConfigPath",
  488. type=str,
  489. default="config/custom_metrics_config.yaml",
  490. help="Custom metrics config path (optional)",
  491. )
  492. args = parser.parse_args()
  493. try:
  494. pipeline = EvaluationPipeline(
  495. args.configPath,
  496. args.logPath,
  497. args.dataPath,
  498. args.reportPath,
  499. args.customMetricsPath,
  500. args.customConfigPath
  501. )
  502. start_time = time.perf_counter()
  503. result = pipeline.execute()
  504. elapsed_time = time.perf_counter() - start_time
  505. if "error" in result:
  506. print(f"Evaluation failed: {result['error']}")
  507. sys.exit(1)
  508. print(f"Evaluation completed, total time: {elapsed_time:.2f}s")
  509. print(f"Report path: {pipeline.report_path}")
  510. except KeyboardInterrupt:
  511. print("\nUser interrupted")
  512. sys.exit(130)
  513. except Exception as e:
  514. print(f"Execution error: {str(e)}")
  515. traceback.print_exc()
  516. sys.exit(1)
  517. if __name__ == "__main__":
  518. warnings.filterwarnings("ignore")
  519. main()