#!/usr/bin/env python
# -*- coding: utf-8 -*-
##################################################################
#
# Copyright (c) 2023 CICV, Inc. All Rights Reserved
#
##################################################################
"""
@Authors:           yangzihao(yangzihao@china-icv.cn)
@Data:              2024/02/21
@Last Modified:     2024/02/21
@Summary:           The template of custom indicator.
"""

import math
import pandas as pd
import numpy as np
from common import zip_time_pairs, continuous_group
from log import logger

"""import functions"""


# def zip_time_pairs(time_list, zip_list):
#     zip_time_pairs = zip(time_list, zip_list)
#     zip_vs_time = [[x, y] for x, y in zip_time_pairs if not math.isnan(y)]
#     return zip_vs_time

# def continuous_group(df):
#     time_list = df['simTime'].values.tolist()
#     frame_list = df['simFrame'].values.tolist()
#
#     group_time = []
#     group_frame = []
#     sub_group_time = []
#     sub_group_frame = []
#
#     for i in range(len(frame_list)):
#         if not sub_group_time or frame_list[i] - frame_list[i - 1] <= 1:
#             sub_group_time.append(time_list[i])
#             sub_group_frame.append(frame_list[i])
#         else:
#             group_time.append(sub_group_time)
#             group_frame.append(sub_group_frame)
#             sub_group_time = [time_list[i]]
#             sub_group_frame = [frame_list[i]]
#
#     group_time.append(sub_group_time)
#     group_frame.append(sub_group_frame)
#     group_time = [g for g in group_time if len(g) >= 2]
#     group_frame = [g for g in group_frame if len(g) >= 2]
#
#     # 输出图表值
#     time = [[g[0], g[-1]] for g in group_time]
#     frame = [[g[0], g[-1]] for g in group_frame]
#
#     time_df = pd.DataFrame(time, columns=['start_time', 'end_time'])
#     frame_df = pd.DataFrame(frame, columns=['start_frame', 'end_frame'])
#
#     result_df = pd.concat([time_df, frame_df], axis=1)
#
#     return result_df


# def continous_judge(frame_list):
#     if not frame_list:
#         return 0
#
#     cnt = 1
#     for i in range(1, len(frame_list)):
#         if frame_list[i] - frame_list[i - 1] <= 3:
#             continue
#         cnt += 1
#     return cnt


# custom metric codes
class CustomMetric(object):
    def __init__(self, all_data, case_name):
        self.data = all_data
        self.optimal_dict = self.data.config
        self.case_name = case_name
        self.markline_df = pd.DataFrame(columns=['start_time', 'end_time', 'start_frame', 'end_frame', 'type'])

        self.df = pd.DataFrame()
        self.df_follow = pd.DataFrame()

        self.time_list_follow = list()
        self.frame_list_follow = list()
        self.dist_list = list()
        self.dist_deviation_list = list()
        self.dist_deviation_list_full_time = list()

        self.result = {
            "name": "跟车距离偏差",
            "value": [],
            # "weight": [],
            "tableData": {
                "avg": "",  # 平均值,或指标值
                "max": "",
                "min": ""
            },
            "reportData": {
                "name": "跟车距离偏差(s)",
                # "legend": [], # 如果有多个data,则需要增加data对应的说明,如:["横向加速度", "纵向加速度"]
                "data": [],
                "markLine": [],
                "range": [],
            },
            "statusFlag": {}
        }
        self.run()

    def data_extract(self):
        self.df = self.data.object_df
        self.df_follow = self.df[self.df['ACC_status'] == "Shut_off"].copy()  # 数字3对应ICA的Active
        # self.df_follow = self.df[self.df['ACC_status'] == "Active"].copy()  # 数字3对应ICA的Active

        if self.df_follow.empty:
            self.result['statusFlag']['functionICA'] = False
        else:
            self.result['statusFlag']['functionICA'] = True

    def dist(self, x1, y1, x2, y2):
        dis = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
        return dis

    def data_analyze(self):
        df = self.df_follow.copy()
        col_list = ['simTime', 'simFrame', 'playerId', 'v', 'posX', 'posY']  # target_id
        df = df[col_list].copy()

        ego_df = df[df['playerId'] == 1][['simTime', 'simFrame', 'v', 'posX', 'posY']]

        # 筛选目标车(同一车道内,距离最近的前车)
        # obj_df = df[df['playerId'] == df['target_id']]
        target_id = 2
        obj_df = df[df['playerId'] == target_id][['simTime', 'simFrame', 'v', 'posX', 'posY']]  # 目标车
        obj_df = obj_df.rename(columns={'v': 'v_obj', 'posX': 'posX_obj', 'posY': 'posY_obj'})

        df_merge = pd.merge(ego_df, obj_df, on=['simTime', 'simFrame'], how='left')

        df_merge['dist'] = df_merge.apply(
            lambda row: self.dist(row['posX'], row['posY'], row['posX_obj'], row['posY_obj']), axis=1)
        self.dist_list = df_merge['dist'].values.tolist()

        df_merge['time_gap'] = df_merge['dist'] / df_merge['v']
        safe_time_gap = 3
        df_merge['dist_deviation'] = df_merge['time_gap'].apply(
            lambda x: 0 if (x >= safe_time_gap) else (safe_time_gap - x))

        df_merge.replace([np.inf, -np.inf], np.nan, inplace=True)  # 异常值处理

        self.time_list_follow = df_merge['simTime'].values.tolist()
        self.frame_list_follow = df_merge['simFrame'].values.tolist()
        self.dist_deviation_list = df_merge['dist_deviation'].values.tolist()

        tmp_df = ego_df[['simTime', 'simFrame']].copy()
        dist_deviation_df = df_merge[['simTime', 'dist_deviation']].copy()
        df_merged1 = pd.merge(tmp_df, dist_deviation_df, on='simTime', how='left')
        self.dist_deviation_list_full_time = df_merged1['dist_deviation'].values.tolist()

        distance_deviation = df_merge['dist_deviation'].max()
        self.result['value'] = [round(distance_deviation, 2)] if not np.isnan(distance_deviation) else [0]

    def markline_statistic(self):
        unfunc_df = pd.DataFrame({'simTime': self.time_list_follow, 'simFrame': self.frame_list_follow,
                                  'dist_deviation': self.dist_deviation_list})
        unfunc_df = unfunc_df[unfunc_df['simFrame'] > 1]

        v_df = unfunc_df[unfunc_df['dist_deviation'] > 10]
        v_df = v_df[['simTime', 'simFrame', 'dist_deviation']]
        v_follow_df = continuous_group(v_df)
        v_follow_df['type'] = "ICA"
        self.markline_df = pd.concat([self.markline_df, v_follow_df], ignore_index=True)

    def report_data_statistic(self):
        time_list = self.df['simTime'].values.tolist()
        graph_list = [x for x in self.dist_deviation_list if not np.isnan(x)]
        self.result['tableData']['avg'] = f'{np.mean(graph_list):.2f}' if graph_list else '-'
        self.result['tableData']['max'] = f'{max(graph_list):.2f}' if graph_list else '-'
        self.result['tableData']['min'] = f'{min(graph_list):.2f}' if graph_list else '-'

        zip_vs_time = zip_time_pairs(time_list, self.dist_deviation_list_full_time)
        self.result['reportData']['data'] = zip_vs_time

        self.markline_statistic()
        markline_slices = self.markline_df.to_dict('records')
        self.result['reportData']['markLine'] = markline_slices

        self.result['reportData']['range'] = f"[0, 10]"

    def run(self):
        # logger.info(f"Custom metric run:[{self.result['name']}].")
        logger.info(f"[case:{self.case_name}] Custom metric:[ica_distance_deviation:{self.result['name']}] evaluate.")

        try:
            self.data_extract()
        except Exception as e:
            logger.error(f"[case:{self.case_name}] Custom metric:{self.result['name']} data extract ERROR!", e)

        try:
            self.data_analyze()
        except Exception as e:
            logger.error(f"[case:{self.case_name}] Custom metric:{self.result['name']} data analyze ERROR!", e)

        try:
            self.report_data_statistic()
        except Exception as e:
            logger.error(f"[case:{self.case_name}] Custom metric:{self.result['name']} report data statistic ERROR!", e)

# if __name__ == "__main__":
#     pass