#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
自定义指标统一模板

本模板提供两种实现自定义指标的方式:
1. 基于类继承的方式(推荐):继承BaseMetric基类,实现calculate方法
2. 基于函数式的方式:实现evaluate函数

用户可以根据自己的需求选择合适的实现方式。
"""

from typing import Dict, Any, Union, Optional
import numpy as np
import logging
from modules.lib.score import Score

# 导入基类
# 注意:实际使用时,请确保路径正确
from modules.lib.metric_registry import BaseMetric

# 指定指标类别(必须)
# 可选值: safety, comfort, traffic, efficient, function, custom
METRIC_CATEGORY = "custom"

#############################################################
# 方式一:基于类继承的实现方式(推荐)
# 优点:结构清晰,易于扩展,支持复杂指标计算
# 适用场景:需要复杂状态管理、多步骤计算的指标
#############################################################

class CustomMetricExample(BaseMetric):
    """自定义指标示例 - 计算平均速度"""
    
    def __init__(self, data: Any):
        """
        初始化指标
        
        Args:
            data: 输入数据,通常包含场景、轨迹等信息
        """
        super().__init__(data)
        # 在这里添加自定义初始化代码
        
    def calculate(self) -> Dict[str, Any]:
        """
        计算指标
        
        Returns:
            计算结果字典,包含以下字段:
            - value: 指标值
            - score: 评分(0-100)
            - details: 详细信息(可选)
        """
        # 在这里实现指标计算逻辑
        result = {
            "value": 0.0,  # 指标值
            "score": 100,  # 评分
            "details": {}  # 详细信息
        }
        
        # 示例:计算平均速度
        try:
            if hasattr(self.data, 'ego_data') and hasattr(self.data.ego_data, 'v'):
                # 获取速度数据
                speeds = self.data.ego_data['v'].values
                
                # 计算平均速度
                avg_speed = np.mean(speeds)
                result['value'] = float(avg_speed)
                
                # 简单评分逻辑
                if avg_speed < 10:
                    result['score'] = 60  # 速度过低
                elif avg_speed > 50:
                    result['score'] = 70  # 速度过高
                else:
                    result['score'] = 100  # 速度适中
                
                # 添加详细信息
                result['details'] = {
                    "max_speed": float(np.max(speeds)),
                    "min_speed": float(np.min(speeds)),
                    "std_speed": float(np.std(speeds))
                }
        except Exception as e:
            # 出错时记录错误信息
            logging.error(f"计算指标失败: {str(e)}")
            result['value'] = 0.0
            result['score'] = 0
            result['details'] = {"error": str(e)}
                
        return result
    
    def report_statistic(self) -> Dict[str, Any]:
        """
        报告统计结果
        可以在这里自定义结果格式
        
        Returns:
            统计结果字典
        """
        result = self.calculate()
        
        # 可以在这里添加额外的处理逻辑
        # 例如:添加时间戳、格式化结果等
        
        return result


#############################################################
# 方式二:基于函数式的实现方式
# 优点:简单直接,易于理解
# 适用场景:简单的指标计算,无需复杂状态管理
#############################################################

def evaluate(data) -> Dict[str, Any]:
    """
    评测自定义指标
    
    Args:
        data: 评测数据,包含场景、轨迹等信息
        
    Returns:
        评测结果,包含指标值、分数、详情等
    """

    try:
        # 计算指标值
        result = calculate_metric(data)
        
        # 可以使用Score类评估结果
        # evaluator = Score(config)   
        # result = evaluator.evaluate(result)
        return result
        
    except Exception as e:
        logging.error(f"评测指标失败: {str(e)}")
        # 发生异常时返回错误信息
        return {
            "value": 0.0,
            "score": 0,
            "details": {
                "error": str(e)
            }
        }
    

def calculate_metric(data) -> Dict[str, Any]:
    """
    计算指标值
    
    Args:
        data: 输入数据
        
    Returns:
        指标计算结果
    """
    # 这里是计算指标的具体逻辑
    # 以下是一个简化的示例
    
    if data is None:
        raise ValueError("输入数据不能为空")
    
    try:
        # 示例:计算TTC (Time To Collision)
        if hasattr(data, 'ego_data'):
            # 这里应该实现实际的指标计算逻辑
            # 临时使用固定值代替实际计算
            metric_value = 1.5
            
            # 返回结果
            return {
                "value": metric_value,
                "score": 85,  # 示例评分
                "details": {
                    "min_value": metric_value,
                    "max_value": metric_value * 2
                }
            }
        else:
            raise ValueError("数据格式不正确,缺少ego_data")
    except Exception as e:
        logging.error(f"计算指标失败: {str(e)}")
        raise


#############################################################
# 使用说明
#############################################################
"""
如何选择实现方式:

1. 基于类继承的方式(推荐):
   - 适用于复杂指标计算
   - 需要维护状态或多步骤计算
   - 需要与系统深度集成

2. 基于函数式的方式:
   - 适用于简单指标计算
   - 逻辑简单,无需复杂状态管理
   - 快速实现原型

文件命名规范:
- 文件名应以 metric_ 开头
- 后跟指标类别、二级指标名、三级指标名
- 例如:metric_safety_safeTime_CustomTTC.py

必要条件:
1. 类实现方式:必须继承 BaseMetric 基类并实现 calculate() 方法
2. 函数实现方式:必须实现 evaluate() 函数
3. 必须在文件中定义 METRIC_CATEGORY 变量,指定指标类别

返回结果格式:
{
    "value": 0.0,    # 指标值
    "score": 100,   # 评分(0-100)
    "details": {}   # 详细信息(可选)
}
"""

# 测试代码(实际使用时可删除)
if __name__ == "__main__":
    # 这里可以添加测试代码
    pass