evaluator_optimized.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. # evaluation_engine.py
  2. import sys
  3. import warnings
  4. import time
  5. import importlib
  6. import yaml # 添加yaml模块导入
  7. from pathlib import Path
  8. import argparse
  9. from concurrent.futures import ThreadPoolExecutor
  10. from functools import lru_cache
  11. from typing import Dict, Any, List, Optional
  12. from datetime import datetime
  13. # 强制导入所有可能动态加载的模块
  14. # 安全设置根目录路径(动态路径管理)
  15. # 判断是否处于编译模式
  16. if hasattr(sys, "_MEIPASS"):
  17. # 编译模式下使用临时资源目录
  18. _ROOT_PATH = Path(sys._MEIPASS)
  19. else:
  20. # 开发模式下使用原工程路径
  21. _ROOT_PATH = Path(__file__).resolve().parent.parent
  22. sys.path.insert(0, str(_ROOT_PATH))
  23. print(f"当前根目录:{_ROOT_PATH}")
  24. print(f'当前系统路径:{sys.path}')
  25. class EvaluationCore:
  26. """评估引擎核心类(单例模式)"""
  27. _instance = None
  28. def __new__(cls, logPath: str, configPath: str = None, customConfigPath: str = None, customMetricsPath: str = None):
  29. if not cls._instance:
  30. cls._instance = super().__new__(cls)
  31. cls._instance._init(logPath, configPath, customConfigPath, customMetricsPath)
  32. return cls._instance
  33. def _init(self, logPath: str = None, configPath: str = None, customConfigPath: str = None,
  34. customMetricsPath: str = None) -> None:
  35. """初始化引擎组件"""
  36. self.log_path = logPath
  37. self.config_path = configPath
  38. self.custom_config_path = customConfigPath
  39. self.custom_metrics_path = customMetricsPath
  40. # 加载配置
  41. self.metrics_config = {}
  42. self.custom_metrics_config = {}
  43. self.merged_config = {} # 添加合并后的配置
  44. # 自定义指标脚本模块
  45. self.custom_metrics_modules = {}
  46. self._init_log_system()
  47. self._load_configs() # 加载并合并配置
  48. self._init_metrics()
  49. self._load_custom_metrics()
  50. def _init_log_system(self) -> None:
  51. """集中式日志管理"""
  52. try:
  53. from modules.lib.log_manager import LogManager
  54. log_manager = LogManager(self.log_path)
  55. self.logger = log_manager.get_logger()
  56. except (PermissionError, IOError) as e:
  57. sys.stderr.write(f"日志系统初始化失败: {str(e)}\n")
  58. sys.exit(1)
  59. def _init_metrics(self) -> None:
  60. """初始化评估模块(策略模式)"""
  61. # from modules.metric import safety, comfort, traffic, efficient, function
  62. self.metric_modules = {
  63. "safety": self._load_module("modules.metric.safety", "SafeManager"),
  64. "comfort": self._load_module("modules.metric.comfort", "ComfortManager"),
  65. "traffic": self._load_module("modules.metric.traffic", "TrafficManager"),
  66. "efficient": self._load_module("modules.metric.efficient", "EfficientManager"),
  67. "function": self._load_module("modules.metric.function", "FunctionManager"),
  68. }
  69. @lru_cache(maxsize=32)
  70. def _load_module(self, module_path: str, class_name: str) -> Any:
  71. """动态加载评估模块(带缓存)"""
  72. try:
  73. __import__(module_path)
  74. return getattr(sys.modules[module_path], class_name)
  75. except (ImportError, AttributeError) as e:
  76. self.logger.error(f"模块加载失败: {module_path}.{class_name} - {str(e)}")
  77. raise
  78. def _load_configs(self) -> None:
  79. """加载并合并内置指标和自定义指标配置"""
  80. # 加载内置指标配置
  81. if self.config_path and Path(self.config_path).exists():
  82. try:
  83. with open(self.config_path, 'r', encoding='utf-8') as f:
  84. self.metrics_config = yaml.safe_load(f)
  85. self.logger.info(f"成功加载内置指标配置: {self.config_path}")
  86. except Exception as e:
  87. self.logger.error(f"加载内置指标配置失败: {str(e)}")
  88. self.metrics_config = {}
  89. # 加载自定义指标配置
  90. if self.custom_config_path and Path(self.custom_config_path).exists():
  91. try:
  92. with open(self.custom_config_path, 'r', encoding='utf-8') as f:
  93. self.custom_metrics_config = yaml.safe_load(f)
  94. self.logger.info(f"成功加载自定义指标配置: {self.custom_config_path}")
  95. except Exception as e:
  96. self.logger.error(f"加载自定义指标配置失败: {str(e)}")
  97. self.custom_metrics_config = {}
  98. # 合并配置
  99. self.merged_config = self._merge_configs(self.metrics_config, self.custom_metrics_config)
  100. def _merge_configs(self, base_config: Dict, custom_config: Dict) -> Dict:
  101. """
  102. 合并内置指标和自定义指标配置
  103. 策略:
  104. 1. 如果自定义指标与内置指标有相同的一级指标,则合并其下的二级指标
  105. 2. 如果自定义指标与内置指标有相同的二级指标,则合并其下的三级指标
  106. 3. 如果是全新的指标,则直接添加
  107. """
  108. merged = base_config.copy()
  109. for level1_key, level1_value in custom_config.items():
  110. # 跳过非指标配置项(如vehicle等)
  111. if not isinstance(level1_value, dict) or 'name' not in level1_value:
  112. if level1_key not in merged:
  113. merged[level1_key] = level1_value
  114. continue
  115. if level1_key not in merged:
  116. # 全新的一级指标
  117. merged[level1_key] = level1_value
  118. else:
  119. # 合并已存在的一级指标下的内容
  120. for level2_key, level2_value in level1_value.items():
  121. if level2_key == 'name' or level2_key == 'priority':
  122. continue
  123. if isinstance(level2_value, dict):
  124. if level2_key not in merged[level1_key]:
  125. # 新的二级指标
  126. merged[level1_key][level2_key] = level2_value
  127. else:
  128. # 合并已存在的二级指标下的内容
  129. for level3_key, level3_value in level2_value.items():
  130. if level3_key == 'name' or level3_key == 'priority':
  131. continue
  132. if isinstance(level3_value, dict):
  133. if level3_key not in merged[level1_key][level2_key]:
  134. # 新的三级指标
  135. merged[level1_key][level2_key][level3_key] = level3_value
  136. return merged
  137. def _load_custom_metrics(self) -> None:
  138. """加载自定义指标脚本"""
  139. if not self.custom_metrics_path or not Path(self.custom_metrics_path).exists():
  140. return
  141. custom_metrics_dir = Path(self.custom_metrics_path)
  142. if not custom_metrics_dir.is_dir():
  143. self.logger.warning(f"自定义指标路径不是目录: {custom_metrics_dir}")
  144. return
  145. # 遍历自定义指标脚本目录
  146. for file_path in custom_metrics_dir.glob("*.py"):
  147. if file_path.name.startswith("metric_") and file_path.name.endswith(".py"):
  148. try:
  149. # 解析脚本名称,获取指标层级信息
  150. parts = file_path.stem[7:].split('_') # 去掉'metric_'前缀
  151. if len(parts) < 3:
  152. self.logger.warning(
  153. f"自定义指标脚本 {file_path.name} 命名不符合规范,应为 metric_<level1>_<level2>_<level3>.py")
  154. continue
  155. level1, level2, level3 = parts[0], parts[1], parts[2]
  156. # 检查指标是否在配置中
  157. if not self._check_metric_in_config(level1, level2, level3, self.custom_metrics_config):
  158. self.logger.warning(f"自定义指标 {level1}.{level2}.{level3} 在配置中不存在,跳过加载")
  159. continue
  160. # 加载脚本模块
  161. module_name = f"custom_metric_{level1}_{level2}_{level3}"
  162. spec = importlib.util.spec_from_file_location(module_name, file_path)
  163. module = importlib.util.module_from_spec(spec)
  164. spec.loader.exec_module(module)
  165. # 检查模块是否包含必要的函数
  166. if not hasattr(module, 'evaluate'):
  167. self.logger.warning(f"自定义指标脚本 {file_path.name} 缺少 evaluate 函数")
  168. continue
  169. # 存储模块引用
  170. key = f"{level1}.{level2}.{level3}"
  171. self.custom_metrics_modules[key] = module
  172. self.logger.info(f"成功加载自定义指标脚本: {file_path.name}")
  173. except Exception as e:
  174. self.logger.error(f"加载自定义指标脚本 {file_path.name} 失败: {str(e)}")
  175. def _check_metric_in_config(self, level1: str, level2: str, level3: str, config: Dict) -> bool:
  176. """检查指标是否在配置中存在"""
  177. try:
  178. return (level1 in config and
  179. isinstance(config[level1], dict) and
  180. level2 in config[level1] and
  181. isinstance(config[level1][level2], dict) and
  182. level3 in config[level1][level2] and
  183. isinstance(config[level1][level2][level3], dict))
  184. except Exception:
  185. return False
  186. def parallel_evaluate(self, data: Any) -> Dict[str, Any]:
  187. """并行化评估引擎(动态线程池)"""
  188. # 存储所有评估结果
  189. results = {}
  190. # 1. 先评估内置指标
  191. self._evaluate_built_in_metrics(data, results)
  192. # 2. 再评估自定义指标并合并结果
  193. self._evaluate_and_merge_custom_metrics(data, results)
  194. return results
  195. def _evaluate_built_in_metrics(self, data: Any, results: Dict[str, Any]) -> None:
  196. """评估内置指标"""
  197. # 关键修改点1:线程数=模块数
  198. with ThreadPoolExecutor(max_workers=len(self.metric_modules)) as executor:
  199. # 关键修改点2:按模块名创建future映射
  200. futures = {
  201. module_name: executor.submit(
  202. self._run_module, module, data, module_name
  203. )
  204. for module_name, module in self.metric_modules.items()
  205. }
  206. # 关键修改点3:按模块顺序处理结果
  207. for module_name, future in futures.items():
  208. try:
  209. from modules.lib.score import Score
  210. evaluator = Score(self.merged_config, module_name)
  211. result_module = future.result()
  212. result = evaluator.evaluate(result_module)
  213. # results.update(result[module_name])
  214. results.update(result)
  215. except Exception as e:
  216. self.logger.error(
  217. f"{module_name} 评估失败: {str(e)}",
  218. exc_info=True,
  219. extra={"stack": True}, # 记录完整堆栈
  220. )
  221. # 错误信息结构化存储
  222. results[module_name] = {
  223. "status": "error",
  224. "message": str(e),
  225. "timestamp": datetime.now().isoformat(),
  226. }
  227. def _evaluate_and_merge_custom_metrics(self, data: Any, results: Dict[str, Any]) -> None:
  228. """评估自定义指标并合并结果"""
  229. if not self.custom_metrics_modules:
  230. return
  231. # 按一级指标分组自定义指标
  232. grouped_metrics = {}
  233. for metric_key in self.custom_metrics_modules:
  234. level1 = metric_key.split('.')[0]
  235. if level1 not in grouped_metrics:
  236. grouped_metrics[level1] = []
  237. grouped_metrics[level1].append(metric_key)
  238. # 处理每个一级指标组
  239. for level1, metric_keys in grouped_metrics.items():
  240. # 检查是否为内置一级指标
  241. is_built_in = level1 in self.metrics_config and 'name' in self.metrics_config[level1]
  242. level1_name = self.merged_config[level1].get('name', level1) if level1 in self.merged_config else level1
  243. # 如果是内置一级指标,将结果合并到已有结果中
  244. if is_built_in and level1_name in results:
  245. for metric_key in metric_keys:
  246. self._evaluate_and_merge_single_metric(data, results, metric_key, level1_name)
  247. else:
  248. # 如果是新的一级指标,创建新的结果结构
  249. if level1_name not in results:
  250. results[level1_name] = {}
  251. # 评估该一级指标下的所有自定义指标
  252. for metric_key in metric_keys:
  253. self._evaluate_and_merge_single_metric(data, results, metric_key, level1_name)
  254. def _evaluate_and_merge_single_metric(self, data: Any, results: Dict[str, Any], metric_key: str,
  255. level1_name: str) -> None:
  256. """评估单个自定义指标并合并结果"""
  257. try:
  258. level1, level2, level3 = metric_key.split('.')
  259. module = self.custom_metrics_modules[metric_key]
  260. # 获取指标配置
  261. metric_config = self.custom_metrics_config[level1][level2][level3]
  262. # 获取指标名称
  263. level2_name = self.custom_metrics_config[level1][level2].get('name', level2)
  264. level3_name = metric_config.get('name', level3)
  265. # 确保结果字典结构存在
  266. if level2_name not in results[level1_name]:
  267. results[level1_name][level2_name] = {}
  268. # 调用自定义指标评测函数
  269. metric_result = module.evaluate(data)
  270. from modules.lib.score import Score
  271. evaluator = Score(self.merged_config, level1_name)
  272. result = evaluator.evaluate(metric_result)
  273. results.update(result)
  274. self.logger.info(f"评测自定义指标: {level1_name}.{level2_name}.{level3_name}")
  275. except Exception as e:
  276. self.logger.error(f"评测自定义指标 {metric_key} 失败: {str(e)}")
  277. # 尝试添加错误信息到结果中
  278. try:
  279. level1, level2, level3 = metric_key.split('.')
  280. level2_name = self.custom_metrics_config[level1][level2].get('name', level2)
  281. level3_name = self.custom_metrics_config[level1][level2][level3].get('name', level3)
  282. if level2_name not in results[level1_name]:
  283. results[level1_name][level2_name] = {}
  284. results[level1_name][level2_name][level3_name] = {
  285. "status": "error",
  286. "message": str(e),
  287. "timestamp": datetime.now().isoformat(),
  288. }
  289. except Exception:
  290. pass
  291. def _run_module(
  292. self, module_class: Any, data: Any, module_name: str
  293. ) -> Dict[str, Any]:
  294. """执行单个评估模块(带熔断机制)"""
  295. try:
  296. instance = module_class(data)
  297. return {module_name: instance.report_statistic()}
  298. except Exception as e:
  299. self.logger.error(f"{module_name} 执行异常: {str(e)}", stack_info=True)
  300. return {module_name: {"error": str(e)}}
  301. class EvaluationPipeline:
  302. """评估流水线控制器"""
  303. def __init__(self, configPath: str, logPath: str, dataPath: str, resultPath: str,
  304. customMetricsPath: Optional[str] = None, customConfigPath: Optional[str] = None):
  305. self.configPath = Path(configPath)
  306. self.custom_config_path = Path(customConfigPath) if customConfigPath else None
  307. self.data_path = Path(dataPath)
  308. self.report_path = Path(resultPath)
  309. self.custom_metrics_path = Path(customMetricsPath) if customMetricsPath else None
  310. # 创建评估引擎实例,传入所有必要参数
  311. self.engine = EvaluationCore(
  312. logPath,
  313. configPath=str(self.configPath),
  314. customConfigPath=str(self.custom_config_path) if self.custom_config_path else None,
  315. customMetricsPath=str(self.custom_metrics_path) if self.custom_metrics_path else None
  316. )
  317. self.data_processor = self._load_data_processor()
  318. def _load_data_processor(self) -> Any:
  319. """动态加载数据预处理模块"""
  320. try:
  321. from modules.lib import data_process
  322. return data_process.DataPreprocessing(self.data_path, self.configPath)
  323. except ImportError as e:
  324. raise RuntimeError(f"数据处理器加载失败: {str(e)}") from e
  325. def execute_pipeline(self) -> Dict[str, Any]:
  326. """端到端执行评估流程"""
  327. self._validate_case()
  328. try:
  329. metric_results = self.engine.parallel_evaluate(self.data_processor)
  330. report = self._generate_report(
  331. self.data_processor.case_name, metric_results
  332. )
  333. return report
  334. except Exception as e:
  335. self.engine.logger.critical(f"流程执行失败: {str(e)}", exc_info=True)
  336. return {"error": str(e)}
  337. def _validate_case(self) -> None:
  338. """用例路径验证"""
  339. case_path = self.data_path
  340. if not case_path.exists():
  341. raise FileNotFoundError(f"用例路径不存在: {case_path}")
  342. if not case_path.is_dir():
  343. raise NotADirectoryError(f"无效的用例目录: {case_path}")
  344. def _generate_report(self, case_name: str, results: Dict) -> Dict:
  345. """生成评估报告(模板方法模式)"""
  346. from modules.lib.common import dict2json
  347. report_path = self.report_path
  348. report_path.mkdir(parents=True, exist_ok=True, mode=0o777)
  349. report_file = report_path / f"{case_name}_report.json"
  350. dict2json(results, report_file)
  351. self.engine.logger.info(f"评估报告已生成: {report_file}")
  352. return results
  353. def main():
  354. """命令行入口(工厂模式)"""
  355. parser = argparse.ArgumentParser(
  356. description="自动驾驶评估系统 V3.0 - 支持动态指标选择和自定义指标",
  357. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  358. )
  359. # 带帮助说明的参数定义,增加默认值
  360. parser.add_argument(
  361. "--logPath",
  362. type=str,
  363. default=r"D:\Cicv\招远\zhaoyuan\test.log",
  364. help="日志文件存储路径",
  365. )
  366. parser.add_argument(
  367. "--dataPath",
  368. type=str,
  369. default=r"D:\Cicv\招远\V2V_CSAE53-2020_ForwardCollision_LST_01-02_new",
  370. help="预处理后的输入数据目录",
  371. )
  372. parser.add_argument(
  373. "--configPath",
  374. type=str,
  375. default=r"D:\Cicv\招远\zhaoyuan\zhaoyuan\config\all_metrics_config.yaml",
  376. help="评估指标配置文件路径",
  377. )
  378. parser.add_argument(
  379. "--reportPath",
  380. type=str,
  381. default=r"D:\Cicv\招远\zhaoyuan\zhaoyuan\result",
  382. help="评估报告输出目录",
  383. )
  384. # 新增自定义指标路径参数(可选)
  385. parser.add_argument(
  386. "--customMetricsPath",
  387. type=str,
  388. default=r"D:\Cicv\招远\zhaoyuan\zhaoyuan\custom_metrics",
  389. help="自定义指标脚本目录(可选)",
  390. )
  391. # 新增自定义指标路径参数(可选)
  392. parser.add_argument(
  393. "--customConfigPath",
  394. type=str,
  395. default=r"D:\Cicv\招远\zhaoyuan\zhaoyuan\test\custom_metrics_config.yaml",
  396. help="自定义指标脚本目录(可选)",
  397. )
  398. args = parser.parse_args()
  399. try:
  400. pipeline = EvaluationPipeline(
  401. args.configPath, args.logPath, args.dataPath, args.reportPath, args.customMetricsPath, args.customConfigPath
  402. )
  403. start_time = time.perf_counter()
  404. result = pipeline.execute_pipeline()
  405. if "error" in result:
  406. sys.exit(1)
  407. print(f"评估完成,耗时: {time.perf_counter() - start_time:.2f}s")
  408. print(f"报告路径: {pipeline.report_path}")
  409. except KeyboardInterrupt:
  410. print("\n用户中断操作")
  411. sys.exit(130)
  412. except Exception as e:
  413. print(f"程序执行异常: {str(e)}")
  414. sys.exit(1)
  415. if __name__ == "__main__":
  416. warnings.filterwarnings("ignore")
  417. main()