evaluator_enhanced.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862
  1. #!/usr/bin/env python3
  2. # evaluator_enhanced.py
  3. import sys
  4. import warnings
  5. import time
  6. import importlib
  7. import importlib.util
  8. import yaml
  9. from pathlib import Path
  10. import argparse
  11. from concurrent.futures import ThreadPoolExecutor
  12. from functools import lru_cache
  13. from typing import Dict, Any, List, Optional, Type, Tuple, Callable, Union
  14. from datetime import datetime
  15. import logging
  16. import traceback
  17. import json
  18. import inspect
  19. import os
  20. import glob
  21. # 常量定义
  22. DEFAULT_WORKERS = 4
  23. CUSTOM_METRIC_PREFIX = "metric_"
  24. CUSTOM_METRIC_FILE_PATTERN = "*.py"
  25. MAX_LOG_SIZE_MB = 300 # 最大日志文件夹大小
  26. # 安全设置根目录路径
  27. if hasattr(sys, "_MEIPASS"):
  28. _ROOT_PATH = Path(sys._MEIPASS)
  29. else:
  30. _ROOT_PATH = Path(__file__).resolve().parent.parent
  31. sys.path.insert(0, str(_ROOT_PATH))
  32. class ConfigManager:
  33. """配置管理组件"""
  34. def __init__(self, logger: logging.Logger):
  35. self.logger = logger
  36. self.base_config: Dict[str, Any] = {}
  37. self.custom_config: Dict[str, Any] = {}
  38. self.merged_config: Dict[str, Any] = {}
  39. self._config_cache = {}
  40. def split_configs(self, all_metrics_path: Path, builtin_metrics_path: Path, custom_metrics_path: Path) -> None:
  41. """从all_metrics_config.yaml拆分成内置和自定义配置"""
  42. # 检查是否已经存在提取的配置文件,如果存在则跳过拆分过程
  43. extracted_builtin_path = builtin_metrics_path.parent / f"{builtin_metrics_path.stem}_extracted{builtin_metrics_path.suffix}"
  44. if extracted_builtin_path.exists() and custom_metrics_path.exists():
  45. self.logger.info(f"使用已存在的拆分配置文件: {extracted_builtin_path}")
  46. return
  47. try:
  48. # 使用缓存加载配置文件,避免重复读取
  49. all_metrics_dict = self._safe_load_config(all_metrics_path)
  50. builtin_metrics_dict = self._safe_load_config(builtin_metrics_path)
  51. # 递归提取内置和自定义指标
  52. extracted_builtin_metrics, custom_metrics_dict = self._split_metrics_recursive(
  53. all_metrics_dict, builtin_metrics_dict
  54. )
  55. # 保存提取的内置指标到新文件
  56. with open(extracted_builtin_path, 'w', encoding='utf-8') as f:
  57. yaml.dump(extracted_builtin_metrics, f, allow_unicode=True, sort_keys=False, indent=2)
  58. self.logger.info(f"拆分配置: 提取的内置指标已保存到 {extracted_builtin_path}")
  59. if custom_metrics_dict:
  60. with open(custom_metrics_path, 'w', encoding='utf-8') as f:
  61. yaml.dump(custom_metrics_dict, f, allow_unicode=True, sort_keys=False, indent=2)
  62. self.logger.info(f"拆分配置: 自定义指标已保存到 {custom_metrics_path}")
  63. except Exception as err:
  64. self.logger.error(f"拆分配置失败: {str(err)}")
  65. raise
  66. def _split_metrics_recursive(self, all_dict: Dict, builtin_dict: Dict) -> Tuple[Dict, Dict]:
  67. """递归拆分内置和自定义指标配置"""
  68. extracted_builtin = {}
  69. custom_metrics = {}
  70. for key, value in all_dict.items():
  71. if key in builtin_dict:
  72. # 如果是字典类型,继续递归
  73. if isinstance(value, dict) and isinstance(builtin_dict[key], dict):
  74. sub_builtin, sub_custom = self._split_metrics_recursive(value, builtin_dict[key])
  75. if sub_builtin:
  76. extracted_builtin[key] = sub_builtin
  77. if sub_custom:
  78. custom_metrics[key] = sub_custom
  79. else:
  80. # 如果不是字典类型,直接复制
  81. extracted_builtin[key] = value
  82. else:
  83. # 如果键不在内置配置中,归类为自定义指标
  84. custom_metrics[key] = value
  85. return extracted_builtin, custom_metrics
  86. def load_configs(self, all_config_path: Optional[Path], builtin_metrics_path: Optional[Path],
  87. custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
  88. """加载并合并配置"""
  89. # 如果已经加载过配置,直接返回缓存的结果
  90. cache_key = f"{all_config_path}_{builtin_metrics_path}_{custom_metrics_path}"
  91. if cache_key in self._config_cache:
  92. self.logger.info("使用缓存的配置数据")
  93. return self._config_cache[cache_key]
  94. # 自动拆分配置
  95. extracted_builtin_path = None
  96. if all_config_path and all_config_path.exists():
  97. # 生成提取的内置指标配置文件路径
  98. extracted_builtin_path = builtin_metrics_path.parent / f"{builtin_metrics_path.stem}_extracted{builtin_metrics_path.suffix}"
  99. self.split_configs(all_config_path, builtin_metrics_path, custom_metrics_path)
  100. # 优先使用提取的内置指标配置
  101. if extracted_builtin_path and extracted_builtin_path.exists():
  102. self.base_config = self._safe_load_config(extracted_builtin_path)
  103. else:
  104. self.base_config = self._safe_load_config(builtin_metrics_path) if builtin_metrics_path else {}
  105. self.custom_config = self._safe_load_config(custom_metrics_path) if custom_metrics_path else {}
  106. if all_config_path and all_config_path.exists():
  107. self.merged_config = self._safe_load_config(all_config_path)
  108. # 缓存配置结果
  109. self._config_cache[cache_key] = self.merged_config
  110. return self.merged_config
  111. return {}
  112. @lru_cache(maxsize=32)
  113. def _safe_load_config(self, config_path: Path) -> Dict[str, Any]:
  114. """安全加载YAML配置,使用增强的lru_cache减少重复读取"""
  115. try:
  116. if not config_path or not config_path.exists():
  117. self.logger.warning(f"Config file not found: {config_path}")
  118. return {}
  119. # 检查文件修改时间,避免使用过期缓存
  120. file_mtime = config_path.stat().st_mtime
  121. cache_key = f"{config_path}_{file_mtime}"
  122. # 使用类级别缓存存储解析后的配置
  123. if not hasattr(self.__class__, '_config_file_cache'):
  124. self.__class__._config_file_cache = {}
  125. if cache_key in self.__class__._config_file_cache:
  126. self.logger.info(f"Using cached config: {config_path}")
  127. return self.__class__._config_file_cache[cache_key]
  128. with config_path.open('r', encoding='utf-8') as f:
  129. config_dict = yaml.safe_load(f) or {}
  130. # 限制缓存大小
  131. if len(self.__class__._config_file_cache) > 20:
  132. # 移除最早添加的项
  133. self.__class__._config_file_cache.pop(next(iter(self.__class__._config_file_cache)))
  134. # 存储到缓存
  135. self.__class__._config_file_cache[cache_key] = config_dict
  136. self.logger.info(f"Loaded config: {config_path}")
  137. return config_dict
  138. except Exception as err:
  139. self.logger.error(f"Failed to load config {config_path}: {str(err)}")
  140. return {}
  141. def get_config(self) -> Dict[str, Any]:
  142. return self.merged_config
  143. def get_base_config(self) -> Dict[str, Any]:
  144. return self.base_config
  145. def get_custom_config(self) -> Dict[str, Any]:
  146. return self.custom_config
  147. class MetricLoader:
  148. """指标加载器组件"""
  149. def __init__(self, logger: logging.Logger, config_manager: ConfigManager):
  150. self.logger = logger
  151. self.config_manager = config_manager
  152. self.metric_modules: Dict[str, Type] = {}
  153. self.custom_metric_modules: Dict[str, Any] = {}
  154. def load_builtin_metrics(self) -> Dict[str, Type]:
  155. """加载内置指标模块"""
  156. module_mapping = {
  157. "safety": ("modules.metric.safety", "SafeManager"),
  158. "comfort": ("modules.metric.comfort", "ComfortManager"),
  159. "traffic": ("modules.metric.traffic", "TrafficManager"),
  160. "efficient": ("modules.metric.efficient", "EfficientManager"),
  161. "function": ("modules.metric.function", "FunctionManager"),
  162. }
  163. self.metric_modules = {
  164. name: self._load_module(*info)
  165. for name, info in module_mapping.items()
  166. }
  167. self.logger.info(f"Loaded builtin metrics: {', '.join(self.metric_modules.keys())}")
  168. return self.metric_modules
  169. @lru_cache(maxsize=32)
  170. def _load_module(self, module_path: str, class_name: str) -> Type:
  171. """动态加载Python模块"""
  172. try:
  173. module = __import__(module_path, fromlist=[class_name])
  174. return getattr(module, class_name)
  175. except (ImportError, AttributeError) as e:
  176. self.logger.error(f"Failed to load module: {module_path}.{class_name} - {str(e)}")
  177. raise
  178. def load_custom_metrics(self, custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
  179. """加载自定义指标模块"""
  180. if not custom_metrics_path or not custom_metrics_path.is_dir():
  181. self.logger.info("No custom metrics path or path not exists")
  182. return {}
  183. # 检查是否有新的自定义指标文件
  184. current_files = set(f.name for f in custom_metrics_path.glob(CUSTOM_METRIC_FILE_PATTERN)
  185. if f.name.startswith(CUSTOM_METRIC_PREFIX))
  186. loaded_files = set(self.custom_metric_modules.keys())
  187. # 如果没有新文件且已有加载的模块,直接返回
  188. if self.custom_metric_modules and not (current_files - loaded_files):
  189. self.logger.info(f"No new custom metrics to load, using {len(self.custom_metric_modules)} cached modules")
  190. return self.custom_metric_modules
  191. loaded_count = 0
  192. for py_file in custom_metrics_path.glob(CUSTOM_METRIC_FILE_PATTERN):
  193. if py_file.name.startswith(CUSTOM_METRIC_PREFIX):
  194. if self._process_custom_metric_file(py_file):
  195. loaded_count += 1
  196. self.logger.info(f"Loaded {loaded_count} custom metric modules")
  197. return self.custom_metric_modules
  198. def _process_custom_metric_file(self, file_path: Path) -> bool:
  199. """处理单个自定义指标文件"""
  200. try:
  201. metric_key = self._validate_metric_file(file_path)
  202. module_name = f"custom_metric_{file_path.stem}"
  203. spec = importlib.util.spec_from_file_location(module_name, file_path)
  204. module = importlib.util.module_from_spec(spec)
  205. spec.loader.exec_module(module)
  206. from modules.lib.metric_registry import BaseMetric
  207. metric_class = None
  208. for name, obj in inspect.getmembers(module):
  209. if (inspect.isclass(obj) and
  210. issubclass(obj, BaseMetric) and
  211. obj != BaseMetric):
  212. metric_class = obj
  213. break
  214. if metric_class:
  215. self.custom_metric_modules[metric_key] = {
  216. 'type': 'class',
  217. 'module': module,
  218. 'class': metric_class
  219. }
  220. self.logger.info(f"Loaded class-based custom metric: {metric_key}")
  221. elif hasattr(module, 'evaluate'):
  222. self.custom_metric_modules[metric_key] = {
  223. 'type': 'function',
  224. 'module': module
  225. }
  226. self.logger.info(f"Loaded function-based custom metric: {metric_key}")
  227. else:
  228. raise AttributeError(f"Missing evaluate() function or BaseMetric subclass: {file_path.name}")
  229. return True
  230. except ValueError as e:
  231. self.logger.warning(str(e))
  232. return False
  233. except Exception as e:
  234. self.logger.error(f"Failed to load custom metric {file_path}: {str(e)}")
  235. return False
  236. def _validate_metric_file(self, file_path: Path) -> str:
  237. """验证自定义指标文件命名规范"""
  238. stem = file_path.stem[len(CUSTOM_METRIC_PREFIX):]
  239. parts = stem.split('_')
  240. if len(parts) < 3:
  241. raise ValueError(
  242. f"Invalid custom metric filename: {file_path.name}, should be metric_<level1>_<level2>_<level3>.py")
  243. level1, level2, level3 = parts[:3]
  244. if not self._is_metric_configured(level1, level2, level3):
  245. raise ValueError(f"Unconfigured metric: {level1}.{level2}.{level3}")
  246. return f"{level1}.{level2}.{level3}"
  247. def _is_metric_configured(self, level1: str, level2: str, level3: str) -> bool:
  248. """检查指标是否在配置中注册"""
  249. custom_config = self.config_manager.get_custom_config()
  250. try:
  251. return (level1 in custom_config and
  252. isinstance(custom_config[level1], dict) and
  253. level2 in custom_config[level1] and
  254. isinstance(custom_config[level1][level2], dict) and
  255. level3 in custom_config[level1][level2] and
  256. isinstance(custom_config[level1][level2][level3], dict))
  257. except Exception:
  258. return False
  259. def get_builtin_metrics(self) -> Dict[str, Type]:
  260. return self.metric_modules
  261. def get_custom_metrics(self) -> Dict[str, Any]:
  262. return self.custom_metric_modules
  263. class EvaluationEngine:
  264. """评估引擎组件"""
  265. def __init__(self, logger: logging.Logger, config_manager: ConfigManager, metric_loader: MetricLoader,
  266. plot_path: str):
  267. self.logger = logger
  268. self.config_manager = config_manager
  269. self.metric_loader = metric_loader
  270. self.plot_path = plot_path
  271. def evaluate(self, data: Any) -> Dict[str, Any]:
  272. """执行评估流程"""
  273. raw_results = self._collect_builtin_metrics(data)
  274. custom_results = self._collect_custom_metrics(data)
  275. return self._process_merged_results(raw_results, custom_results)
  276. def _collect_builtin_metrics(self, data: Any) -> Dict[str, Any]:
  277. """收集内置指标结果 - 优化版本"""
  278. metric_modules = self.metric_loader.get_builtin_metrics()
  279. raw_results: Dict[str, Any] = {}
  280. # 获取配置中实际存在的指标 - 使用更高效的字典推导
  281. config = self.config_manager.get_config()
  282. filtered_modules = {
  283. name: module for name, module in metric_modules.items()
  284. if name in config and isinstance(config[name], dict)
  285. }
  286. if not filtered_modules:
  287. return raw_results
  288. # 优化线程池大小,根据CPU核心数和任务数量动态调整
  289. import multiprocessing
  290. cpu_count = multiprocessing.cpu_count()
  291. max_workers = min(len(filtered_modules), max(DEFAULT_WORKERS, cpu_count))
  292. # 使用上下文管理器确保线程池正确关闭
  293. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  294. # 预先创建所有任务
  295. futures_to_names = {
  296. executor.submit(self._run_module, module, data, module_name, self.plot_path): module_name
  297. for module_name, module in filtered_modules.items()
  298. }
  299. # 使用as_completed获取结果,避免阻塞
  300. from concurrent.futures import as_completed
  301. for future in as_completed(futures_to_names):
  302. module_name = futures_to_names[future]
  303. try:
  304. result = future.result()
  305. raw_results[module_name] = result[module_name]
  306. except Exception as e:
  307. self.logger.error(
  308. f"{module_name} evaluation failed: {str(e)}",
  309. exc_info=True,
  310. )
  311. raw_results[module_name] = {
  312. "status": "error",
  313. "message": str(e),
  314. "timestamp": datetime.now().isoformat(),
  315. }
  316. return raw_results
  317. def _collect_custom_metrics(self, data: Any) -> Dict[str, Dict]:
  318. """收集自定义指标结果"""
  319. custom_metrics = self.metric_loader.get_custom_metrics()
  320. if not custom_metrics:
  321. return {}
  322. custom_results = {}
  323. # 使用线程池并行处理自定义指标
  324. max_workers = min(len(custom_metrics), DEFAULT_WORKERS)
  325. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  326. futures = {}
  327. # 提交所有自定义指标任务
  328. for metric_key, metric_info in custom_metrics.items():
  329. futures[executor.submit(self._run_custom_metric, metric_key, metric_info, data)] = metric_key
  330. # 收集结果
  331. for future in futures:
  332. metric_key = futures[future]
  333. try:
  334. level1, result = future.result()
  335. if level1:
  336. custom_results[level1] = result
  337. except Exception as e:
  338. self.logger.error(f"Custom metric {metric_key} execution failed: {str(e)}")
  339. return custom_results
  340. def _run_custom_metric(self, metric_key: str, metric_info: Dict, data: Any) -> Tuple[str, Dict]:
  341. """执行单个自定义指标"""
  342. try:
  343. level1, level2, level3 = metric_key.split('.')
  344. if metric_info['type'] == 'class':
  345. metric_class = metric_info['class']
  346. metric_instance = metric_class(data)
  347. metric_result = metric_instance.calculate()
  348. else:
  349. module = metric_info['module']
  350. metric_result = module.evaluate(data)
  351. self.logger.info(f"Calculated custom metric: {level1}.{level2}.{level3}")
  352. return level1, metric_result
  353. except Exception as e:
  354. self.logger.error(f"Custom metric {metric_key} failed: {str(e)}")
  355. try:
  356. level1 = metric_key.split('.')[0]
  357. return level1, {
  358. "status": "error",
  359. "message": str(e),
  360. "timestamp": datetime.now().isoformat(),
  361. }
  362. except Exception:
  363. return "", {}
  364. def _process_merged_results(self, raw_results: Dict, custom_results: Dict) -> Dict:
  365. """处理合并后的评估结果"""
  366. from modules.lib.score import Score
  367. final_results = {}
  368. merged_config = self.config_manager.get_config()
  369. for level1, level1_data in raw_results.items():
  370. if level1 in custom_results:
  371. level1_data.update(custom_results[level1])
  372. try:
  373. evaluator = Score(merged_config, level1)
  374. final_results.update(evaluator.evaluate(level1_data))
  375. except Exception as e:
  376. final_results[level1] = self._format_error(e)
  377. for level1, level1_data in custom_results.items():
  378. if level1 not in raw_results:
  379. try:
  380. evaluator = Score(merged_config, level1)
  381. final_results.update(evaluator.evaluate(level1_data))
  382. except Exception as e:
  383. final_results[level1] = self._format_error(e)
  384. return final_results
  385. def _format_error(self, e: Exception) -> Dict:
  386. return {
  387. "status": "error",
  388. "message": str(e),
  389. "timestamp": datetime.now().isoformat()
  390. }
  391. def _run_module(self, module_class: Any, data: Any, module_name: str, plot_path: str) -> Dict[str, Any]:
  392. """执行单个评估模块"""
  393. try:
  394. instance = module_class(data, plot_path)
  395. return {module_name: instance.report_statistic()}
  396. except Exception as e:
  397. self.logger.error(f"{module_name} execution error: {str(e)}", exc_info=True)
  398. return {module_name: {"error": str(e)}}
  399. class LoggingManager:
  400. """日志管理组件"""
  401. def __init__(self, log_dir: Path, case_name: str):
  402. self.log_dir = log_dir
  403. self.case_name = case_name
  404. self.logger = self._init_logger()
  405. def _init_logger(self) -> logging.Logger:
  406. """初始化日志系统"""
  407. try:
  408. # 确保日志目录存在
  409. self.log_dir.mkdir(parents=True, exist_ok=True)
  410. # 创建以用例名命名的日志文件
  411. log_file = self.log_dir / f"{self.case_name}.log"
  412. # 检查日志文件夹大小并清理
  413. self._cleanup_old_logs()
  414. from modules.lib.log_manager import LogManager
  415. log_manager = LogManager(log_file)
  416. return log_manager.get_logger()
  417. except (ImportError, PermissionError, IOError) as e:
  418. # 回退到基础日志配置
  419. logger = logging.getLogger("evaluator")
  420. logger.setLevel(logging.INFO)
  421. # 创建控制台处理器
  422. console_handler = logging.StreamHandler()
  423. console_handler.setFormatter(
  424. logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
  425. logger.addHandler(console_handler)
  426. # 尝试创建文件处理器
  427. try:
  428. self.log_dir.mkdir(parents=True, exist_ok=True)
  429. log_file = self.log_dir / f"{self.case_name}.log"
  430. file_handler = logging.FileHandler(log_file, encoding='utf-8')
  431. file_handler.setFormatter(
  432. logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
  433. logger.addHandler(file_handler)
  434. except Exception as file_err:
  435. logger.warning(f"无法创建文件日志: {str(file_err)}")
  436. logger.warning(f"Failed to init standard logger: {str(e)}, using fallback logger")
  437. return logger
  438. def _cleanup_old_logs(self):
  439. """当日志文件夹超过300MB时删除最早的日志文件"""
  440. try:
  441. # 获取所有日志文件
  442. log_files = list(self.log_dir.glob("*.log"))
  443. # 计算总大小 (MB)
  444. total_size = sum(f.stat().st_size for f in log_files) / (1024 * 1024)
  445. if total_size > MAX_LOG_SIZE_MB:
  446. # 使用基础日志记录警告(因为标准日志器可能尚未初始化)
  447. logging.warning(f"日志文件夹大小 {total_size:.2f}MB > {MAX_LOG_SIZE_MB}MB, 开始清理...")
  448. # 按修改时间排序 (旧文件在前)
  449. log_files.sort(key=lambda f: f.stat().st_mtime)
  450. # 删除文件直到总大小低于阈值
  451. while total_size > MAX_LOG_SIZE_MB and log_files:
  452. oldest_file = log_files.pop(0)
  453. file_size_mb = oldest_file.stat().st_size / (1024 * 1024)
  454. try:
  455. oldest_file.unlink()
  456. total_size -= file_size_mb
  457. logging.info(f"已删除旧日志文件: {oldest_file.name} ({file_size_mb:.2f}MB)")
  458. except Exception as e:
  459. logging.error(f"删除日志文件失败 {oldest_file.name}: {str(e)}")
  460. logging.info(f"清理完成, 当前日志文件夹大小: {total_size:.2f}MB")
  461. except Exception as e:
  462. # 使用基础日志记录错误
  463. logging.error(f"清理日志时出错: {str(e)}")
  464. def get_logger(self) -> logging.Logger:
  465. return self.logger
  466. class DataProcessor:
  467. """数据处理组件"""
  468. def __init__(self, logger: logging.Logger, data_path: Path, config_path: Optional[Path] = None):
  469. self.logger = logger
  470. self.data_path = data_path
  471. self.config_path = config_path
  472. self.case_name = self.data_path.name
  473. self._processor = None
  474. @property
  475. def processor(self) -> Any:
  476. """懒加载数据处理器,只在首次访问时创建"""
  477. if self._processor is None:
  478. self._processor = self._load_processor()
  479. return self._processor
  480. def _load_processor(self) -> Any:
  481. """加载数据处理器,优化版本"""
  482. try:
  483. start_time = time.perf_counter()
  484. from modules.lib import data_process
  485. # 使用缓存机制,避免重复加载相同数据
  486. cache_key = f"{self.data_path}_{self.config_path}"
  487. if hasattr(self.__class__, '_processor_cache') and cache_key in self.__class__._processor_cache:
  488. processor = self.__class__._processor_cache[cache_key]
  489. self.logger.info(f"Using cached data processor for {self.data_path.name}")
  490. else:
  491. processor = data_process.DataPreprocessing(self.data_path, self.config_path)
  492. # 初始化类级别缓存
  493. if not hasattr(self.__class__, '_processor_cache'):
  494. self.__class__._processor_cache = {}
  495. # 限制缓存大小,避免内存泄漏
  496. if len(self.__class__._processor_cache) > 5:
  497. # 移除最早添加的缓存项
  498. self.__class__._processor_cache.pop(next(iter(self.__class__._processor_cache)))
  499. self.__class__._processor_cache[cache_key] = processor
  500. elapsed_time = time.perf_counter() - start_time
  501. self.logger.info(f"Data processor loaded in {elapsed_time:.2f}s")
  502. return processor
  503. except ImportError as e:
  504. self.logger.error(f"Failed to load data processor: {str(e)}")
  505. raise RuntimeError(f"Failed to load data processor: {str(e)}") from e
  506. def validate(self) -> None:
  507. """验证数据路径"""
  508. if not self.data_path.exists():
  509. raise FileNotFoundError(f"Data path not exists: {self.data_path}")
  510. if not self.data_path.is_dir():
  511. raise NotADirectoryError(f"Invalid data directory: {self.data_path}")
  512. class EvaluationPipeline:
  513. """评估流水线控制器"""
  514. def __init__(self, all_config_path: str, base_config_path: str, log_dir: str, data_path: str, report_path: str,
  515. plot_path: str,
  516. custom_metrics_path: Optional[str] = None, custom_config_path: Optional[str] = None):
  517. # 路径初始化
  518. self.all_config_path = Path(all_config_path) if all_config_path else None
  519. self.base_config_path = Path(base_config_path) if base_config_path else None
  520. self.custom_config_path = Path(custom_config_path) if custom_config_path else None
  521. self.data_path = Path(data_path)
  522. self.report_path = Path(report_path)
  523. self.plot_path = Path(plot_path)
  524. self.custom_metrics_path = Path(custom_metrics_path) if custom_metrics_path else None
  525. self.log_dir = Path(log_dir)
  526. # 获取用例名
  527. self.case_name = self.data_path.name
  528. # 日志管理 (使用用例名)
  529. self.logging_manager = LoggingManager(self.log_dir, self.case_name)
  530. self.logger = self.logging_manager.get_logger()
  531. # 配置
  532. self.config_manager = ConfigManager(self.logger)
  533. self.config = self.config_manager.load_configs(
  534. self.all_config_path, self.base_config_path, self.custom_config_path
  535. )
  536. # 指标加载
  537. self.metric_loader = MetricLoader(self.logger, self.config_manager)
  538. self.metric_loader.load_builtin_metrics()
  539. self.metric_loader.load_custom_metrics(self.custom_metrics_path)
  540. # 数据处理
  541. self.data_processor = DataProcessor(self.logger, self.data_path, self.all_config_path)
  542. self.evaluation_engine = EvaluationEngine(self.logger, self.config_manager, self.metric_loader, self.plot_path)
  543. def execute(self) -> Dict[str, Any]:
  544. """执行评估流水线"""
  545. try:
  546. # 只在首次运行时验证数据路径
  547. self.data_processor.validate()
  548. self.logger.info(f"Start evaluation: {self.data_path.name}")
  549. start_time = time.perf_counter()
  550. # 性能分析日志
  551. config_start = time.perf_counter()
  552. results = self.evaluation_engine.evaluate(self.data_processor.processor)
  553. eval_time = time.perf_counter() - config_start
  554. # 生成报告
  555. report_start = time.perf_counter()
  556. report = self._generate_report(self.data_processor.case_name, results)
  557. report_time = time.perf_counter() - report_start
  558. # 总耗时
  559. elapsed_time = time.perf_counter() - start_time
  560. self.logger.info(
  561. f"Evaluation completed, time: {elapsed_time:.2f}s (评估: {eval_time:.2f}s, 报告: {report_time:.2f}s)")
  562. return report
  563. except Exception as e:
  564. self.logger.critical(f"Evaluation failed: {str(e)}", exc_info=True)
  565. return {"error": str(e), "traceback": traceback.format_exc()}
  566. def _add_overall_result(self, report: Dict[str, Any]) -> Dict[str, Any]:
  567. """处理评测报告并添加总体结果字段"""
  568. # 加载阈值参数
  569. thresholds = {
  570. "T0": self.config['T_threshold']['T0_threshold'],
  571. "T1": self.config['T_threshold']['T1_threshold'],
  572. "T2": self.config['T_threshold']['T2_threshold']
  573. }
  574. # 初始化计数器
  575. counters = {'p0': 0, 'p1': 0, 'p2': 0}
  576. # 优化:一次性收集所有失败的指标
  577. failed_categories = [
  578. (category, category_data.get('priority'))
  579. for category, category_data in report.items()
  580. if isinstance(category_data, dict) and category != "metadata" and not category_data.get('result', True)
  581. ]
  582. # 计数
  583. for _, priority in failed_categories:
  584. if priority == 0:
  585. counters['p0'] += 1
  586. elif priority == 1:
  587. counters['p1'] += 1
  588. elif priority == 2:
  589. counters['p2'] += 1
  590. # 阈值判断逻辑
  591. overall_result = not (
  592. counters['p0'] > thresholds['T0'] or
  593. counters['p1'] > thresholds['T1'] or
  594. counters['p2'] > thresholds['T2']
  595. )
  596. # 生成处理后的报告
  597. processed_report = report.copy()
  598. processed_report['overall_result'] = overall_result
  599. # 添加统计信息
  600. processed_report['threshold_checks'] = {
  601. 'T0_threshold': thresholds['T0'],
  602. 'T1_threshold': thresholds['T1'],
  603. 'T2_threshold': thresholds['T2'],
  604. 'actual_counts': counters
  605. }
  606. self.logger.info(f"Added overall result: {overall_result}")
  607. return processed_report
  608. def _generate_report(self, case_name: str, results: Dict[str, Any]) -> Dict[str, Any]:
  609. """生成评估报告"""
  610. from modules.lib.common import dict2json
  611. self.report_path.mkdir(parents=True, exist_ok=True)
  612. results["metadata"] = {
  613. "case_name": case_name,
  614. "timestamp": datetime.now().isoformat(),
  615. "version": "1.0",
  616. }
  617. # 添加总体结果评估
  618. results = self._add_overall_result(results)
  619. report_file = self.report_path / f"{case_name}_report.json"
  620. dict2json(results, report_file)
  621. self.logger.info(f"Report generated: {report_file}")
  622. return results
  623. def main():
  624. """命令行入口"""
  625. parser = argparse.ArgumentParser(
  626. description="Autonomous Driving Evaluation System V3.1",
  627. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  628. )
  629. # 必要参数
  630. parser.add_argument(
  631. "--dataPath",
  632. default=r"D:\Cicv\招远\测试数据\二轮评测数据\单车\AD_GBT41798-2022_RoadTraInfraObstRR_LST_010",
  633. type=str,
  634. # required=True,
  635. help="Input data directory",
  636. )
  637. # 配置参数
  638. config_group = parser.add_argument_group('Configuration')
  639. config_group.add_argument(
  640. "--allConfigPath",
  641. type=str,
  642. default=r"D:\Cicv\招远\评测代码\zhaoyuan_0808new\zhaoyuan\config\LST_1_all_metrics_config.yaml",
  643. help="Full metrics config file path (built-in + custom)",
  644. )
  645. config_group.add_argument(
  646. "--baseConfigPath",
  647. type=str,
  648. default=r"D:\Cicv\招远\评测代码\zhaoyuan_0808new\zhaoyuan\config\LST_1_builtin_metrics_config.yaml",
  649. # required=True,
  650. help="Built-in metrics config file path",
  651. )
  652. config_group.add_argument(
  653. "--customConfigPath",
  654. type=str,
  655. default=r"config/custom_metrics_config.yaml",
  656. help="Custom metrics config path (optional)",
  657. )
  658. # 输出参数
  659. output_group = parser.add_argument_group('Output')
  660. output_group.add_argument(
  661. "--logDir",
  662. type=str,
  663. default="logs",
  664. help="Log directory path",
  665. )
  666. output_group.add_argument(
  667. "--reportPath",
  668. type=str,
  669. default="reports",
  670. help="Output report directory",
  671. )
  672. output_group.add_argument(
  673. "--plotPath",
  674. type=str,
  675. default=r"D:\Cicv\招远\评测代码\zhaoyuan_0808new\zhaoyuan\plots",
  676. # required=True,
  677. help="Output plot csv directory",
  678. )
  679. # 扩展参数
  680. ext_group = parser.add_argument_group('Extensions')
  681. ext_group.add_argument(
  682. "--customMetricsPath",
  683. type=str,
  684. help="Custom metrics scripts directory (optional)",
  685. )
  686. args = parser.parse_args()
  687. try:
  688. pipeline = EvaluationPipeline(
  689. all_config_path=args.allConfigPath,
  690. base_config_path=args.baseConfigPath,
  691. log_dir=args.logDir,
  692. data_path=args.dataPath,
  693. report_path=args.reportPath,
  694. plot_path=args.plotPath,
  695. custom_metrics_path=args.customMetricsPath,
  696. custom_config_path=args.customConfigPath
  697. )
  698. start_time = time.perf_counter()
  699. result = pipeline.execute()
  700. elapsed_time = time.perf_counter() - start_time
  701. if "error" in result:
  702. print(f"Evaluation failed: {result['error']}")
  703. sys.exit(1)
  704. print(f"Evaluation completed, total time: {elapsed_time:.2f}s")
  705. print(f"Report path: {pipeline.report_path}")
  706. except KeyboardInterrupt:
  707. print("\nUser interrupted")
  708. sys.exit(130)
  709. except Exception as e:
  710. print(f"Execution error: {str(e)}")
  711. traceback.print_exc()
  712. sys.exit(1)
  713. if __name__ == "__main__":
  714. warnings.filterwarnings("ignore")
  715. main()