123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- """
- 自定义指标统一模板
- 本模板提供两种实现自定义指标的方式:
- 1. 基于类继承的方式(推荐):继承BaseMetric基类,实现calculate方法
- 2. 基于函数式的方式:实现evaluate函数
- 用户可以根据自己的需求选择合适的实现方式。
- """
- from typing import Dict, Any, Union, Optional
- import numpy as np
- import logging
- from modules.lib.score import Score
- from modules.lib.metric_registry import BaseMetric
- METRIC_CATEGORY = "custom"
- class CustomMetricExample(BaseMetric):
- """自定义指标示例 - 计算平均速度"""
-
- def __init__(self, data: Any):
- """
- 初始化指标
-
- Args:
- data: 输入数据,通常包含场景、轨迹等信息
- """
- super().__init__(data)
-
-
- def calculate(self) -> Dict[str, Any]:
- """
- 计算指标
-
- Returns:
- 计算结果字典,包含以下字段:
- - value: 指标值
- - score: 评分(0-100)
- - details: 详细信息(可选)
- """
-
- result = {
- "value": 0.0,
- "score": 100,
- "details": {}
- }
-
-
- try:
- if hasattr(self.data, 'ego_data') and hasattr(self.data.ego_data, 'v'):
-
- speeds = self.data.ego_data['v'].values
-
-
- avg_speed = np.mean(speeds)
- result['value'] = float(avg_speed)
-
-
- if avg_speed < 10:
- result['score'] = 60
- elif avg_speed > 50:
- result['score'] = 70
- else:
- result['score'] = 100
-
-
- result['details'] = {
- "max_speed": float(np.max(speeds)),
- "min_speed": float(np.min(speeds)),
- "std_speed": float(np.std(speeds))
- }
- except Exception as e:
-
- logging.error(f"计算指标失败: {str(e)}")
- result['value'] = 0.0
- result['score'] = 0
- result['details'] = {"error": str(e)}
-
- return result
-
- def report_statistic(self) -> Dict[str, Any]:
- """
- 报告统计结果
- 可以在这里自定义结果格式
-
- Returns:
- 统计结果字典
- """
- result = self.calculate()
-
-
-
-
- return result
- def evaluate(data) -> Dict[str, Any]:
- """
- 评测自定义指标
-
- Args:
- data: 评测数据,包含场景、轨迹等信息
-
- Returns:
- 评测结果,包含指标值、分数、详情等
- """
- try:
-
- result = calculate_metric(data)
-
-
-
-
- return result
-
- except Exception as e:
- logging.error(f"评测指标失败: {str(e)}")
-
- return {
- "value": 0.0,
- "score": 0,
- "details": {
- "error": str(e)
- }
- }
-
- def calculate_metric(data) -> Dict[str, Any]:
- """
- 计算指标值
-
- Args:
- data: 输入数据
-
- Returns:
- 指标计算结果
- """
-
-
-
- if data is None:
- raise ValueError("输入数据不能为空")
-
- try:
-
- if hasattr(data, 'ego_data'):
-
-
- metric_value = 1.5
-
-
- return {
- "value": metric_value,
- "score": 85,
- "details": {
- "min_value": metric_value,
- "max_value": metric_value * 2
- }
- }
- else:
- raise ValueError("数据格式不正确,缺少ego_data")
- except Exception as e:
- logging.error(f"计算指标失败: {str(e)}")
- raise
- """
- 如何选择实现方式:
- 1. 基于类继承的方式(推荐):
- - 适用于复杂指标计算
- - 需要维护状态或多步骤计算
- - 需要与系统深度集成
- 2. 基于函数式的方式:
- - 适用于简单指标计算
- - 逻辑简单,无需复杂状态管理
- - 快速实现原型
- 文件命名规范:
- - 文件名应以 metric_ 开头
- - 后跟指标类别、二级指标名、三级指标名
- - 例如:metric_safety_safeTime_CustomTTC.py
- 必要条件:
- 1. 类实现方式:必须继承 BaseMetric 基类并实现 calculate() 方法
- 2. 函数实现方式:必须实现 evaluate() 函数
- 3. 必须在文件中定义 METRIC_CATEGORY 变量,指定指标类别
- 返回结果格式:
- {
- "value": 0.0, # 指标值
- "score": 100, # 评分(0-100)
- "details": {} # 详细信息(可选)
- }
- """
- if __name__ == "__main__":
-
- pass
|