#!/usr/bin/env python
# -*- coding: utf-8 -*-
##################################################################
#
# Copyright (c) 2025 CICV, Inc. All Rights Reserved
#
##################################################################
"""
@Authors:           zhanghaiwen(zhanghaiwen@china-icv.cn)
@Data:              2025/01/5
@Last Modified:     2025/01/5
@Summary:           Function Metrics Calculation
"""

import sys
from pathlib import Path

# 添加项目根目录到系统路径
root_path = Path(__file__).resolve().parent.parent
sys.path.append(str(root_path))

from modules.lib.score import Score
from modules.lib.log_manager import LogManager
import numpy as np
from typing import Dict, Tuple, Optional, Callable, Any
import pandas as pd
import yaml

# ----------------------
# 基础工具函数 (Pure functions)
# ----------------------
scenario_sign_dict = {"LeftTurnAssist": 206, "HazardousLocationW": 207, "RedLightViolationW": 208,
                      "CoorperativeIntersectionPassing": 225, "GreenLightOptimalSpeedAdvisory": 234,
                      "ForwardCollision": 212}


def calculate_distance_PGVIL(ego_pos: np.ndarray, obj_pos: np.ndarray) -> np.ndarray:
    """向量化距离计算"""
    return np.linalg.norm(ego_pos - obj_pos, axis=1)


def calculate_relative_speed_PGVIL(
        ego_speed: np.ndarray, obj_speed: np.ndarray
) -> np.ndarray:
    """向量化相对速度计算"""
    return np.linalg.norm(ego_speed - obj_speed, axis=1)


def calculate_distance(ego_df: pd.DataFrame, correctwarning: int) -> np.ndarray:
    """向量化距离计算"""
    dist = ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]['relative_dist']
    return dist


def calculate_relative_speed(ego_df: pd.DataFrame, correctwarning: int) -> np.ndarray:
    """向量化相对速度计算"""
    return ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]['composite_v']


def extract_ego_obj(data: pd.DataFrame) -> Tuple[pd.Series, pd.DataFrame]:
    """数据提取函数"""
    ego = data[data["playerId"] == 1].iloc[0]
    obj = data[data["playerId"] != 1]
    return ego, obj


def get_first_warning(data_processed) -> Optional[pd.DataFrame]:
    """带缓存的预警数据获取"""
    ego_df = data_processed.ego_data
    obj_df = data_processed.object_df

    scenario_name = data_processed.function_config["function"]["scenario"]["name"]
    correctwarning = scenario_sign_dict.get(scenario_name)

    if correctwarning is None:
        print("无法获取正确的预警信号标志位!")
        return None
    warning_rows = ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]

    warning_times = warning_rows['simTime']
    if warning_times.empty:
        print("没有找到预警数据!")
        return None

    first_time = warning_times.iloc[0]
    return obj_df[obj_df['simTime'] == first_time]


# ----------------------
# 核心计算功能函数
# ----------------------
def latestWarningDistance_LST(data) -> dict:
    """预警距离计算流水线"""
    scenario_name = data.function_config["function"]["scenario"]["name"]
    correctwarning = scenario_sign_dict[scenario_name]
    ego_df = data.ego_data
    warning_dist = calculate_distance(ego_df, correctwarning)
    if warning_dist.empty:
        return {"latestWarningDistance_LST": 0.0}

    return {"latestWarningDistance_LST": float(warning_dist.iloc[-1])}


def earliestWarningDistance_LST(data) -> dict:
    """预警距离计算流水线"""
    scenario_name = data.function_config["function"]["scenario"]["name"]
    correctwarning = scenario_sign_dict[scenario_name]
    ego_df = data.ego_data
    warning_dist = calculate_distance(ego_df, correctwarning)
    if warning_dist.empty:
        return {"earliestWarningDistance_LST": 0.0}

    return {"earliestWarningDistance_LST": float(warning_dist.iloc[0]) if len(warning_dist) > 0 else np.inf}


def latestWarningDistance_TTC_LST(data) -> dict:
    """TTC计算流水线"""
    scenario_name = data.function_config["function"]["scenario"]["name"]
    correctwarning = scenario_sign_dict[scenario_name]
    ego_df = data.ego_data
    warning_dist = calculate_distance(ego_df, correctwarning)
    if warning_dist.empty:
        return {"latestWarningDistance_TTC_LST": 0.0}

    warning_speed = calculate_relative_speed(ego_df, correctwarning)

    with np.errstate(divide='ignore', invalid='ignore'):
        ttc = np.where(warning_speed != 0, warning_dist / warning_speed, np.inf)

    return {"latestWarningDistance_TTC_LST": float(ttc[-1]) if len(ttc) > 0 else np.inf}


def earliestWarningDistance_TTC_LST(data) -> dict:
    """TTC计算流水线"""
    scenario_name = data.function_config["function"]["scenario"]["name"]
    correctwarning = scenario_sign_dict[scenario_name]
    ego_df = data.ego_data
    warning_dist = calculate_distance(ego_df, correctwarning)
    if warning_dist.empty:
        return {"earliestWarningDistance_TTC_LST": 0.0}

    warning_speed = calculate_relative_speed(ego_df, correctwarning)

    with np.errstate(divide='ignore', invalid='ignore'):
        ttc = np.where(warning_speed != 0, warning_dist / warning_speed, np.inf)

    return {"earliestWarningDistance_TTC_LST": float(ttc[0]) if len(ttc) > 0 else np.inf}


def warningDelayTime_LST(data):
    scenario_name = data.function_config["function"]["scenario"]["name"]
    correctwarning = scenario_sign_dict[scenario_name]
    ego_df = data.ego_data
    HMI_warning_rows = ego_df[(ego_df['ifwarning'] == correctwarning)]['simTime'].tolist()
    simTime_HMI = HMI_warning_rows[0] if len(HMI_warning_rows) > 0 else None
    rosbag_warning_rows = ego_df[(ego_df['event_Type'].notna()) & ((ego_df['event_Type'] != np.nan))][
        'simTime'].tolist()
    simTime_rosbag = rosbag_warning_rows[0] if len(rosbag_warning_rows) > 0 else None
    if (simTime_HMI is None) or (simTime_rosbag is None):
        print("预警出错!")
        delay_time = 100.0
    else:
        delay_time = abs(simTime_HMI - simTime_rosbag)
    return {"warningDelayTime_LST": delay_time}


def warningDelayTimeOf4_LST(data):
    scenario_name = data.function_config["function"]["scenario"]["name"]
    correctwarning = scenario_sign_dict[scenario_name]
    ego_df = data.ego_data
    ego_speed_simtime = ego_df[ego_df['accel'] <= -4]['simTime'].tolist()  # 单位m/s^2
    warning_simTime = ego_df[ego_df['ifwarning'] == correctwarning]['simTime'].tolist()
    if (len(warning_simTime) == 0) and (len(ego_speed_simtime) == 0):
        return {"warningDelayTimeOf4_LST": 0}
    elif (len(warning_simTime) == 0) and (len(ego_speed_simtime) > 0):
        return {"warningDelayTimeOf4_LST": ego_speed_simtime[0]}
    elif (len(warning_simTime) > 0) and (len(ego_speed_simtime) == 0):
        return {"warningDelayTimeOf4_LST": None}
    else:
        return {"warningDelayTimeOf4_LST": warning_simTime[0] - ego_speed_simtime[0]}


def rightWarningSignal_LST(data):
    scenario_name = data.function_config["function"]["scenario"]["name"]
    correctwarning = scenario_sign_dict[scenario_name]
    ego_df = data.ego_data
    if ego_df['ifwarning'].empty:
        print("无法获取正确预警信号标志位!")
        return
    warning_rows = ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['ifwarning'].notna())]
    if warning_rows.empty:
        return {"rightWarningSignal_LST": -1}
    else:
        return {"rightWarningSignal_LST": 1}


def ifCrossingRedLight_LST(data):
    scenario_name = data.function_config["function"]["scenario"]["name"]
    correctwarning = scenario_sign_dict[scenario_name]
    ego_df = data.ego_data
    redlight_simtime = ego_df[
        (ego_df['ifwarning'] == correctwarning) & (ego_df['stateMask'] == 1) & (ego_df['relative_dist'] == 0) & (
                    ego_df['v'] != 0)]['simTime']
    if redlight_simtime.empty:
        return {"ifCrossingRedLight_LST": -1}
    else:
        return {"ifCrossingRedLight_LST": 1}


def ifStopgreenWaveSpeedGuidance_LST(data):
    scenario_name = data.function_config["function"]["scenario"]["name"]
    correctwarning = scenario_sign_dict[scenario_name]
    ego_df = data.ego_data
    greenlight_simtime = \
    ego_df[(ego_df['ifwarning'] == correctwarning) & (ego_df['stateMask'] == 0) & (ego_df['v'] == 0)]['simTime']
    if greenlight_simtime.empty:
        return {"ifStopgreenWaveSpeedGuidance_LST": -1}
    else:
        return {"ifStopgreenWaveSpeedGuidance_LST": 1}


def rightWarningSignal_PGVIL(data_processed) -> dict:
    """判断是否发出正确预警信号"""

    ego_df = data_processed.ego_data
    scenario_name = data_processed.function_config["function"]["scenario"]["name"]
    correctwarning = scenario_sign_dict[scenario_name]

    if correctwarning is None:
        print("无法获取正确的预警信号标志位!")
        return None
    # 找出本行 correctwarning 和 ifwarning 相等,且 correctwarning 不是 NaN 的行
    warning_rows = ego_df[
        (ego_df["ifwarning"] == correctwarning) & (ego_df["ifwarning"].notna())
        ]

    if warning_rows.empty:
        return {"rightWarningSignal_PGVIL": -1}
    else:
        return {"rightWarningSignal_PGVIL": 1}


def latestWarningDistance_PGVIL(data_processed) -> dict:
    """预警距离计算流水线"""
    ego_df = data_processed.ego_data
    obj_df = data_processed.object_df
    warning_data = get_first_warning(data_processed)
    if warning_data is None:
        return {"latestWarningDistance_PGVIL": 0.0}

    ego, obj = extract_ego_obj(warning_data)
    distances = calculate_distance_PGVIL(
        np.array([[ego["posX"], ego["posY"]]]), obj[["posX", "posY"]].values
    )
    if distances.size == 0:
        print("没有找到数据!")
        return {"latestWarningDistance_PGVIL": 15}  # 或返回其他默认值,如0.0

    return {"latestWarningDistance_PGVIL": float(np.min(distances))}


def latestWarningDistance_TTC_PGVIL(data_processed) -> dict:
    """TTC计算流水线"""
    ego_df = data_processed.ego_data
    obj_df = data_processed.object_df

    warning_data = get_first_warning(data_processed)
    if warning_data is None:
        return {"latestWarningDistance_TTC_PGVIL": 0.0}

    ego, obj = extract_ego_obj(warning_data)
    # 向量化计算
    ego_pos = np.array([[ego["posX"], ego["posY"]]])
    ego_speed = np.array([[ego["speedX"], ego["speedY"]]])
    obj_pos = obj[["posX", "posY"]].values
    obj_speed = obj[["speedX", "speedY"]].values

    distances = calculate_distance_PGVIL(ego_pos, obj_pos)
    rel_speeds = calculate_relative_speed_PGVIL(ego_speed, obj_speed)

    with np.errstate(divide="ignore", invalid="ignore"):
        ttc = np.where(rel_speeds != 0, distances / rel_speeds, np.inf)
    if ttc.size == 0:
        print("没有找到数据!")
        return {"latestWarningDistance_TTC_PGVIL": 2}  # 或返回其他默认值,如0.0

    return {"latestWarningDistance_TTC_PGVIL": float(np.nanmin(ttc))}


def earliestWarningDistance_PGVIL(data_processed) -> dict:
    """预警距离计算流水线"""
    ego_df = data_processed.ego_data
    obj_df = data_processed.object_df

    warning_data = get_first_warning(data_processed)
    if warning_data is None:
        return {"earliestWarningDistance_PGVIL": 0}

    ego, obj = extract_ego_obj(warning_data)
    distances = calculate_distance_PGVIL(
        np.array([[ego["posX"], ego["posY"]]]), obj[["posX", "posY"]].values
    )
    if distances.size == 0:
        print("没有找到数据!")
        return {"earliestWarningDistance_PGVIL": 15}  # 或返回其他默认值,如0.0

    return {"earliestWarningDistance": float(np.min(distances))}


def earliestWarningDistance_TTC_PGVIL(data_processed) -> dict:
    """TTC计算流水线"""
    ego_df = data_processed.ego_data
    obj_df = data_processed.object_df

    warning_data = get_first_warning(data_processed)
    if warning_data is None:
        return {"earliestWarningDistance_TTC_PGVIL": 0.0}

    ego, obj = extract_ego_obj(warning_data)

    # 向量化计算
    ego_pos = np.array([[ego["posX"], ego["posY"]]])
    ego_speed = np.array([[ego["speedX"], ego["speedY"]]])
    obj_pos = obj[["posX", "posY"]].values
    obj_speed = obj[["speedX", "speedY"]].values

    distances = calculate_distance_PGVIL(ego_pos, obj_pos)
    rel_speeds = calculate_relative_speed_PGVIL(ego_speed, obj_speed)

    with np.errstate(divide="ignore", invalid="ignore"):
        ttc = np.where(rel_speeds != 0, distances / rel_speeds, np.inf)
    if ttc.size == 0:
        print("没有找到数据!")
        return {"earliestWarningDistance_TTC_PGVIL": 2}  # 或返回其他默认值,如0.0

    return {"earliestWarningDistance_TTC_PGVIL": float(np.nanmin(ttc))}


# def delayOfEmergencyBrakeWarning(data_processed) -> dict:
#     #预警时机相对背景车辆减速度达到-4m/s2后的时延
#     ego_df = data_processed.ego_data
#     obj_df = data_processed.object_df
#     warning_data = get_first_warning(data_processed)
#     if warning_data is None:
#         return {"delayOfEmergencyBrakeWarning": -1}
#     try:
#         ego, obj = extract_ego_obj(warning_data)
#         # 向量化计算
#         obj_speed = np.array([[obj_df["speedX"], obj_df["speedY"]]])
#         # 计算背景车辆减速度
#         simtime_gap = obj["simTime"].iloc[1] - obj["simTime"].iloc[0]
#         simtime_freq = 1 / simtime_gap#每秒采样频率
#         # simtime_freq为一个时间窗,找出时间窗内的最大减速度
#         obj_speed_magnitude = np.linalg.norm(obj_speed, axis=1)#速度向量的模长
#         obj_speed_change = np.diff(speed_magnitude)#速度模长的变化量
#         obj_deceleration = np.diff(obj_speed_magnitude) / simtime_gap
#         #找到最大减速度,若最大减速度小于-4m/s2,则计算最大减速度对应的时间,和warning_data的差值进行对比
#         max_deceleration = np.max(obj_deceleration)
#         if max_deceleration < -4:
#             max_deceleration_times = obj["simTime"].iloc[np.argmax(obj_deceleration)]
#             max_deceleration_time = max_deceleration_times.iloc[0]
#             delay_time = ego["simTime"] - max_deceleration_time
#             return {"delayOfEmergencyBrakeWarning": float(delay_time)}

#         else:
#             print("没有达到预警减速度阈值:-4m/s^2")
#             return {"delayOfEmergencyBrakeWarning": -1}


def warningDelayTime_PGVIL(data_processed) -> dict:
    """车端接收到预警到HMI发出预警的时延"""
    ego_df = data_processed.ego_data
    # #打印ego_df的列名
    # print(ego_df.columns.tolist())

    warning_data = get_first_warning(data_processed)

    if warning_data is None:
        return {"warningDelayTime_PGVIL": -1}
    try:
        ego, obj = extract_ego_obj(warning_data)

        # 找到event_Type不为空,且playerID为1的行
        rosbag_warning_rows = ego_df[(ego_df["event_Type"].notna())]

        first_time = rosbag_warning_rows["simTime"].iloc[0]
        warning_time = warning_data[warning_data["playerId"] == 1]["simTime"].iloc[0]
        delay_time = warning_time - first_time

        return {"warningDelayTime_PGVIL": float(delay_time)}

    except Exception as e:
        print(f"计算预警时延时发生错误: {e}")
        return {"warningDelayTime_PGVIL": -1}


def get_car_to_stop_line_distance(ego, car_point, stop_line_points):
    """
    计算主车后轴中心点到停止线的距离
    :return 距离
    """
    distance_carpoint_carhead = ego["dimX"] / 2 + ego["offX"]
    # 计算停止线的方向向量
    line_vector = np.array(
        [
            stop_line_points[1][0] - stop_line_points[0][0],
            stop_line_points[1][1] - stop_line_points[0][1],
        ]
    )
    direction_vector_norm = np.linalg.norm(line_vector)
    direction_vector_unit = (
        line_vector / direction_vector_norm
        if direction_vector_norm != 0
        else np.array([0, 0])
    )
    # 计算主车后轴中心点到停止线投影的坐标(垂足)
    projection_length = np.dot(car_point - stop_line_points[0], direction_vector_unit)
    perpendicular_foot = stop_line_points[0] + projection_length * direction_vector_unit

    # 计算主车后轴中心点到垂足的距离
    distance_to_foot = np.linalg.norm(car_point - perpendicular_foot)
    carhead_distance_to_foot = distance_to_foot - distance_carpoint_carhead

    return carhead_distance_to_foot


def ifCrossingRedLight_PGVIL(data_processed) -> dict:
    # 判断车辆是否闯红灯

    stop_line_points = np.array([(276.555, -35.575), (279.751, -33.683)])
    X_OFFSET = 258109.4239876
    Y_OFFSET = 4149969.964821
    stop_line_points += np.array([[X_OFFSET, Y_OFFSET]])
    ego_df = data_processed.ego_data

    prev_distance = float("inf")  # 初始化为正无穷大
    """
    traffic_light_status
    0x100000为绿灯,1048576
    0x1000000为黄灯,16777216
    0x10000000为红灯,268435456
    """
    red_light_violation = False
    for index, ego in ego_df.iterrows():
        car_point = (ego["posX"], ego["posY"])
        stateMask = ego["stateMask"]
        simTime = ego["simTime"]
        distance_to_stopline = get_car_to_stop_line_distance(
            ego, car_point, stop_line_points
        )

        # 主车车头跨越停止线时非绿灯,返回-1,闯红灯
        if prev_distance > 0 and distance_to_stopline < 0:
            if stateMask is not None and stateMask != 1048576:
                red_light_violation = True
            break
        prev_distance = distance_to_stopline

    if red_light_violation:
        return {"ifCrossingRedLight_PGVIL": -1}  # 闯红灯
    else:
        return {"ifCrossingRedLight_PGVIL": 1}  # 没有闯红灯


# def ifStopgreenWaveSpeedGuidance(data_processed) -> dict:
#     #在绿波车速引导期间是否发生停车


# def mindisStopline(data_processed) -> dict:
#     """
#     当有停车让行标志/标线时车辆最前端与停车让行线的最小距离应在0-4m之间
#     """
#     ego_df = data_processed.ego_data
#     obj_df = data_processed.object_df
#     stop_giveway_simtime = ego_df[
#         ego_df["sign_type1"] == 32 |
#         ego_df["stopline_type"] == 3
#     ]["simTime"]
#     stop_giveway_data = ego_df[
#         ego_df["sign_type1"] == 32 |
#         ego_df["stopline_type"] == 3
#     ]["simTime"]
#     if stop_giveway_simtime.empty:
#         print("没有停车让行标志/标线")

#     ego_data = stop_giveway_data[stop_giveway_data['playerId'] == 1]
#     distance_carpoint_carhead = ego_data['dimX'].iloc[0]/2 + ego_data['offX'].iloc[0]
#     distance_to_stoplines = []
#     for _,row in ego_data.iterrows():
#         ego_pos = np.array([row["posX"], row["posY"], row["posH"]])
#         stop_line_points = [
#             [row["stopline_x1"], row["stopline_y1"]],
#             [row["stopline_x2"], row["stopline_y2"]],
#         ]
#         distance_to_stopline = get_car_to_stop_line_distance(ego_pos, stop_line_points)
#         distance_to_stoplines.append(distance_to_stopline)

#     mindisStopline = np.min(distance_to_stoplines) - distance_carpoint_carhead
#     return {"mindisStopline": mindisStopline}


class FunctionRegistry:
    """动态函数注册器(支持参数验证)"""

    def __init__(self, data_processed):
        self.logger = LogManager().get_logger()  # 获取全局日志实例
        self.data = data_processed
        self.fun_config = data_processed.function_config["function"]
        self.level_3_merics = self._extract_level_3_metrics(self.fun_config)
        self._registry: Dict[str, Callable] = {}
        self._registry = self._build_registry()

    def _extract_level_3_metrics(self, config_node: dict) -> list:
        """DFS遍历提取第三层指标(时间复杂度O(n))[4](@ref)"""
        metrics = []

        def _recurse(node):
            if isinstance(node, dict):
                if "name" in node and not any(
                        isinstance(v, dict) for v in node.values()
                ):
                    metrics.append(node["name"])
                for v in node.values():
                    _recurse(v)

        _recurse(config_node)
        self.logger.info(f"评比的功能指标列表:{metrics}")
        return metrics

    def _build_registry(self) -> dict:
        """自动注册指标函数(防御性编程)"""
        registry = {}
        for func_name in self.level_3_merics:
            try:
                registry[func_name] = globals()[func_name]
            except KeyError:
                print(f"未实现指标函数: {func_name}")
                self.logger.error(f"未实现指标函数: {func_name}")
        return registry

    def batch_execute(self) -> dict:
        """批量执行指标计算(带熔断机制)"""
        results = {}
        for name, func in self._registry.items():
            try:
                result = func(self.data)  # 统一传递数据上下文
                results.update(result)
            except Exception as e:
                print(f"{name} 执行失败: {str(e)}")
                self.logger.error(f"{name} 执行失败: {str(e)}", exc_info=True)
                results[name] = None
        self.logger.info(f"功能指标计算结果:{results}")
        return results


class FunctionManager:
    """管理功能指标计算的类"""

    def __init__(self, data_processed):
        self.data = data_processed
        self.function = FunctionRegistry(self.data)

    def report_statistic(self):
        """
        计算并报告功能指标结果。
        :return: 评估结果
        """
        function_result = self.function.batch_execute()

        evaluator = Score(self.data.function_config)
        result = evaluator.evaluate(function_result)
        print("\n[功能性表现及评价结果]")
        return result
        # self.logger.info(f'Function Result: {function_result}')


# 使用示例
if __name__ == "__main__":
    pass