evaluator_test.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # evaluation_engine.py
  2. import sys
  3. import warnings
  4. import time
  5. from pathlib import Path
  6. import argparse
  7. from concurrent.futures import ThreadPoolExecutor
  8. from functools import lru_cache
  9. from typing import Dict, Any
  10. from datetime import datetime
  11. # 强制导入所有可能动态加载的模块
  12. # 安全设置根目录路径(动态路径管理)
  13. # 判断是否处于编译模式
  14. if hasattr(sys, "_MEIPASS"):
  15. # 编译模式下使用临时资源目录
  16. _ROOT_PATH = Path(sys._MEIPASS)
  17. else:
  18. # 开发模式下使用原工程路径
  19. _ROOT_PATH = Path(__file__).resolve().parent.parent
  20. sys.path.insert(0, str(_ROOT_PATH))
  21. print(f"当前根目录:{_ROOT_PATH}")
  22. print(f'当前系统路径:{sys.path}')
  23. class EvaluationCore:
  24. """评估引擎核心类(单例模式)"""
  25. _instance = None
  26. def __new__(cls, logPath: str):
  27. if not cls._instance:
  28. cls._instance = super().__new__(cls)
  29. cls._instance._init(logPath)
  30. return cls._instance
  31. def _init(self, logPath: str = None) -> None:
  32. """初始化引擎组件"""
  33. # configPath: str, logPath: str, dataPath: str, resultPath: str
  34. self.log_path = logPath
  35. self._init_log_system()
  36. self._init_metrics()
  37. def _init_log_system(self) -> None:
  38. """集中式日志管理"""
  39. try:
  40. from modules.lib.log_manager import LogManager
  41. log_manager = LogManager(self.log_path)
  42. self.logger = log_manager.get_logger()
  43. except (PermissionError, IOError) as e:
  44. sys.stderr.write(f"日志系统初始化失败: {str(e)}\n")
  45. sys.exit(1)
  46. def _init_metrics(self) -> None:
  47. """初始化评估模块(策略模式)"""
  48. # from modules.metric import safety, comfort, traffic, efficient, function
  49. self.metric_modules = {
  50. "safety": self._load_module("modules.metric.safety", "Safe"),
  51. "comfort": self._load_module("modules.metric.comfort", "Comfort"),
  52. "traffic": self._load_module("modules.metric.traffic", "ViolationManager"),
  53. "efficient": self._load_module("modules.metric.efficient", "Efficient"),
  54. "function": self._load_module("modules.metric.function", "FunctionManager"),
  55. }
  56. @lru_cache(maxsize=32)
  57. def _load_module(self, module_path: str, class_name: str) -> Any:
  58. """动态加载评估模块(带缓存)"""
  59. try:
  60. __import__(module_path)
  61. return getattr(sys.modules[module_path], class_name)
  62. except (ImportError, AttributeError) as e:
  63. self.logger.error(f"模块加载失败: {module_path}.{class_name} - {str(e)}")
  64. raise
  65. def parallel_evaluate(self, data: Any) -> Dict[str, Any]:
  66. """并行化评估引擎(动态线程池)"""
  67. results = {}
  68. # 关键修改点1:线程数=模块数
  69. with ThreadPoolExecutor(max_workers=len(self.metric_modules)) as executor:
  70. # 关键修改点2:按模块名创建future映射
  71. futures = {
  72. module_name: executor.submit(
  73. self._run_module, module, data, module_name
  74. )
  75. for module_name, module in self.metric_modules.items()
  76. }
  77. # 关键修改点3:按模块顺序处理结果
  78. for module_name, future in futures.items():
  79. try:
  80. result = future.result()
  81. results.update(result[module_name])
  82. # 结构化合并结果(保留模块名键)
  83. # results[module_name] = result.get(module_name, {})
  84. except Exception as e:
  85. self.logger.error(
  86. f"{module_name} 评估失败: {str(e)}",
  87. exc_info=True,
  88. extra={"stack": True}, # 记录完整堆栈
  89. )
  90. # 错误信息结构化存储
  91. results[module_name] = {
  92. "status": "error",
  93. "message": str(e),
  94. "timestamp": datetime.now().isoformat(),
  95. }
  96. return results
  97. def _run_module(
  98. self, module_class: Any, data: Any, module_name: str
  99. ) -> Dict[str, Any]:
  100. """执行单个评估模块(带熔断机制)"""
  101. try:
  102. instance = module_class(data)
  103. return {module_name: instance.report_statistic()}
  104. except Exception as e:
  105. self.logger.error(f"{module_name} 执行异常: {str(e)}", stack_info=True)
  106. return {module_name: {"error": str(e)}}
  107. class EvaluationPipeline:
  108. """评估流水线控制器"""
  109. def __init__(self, configPath: str, logPath: str, dataPath: str, resultPath: str):
  110. self.engine = EvaluationCore(logPath)
  111. self.configPath = configPath
  112. self.data_path = dataPath
  113. self.report_path = resultPath
  114. # self.case_name = os.path.basename(os.path.dirname(dataPath))
  115. self.data_processor = self._load_data_processor()
  116. def _load_data_processor(self) -> Any:
  117. """动态加载数据预处理模块"""
  118. try:
  119. from modules.lib import data_process
  120. return data_process.DataPreprocessing(self.data_path, self.configPath)
  121. except ImportError as e:
  122. raise RuntimeError(f"数据处理器加载失败: {str(e)}") from e
  123. def execute_pipeline(self) -> Dict[str, Any]:
  124. """端到端执行评估流程"""
  125. self._validate_case()
  126. try:
  127. metric_results = self.engine.parallel_evaluate(self.data_processor)
  128. from modules.lib.score import get_overall_result
  129. all_result = get_overall_result(metric_results, self.configPath)
  130. report = self._generate_report(
  131. self.data_processor.case_name, all_result
  132. )
  133. return report
  134. except Exception as e:
  135. self.engine.logger.critical(f"流程执行失败: {str(e)}", exc_info=True)
  136. return {"error": str(e)}
  137. def _validate_case(
  138. self,
  139. ) -> None:
  140. """用例路径验证"""
  141. case_path = self.data_path
  142. if not case_path.exists():
  143. raise FileNotFoundError(f"用例路径不存在: {case_path}")
  144. if not case_path.is_dir():
  145. raise NotADirectoryError(f"无效的用例目录: {case_path}")
  146. def _generate_report(self, case_name: str, results: Dict) -> Dict:
  147. """生成评估报告(模板方法模式)"""
  148. from modules.lib.common import dict2json
  149. report_path = self.report_path
  150. report_path.mkdir(parents=True, exist_ok=True, mode=0o777)
  151. dict2json(results, report_path / f"{case_name}_report.json")
  152. return results
  153. def main():
  154. """命令行入口(工厂模式)"""
  155. parser = argparse.ArgumentParser(
  156. description="自动驾驶评估系统 V2.0",
  157. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  158. )
  159. # 带帮助说明的参数定义,设置为必传参数
  160. parser.add_argument(
  161. "--logPath",
  162. type=Path,
  163. default="/home/kevin/kevin/zhaoyuan/zhaoyuan/log/runtime.log",
  164. help="日志文件存储路径",
  165. )
  166. parser.add_argument(
  167. "--dataPath",
  168. type=Path,
  169. default="/home/kevin/kevin/zhaoyuan/sqlite3_demo/docker_build/zhaoyuan_0320/V2V_CSAE53-2020_ForwardCollision_LST_02-03",
  170. help="预处理后的输入数据目录",
  171. )
  172. parser.add_argument(
  173. "--configPath",
  174. type=Path,
  175. default="/home/kevin/kevin/zhaoyuan/sqlite3_demo/docker_build/zhaoyuan_0320/config/metric_config.yaml",
  176. help="评估指标配置文件路径",
  177. )
  178. parser.add_argument(
  179. "--reportPath",
  180. type=Path,
  181. default="/home/kevin/kevin/zhaoyuan/sqlite3_demo/docker_build/zhaoyuan_0320/result",
  182. help="评估报告输出目录",
  183. )
  184. args = parser.parse_args()
  185. try:
  186. pipeline = EvaluationPipeline(
  187. args.configPath, args.logPath, args.dataPath, args.reportPath
  188. )
  189. start_time = time.perf_counter()
  190. result = pipeline.execute_pipeline()
  191. if "error" in result:
  192. sys.exit(1)
  193. print(f"评估完成,耗时: {time.perf_counter()-start_time:.2f}s")
  194. print(f"报告路径: {pipeline.report_path}")
  195. except KeyboardInterrupt:
  196. print("\n用户中断操作")
  197. sys.exit(130)
  198. if __name__ == "__main__":
  199. warnings.filterwarnings("ignore")
  200. main()