#!/usr/bin/env python
# -*- coding: utf-8 -*-
##################################################################
#
# Copyright (c) 2023 CICV, Inc. All Rights Reserved
#
##################################################################
"""
@Authors:           zhanghaiwen(zhanghaiwen@china-icv.cn), yangzihao(yangzihao@china-icv.cn)
@Data:              2024/01/12
@Last Modified:     2024/01/12
@Summary:           Weight cal
"""

import sys

sys.path.append('../common')
sys.path.append('modules')
sys.path.append('score')

import numpy as np
from common import get_interpolation
from functools import reduce


def cal_weight_from_80_old(score_list):
    # weight process
    s_list = [80.1 if x == 80 else x for x in score_list]
    weight_list = abs((np.array(s_list) - 80) / 100)

    # normalization
    weight_list = weight_list / sum(weight_list)
    return weight_list


def cal_weight_from_80(score_list):
    # weight process
    s_list = [80.1 if x == 80 else x for x in score_list]
    weight_list = abs((np.array(s_list) - 80) / 100)

    # normalization
    weight_list = weight_list / sum(weight_list)

    # Round to 5 decimal places, but keep a copy of the original normalized values
    rounded_list = np.round(weight_list, 4).tolist()

    # Calculate the difference between the rounded sum and 1
    diff = 1 - sum(rounded_list)

    # If the difference is not zero, adjust the last element to make the sum exactly 1
    if diff != 0:
        rounded_list[-1] += diff
        rounded_list[-1] = round(rounded_list[-1], 4)

    return rounded_list


def cal_score_with_priority(score_list, weight_list, priority_list):
    """

    """
    rho = 1
    flag = any(i < 80 for i in score_list)
    if flag:
        rho = 0.9
        for i in range(len(score_list)):
            if score_list[i] < 80 and priority_list[i] == 0:
                rho = 0.8

    # calculate score
    score_all = np.dot(weight_list, score_list) * rho
    return score_all


def cal_score_from_80(score_list):
    """

    """
    # weight process
    weight_list = cal_weight_from_80(score_list)
    # calculate score
    score_all = np.dot(weight_list, score_list)
    # score_all = score_all * 0.8 if flag else score_all
    return round(score_all, 2)


class ScoreModel(object):
    """
    信息量越大,权重越大
    对比强度和冲突性越大,信息量越大
    标准差越大,对比强度越大
    相关性越小,冲突性越大
    ————————————————————————
    单列标准差大,列之间相关性小 -> 则权重大
    """

    def __init__(self, kind_list, optimal_value, multiple_list, arr):
        # n for cases
        # m for indicators
        self.n, self.m = arr.shape
        self.kind = kind_list
        self.optimal = optimal_value
        self.X = arr
        self.rho = 1 / 3  # 一般选0.5,最低分计算公式: rho/(1+rho)
        self.multiple = np.array(multiple_list)

    def calculate_score(self):
        """
        灰色关联理论
        :return:
        """
        # m个指标,n个场景
        val_mean = []
        optimal_value = self.optimal

        for i in range(self.m):
            opt_val = optimal_value[i]
            val_mean_i = (sum(self.X[:, i]) + opt_val) / (self.n + 1)  # Eq(15)
            val_mean.append(val_mean_i)

        self.X = self.X / np.array(val_mean)  # 无量纲化
        optimal_value = np.array(optimal_value) / np.array(val_mean)  # 最优值无量纲化
        abs_X = abs(optimal_value - self.X)
        minn = 0
        maxx = 2 * (self.multiple[0][1] - 1) / (self.multiple[0][1] + 1)  # 五倍时参数为1.333333,三倍时参数为1
        eta = (minn + self.rho * maxx) / (abs_X + self.rho * maxx)  # Eq(16)
        Eta = [x * 80 for x in list(np.mean(eta, axis=0))]
        return Eta

    def cal_score(self):
        """
        数据处理前进行特判,先将无需打分的数据直接给分
        例如,大于基准值五倍的值,直接给出100分、0分
        单列均为同一个值时,符合预期值则100分,否则0分

        先完成单用例特判,再考虑多用例特判
        """
        # 单用例版本
        # for j in range(self.n):
        # multiple = 5

        inteval_20_coefficient = 3

        flag_list = [-1] * self.m
        column_list = []
        for i in range(self.m):
            optimal = self.optimal[i]
            multiple = self.multiple[i]
            if self.kind[i] == 1:  # 极大型
                if np.all(self.X[:, i] >= optimal * multiple[1]):  # 补充线性插值
                    flag_list[i] = 100
                elif np.all(self.X[:, i] >= optimal):
                    flag_list[i] = float(get_interpolation(self.X[:, i], [optimal, 80], [optimal * multiple[1], 100]))
                elif self.X[:, i] <= optimal * multiple[0]:
                    flag_list[i] = 0
                else:
                    column_list.append(i)
            elif self.kind[i] == -1:  # 极小型
                if np.all(self.X[:, i] <= optimal * multiple[0]):
                    flag_list[i] = 100
                elif np.all(self.X[:, i] <= optimal):
                    flag_list[i] = float(get_interpolation(self.X[:, i], [optimal, 80], [optimal * multiple[0], 100]))
                elif self.X[:, i] >= optimal * multiple[1]:
                    flag_list[i] = 0
                else:
                    column_list.append(i)
            elif self.kind[i] == 0:  # 区间型
                if np.all(optimal * multiple[0] <= self.X[:, i] <= optimal):
                    flag_list[i] = float(
                        get_interpolation(optimal - self.X[:, i], [abs(optimal - optimal * multiple[0]), 80], [0, 100]))
                elif np.all(optimal <= self.X[:, i] <= optimal * multiple[1]):
                    flag_list[i] = float(
                        get_interpolation(self.X[:, i] - optimal, [abs(optimal * multiple[1] - optimal), 80], [0, 100]))
                elif np.all(self.X[:, i] < optimal * multiple[0]):
                    dist = optimal * multiple[0] - self.X[:, i]
                    interval_dist = (optimal - optimal * multiple[0]) / inteval_20_coefficient
                    if dist < interval_dist:
                        flag_list[i] = float(get_interpolation(dist, [interval_dist, 20], [0, 80]))
                    else:
                        flag_list[i] = 0
                elif np.all(optimal * multiple[1] < self.X[:, i]):
                    dist = self.X[:, i] - optimal * multiple[1]
                    interval_dist = (optimal * multiple[1] - optimal) / inteval_20_coefficient
                    if dist < interval_dist:
                        flag_list[i] = float(get_interpolation(dist, [interval_dist, 20], [0, 80]))
                    else:
                        flag_list[i] = 0
                else:
                    column_list.append(i)

        arr_temp = self.X[:, column_list]
        kind_temp = [self.kind[i] for i in range(len(flag_list)) if flag_list[i] == -1]
        optimal_temp = [self.optimal[i] for i in range(len(flag_list)) if flag_list[i] == -1]
        multiple_temp = [self.multiple[i] for i in range(len(flag_list)) if flag_list[i] == -1]

        # n_temp = len(arr_temp)
        m_temp = len(arr_temp[0])

        critic_m = ScoreModel(kind_temp, optimal_temp, multiple_temp, arr_temp)

        if -1 not in flag_list:  # 全为特殊值
            score = sum(flag_list) / len(flag_list)
        elif all(x == -1 for x in flag_list):  # 无特殊值
            score_temp = critic_m.calculate_score()
            # score = sum(score_temp) / len(score_temp)
            flag_list = score_temp
        else:  # 部分为特殊值
            score_temp = critic_m.calculate_score()
            # score_temp_mean = sum(score_temp) / len(score_temp)
            # w_temp = m_temp / self.m
            # score = 100 * (1 - w_temp) + score_temp_mean * w_temp
            index = 0
            for i, flag in enumerate(flag_list):
                if flag == -1:
                    flag_list[i] = score_temp[index]
                    index += 1

        score_temp = flag_list

        return score_temp


class AHP:
    def __init__(self, matrix):
        self.A = np.array(matrix)
        self.n = len(matrix)

    def _get_consistency_ratio(self, w_max):
        RI = [0, 0, 0.0001, 0.52, 0.89, 1.12, 1.26, 1.36,
              1.41, 1.46, 1.49, 1.52, 1.54, 1.56, 1.58, 1.59,
              1.5943, 1.6064, 1.6133, 1.6207, 1.6292]
        CI = (w_max - self.n) / (self.n - 1)
        CR = CI / RI[self.n]
        return CR

    def get_weights(self, method='eigenvalue'):
        # Check consistency of pairwise comparison matrix
        w, v = np.linalg.eig(self.A)
        w_index = np.argmax(w)
        w_max = np.real(w[w_index])
        cr = self._get_consistency_ratio(w_max)

        if cr > 0.1:
            raise ValueError('The pairwise comparison matrix is inconsistent.')

        # Normalize matrix
        line_sum = [sum(m) for m in zip(*self.A)]
        D = np.zeros((self.n, self.n))
        for i in range(self.n):
            for j in range(self.n):
                D[i][j] = self.A[i][j] / line_sum[j]

        # Calculate weights with selected method
        if method == 'arithmetic':
            weights = np.zeros(self.n)
            for i in range(self.n):
                weights[i] = np.average(D[i])
        elif method == 'geometric':
            weights = np.zeros(self.n)
            for i in range(self.n):
                weights[i] = reduce(lambda x, y: x * y, self.A[i])
                weights[i] = pow(weights[i], 1 / self.n)
            weights = [e / np.sum(weights) for e in weights]
        elif method == 'eigenvalue':
            weights = np.zeros(self.n)
            v_index = np.argmax(v)
            v_max = np.real(v[:, v_index])
            weights = [e / np.sum(v_max) for e in v_max]

        return weights


if __name__ == "__main__":
    kind_list = [-1]
    optimal_value = [6]
    multiple_list = [[0.5, 2]]  # [3, 12]

    arr = [[1.999]]
    # arr = [[2.1]]
    # arr = [[2.999]]
    # arr = [[3]]
    # arr = [[4]]
    # arr = [[6]]
    # arr = [[11]]
    # arr = [[11.999]]
    # arr = [[12]]
    # arr = [[12.1]]

    cc = ScoreModel(kind_list, optimal_value, multiple_list, np.array(arr))
    res = cc.cal_score()
    print(res)