evaluator_enhanced.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759
  1. #!/usr/bin/env python3
  2. # evaluator_enhanced.py
  3. import sys
  4. import warnings
  5. import time
  6. import importlib
  7. import importlib.util
  8. import yaml
  9. from pathlib import Path
  10. import argparse
  11. from concurrent.futures import ThreadPoolExecutor
  12. from functools import lru_cache
  13. from typing import Dict, Any, List, Optional, Type, Tuple, Callable, Union
  14. from datetime import datetime
  15. import logging
  16. import traceback
  17. import json
  18. import inspect
  19. # 常量定义
  20. DEFAULT_WORKERS = 4
  21. CUSTOM_METRIC_PREFIX = "metric_"
  22. CUSTOM_METRIC_FILE_PATTERN = "*.py"
  23. # 安全设置根目录路径
  24. if hasattr(sys, "_MEIPASS"):
  25. _ROOT_PATH = Path(sys._MEIPASS)
  26. else:
  27. _ROOT_PATH = Path(__file__).resolve().parent.parent
  28. sys.path.insert(0, str(_ROOT_PATH))
  29. class ConfigManager:
  30. """配置管理组件"""
  31. def __init__(self, logger: logging.Logger):
  32. self.logger = logger
  33. self.base_config: Dict[str, Any] = {}
  34. self.custom_config: Dict[str, Any] = {}
  35. self.merged_config: Dict[str, Any] = {}
  36. self._config_cache = {}
  37. def split_configs(self, all_metrics_path: Path, builtin_metrics_path: Path, custom_metrics_path: Path) -> None:
  38. """从all_metrics_config.yaml拆分成内置和自定义配置"""
  39. # 检查是否已经存在提取的配置文件,如果存在则跳过拆分过程
  40. extracted_builtin_path = builtin_metrics_path.parent / f"{builtin_metrics_path.stem}_extracted{builtin_metrics_path.suffix}"
  41. if extracted_builtin_path.exists() and custom_metrics_path.exists():
  42. self.logger.info(f"使用已存在的拆分配置文件: {extracted_builtin_path}")
  43. return
  44. try:
  45. # 使用缓存加载配置文件,避免重复读取
  46. all_metrics_dict = self._safe_load_config(all_metrics_path)
  47. builtin_metrics_dict = self._safe_load_config(builtin_metrics_path)
  48. # 递归提取内置和自定义指标
  49. extracted_builtin_metrics, custom_metrics_dict = self._split_metrics_recursive(
  50. all_metrics_dict, builtin_metrics_dict
  51. )
  52. # 保存提取的内置指标到新文件
  53. with open(extracted_builtin_path, 'w', encoding='utf-8') as f:
  54. yaml.dump(extracted_builtin_metrics, f, allow_unicode=True, sort_keys=False, indent=2)
  55. self.logger.info(f"拆分配置: 提取的内置指标已保存到 {extracted_builtin_path}")
  56. if custom_metrics_dict:
  57. with open(custom_metrics_path, 'w', encoding='utf-8') as f:
  58. yaml.dump(custom_metrics_dict, f, allow_unicode=True, sort_keys=False, indent=2)
  59. self.logger.info(f"拆分配置: 自定义指标已保存到 {custom_metrics_path}")
  60. except Exception as err:
  61. self.logger.error(f"拆分配置失败: {str(err)}")
  62. raise
  63. def _split_metrics_recursive(self, all_dict: Dict, builtin_dict: Dict) -> Tuple[Dict, Dict]:
  64. """递归拆分内置和自定义指标配置"""
  65. extracted_builtin = {}
  66. custom_metrics = {}
  67. for key, value in all_dict.items():
  68. if key in builtin_dict:
  69. # 如果是字典类型,继续递归
  70. if isinstance(value, dict) and isinstance(builtin_dict[key], dict):
  71. sub_builtin, sub_custom = self._split_metrics_recursive(value, builtin_dict[key])
  72. if sub_builtin:
  73. extracted_builtin[key] = sub_builtin
  74. if sub_custom:
  75. custom_metrics[key] = sub_custom
  76. else:
  77. # 如果不是字典类型,直接复制
  78. extracted_builtin[key] = value
  79. else:
  80. # 如果键不在内置配置中,归类为自定义指标
  81. custom_metrics[key] = value
  82. return extracted_builtin, custom_metrics
  83. def load_configs(self, all_config_path: Optional[Path], builtin_metrics_path: Optional[Path],
  84. custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
  85. """加载并合并配置"""
  86. # 如果已经加载过配置,直接返回缓存的结果
  87. cache_key = f"{all_config_path}_{builtin_metrics_path}_{custom_metrics_path}"
  88. if cache_key in self._config_cache:
  89. self.logger.info("使用缓存的配置数据")
  90. return self._config_cache[cache_key]
  91. # 自动拆分配置
  92. extracted_builtin_path = None
  93. if all_config_path and all_config_path.exists():
  94. # 生成提取的内置指标配置文件路径
  95. extracted_builtin_path = builtin_metrics_path.parent / f"{builtin_metrics_path.stem}_extracted{builtin_metrics_path.suffix}"
  96. self.split_configs(all_config_path, builtin_metrics_path, custom_metrics_path)
  97. # 优先使用提取的内置指标配置
  98. if extracted_builtin_path and extracted_builtin_path.exists():
  99. self.base_config = self._safe_load_config(extracted_builtin_path)
  100. else:
  101. self.base_config = self._safe_load_config(builtin_metrics_path) if builtin_metrics_path else {}
  102. self.custom_config = self._safe_load_config(custom_metrics_path) if custom_metrics_path else {}
  103. if all_config_path and all_config_path.exists():
  104. self.merged_config = self._safe_load_config(all_config_path)
  105. # 缓存配置结果
  106. self._config_cache[cache_key] = self.merged_config
  107. return self.merged_config
  108. return {}
  109. @lru_cache(maxsize=16)
  110. def _safe_load_config(self, config_path: Path) -> Dict[str, Any]:
  111. """安全加载YAML配置,使用lru_cache减少重复读取"""
  112. try:
  113. if not config_path or not config_path.exists():
  114. self.logger.warning(f"Config file not found: {config_path}")
  115. return {}
  116. with config_path.open('r', encoding='utf-8') as f:
  117. config_dict = yaml.safe_load(f) or {}
  118. self.logger.info(f"Loaded config: {config_path}")
  119. return config_dict
  120. except Exception as err:
  121. self.logger.error(f"Failed to load config {config_path}: {str(err)}")
  122. return {}
  123. def get_config(self) -> Dict[str, Any]:
  124. return self.merged_config
  125. def get_base_config(self) -> Dict[str, Any]:
  126. return self.base_config
  127. def get_custom_config(self) -> Dict[str, Any]:
  128. return self.custom_config
  129. class MetricLoader:
  130. """指标加载器组件"""
  131. def __init__(self, logger: logging.Logger, config_manager: ConfigManager):
  132. self.logger = logger
  133. self.config_manager = config_manager
  134. self.metric_modules: Dict[str, Type] = {}
  135. self.custom_metric_modules: Dict[str, Any] = {}
  136. def load_builtin_metrics(self) -> Dict[str, Type]:
  137. """加载内置指标模块"""
  138. module_mapping = {
  139. "safety": ("modules.metric.safety", "SafeManager"),
  140. "comfort": ("modules.metric.comfort", "ComfortManager"),
  141. "traffic": ("modules.metric.traffic", "TrafficManager"),
  142. "efficient": ("modules.metric.efficient", "EfficientManager"),
  143. "function": ("modules.metric.function", "FunctionManager"),
  144. }
  145. self.metric_modules = {
  146. name: self._load_module(*info)
  147. for name, info in module_mapping.items()
  148. }
  149. self.logger.info(f"Loaded builtin metrics: {', '.join(self.metric_modules.keys())}")
  150. return self.metric_modules
  151. @lru_cache(maxsize=32)
  152. def _load_module(self, module_path: str, class_name: str) -> Type:
  153. """动态加载Python模块"""
  154. try:
  155. module = __import__(module_path, fromlist=[class_name])
  156. return getattr(module, class_name)
  157. except (ImportError, AttributeError) as e:
  158. self.logger.error(f"Failed to load module: {module_path}.{class_name} - {str(e)}")
  159. raise
  160. def load_custom_metrics(self, custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
  161. """加载自定义指标模块"""
  162. if not custom_metrics_path or not custom_metrics_path.is_dir():
  163. self.logger.info("No custom metrics path or path not exists")
  164. return {}
  165. # 检查是否有新的自定义指标文件
  166. current_files = set(f.name for f in custom_metrics_path.glob(CUSTOM_METRIC_FILE_PATTERN)
  167. if f.name.startswith(CUSTOM_METRIC_PREFIX))
  168. loaded_files = set(self.custom_metric_modules.keys())
  169. # 如果没有新文件且已有加载的模块,直接返回
  170. if self.custom_metric_modules and not (current_files - loaded_files):
  171. self.logger.info(f"No new custom metrics to load, using {len(self.custom_metric_modules)} cached modules")
  172. return self.custom_metric_modules
  173. loaded_count = 0
  174. for py_file in custom_metrics_path.glob(CUSTOM_METRIC_FILE_PATTERN):
  175. if py_file.name.startswith(CUSTOM_METRIC_PREFIX):
  176. if self._process_custom_metric_file(py_file):
  177. loaded_count += 1
  178. self.logger.info(f"Loaded {loaded_count} custom metric modules")
  179. return self.custom_metric_modules
  180. def _process_custom_metric_file(self, file_path: Path) -> bool:
  181. """处理单个自定义指标文件"""
  182. try:
  183. metric_key = self._validate_metric_file(file_path)
  184. module_name = f"custom_metric_{file_path.stem}"
  185. spec = importlib.util.spec_from_file_location(module_name, file_path)
  186. module = importlib.util.module_from_spec(spec)
  187. spec.loader.exec_module(module)
  188. from modules.lib.metric_registry import BaseMetric
  189. metric_class = None
  190. for name, obj in inspect.getmembers(module):
  191. if (inspect.isclass(obj) and
  192. issubclass(obj, BaseMetric) and
  193. obj != BaseMetric):
  194. metric_class = obj
  195. break
  196. if metric_class:
  197. self.custom_metric_modules[metric_key] = {
  198. 'type': 'class',
  199. 'module': module,
  200. 'class': metric_class
  201. }
  202. self.logger.info(f"Loaded class-based custom metric: {metric_key}")
  203. elif hasattr(module, 'evaluate'):
  204. self.custom_metric_modules[metric_key] = {
  205. 'type': 'function',
  206. 'module': module
  207. }
  208. self.logger.info(f"Loaded function-based custom metric: {metric_key}")
  209. else:
  210. raise AttributeError(f"Missing evaluate() function or BaseMetric subclass: {file_path.name}")
  211. return True
  212. except ValueError as e:
  213. self.logger.warning(str(e))
  214. return False
  215. except Exception as e:
  216. self.logger.error(f"Failed to load custom metric {file_path}: {str(e)}")
  217. return False
  218. def _validate_metric_file(self, file_path: Path) -> str:
  219. """验证自定义指标文件命名规范"""
  220. stem = file_path.stem[len(CUSTOM_METRIC_PREFIX):]
  221. parts = stem.split('_')
  222. if len(parts) < 3:
  223. raise ValueError(
  224. f"Invalid custom metric filename: {file_path.name}, should be metric_<level1>_<level2>_<level3>.py")
  225. level1, level2, level3 = parts[:3]
  226. if not self._is_metric_configured(level1, level2, level3):
  227. raise ValueError(f"Unconfigured metric: {level1}.{level2}.{level3}")
  228. return f"{level1}.{level2}.{level3}"
  229. def _is_metric_configured(self, level1: str, level2: str, level3: str) -> bool:
  230. """检查指标是否在配置中注册"""
  231. custom_config = self.config_manager.get_custom_config()
  232. try:
  233. return (level1 in custom_config and
  234. isinstance(custom_config[level1], dict) and
  235. level2 in custom_config[level1] and
  236. isinstance(custom_config[level1][level2], dict) and
  237. level3 in custom_config[level1][level2] and
  238. isinstance(custom_config[level1][level2][level3], dict))
  239. except Exception:
  240. return False
  241. def get_builtin_metrics(self) -> Dict[str, Type]:
  242. return self.metric_modules
  243. def get_custom_metrics(self) -> Dict[str, Any]:
  244. return self.custom_metric_modules
  245. class EvaluationEngine:
  246. """评估引擎组件"""
  247. def __init__(self, logger: logging.Logger, config_manager: ConfigManager, metric_loader: MetricLoader,
  248. plot_path: str):
  249. self.logger = logger
  250. self.config_manager = config_manager
  251. self.metric_loader = metric_loader
  252. self.plot_path = plot_path
  253. def evaluate(self, data: Any) -> Dict[str, Any]:
  254. """执行评估流程"""
  255. raw_results = self._collect_builtin_metrics(data)
  256. custom_results = self._collect_custom_metrics(data)
  257. return self._process_merged_results(raw_results, custom_results)
  258. def _collect_builtin_metrics(self, data: Any) -> Dict[str, Any]:
  259. """收集内置指标结果"""
  260. metric_modules = self.metric_loader.get_builtin_metrics()
  261. x = metric_modules.items()
  262. raw_results: Dict[str, Any] = {}
  263. # 获取配置中实际存在的指标
  264. config = self.config_manager.get_config()
  265. available_metrics = {
  266. metric_name for metric_name in metric_modules.keys()
  267. if metric_name in config and isinstance(config[metric_name], dict)
  268. }
  269. # 只处理配置中存在的指标
  270. filtered_modules = {
  271. name: module for name, module in metric_modules.items()
  272. if name in available_metrics
  273. }
  274. # 优化线程池大小,避免创建过多线程
  275. max_workers = min(len(filtered_modules), DEFAULT_WORKERS)
  276. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  277. futures = {
  278. executor.submit(self._run_module, module, data, module_name, self.plot_path): module_name for
  279. module_name, module in filtered_modules.items()
  280. }
  281. for future in futures:
  282. module_name = futures[future]
  283. try:
  284. result = future.result()
  285. raw_results[module_name] = result[module_name]
  286. except Exception as e:
  287. self.logger.error(
  288. f"{module_name} evaluation failed: {str(e)}",
  289. exc_info=True,
  290. )
  291. raw_results[module_name] = {
  292. "status": "error",
  293. "message": str(e),
  294. "timestamp": datetime.now().isoformat(),
  295. }
  296. return raw_results
  297. def _collect_custom_metrics(self, data: Any) -> Dict[str, Dict]:
  298. """收集自定义指标结果"""
  299. custom_metrics = self.metric_loader.get_custom_metrics()
  300. if not custom_metrics:
  301. return {}
  302. custom_results = {}
  303. # 使用线程池并行处理自定义指标
  304. max_workers = min(len(custom_metrics), DEFAULT_WORKERS)
  305. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  306. futures = {}
  307. # 提交所有自定义指标任务
  308. for metric_key, metric_info in custom_metrics.items():
  309. futures[executor.submit(self._run_custom_metric, metric_key, metric_info, data)] = metric_key
  310. # 收集结果
  311. for future in futures:
  312. metric_key = futures[future]
  313. try:
  314. level1, result = future.result()
  315. if level1:
  316. custom_results[level1] = result
  317. except Exception as e:
  318. self.logger.error(f"Custom metric {metric_key} execution failed: {str(e)}")
  319. return custom_results
  320. def _run_custom_metric(self, metric_key: str, metric_info: Dict, data: Any) -> Tuple[str, Dict]:
  321. """执行单个自定义指标"""
  322. try:
  323. level1, level2, level3 = metric_key.split('.')
  324. if metric_info['type'] == 'class':
  325. metric_class = metric_info['class']
  326. metric_instance = metric_class(data)
  327. metric_result = metric_instance.calculate()
  328. else:
  329. module = metric_info['module']
  330. metric_result = module.evaluate(data)
  331. self.logger.info(f"Calculated custom metric: {level1}.{level2}.{level3}")
  332. return level1, metric_result
  333. except Exception as e:
  334. self.logger.error(f"Custom metric {metric_key} failed: {str(e)}")
  335. try:
  336. level1 = metric_key.split('.')[0]
  337. return level1, {
  338. "status": "error",
  339. "message": str(e),
  340. "timestamp": datetime.now().isoformat(),
  341. }
  342. except Exception:
  343. return "", {}
  344. def _process_merged_results(self, raw_results: Dict, custom_results: Dict) -> Dict:
  345. """处理合并后的评估结果"""
  346. from modules.lib.score import Score
  347. final_results = {}
  348. merged_config = self.config_manager.get_config()
  349. for level1, level1_data in raw_results.items():
  350. if level1 in custom_results:
  351. level1_data.update(custom_results[level1])
  352. try:
  353. evaluator = Score(merged_config, level1)
  354. final_results.update(evaluator.evaluate(level1_data))
  355. except Exception as e:
  356. final_results[level1] = self._format_error(e)
  357. for level1, level1_data in custom_results.items():
  358. if level1 not in raw_results:
  359. try:
  360. evaluator = Score(merged_config, level1)
  361. final_results.update(evaluator.evaluate(level1_data))
  362. except Exception as e:
  363. final_results[level1] = self._format_error(e)
  364. return final_results
  365. def _format_error(self, e: Exception) -> Dict:
  366. return {
  367. "status": "error",
  368. "message": str(e),
  369. "timestamp": datetime.now().isoformat()
  370. }
  371. def _run_module(self, module_class: Any, data: Any, module_name: str, plot_path: str) -> Dict[str, Any]:
  372. """执行单个评估模块"""
  373. try:
  374. instance = module_class(data, plot_path)
  375. return {module_name: instance.report_statistic()}
  376. except Exception as e:
  377. self.logger.error(f"{module_name} execution error: {str(e)}", exc_info=True)
  378. return {module_name: {"error": str(e)}}
  379. class LoggingManager:
  380. """日志管理组件"""
  381. def __init__(self, log_path: Path):
  382. self.log_path = log_path
  383. self.logger = self._init_logger()
  384. def _init_logger(self) -> logging.Logger:
  385. """初始化日志系统"""
  386. try:
  387. from modules.lib.log_manager import LogManager
  388. log_manager = LogManager(self.log_path)
  389. return log_manager.get_logger()
  390. except (ImportError, PermissionError, IOError) as e:
  391. logger = logging.getLogger("evaluator")
  392. logger.setLevel(logging.INFO)
  393. console_handler = logging.StreamHandler()
  394. console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
  395. logger.addHandler(console_handler)
  396. logger.warning(f"Failed to init standard logger: {str(e)}, using fallback logger")
  397. return logger
  398. def get_logger(self) -> logging.Logger:
  399. return self.logger
  400. class DataProcessor:
  401. """数据处理组件"""
  402. def __init__(self, logger: logging.Logger, data_path: Path, config_path: Optional[Path] = None):
  403. self.logger = logger
  404. self.data_path = data_path
  405. self.config_path = config_path
  406. self.case_name = self.data_path.name
  407. self._processor = None
  408. @property
  409. def processor(self) -> Any:
  410. """懒加载数据处理器,只在首次访问时创建"""
  411. if self._processor is None:
  412. self._processor = self._load_processor()
  413. return self._processor
  414. def _load_processor(self) -> Any:
  415. """加载数据处理器"""
  416. try:
  417. start_time = time.perf_counter()
  418. from modules.lib import data_process
  419. processor = data_process.DataPreprocessing(self.data_path, self.config_path)
  420. elapsed_time = time.perf_counter() - start_time
  421. self.logger.info(f"Data processor loaded in {elapsed_time:.2f}s")
  422. return processor
  423. except ImportError as e:
  424. self.logger.error(f"Failed to load data processor: {str(e)}")
  425. raise RuntimeError(f"Failed to load data processor: {str(e)}") from e
  426. def validate(self) -> None:
  427. """验证数据路径"""
  428. if not self.data_path.exists():
  429. raise FileNotFoundError(f"Data path not exists: {self.data_path}")
  430. if not self.data_path.is_dir():
  431. raise NotADirectoryError(f"Invalid data directory: {self.data_path}")
  432. class EvaluationPipeline:
  433. """评估流水线控制器"""
  434. def __init__(self, all_config_path: str, base_config_path: str, log_path: str, data_path: str, report_path: str,
  435. plot_path: str,
  436. custom_metrics_path: Optional[str] = None, custom_config_path: Optional[str] = None):
  437. # 路径初始化
  438. self.all_config_path = Path(all_config_path) if all_config_path else None
  439. self.base_config_path = Path(base_config_path) if base_config_path else None
  440. self.custom_config_path = Path(custom_config_path) if custom_config_path else None
  441. self.data_path = Path(data_path)
  442. self.report_path = Path(report_path)
  443. self.plot_path = Path(plot_path)
  444. self.custom_metrics_path = Path(custom_metrics_path) if custom_metrics_path else None
  445. # 日志
  446. self.logging_manager = LoggingManager(Path(log_path))
  447. self.logger = self.logging_manager.get_logger()
  448. # 配置
  449. self.config_manager = ConfigManager(self.logger)
  450. self.config = self.config_manager.load_configs(
  451. self.all_config_path, self.base_config_path, self.custom_config_path
  452. )
  453. # 指标加载
  454. self.metric_loader = MetricLoader(self.logger, self.config_manager)
  455. self.metric_loader.load_builtin_metrics()
  456. self.metric_loader.load_custom_metrics(self.custom_metrics_path)
  457. # 数据处理
  458. self.data_processor = DataProcessor(self.logger, self.data_path, self.all_config_path)
  459. self.evaluation_engine = EvaluationEngine(self.logger, self.config_manager, self.metric_loader, self.plot_path)
  460. def execute(self) -> Dict[str, Any]:
  461. """执行评估流水线"""
  462. try:
  463. # 只在首次运行时验证数据路径
  464. self.data_processor.validate()
  465. self.logger.info(f"Start evaluation: {self.data_path.name}")
  466. start_time = time.perf_counter()
  467. # 性能分析日志
  468. config_start = time.perf_counter()
  469. results = self.evaluation_engine.evaluate(self.data_processor.processor)
  470. eval_time = time.perf_counter() - config_start
  471. # 生成报告
  472. report_start = time.perf_counter()
  473. report = self._generate_report(self.data_processor.case_name, results)
  474. report_time = time.perf_counter() - report_start
  475. # 总耗时
  476. elapsed_time = time.perf_counter() - start_time
  477. self.logger.info(
  478. f"Evaluation completed, time: {elapsed_time:.2f}s (评估: {eval_time:.2f}s, 报告: {report_time:.2f}s)")
  479. return report
  480. except Exception as e:
  481. self.logger.critical(f"Evaluation failed: {str(e)}", exc_info=True)
  482. return {"error": str(e), "traceback": traceback.format_exc()}
  483. def _add_overall_result(self, report: Dict[str, Any]) -> Dict[str, Any]:
  484. """处理评测报告并添加总体结果字段"""
  485. # 加载阈值参数
  486. thresholds = {
  487. "T0": self.config['T_threshold']['T0_threshold'],
  488. "T1": self.config['T_threshold']['T1_threshold'],
  489. "T2": self.config['T_threshold']['T2_threshold']
  490. }
  491. # 初始化计数器
  492. counters = {'p0': 0, 'p1': 0, 'p2': 0}
  493. # 优化:一次性收集所有失败的指标
  494. failed_categories = [
  495. (category, category_data.get('priority'))
  496. for category, category_data in report.items()
  497. if isinstance(category_data, dict) and category != "metadata" and not category_data.get('result', True)
  498. ]
  499. # 计数
  500. for _, priority in failed_categories:
  501. if priority == 0:
  502. counters['p0'] += 1
  503. elif priority == 1:
  504. counters['p1'] += 1
  505. elif priority == 2:
  506. counters['p2'] += 1
  507. # 阈值判断逻辑
  508. overall_result = not (
  509. counters['p0'] > thresholds['T0'] or
  510. counters['p1'] > thresholds['T1'] or
  511. counters['p2'] > thresholds['T2']
  512. )
  513. # 生成处理后的报告
  514. processed_report = report.copy()
  515. processed_report['overall_result'] = overall_result
  516. # 添加统计信息
  517. processed_report['threshold_checks'] = {
  518. 'T0_threshold': thresholds['T0'],
  519. 'T1_threshold': thresholds['T1'],
  520. 'T2_threshold': thresholds['T2'],
  521. 'actual_counts': counters
  522. }
  523. self.logger.info(f"Added overall result: {overall_result}")
  524. return processed_report
  525. def _generate_report(self, case_name: str, results: Dict[str, Any]) -> Dict[str, Any]:
  526. """生成评估报告"""
  527. from modules.lib.common import dict2json
  528. self.report_path.mkdir(parents=True, exist_ok=True)
  529. results["metadata"] = {
  530. "case_name": case_name,
  531. "timestamp": datetime.now().isoformat(),
  532. "version": "1.0",
  533. }
  534. # 添加总体结果评估
  535. results = self._add_overall_result(results)
  536. report_file = self.report_path / f"{case_name}_report.json"
  537. dict2json(results, report_file)
  538. self.logger.info(f"Report generated: {report_file}")
  539. return results
  540. def main():
  541. """命令行入口"""
  542. parser = argparse.ArgumentParser(
  543. description="Autonomous Driving Evaluation System V3.1",
  544. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  545. )
  546. # 必要参数
  547. parser.add_argument(
  548. "--dataPath",
  549. type=str,
  550. # default=r"D:\Cicv\招远\V2V_CSAE53-2020_ForwardCollision_LST_01-02_new",
  551. # default=r"D:\Cicv\招远\AD_GBT41798-2022_TrafficSignalRecognitionAndResponse_LST_01",
  552. # default=r"/home/server/桌面/XGJ/zhaoyuan_DataPreProcess/output/AD_GBT41798-2022_TrafficSignalRecognitionAndResponse_LST_02",
  553. default=r"/home/server/桌面/XGJ/zhaoyuan_DataPreProcess/output/V2I_CSAE53-2020_LeftTurnAssist_PGVIL_demo",
  554. help="Input data directory",
  555. )
  556. # 配置参数
  557. config_group = parser.add_argument_group('Configuration')
  558. config_group.add_argument(
  559. "--allConfigPath",
  560. type=str,
  561. default=r"/home/server/anaconda3/envs/vitual_XGJ/zhaoyuan_0617/zhaoyuan/config/all_metrics_config.yaml",
  562. help="Full metrics config file path (built-in + custom)",
  563. )
  564. config_group.add_argument(
  565. "--baseConfigPath",
  566. type=str,
  567. default=r"/home/server/anaconda3/envs/vitual_XGJ/zhaoyuan_0617/zhaoyuan/config/all_metrics_config.yaml",
  568. help="Built-in metrics config file path",
  569. )
  570. config_group.add_argument(
  571. "--customConfigPath",
  572. type=str,
  573. default=r"config/custom_metrics_config.yaml",
  574. help="Custom metrics config path (optional)",
  575. )
  576. # 输出参数
  577. output_group = parser.add_argument_group('Output')
  578. output_group.add_argument(
  579. "--logPath",
  580. type=str,
  581. default="test.log",
  582. help="Log file path",
  583. )
  584. output_group.add_argument(
  585. "--reportPath",
  586. type=str,
  587. default="reports",
  588. help="Output report directory",
  589. )
  590. output_group.add_argument(
  591. "--plotPath",
  592. type=str,
  593. default=r"/home/server/anaconda3/envs/vitual_XGJ/zhaoyuan_0617/zhaoyuan/scripts/reports/datas",
  594. help="Output plot csv directory",
  595. )
  596. # 扩展参数
  597. ext_group = parser.add_argument_group('Extensions')
  598. ext_group.add_argument(
  599. "--customMetricsPath",
  600. type=str,
  601. default="custom_metrics",
  602. help="Custom metrics scripts directory (optional)",
  603. )
  604. args = parser.parse_args()
  605. try:
  606. pipeline = EvaluationPipeline(
  607. all_config_path=args.allConfigPath,
  608. base_config_path=args.baseConfigPath,
  609. log_path=args.logPath,
  610. data_path=args.dataPath,
  611. report_path=args.reportPath,
  612. plot_path=args.plotPath,
  613. custom_metrics_path=args.customMetricsPath,
  614. custom_config_path=args.customConfigPath
  615. )
  616. start_time = time.perf_counter()
  617. result = pipeline.execute()
  618. elapsed_time = time.perf_counter() - start_time
  619. if "error" in result:
  620. print(f"Evaluation failed: {result['error']}")
  621. sys.exit(1)
  622. print(f"Evaluation completed, total time: {elapsed_time:.2f}s")
  623. print(f"Report path: {pipeline.report_path}")
  624. except KeyboardInterrupt:
  625. print("\nUser interrupted")
  626. sys.exit(130)
  627. except Exception as e:
  628. print(f"Execution error: {str(e)}")
  629. traceback.print_exc()
  630. sys.exit(1)
  631. if __name__ == "__main__":
  632. warnings.filterwarnings("ignore")
  633. main()