custom_metric_template.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 自定义指标模板
  5. 用户可以基于此模板创建自己的指标
  6. """
  7. from typing import Dict, Any
  8. import numpy as np
  9. from modules.lib.score import Score
  10. # 导入基类
  11. # 注意:实际使用时,请确保路径正确
  12. from modules.lib.metric_registry import BaseMetric
  13. # 指定指标类别(必须)
  14. # 可选值: safety, comfort, traffic, efficient, function, custom
  15. METRIC_CATEGORY = "custom"
  16. class CustomMetricExample(BaseMetric):
  17. """自定义指标示例 - 计算平均速度"""
  18. def __init__(self, data: Any):
  19. """
  20. 初始化指标
  21. Args:
  22. data: 输入数据
  23. """
  24. super().__init__(data)
  25. # 在这里添加自定义初始化代码
  26. def calculate(self) -> Dict[str, Any]:
  27. """
  28. 计算指标
  29. Returns:
  30. 计算结果字典
  31. """
  32. # 在这里实现指标计算逻辑
  33. result = {
  34. "value": 0.0, # 指标值
  35. "score": 100, # 评分
  36. "details": {} # 详细信息
  37. }
  38. # 示例:计算平均速度
  39. try:
  40. if hasattr(self.data, 'velocities') and self.data.velocities:
  41. velocities = self.data.velocities
  42. if isinstance(velocities, dict) and 'vx' in velocities and 'vy' in velocities:
  43. # 计算合速度
  44. vx = np.array(velocities['vx'])
  45. vy = np.array(velocities['vy'])
  46. speeds = np.sqrt(vx**2 + vy**2)
  47. # 计算平均速度
  48. avg_speed = np.mean(speeds)
  49. result['value'] = float(avg_speed)
  50. # 简单评分逻辑
  51. if avg_speed < 10:
  52. result['score'] = 60 # 速度过低
  53. elif avg_speed > 50:
  54. result['score'] = 70 # 速度过高
  55. else:
  56. result['score'] = 100 # 速度适中
  57. # 添加详细信息
  58. result['details'] = {
  59. "max_speed": float(np.max(speeds)),
  60. "min_speed": float(np.min(speeds)),
  61. "std_speed": float(np.std(speeds))
  62. }
  63. except Exception as e:
  64. # 出错时记录错误信息
  65. result['value'] = 0.0
  66. result['score'] = 0
  67. result['details'] = {"error": str(e)}
  68. return result
  69. def report_statistic(self) -> Dict[str, Any]:
  70. """
  71. 报告统计结果
  72. 可以在这里自定义结果格式
  73. """
  74. result = self.calculate()
  75. # 可以在这里添加额外的处理逻辑
  76. # 例如:添加时间戳、格式化结果等
  77. return result
  78. # 可以在同一文件中定义多个指标类
  79. class AnotherCustomMetric(BaseMetric):
  80. """另一个自定义指标示例 - 计算加速度变化率"""
  81. def __init__(self, data: Any):
  82. super().__init__(data)
  83. def calculate(self) -> Dict[str, Any]:
  84. # 实现您的计算逻辑
  85. return {"value": 0.0, "score": 100, "details": {}}