import json 

from typing import Dict, Any

from modules.lib.log_manager import LogManager

from modules.config import config

class Score:  
    def __init__(self, yaml_config ):
        self.logger = LogManager().get_logger()  # 获取全局日志实例   
        self.calculated_metrics = None
        self.config = yaml_config
        self.module_config = None
        self.module_name = None
        self.t_threshold = None
        self.process_config(self.config)
        self.level_3_merics = self._extract_level_3_metrics(self.module_config) 
        self.result = {}  
    
    def process_config(self, config_dict):
        t_threshold = config_dict.get("T_threshold")
        if t_threshold is None:
            raise ValueError("配置中缺少 T_threshold 键")

        module_keys = [key for key in config_dict if (key != "T_threshold")]
        if len(module_keys) != 1:
            raise ValueError("配置字典应包含且仅包含一个模块配置键")
        
        module_name = module_keys[0]
        module_config = config_dict[module_name]
        # print(f'模块名称:{module_name}')
        # print(f'模块配置:{module_config}')
        # print(f'T_threshold:{t_threshold}')

        # 实际业务逻辑(示例:存储到对象属性)
        self.module_name = module_name
        self.module_config = module_config
        self.t_threshold = t_threshold
        self.logger.info(f'模块名称:{module_name}')
        self.logger.info(f'模块配置:{module_config}')
        self.logger.info(f'T_threshold:{t_threshold}')
    def _extract_level_3_metrics(self, d):
        name = []
        for key, value in d.items():
            if isinstance(value, dict):  # 如果值是字典,继续遍历
                self._extract_level_3_metrics(value)
            elif key == 'name':  # 找到name键时,将值添加到列表
                name.append(value)
        return name
                         
    def is_within_range(self, value, min_val, max_val):  
        return min_val <= value <= max_val  
  
    def evaluate_level_3(self, metrics):  
        result3 = {}  
        name = metrics.get('name')  
        priority = metrics.get('priority')  
        max_val = metrics.get('max')  
        min_val = metrics.get('min')
        
        self.level_3_merics.append(name)  
        metric_value = self.calculated_metrics.get(name)  
        result3[name] = {  
            'result': True,  
            'priority': priority 
        } 
        if metric_value is None:  
            return result3  
  
        if not self.is_within_range(metric_value, min_val, max_val) and priority == 0:  
            result3[name]['result'] = False  
        elif not self.is_within_range(metric_value, min_val, max_val) and priority == 1:  
            result3[name]['priority_1_count'] += 1  
  
        # Count priority 1 failures and override result if more than 3  
       
        priority_1_metrics = [v for v in result3.values() if v['priority'] == 1 and not v['result']]  
        if len([m for m in priority_1_metrics if not m['result']]) > 3:  
            result3[name]['result'] = False
  
        return result3  
  
    def evaluate_level_2(self, metrics):  
        result2 = {}  
        name = metrics.get('name')  
        priority = metrics.get('priority') 
        result2[name] = {}  
  
        for metric, sub_metrics in metrics.items():  
            if metric not in ['name', 'priority']:  
                result2[name].update(self.evaluate_level_3(sub_metrics))  
  
        # Aggregate results for level 2  config.T0 config.T1 config.T2
        priority_0_count = sum(1 for v in result2[name].values() if v['priority'] == 0 and not v['result']) 
        priority_1_count = sum(1 for v in result2[name].values() if v['priority'] == 1 and not v['result']) 
        priority_2_count = sum(1 for v in result2[name].values() if v['priority'] == 2 and not v['result']) 

        if priority_0_count > self.t_threshold['T0_threshold']:  
            result2[name]['result'] = False
            
        elif priority_1_count > self.t_threshold['T1_threshold']:  
            for metric in result2[name].values():  
                metric['result'] = False 
        elif priority_2_count > self.t_threshold['T2_threshold']:  
            for metric in result2[name].values():  
                metric['result'] = False
        else:  
            result2[name]['result'] = True  # Default to True unless overridden  
        result2[name]['priority'] = priority 
        result2[name]['priority_0_count'] = priority_0_count  
        result2[name]['priority_1_count'] = priority_1_count
        result2[name]['priority_2_count'] = priority_2_count  
  
        return result2  
  
    def evaluate_level_1(self): 

        name = self.module_config.get('name')
        priority = self.module_config.get('priority') 
        result1 = {} 
        result1[name] = {}  
        for metric, metrics in self.module_config.items():
            if metric not in ['name', 'priority']:  
                result1[name].update(self.evaluate_level_2(metrics))
                
        # Aggregate results for level 2  config.T0 config.T1 config.T2
        priority_0_count = sum(1 for v in result1[name].values() if v['priority'] == 0 and not v['result']) 
        priority_1_count = sum(1 for v in result1[name].values() if v['priority'] == 1 and not v['result']) 
        priority_2_count = sum(1 for v in result1[name].values() if v['priority'] == 2 and not v['result']) 

        if priority_0_count > self.t_threshold['T0_threshold']:  
            result1[name]['result'] = False
            
        elif priority_1_count > self.t_threshold['T1_threshold']:  
            for metric in result1[name].values():  
                metric['result'] = False 
        elif priority_2_count > self.t_threshold['T2_threshold']:  
            for metric in result1[name].values():  
                metric['result'] = False
        else:  
            result1[name]['result'] = True  # Default to True unless overridden  
        result1[name]['priority'] = priority 
        result1[name]['priority_0_count'] = priority_0_count  
        result1[name]['priority_1_count'] = priority_1_count
        result1[name]['priority_2_count'] = priority_2_count  

        return result1  
  
    def evaluate(self, calculated_metrics):
        self.calculated_metrics = calculated_metrics  
        self.result = self.evaluate_level_1()  
        return self.result 

    def evaluate_single_case(self, case_name, priority, json_dict):

        name = case_name
        result = {} 
        result[name] = {}  
        # print(json_dict)
        # Aggregate results for level 2  config.T0 config.T1 config.T2
        priority_0_count = sum(1 for v in json_dict.values() if v['priority'] == 0 and not v['result']) 
        priority_1_count = sum(1 for v in json_dict.values() if v['priority'] == 1 and not v['result']) 
        priority_2_count = sum(1 for v in json_dict.values() if v['priority'] == 2 and not v['result']) 

        if priority_0_count > config.T0:  
            result[name]['result'] = False
            
        elif priority_1_count > config.T1:  
            for metric in result[name].values():  
                metric['result'] = False 
        elif priority_2_count > config.T2:  
            for metric in result[name].values():  
                metric['result'] = False
        else:  
            result[name]['result'] = True  # Default to True unless overridden  
        result[name]['priority'] = priority 
        result[name]['priority_0_count'] = priority_0_count  
        result[name]['priority_1_count'] = priority_1_count
        result[name]['priority_2_count'] = priority_2_count  
        result[case_name].update(json_dict)
        
        return result  
import yaml
def load_thresholds(config_path: str) -> Dict[str, int]:
    """从YAML配置文件加载阈值参数"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return {
        "T0": config['T_threshold']['T0_threshold'],
        "T1": config['T_threshold']['T1_threshold'],
        "T2": config['T_threshold']['T2_threshold']
    }

def get_overall_result(report: Dict[str, Any], config_path: str) -> Dict[str, Any]:
    """
    处理评测报告并添加总体结果字段
    
    参数:
        report: 原始评测报告字典
        config_path: YAML配置文件路径
        
    返回:
        添加了 overall_result 的处理后报告
    """
    # 加载阈值参数
    thresholds = load_thresholds(config_path)
    
    # 初始化计数器
    counters = {'p0': 0, 'p1': 0, 'p2': 0}
    
    # 目标分类
    target_categories = ['function', 'safety', 'comfort', 'traffic', 'efficient']
    
    # 直接统计每个维度的结果
    for category in target_categories:
        if category in report:
            # 如果该维度的结果为False,根据其priority增加对应计数
            if not report[category].get('result', True):
                priority = report[category].get('priority')
                if priority == 0:
                    counters['p0'] += 1
                elif priority == 1:
                    counters['p1'] += 1
                elif priority == 2:
                    counters['p2'] += 1
    
    # 阈值判断逻辑
    thresholds_exceeded = (
        counters['p0'] > thresholds['T0'],
        counters['p1'] > thresholds['T1'],
        counters['p2'] > thresholds['T2']
    )
    
    # 生成处理后的报告
    processed_report = report.copy()
    processed_report['overall_result'] = not any(thresholds_exceeded)
    
    # 添加统计信息
    processed_report['threshold_checks'] = {
        'T0_threshold': thresholds['T0'],
        'T1_threshold': thresholds['T1'],
        'T2_threshold': thresholds['T2'],
        'actual_counts': counters
    }
    
    return processed_report
  
  
def main():  
    pass
   
  
if __name__ == '__main__':  
    main()