#!/usr/bin/env python
# -*- coding: utf-8 -*-
##################################################################
#
# Copyright (c) 2024 CICV, Inc. All Rights Reserved
#
##################################################################
"""
@Authors:           zhanghaiwen(zhanghaiwen@china-icv.cn)
@Data:              2024/10/17
@Last Modified:     2024/10/17
@Summary:           Evaluation functions
"""

import os

import numpy as np
import pandas as pd

import yaml



from modules.lib.log_manager import LogManager

# from lib import log  # 确保这个路径是正确的,或者调整它
# logger = None  # 初始化为 None


class DataPreprocessing:
    def __init__(self, data_path, config_path):
        # Initialize paths and data containers
        # self.logger = log.get_logger()
        
        self.data_path = data_path
        self.case_name = os.path.basename(os.path.normpath(data_path))

        self.config_path = config_path

        # Initialize DataFrames
        self.object_df = pd.DataFrame()
        self.driver_ctrl_df = pd.DataFrame()
        self.vehicle_sys_df = pd.DataFrame()
        self.ego_data_df = pd.DataFrame()

        # Environment data
        self.lane_info_df = pd.DataFrame()
        self.road_mark_df = pd.DataFrame()
        self.road_pos_df = pd.DataFrame()
        self.traffic_light_df = pd.DataFrame()
        self.traffic_signal_df = pd.DataFrame()

        self.vehicle_config = {}
        self.safety_config = {}
        self.comfort_config = {}
        self.efficient_config = {}
        self.function_config = {}
        self.traffic_config = {}

        # Initialize data for later processing
        self.obj_data = {}
        self.ego_data = {}
        self.obj_id_list = []

        # Data quality level
        self.data_quality_level = 15

        # Process mode and prepare report information
        self._process_mode()
        self._get_yaml_config()
        self.report_info = self._get_report_info(self.obj_data.get(1, pd.DataFrame()))

    def _process_mode(self):
        """Handle different processing modes."""
        self._real_process_object_df()

    def _get_yaml_config(self):
        with open(self.config_path, 'r') as f:
            full_config = yaml.safe_load(f)

        modules = ["vehicle", "T_threshold", "safety", "comfort", "efficient", "function", "traffic"]
        
        # 1. 初始化 vehicle_config(不涉及 T_threshold 合并)
        self.vehicle_config = full_config[modules[0]]
        
        # 2. 定义 T_threshold_config(封装为字典)
        T_threshold_config = {"T_threshold": full_config[modules[1]]}
        
        # 3. 统一处理需要合并 T_threshold 的模块
        # 3.1 safety_config
        self.safety_config = {"safety": full_config[modules[2]]}
        self.safety_config.update(T_threshold_config)
        
        # 3.2 comfort_config
        self.comfort_config = {"comfort": full_config[modules[3]]}
        self.comfort_config.update(T_threshold_config)
        
        # 3.3 efficient_config
        self.efficient_config = {"efficient": full_config[modules[4]]}
        self.efficient_config.update(T_threshold_config)
        
        # 3.4 function_config
        self.function_config = {"function": full_config[modules[5]]}
        self.function_config.update(T_threshold_config)
        
        # 3.5 traffic_config
        self.traffic_config = {"traffic": full_config[modules[6]]}
        self.traffic_config.update(T_threshold_config)

    @staticmethod
    def cal_velocity(lat_v, lon_v):
        """Calculate resultant velocity from lateral and longitudinal components."""
        return np.sqrt(lat_v**2 + lon_v**2)

    def _real_process_object_df(self):
        """Process the object DataFrame."""
        try:
            # 读取 CSV 文件
            merged_csv_path = os.path.join(self.data_path, "merged_ObjState.csv")
            # self.object_df = pd.read_csv(
            #     merged_csv_path, dtype={"simTime": float}
            # ).drop_duplicates(subset=["simTime", "simFrame", "playerId"])
            self.object_df = pd.read_csv(
                merged_csv_path,
                dtype={"simTime": float},
                engine="python",
                on_bad_lines="skip",#自动跳过异常行
                na_values=["","NA","null","NaN"]#明确处理缺失值
            ).drop_duplicates(subset=["simTime", "simFrame", "playerId"])
            self.object_df.columns = [col.replace("+AF8-", "_") for col in self.object_df.columns]

            data = self.object_df.copy()

            # Calculate common parameters
            data["lat_v"] = data["speedY"] * 1
            data["lon_v"] = data["speedX"] * 1
            data["v"] = data.apply(
                lambda row: self.cal_velocity(row["lat_v"], row["lon_v"]), axis=1
            )
            data["v"] = data["v"]  # km/h

            # Calculate acceleration components
            data["lat_acc"] = data["accelY"] * 1
            data["lon_acc"] = data["accelX"] * 1
            data["accel"] = data.apply(
                lambda row: self.cal_velocity(row["lat_acc"], row["lon_acc"]), axis=1
            )

            # Drop rows with missing 'type' and reset index
            data = data.dropna(subset=["type"])
            data.reset_index(drop=True, inplace=True)
            self.object_df = data.copy()

            # Calculate respective parameters for each object
            for obj_id, obj_data in data.groupby("playerId"):
                self.obj_data[obj_id] = self._calculate_object_parameters(obj_data)

            # Get object id list
            EGO_PLAYER_ID = 1
            self.obj_id_list = list(self.obj_data.keys())
            self.ego_data = self.obj_data[EGO_PLAYER_ID]
            
            # 添加这一行:处理自车数据,进行坐标系转换
            self.ego_data = self.process_ego_data(self.ego_data)

        except Exception as e:
            # self.logger.error(f"Error processing object DataFrame: {e}")
            raise

    def _calculate_object_parameters(self, obj_data):
        """Calculate additional parameters for a single object."""
        obj_data = obj_data.copy()
        obj_data["time_diff"] = obj_data["simTime"].diff()

        obj_data["lat_acc_diff"] = obj_data["lat_acc"].diff()
        obj_data["lon_acc_diff"] = obj_data["lon_acc"].diff()
        obj_data["yawrate_diff"] = obj_data["speedH"].diff()

        obj_data["lat_acc_roc"] = (
            obj_data["lat_acc_diff"] / obj_data["time_diff"]
        ).replace([np.inf, -np.inf], [9999, -9999])
        obj_data["lon_acc_roc"] = (
            obj_data["lon_acc_diff"] / obj_data["time_diff"]
        ).replace([np.inf, -np.inf], [9999, -9999])
        obj_data["accelH"] = (
            obj_data["yawrate_diff"] / obj_data["time_diff"]
        ).replace([np.inf, -np.inf], [9999, -9999])

        return obj_data

    def _get_driver_ctrl_data(self, df):
        """
        Process and get driver control information.

        Args:
            df: A DataFrame containing driver control data.

        Returns:
            A dictionary of driver control info.
        """
        driver_ctrl_data = {
            "time_list": df["simTime"].round(2).tolist(),
            "frame_list": df["simFrame"].tolist(),
            "brakePedal_list": (
                (df["brakePedal"] * 100).tolist()
                if df["brakePedal"].max() < 1
                else df["brakePedal"].tolist()
            ),
            "throttlePedal_list": (
                (df["throttlePedal"] * 100).tolist()
                if df["throttlePedal"].max() < 1
                else df["throttlePedal"].tolist()
            ),
            "steeringWheel_list": df["steeringWheel"].tolist(),
        }
        return driver_ctrl_data

    def _get_report_info(self, df):
        """Extract report information from the DataFrame."""
        mileage = self._mileage_cal(df)
        duration = self._duration_cal(df)
        return {"mileage": mileage, "duration": duration}

    def _mileage_cal(self, df):
        """Calculate mileage based on the driving data."""
        if df["travelDist"].nunique() == 1:
            df["time_diff"] = df["simTime"].diff().fillna(0)
            df["avg_speed"] = (df["v"] + df["v"].shift()).fillna(0) / 2
            df["distance_increment"] = df["avg_speed"] * df["time_diff"] / 3.6
            df["travelDist"] = df["distance_increment"].cumsum().fillna(0)

            mileage = round(df["travelDist"].iloc[-1] - df["travelDist"].iloc[0], 2)
            return mileage
        return 0.0  # Return 0 if travelDist is not valid

    def _duration_cal(self, df):
        """Calculate duration of the driving data."""
        return df["simTime"].iloc[-1] - df["simTime"].iloc[0]

    def process_ego_data(self, ego_data):
        """处理自车数据,包括坐标系转换等"""
        # 添加坐标系转换:将东北天坐标系下的加速度转换为车辆坐标系下的加速度
        # 使用车辆航向角进行转换
        # 注意:与safety.py保持一致,使用(90 - heading)作为与x轴的夹角
        ego_data['heading_rad'] = np.deg2rad(90 - ego_data['posH'])  # 转换为与x轴的夹角
        
        # 计算车辆坐标系下的纵向和横向加速度
        # 假设原始数据中accelX和accelY是东北天坐标系下的加速度
        ego_data['lon_acc_vehicle'] = ego_data['accelX'] * np.cos(ego_data['heading_rad']) + \
                                     ego_data['accelY'] * np.sin(ego_data['heading_rad'])
        ego_data['lat_acc_vehicle'] = -ego_data['accelX'] * np.sin(ego_data['heading_rad']) + \
                                     ego_data['accelY'] * np.cos(ego_data['heading_rad'])
        
        # 将原始的东北天坐标系加速度保留,但在comfort.py中使用车辆坐标系加速度
        ego_data['lon_acc'] = ego_data['lon_acc_vehicle']
        ego_data['lat_acc'] = ego_data['lat_acc_vehicle']
        
        return ego_data