evaluator_optimized.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. # evaluation_engine.py
  2. import sys
  3. import warnings
  4. import time
  5. import importlib
  6. import yaml # 添加yaml模块导入
  7. from pathlib import Path
  8. import argparse
  9. from concurrent.futures import ThreadPoolExecutor
  10. from functools import lru_cache
  11. from typing import Dict, Any, List, Optional
  12. from datetime import datetime
  13. # 强制导入所有可能动态加载的模块
  14. # 安全设置根目录路径(动态路径管理)
  15. # 判断是否处于编译模式
  16. if hasattr(sys, "_MEIPASS"):
  17. # 编译模式下使用临时资源目录
  18. _ROOT_PATH = Path(sys._MEIPASS)
  19. else:
  20. # 开发模式下使用原工程路径
  21. _ROOT_PATH = Path(__file__).resolve().parent.parent
  22. sys.path.insert(0, str(_ROOT_PATH))
  23. print(f"当前根目录:{_ROOT_PATH}")
  24. print(f'当前系统路径:{sys.path}')
  25. class EvaluationCore:
  26. """评估引擎核心类(单例模式)"""
  27. _instance = None
  28. def __new__(cls, logPath: str, configPath: str = None, customConfigPath: str = None, customMetricsPath: str = None):
  29. if not cls._instance:
  30. cls._instance = super().__new__(cls)
  31. cls._instance._init(logPath, configPath, customConfigPath, customMetricsPath)
  32. return cls._instance
  33. def _init(self, logPath: str = None, configPath: str = None, customConfigPath: str = None, customMetricsPath: str = None) -> None:
  34. """初始化引擎组件"""
  35. self.log_path = logPath
  36. self.config_path = configPath
  37. self.custom_config_path = customConfigPath
  38. self.custom_metrics_path = customMetricsPath
  39. # 加载配置
  40. self.metrics_config = {}
  41. self.custom_metrics_config = {}
  42. self.merged_config = {} # 添加合并后的配置
  43. # 自定义指标脚本模块
  44. self.custom_metrics_modules = {}
  45. self._init_log_system()
  46. self._load_configs() # 加载并合并配置
  47. self._init_metrics()
  48. self._load_custom_metrics()
  49. def _init_log_system(self) -> None:
  50. """集中式日志管理"""
  51. try:
  52. from modules.lib.log_manager import LogManager
  53. log_manager = LogManager(self.log_path)
  54. self.logger = log_manager.get_logger()
  55. except (PermissionError, IOError) as e:
  56. sys.stderr.write(f"日志系统初始化失败: {str(e)}\n")
  57. sys.exit(1)
  58. def _init_metrics(self) -> None:
  59. """初始化评估模块(策略模式)"""
  60. # from modules.metric import safety, comfort, traffic, efficient, function
  61. self.metric_modules = {
  62. "safety": self._load_module("modules.metric.safety", "SafeManager"),
  63. "comfort": self._load_module("modules.metric.comfort", "ComfortManager"),
  64. "traffic": self._load_module("modules.metric.traffic", "ViolationManager"),
  65. "efficient": self._load_module("modules.metric.efficient", "EfficientManager"),
  66. "function": self._load_module("modules.metric.function", "FunctionManager"),
  67. }
  68. @lru_cache(maxsize=32)
  69. def _load_module(self, module_path: str, class_name: str) -> Any:
  70. """动态加载评估模块(带缓存)"""
  71. try:
  72. __import__(module_path)
  73. return getattr(sys.modules[module_path], class_name)
  74. except (ImportError, AttributeError) as e:
  75. self.logger.error(f"模块加载失败: {module_path}.{class_name} - {str(e)}")
  76. raise
  77. def _load_configs(self) -> None:
  78. """加载并合并内置指标和自定义指标配置"""
  79. # 加载内置指标配置
  80. if self.config_path and Path(self.config_path).exists():
  81. try:
  82. with open(self.config_path, 'r', encoding='utf-8') as f:
  83. self.metrics_config = yaml.safe_load(f)
  84. self.logger.info(f"成功加载内置指标配置: {self.config_path}")
  85. except Exception as e:
  86. self.logger.error(f"加载内置指标配置失败: {str(e)}")
  87. self.metrics_config = {}
  88. # 加载自定义指标配置
  89. if self.custom_config_path and Path(self.custom_config_path).exists():
  90. try:
  91. with open(self.custom_config_path, 'r', encoding='utf-8') as f:
  92. self.custom_metrics_config = yaml.safe_load(f)
  93. self.logger.info(f"成功加载自定义指标配置: {self.custom_config_path}")
  94. except Exception as e:
  95. self.logger.error(f"加载自定义指标配置失败: {str(e)}")
  96. self.custom_metrics_config = {}
  97. # 合并配置
  98. self.merged_config = self._merge_configs(self.metrics_config, self.custom_metrics_config)
  99. def _merge_configs(self, base_config: Dict, custom_config: Dict) -> Dict:
  100. """
  101. 合并内置指标和自定义指标配置
  102. 策略:
  103. 1. 如果自定义指标与内置指标有相同的一级指标,则合并其下的二级指标
  104. 2. 如果自定义指标与内置指标有相同的二级指标,则合并其下的三级指标
  105. 3. 如果是全新的指标,则直接添加
  106. """
  107. merged = base_config.copy()
  108. for level1_key, level1_value in custom_config.items():
  109. # 跳过非指标配置项(如vehicle等)
  110. if not isinstance(level1_value, dict) or 'name' not in level1_value:
  111. if level1_key not in merged:
  112. merged[level1_key] = level1_value
  113. continue
  114. if level1_key not in merged:
  115. # 全新的一级指标
  116. merged[level1_key] = level1_value
  117. else:
  118. # 合并已存在的一级指标下的内容
  119. for level2_key, level2_value in level1_value.items():
  120. if level2_key == 'name' or level2_key == 'priority':
  121. continue
  122. if isinstance(level2_value, dict):
  123. if level2_key not in merged[level1_key]:
  124. # 新的二级指标
  125. merged[level1_key][level2_key] = level2_value
  126. else:
  127. # 合并已存在的二级指标下的内容
  128. for level3_key, level3_value in level2_value.items():
  129. if level3_key == 'name' or level3_key == 'priority':
  130. continue
  131. if isinstance(level3_value, dict):
  132. if level3_key not in merged[level1_key][level2_key]:
  133. # 新的三级指标
  134. merged[level1_key][level2_key][level3_key] = level3_value
  135. return merged
  136. def _load_custom_metrics(self) -> None:
  137. """加载自定义指标脚本"""
  138. if not self.custom_metrics_path or not Path(self.custom_metrics_path).exists():
  139. return
  140. custom_metrics_dir = Path(self.custom_metrics_path)
  141. if not custom_metrics_dir.is_dir():
  142. self.logger.warning(f"自定义指标路径不是目录: {custom_metrics_dir}")
  143. return
  144. # 遍历自定义指标脚本目录
  145. for file_path in custom_metrics_dir.glob("*.py"):
  146. if file_path.name.startswith("metric_") and file_path.name.endswith(".py"):
  147. try:
  148. # 解析脚本名称,获取指标层级信息
  149. parts = file_path.stem[7:].split('_') # 去掉'metric_'前缀
  150. if len(parts) < 3:
  151. self.logger.warning(f"自定义指标脚本 {file_path.name} 命名不符合规范,应为 metric_<level1>_<level2>_<level3>.py")
  152. continue
  153. level1, level2, level3 = parts[0], parts[1], parts[2]
  154. # 检查指标是否在配置中
  155. if not self._check_metric_in_config(level1, level2, level3, self.custom_metrics_config):
  156. self.logger.warning(f"自定义指标 {level1}.{level2}.{level3} 在配置中不存在,跳过加载")
  157. continue
  158. # 加载脚本模块
  159. module_name = f"custom_metric_{level1}_{level2}_{level3}"
  160. spec = importlib.util.spec_from_file_location(module_name, file_path)
  161. module = importlib.util.module_from_spec(spec)
  162. spec.loader.exec_module(module)
  163. # 检查模块是否包含必要的函数
  164. if not hasattr(module, 'evaluate'):
  165. self.logger.warning(f"自定义指标脚本 {file_path.name} 缺少 evaluate 函数")
  166. continue
  167. # 存储模块引用
  168. key = f"{level1}.{level2}.{level3}"
  169. self.custom_metrics_modules[key] = module
  170. self.logger.info(f"成功加载自定义指标脚本: {file_path.name}")
  171. except Exception as e:
  172. self.logger.error(f"加载自定义指标脚本 {file_path.name} 失败: {str(e)}")
  173. def _check_metric_in_config(self, level1: str, level2: str, level3: str, config: Dict) -> bool:
  174. """检查指标是否在配置中存在"""
  175. try:
  176. return (level1 in config and
  177. isinstance(config[level1], dict) and
  178. level2 in config[level1] and
  179. isinstance(config[level1][level2], dict) and
  180. level3 in config[level1][level2] and
  181. isinstance(config[level1][level2][level3], dict))
  182. except Exception:
  183. return False
  184. def parallel_evaluate(self, data: Any) -> Dict[str, Any]:
  185. """并行化评估引擎(动态线程池)"""
  186. # 存储所有评估结果
  187. results = {}
  188. # 1. 先评估内置指标
  189. self._evaluate_built_in_metrics(data, results)
  190. # 2. 再评估自定义指标并合并结果
  191. self._evaluate_and_merge_custom_metrics(data, results)
  192. return results
  193. def _evaluate_built_in_metrics(self, data: Any, results: Dict[str, Any]) -> None:
  194. """评估内置指标"""
  195. # 关键修改点1:线程数=模块数
  196. with ThreadPoolExecutor(max_workers=len(self.metric_modules)) as executor:
  197. # 关键修改点2:按模块名创建future映射
  198. futures = {
  199. module_name: executor.submit(
  200. self._run_module, module, data, module_name
  201. )
  202. for module_name, module in self.metric_modules.items()
  203. }
  204. # 关键修改点3:按模块顺序处理结果
  205. for module_name, future in futures.items():
  206. try:
  207. from modules.lib.score import Score
  208. evaluator = Score(self.merged_config, module_name)
  209. result_module = future.result()
  210. result = evaluator.evaluate(result_module)
  211. # results.update(result[module_name])
  212. results.update(result)
  213. except Exception as e:
  214. self.logger.error(
  215. f"{module_name} 评估失败: {str(e)}",
  216. exc_info=True,
  217. extra={"stack": True}, # 记录完整堆栈
  218. )
  219. # 错误信息结构化存储
  220. results[module_name] = {
  221. "status": "error",
  222. "message": str(e),
  223. "timestamp": datetime.now().isoformat(),
  224. }
  225. def _evaluate_and_merge_custom_metrics(self, data: Any, results: Dict[str, Any]) -> None:
  226. """评估自定义指标并合并结果"""
  227. if not self.custom_metrics_modules:
  228. return
  229. # 按一级指标分组自定义指标
  230. grouped_metrics = {}
  231. for metric_key in self.custom_metrics_modules:
  232. level1 = metric_key.split('.')[0]
  233. if level1 not in grouped_metrics:
  234. grouped_metrics[level1] = []
  235. grouped_metrics[level1].append(metric_key)
  236. # 处理每个一级指标组
  237. for level1, metric_keys in grouped_metrics.items():
  238. # 检查是否为内置一级指标
  239. is_built_in = level1 in self.metrics_config and 'name' in self.metrics_config[level1]
  240. level1_name = self.merged_config[level1].get('name', level1) if level1 in self.merged_config else level1
  241. # 如果是内置一级指标,将结果合并到已有结果中
  242. if is_built_in and level1_name in results:
  243. for metric_key in metric_keys:
  244. self._evaluate_and_merge_single_metric(data, results, metric_key, level1_name)
  245. else:
  246. # 如果是新的一级指标,创建新的结果结构
  247. if level1_name not in results:
  248. results[level1_name] = {}
  249. # 评估该一级指标下的所有自定义指标
  250. for metric_key in metric_keys:
  251. self._evaluate_and_merge_single_metric(data, results, metric_key, level1_name)
  252. def _evaluate_and_merge_single_metric(self, data: Any, results: Dict[str, Any], metric_key: str, level1_name: str) -> None:
  253. """评估单个自定义指标并合并结果"""
  254. try:
  255. level1, level2, level3 = metric_key.split('.')
  256. module = self.custom_metrics_modules[metric_key]
  257. # 获取指标配置
  258. metric_config = self.custom_metrics_config[level1][level2][level3]
  259. # 获取指标名称
  260. level2_name = self.custom_metrics_config[level1][level2].get('name', level2)
  261. level3_name = metric_config.get('name', level3)
  262. # 确保结果字典结构存在
  263. if level2_name not in results[level1_name]:
  264. results[level1_name][level2_name] = {}
  265. # 调用自定义指标评测函数
  266. metric_result = module.evaluate(data)
  267. from modules.lib.score import Score
  268. evaluator = Score(self.merged_config, level1_name)
  269. result = evaluator.evaluate(metric_result)
  270. results.update(result)
  271. self.logger.info(f"评测自定义指标: {level1_name}.{level2_name}.{level3_name}")
  272. except Exception as e:
  273. self.logger.error(f"评测自定义指标 {metric_key} 失败: {str(e)}")
  274. # 尝试添加错误信息到结果中
  275. try:
  276. level1, level2, level3 = metric_key.split('.')
  277. level2_name = self.custom_metrics_config[level1][level2].get('name', level2)
  278. level3_name = self.custom_metrics_config[level1][level2][level3].get('name', level3)
  279. if level2_name not in results[level1_name]:
  280. results[level1_name][level2_name] = {}
  281. results[level1_name][level2_name][level3_name] = {
  282. "status": "error",
  283. "message": str(e),
  284. "timestamp": datetime.now().isoformat(),
  285. }
  286. except Exception:
  287. pass
  288. def _run_module(
  289. self, module_class: Any, data: Any, module_name: str
  290. ) -> Dict[str, Any]:
  291. """执行单个评估模块(带熔断机制)"""
  292. try:
  293. instance = module_class(data)
  294. return {module_name: instance.report_statistic()}
  295. except Exception as e:
  296. self.logger.error(f"{module_name} 执行异常: {str(e)}", stack_info=True)
  297. return {module_name: {"error": str(e)}}
  298. class EvaluationPipeline:
  299. """评估流水线控制器"""
  300. def __init__(self, configPath: str, logPath: str, dataPath: str, resultPath: str, customMetricsPath: Optional[str] = None, customConfigPath: Optional[str] = None):
  301. self.configPath = Path(configPath)
  302. self.custom_config_path = Path(customConfigPath) if customConfigPath else None
  303. self.data_path = Path(dataPath)
  304. self.report_path = Path(resultPath)
  305. self.custom_metrics_path = Path(customMetricsPath) if customMetricsPath else None
  306. # 创建评估引擎实例,传入所有必要参数
  307. self.engine = EvaluationCore(
  308. logPath,
  309. configPath=str(self.configPath),
  310. customConfigPath=str(self.custom_config_path) if self.custom_config_path else None,
  311. customMetricsPath=str(self.custom_metrics_path) if self.custom_metrics_path else None
  312. )
  313. self.data_processor = self._load_data_processor()
  314. def _load_data_processor(self) -> Any:
  315. """动态加载数据预处理模块"""
  316. try:
  317. from modules.lib import data_process
  318. return data_process.DataPreprocessing(self.data_path, self.configPath)
  319. except ImportError as e:
  320. raise RuntimeError(f"数据处理器加载失败: {str(e)}") from e
  321. def execute_pipeline(self) -> Dict[str, Any]:
  322. """端到端执行评估流程"""
  323. self._validate_case()
  324. try:
  325. metric_results = self.engine.parallel_evaluate(self.data_processor)
  326. report = self._generate_report(
  327. self.data_processor.case_name, metric_results
  328. )
  329. return report
  330. except Exception as e:
  331. self.engine.logger.critical(f"流程执行失败: {str(e)}", exc_info=True)
  332. return {"error": str(e)}
  333. def _validate_case(self) -> None:
  334. """用例路径验证"""
  335. case_path = self.data_path
  336. if not case_path.exists():
  337. raise FileNotFoundError(f"用例路径不存在: {case_path}")
  338. if not case_path.is_dir():
  339. raise NotADirectoryError(f"无效的用例目录: {case_path}")
  340. def _generate_report(self, case_name: str, results: Dict) -> Dict:
  341. """生成评估报告(模板方法模式)"""
  342. from modules.lib.common import dict2json
  343. report_path = self.report_path
  344. report_path.mkdir(parents=True, exist_ok=True, mode=0o777)
  345. report_file = report_path / f"{case_name}_report.json"
  346. dict2json(results, report_file)
  347. self.engine.logger.info(f"评估报告已生成: {report_file}")
  348. return results
  349. def main():
  350. """命令行入口(工厂模式)"""
  351. parser = argparse.ArgumentParser(
  352. description="自动驾驶评估系统 V3.0 - 支持动态指标选择和自定义指标",
  353. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  354. )
  355. # 带帮助说明的参数定义,增加默认值
  356. parser.add_argument(
  357. "--logPath",
  358. type=str,
  359. default="d:/Kevin/zhaoyuan/zhaoyuan_new/logs/test.log",
  360. help="日志文件存储路径",
  361. )
  362. parser.add_argument(
  363. "--dataPath",
  364. type=str,
  365. default="d:/Kevin/zhaoyuan/zhaoyuan_new/data/zhaoyuan1",
  366. help="预处理后的输入数据目录",
  367. )
  368. parser.add_argument(
  369. "--configPath",
  370. type=str,
  371. default="d:/Kevin/zhaoyuan/zhaoyuan_new/config/metrics_config.yaml",
  372. help="评估指标配置文件路径",
  373. )
  374. parser.add_argument(
  375. "--reportPath",
  376. type=str,
  377. default="d:/Kevin/zhaoyuan/zhaoyuan_new/reports",
  378. help="评估报告输出目录",
  379. )
  380. # 新增自定义指标路径参数(可选)
  381. parser.add_argument(
  382. "--customMetricsPath",
  383. type=str,
  384. default="d:/Kevin/zhaoyuan/zhaoyuan_new/custom_metrics",
  385. help="自定义指标脚本目录(可选)",
  386. )
  387. # 新增自定义指标路径参数(可选)
  388. parser.add_argument(
  389. "--customConfigPath",
  390. type=str,
  391. default="d:/Kevin/zhaoyuan/zhaoyuan_new/config/custom_metrics_config.yaml",
  392. help="自定义指标脚本目录(可选)",
  393. )
  394. args = parser.parse_args()
  395. try:
  396. pipeline = EvaluationPipeline(
  397. args.configPath, args.logPath, args.dataPath, args.reportPath, args.customMetricsPath, args.customConfigPath
  398. )
  399. start_time = time.perf_counter()
  400. result = pipeline.execute_pipeline()
  401. if "error" in result:
  402. sys.exit(1)
  403. print(f"评估完成,耗时: {time.perf_counter()-start_time:.2f}s")
  404. print(f"报告路径: {pipeline.report_path}")
  405. except KeyboardInterrupt:
  406. print("\n用户中断操作")
  407. sys.exit(130)
  408. except Exception as e:
  409. print(f"程序执行异常: {str(e)}")
  410. sys.exit(1)
  411. if __name__ == "__main__":
  412. warnings.filterwarnings("ignore")
  413. main()