safety.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 安全指标计算模块
  5. """
  6. import numpy as np
  7. import pandas as pd
  8. from typing import Dict, Any, List, Optional
  9. from modules.lib.score import Score
  10. from modules.lib.log_manager import LogManager
  11. # 安全指标计算函数
  12. def calculate_ttc(data_processed) -> dict:
  13. """计算TTC (Time To Collision)"""
  14. # 实现TTC计算逻辑
  15. # ...
  16. return {"TTC": 3.5} # 示例返回值
  17. def calculate_mttc(data_processed) -> dict:
  18. """计算MTTC (Modified Time To Collision)"""
  19. # 实现MTTC计算逻辑
  20. # ...
  21. return {"MTTC": 4.2} # 示例返回值
  22. def calculate_thw(data_processed) -> dict:
  23. """计算THW (Time Headway)"""
  24. # 实现THW计算逻辑
  25. # ...
  26. return {"THW": 2.1} # 示例返回值
  27. def calculate_collision_risk(data_processed) -> dict:
  28. """计算碰撞风险"""
  29. # 实现碰撞风险计算逻辑
  30. # ...
  31. return {"collisionRisk": 0.15} # 示例返回值
  32. class SafetyRegistry:
  33. """安全指标注册器"""
  34. def __init__(self, data_processed):
  35. self.logger = LogManager().get_logger()
  36. self.data = data_processed
  37. self.safety_config = data_processed.safety_config["safety"]
  38. self.metrics = self._extract_metrics(self.safety_config)
  39. self._registry = self._build_registry()
  40. def _extract_metrics(self, config_node: dict) -> list:
  41. """从配置中提取指标名称"""
  42. metrics = []
  43. def _recurse(node):
  44. if isinstance(node, dict):
  45. if 'name' in node and not any(isinstance(v, dict) for v in node.values()):
  46. metrics.append(node['name'])
  47. for v in node.values():
  48. _recurse(v)
  49. _recurse(config_node)
  50. self.logger.info(f'评比的安全指标列表:{metrics}')
  51. return metrics
  52. def _build_registry(self) -> dict:
  53. """构建指标函数注册表"""
  54. registry = {}
  55. for metric_name in self.metrics:
  56. func_name = f"calculate_{metric_name.lower()}"
  57. if func_name in globals():
  58. registry[metric_name] = globals()[func_name]
  59. else:
  60. self.logger.warning(f"未实现安全指标函数: {func_name}")
  61. return registry
  62. def batch_execute(self) -> dict:
  63. """批量执行指标计算"""
  64. results = {}
  65. for name, func in self._registry.items():
  66. try:
  67. result = func(self.data)
  68. results.update(result)
  69. except Exception as e:
  70. self.logger.error(f"{name} 执行失败: {str(e)}", exc_info=True)
  71. results[name] = None
  72. self.logger.info(f'安全指标计算结果:{results}')
  73. return results
  74. class SafeManager:
  75. """安全指标管理类"""
  76. def __init__(self, data_processed):
  77. self.data = data_processed
  78. self.registry = SafetyRegistry(self.data)
  79. def report_statistic(self):
  80. """计算并报告安全指标结果"""
  81. safety_result = self.registry.batch_execute()
  82. # evaluator = Score(self.data.safety_config)
  83. # result = evaluator.evaluate(safety_result)
  84. # return result
  85. return safety_result