#!/usr/bin/env python # -*- coding: utf-8 -*- """ 自定义指标统一模板 本模板提供两种实现自定义指标的方式: 1. 基于类继承的方式(推荐):继承BaseMetric基类,实现calculate方法 2. 基于函数式的方式:实现evaluate函数 用户可以根据自己的需求选择合适的实现方式。 """ from typing import Dict, Any, Union, Optional import numpy as np import logging from modules.lib.score import Score # 导入基类 # 注意:实际使用时,请确保路径正确 from modules.lib.metric_registry import BaseMetric # 指定指标类别(必须) # 可选值: safety, comfort, traffic, efficient, function, custom METRIC_CATEGORY = "custom" ############################################################# # 方式一:基于类继承的实现方式(推荐) # 优点:结构清晰,易于扩展,支持复杂指标计算 # 适用场景:需要复杂状态管理、多步骤计算的指标 ############################################################# class CustomMetricExample(BaseMetric): """自定义指标示例 - 计算平均速度""" def __init__(self, data: Any): """ 初始化指标 Args: data: 输入数据,通常包含场景、轨迹等信息 """ super().__init__(data) # 在这里添加自定义初始化代码 def calculate(self) -> Dict[str, Any]: """ 计算指标 Returns: 计算结果字典,包含以下字段: - value: 指标值 - score: 评分(0-100) - details: 详细信息(可选) """ # 在这里实现指标计算逻辑 result = { "value": 0.0, # 指标值 "score": 100, # 评分 "details": {} # 详细信息 } # 示例:计算平均速度 try: if hasattr(self.data, 'ego_data') and hasattr(self.data.ego_data, 'v'): # 获取速度数据 speeds = self.data.ego_data['v'].values # 计算平均速度 avg_speed = np.mean(speeds) result['value'] = float(avg_speed) # 简单评分逻辑 if avg_speed < 10: result['score'] = 60 # 速度过低 elif avg_speed > 50: result['score'] = 70 # 速度过高 else: result['score'] = 100 # 速度适中 # 添加详细信息 result['details'] = { "max_speed": float(np.max(speeds)), "min_speed": float(np.min(speeds)), "std_speed": float(np.std(speeds)) } except Exception as e: # 出错时记录错误信息 logging.error(f"计算指标失败: {str(e)}") result['value'] = 0.0 result['score'] = 0 result['details'] = {"error": str(e)} return result def report_statistic(self) -> Dict[str, Any]: """ 报告统计结果 可以在这里自定义结果格式 Returns: 统计结果字典 """ result = self.calculate() # 可以在这里添加额外的处理逻辑 # 例如:添加时间戳、格式化结果等 return result ############################################################# # 方式二:基于函数式的实现方式 # 优点:简单直接,易于理解 # 适用场景:简单的指标计算,无需复杂状态管理 ############################################################# def evaluate(data) -> Dict[str, Any]: """ 评测自定义指标 Args: data: 评测数据,包含场景、轨迹等信息 Returns: 评测结果,包含指标值、分数、详情等 """ try: # 计算指标值 result = calculate_metric(data) # 可以使用Score类评估结果 # evaluator = Score(config) # result = evaluator.evaluate(result) return result except Exception as e: logging.error(f"评测指标失败: {str(e)}") # 发生异常时返回错误信息 return { "value": 0.0, "score": 0, "details": { "error": str(e) } } def calculate_metric(data) -> Dict[str, Any]: """ 计算指标值 Args: data: 输入数据 Returns: 指标计算结果 """ # 这里是计算指标的具体逻辑 # 以下是一个简化的示例 if data is None: raise ValueError("输入数据不能为空") try: # 示例:计算TTC (Time To Collision) if hasattr(data, 'ego_data'): # 这里应该实现实际的指标计算逻辑 # 临时使用固定值代替实际计算 metric_value = 1.5 # 返回结果 return { "value": metric_value, "score": 85, # 示例评分 "details": { "min_value": metric_value, "max_value": metric_value * 2 } } else: raise ValueError("数据格式不正确,缺少ego_data") except Exception as e: logging.error(f"计算指标失败: {str(e)}") raise ############################################################# # 使用说明 ############################################################# """ 如何选择实现方式: 1. 基于类继承的方式(推荐): - 适用于复杂指标计算 - 需要维护状态或多步骤计算 - 需要与系统深度集成 2. 基于函数式的方式: - 适用于简单指标计算 - 逻辑简单,无需复杂状态管理 - 快速实现原型 文件命名规范: - 文件名应以 metric_ 开头 - 后跟指标类别、二级指标名、三级指标名 - 例如:metric_safety_safeTime_CustomTTC.py 必要条件: 1. 类实现方式:必须继承 BaseMetric 基类并实现 calculate() 方法 2. 函数实现方式:必须实现 evaluate() 函数 3. 必须在文件中定义 METRIC_CATEGORY 变量,指定指标类别 返回结果格式: { "value": 0.0, # 指标值 "score": 100, # 评分(0-100) "details": {} # 详细信息(可选) } """ # 测试代码(实际使用时可删除) if __name__ == "__main__": # 这里可以添加测试代码 pass