#!/usr/bin/env python3
# evaluator_enhanced.py
import sys
import warnings
import time
import importlib
import importlib.util
import yaml
from pathlib import Path
import argparse
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import Dict, Any, List, Optional, Type, Tuple, Callable, Union
from datetime import datetime
import logging
import traceback
import json
import inspect


# 常量定义
DEFAULT_WORKERS = 4
CUSTOM_METRIC_PREFIX = "metric_"
CUSTOM_METRIC_FILE_PATTERN = "*.py"

# 安全设置根目录路径
if hasattr(sys, "_MEIPASS"):
    _ROOT_PATH = Path(sys._MEIPASS)
else:
    _ROOT_PATH = Path(__file__).resolve().parent.parent

sys.path.insert(0, str(_ROOT_PATH))

class ConfigManager:
    """配置管理组件"""
    
    def __init__(self, logger: logging.Logger):
        self.logger = logger
        self.base_config: Dict[str, Any] = {}
        self.custom_config: Dict[str, Any] = {}
        self.merged_config: Dict[str, Any] = {}
    
    def split_configs(self, all_metrics_path: Path, builtin_metrics_path: Path, custom_metrics_path: Path) -> None:
        """从all_metrics_config.yaml拆分成内置和自定义配置"""
        try:
            with open(all_metrics_path, 'r', encoding='utf-8') as f:
                all_metrics_dict = yaml.safe_load(f) or {}
            with open(builtin_metrics_path, 'r', encoding='utf-8') as f:
                builtin_metrics_dict = yaml.safe_load(f) or {}
            custom_metrics_dict = self._find_custom_metrics(all_metrics_dict, builtin_metrics_dict)
            if custom_metrics_dict:
                with open(custom_metrics_path, 'w', encoding='utf-8') as f:
                    yaml.dump(custom_metrics_dict, f, allow_unicode=True, sort_keys=False, indent=2)
                self.logger.info(f"Split configs: custom metrics saved to {custom_metrics_path}")
        except Exception as err:
            self.logger.error(f"Failed to split configs: {str(err)}")
            raise
    
    def _find_custom_metrics(self, all_metrics, builtin_metrics, current_path=""):
        """递归比较找出自定义指标"""
        custom_metrics = {}
        
        if isinstance(all_metrics, dict) and isinstance(builtin_metrics, dict):
            for key in all_metrics:
                if key not in builtin_metrics:
                    custom_metrics[key] = all_metrics[key]
                else:
                    child_custom = self._find_custom_metrics(
                        all_metrics[key], 
                        builtin_metrics[key],
                        f"{current_path}.{key}" if current_path else key
                    )
                    if child_custom:
                        custom_metrics[key] = child_custom
        elif all_metrics != builtin_metrics:
            return all_metrics
        
        if custom_metrics:
            return self._ensure_structure(custom_metrics, all_metrics, current_path)
        return None
    
    def _ensure_structure(self, metrics_dict, full_dict, path):
        """确保每级包含name和priority"""
        if not isinstance(metrics_dict, dict):
            return metrics_dict
        
        current = full_dict
        for key in path.split('.'):
            if key in current:
                current = current[key]
            else:
                break
        
        result = {}
        if isinstance(current, dict):
            if 'name' in current:
                result['name'] = current['name']
            if 'priority' in current:
                result['priority'] = current['priority']
        
        for key, value in metrics_dict.items():
            if key not in ['name', 'priority']:
                result[key] = self._ensure_structure(value, full_dict, f"{path}.{key}" if path else key)
        
        return result

    def load_configs(self, all_config_path: Optional[Path], builtin_metrics_path: Optional[Path], custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
        """加载并合并配置"""
        # 自动拆分配置
        
        if all_config_path.exists():
            self.split_configs(all_config_path, builtin_metrics_path, custom_metrics_path)
            
        self.base_config = self._safe_load_config(builtin_metrics_path) if builtin_metrics_path else {}
        self.custom_config = self._safe_load_config(custom_metrics_path) if custom_metrics_path else {}
        self.merged_config = self._merge_configs(self.base_config, self.custom_config)
        return self.merged_config
    
    def _safe_load_config(self, config_path: Path) -> Dict[str, Any]:
        """安全加载YAML配置"""
        try:
            if not config_path.exists():
                self.logger.warning(f"Config file not found: {config_path}")
                return {}
            with config_path.open('r', encoding='utf-8') as f:
                config_dict = yaml.safe_load(f) or {}
                self.logger.info(f"Loaded config: {config_path}")
                return config_dict
        except Exception as err:
            self.logger.error(f"Failed to load config {config_path}: {str(err)}")
            return {}
    
    def _merge_configs(self, builtin_config: Dict, custom_config: Dict) -> Dict:
        """智能合并配置"""
        merged_config = builtin_config.copy()
        for level1_key, level1_value in custom_config.items():
            if not isinstance(level1_value, dict) or 'name' not in level1_value:
                if level1_key not in merged_config:
                    merged_config[level1_key] = level1_value
                continue
            if level1_key not in merged_config:
                merged_config[level1_key] = level1_value
            else:
                for level2_key, level2_value in level1_value.items():
                    if level2_key in ['name', 'priority']:
                        continue
                    if isinstance(level2_value, dict):
                        if level2_key not in merged_config[level1_key]:
                            merged_config[level1_key][level2_key] = level2_value
                        else:
                            for level3_key, level3_value in level2_value.items():
                                if level3_key in ['name', 'priority']:
                                    continue
                                if isinstance(level3_value, dict):
                                    if level3_key not in merged_config[level1_key][level2_key]:
                                        merged_config[level1_key][level2_key][level3_key] = level3_value
        return merged_config
    
    def get_config(self) -> Dict[str, Any]:
        return self.merged_config
    
    def get_base_config(self) -> Dict[str, Any]:
        return self.base_config
    
    def get_custom_config(self) -> Dict[str, Any]:
        return self.custom_config

class MetricLoader:
    """指标加载器组件"""
    
    def __init__(self, logger: logging.Logger, config_manager: ConfigManager):
        self.logger = logger
        self.config_manager = config_manager
        self.metric_modules: Dict[str, Type] = {}
        self.custom_metric_modules: Dict[str, Any] = {}
    
    def load_builtin_metrics(self) -> Dict[str, Type]:
        """加载内置指标模块"""
        module_mapping = {
            "safety": ("modules.metric.safety", "SafeManager"),
            "comfort": ("modules.metric.comfort", "ComfortManager"),
            "traffic": ("modules.metric.traffic", "TrafficManager"),
            "efficient": ("modules.metric.efficient", "EfficientManager"),
            "function": ("modules.metric.function", "FunctionManager"),
        }
        
        self.metric_modules = {
            name: self._load_module(*info)
            for name, info in module_mapping.items()
        }
        
        self.logger.info(f"Loaded builtin metrics: {', '.join(self.metric_modules.keys())}")
        return self.metric_modules
    
    @lru_cache(maxsize=32)
    def _load_module(self, module_path: str, class_name: str) -> Type:
        """动态加载Python模块"""
        try:
            module = __import__(module_path, fromlist=[class_name])
            return getattr(module, class_name)
        except (ImportError, AttributeError) as e:
            self.logger.error(f"Failed to load module: {module_path}.{class_name} - {str(e)}")
            raise
    
    def load_custom_metrics(self, custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
        """加载自定义指标模块"""
        if not custom_metrics_path or not custom_metrics_path.is_dir():
            self.logger.info("No custom metrics path or path not exists")
            return {}

        loaded_count = 0
        for py_file in custom_metrics_path.glob(CUSTOM_METRIC_FILE_PATTERN):
            if py_file.name.startswith(CUSTOM_METRIC_PREFIX):
                if self._process_custom_metric_file(py_file):
                    loaded_count += 1
        
        self.logger.info(f"Loaded {loaded_count} custom metric modules")
        return self.custom_metric_modules
    
    def _process_custom_metric_file(self, file_path: Path) -> bool:
        """处理单个自定义指标文件"""
        try:
            metric_key = self._validate_metric_file(file_path)
            
            module_name = f"custom_metric_{file_path.stem}"
            spec = importlib.util.spec_from_file_location(module_name, file_path)
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
            
            from modules.lib.metric_registry import BaseMetric
            metric_class = None
            
            for name, obj in inspect.getmembers(module):
                if (inspect.isclass(obj) and 
                    issubclass(obj, BaseMetric) and 
                    obj != BaseMetric):
                    metric_class = obj
                    break
            
            if metric_class:
                self.custom_metric_modules[metric_key] = {
                    'type': 'class',
                    'module': module,
                    'class': metric_class
                }
                self.logger.info(f"Loaded class-based custom metric: {metric_key}")
            elif hasattr(module, 'evaluate'):
                self.custom_metric_modules[metric_key] = {
                    'type': 'function',
                    'module': module
                }
                self.logger.info(f"Loaded function-based custom metric: {metric_key}")
            else:
                raise AttributeError(f"Missing evaluate() function or BaseMetric subclass: {file_path.name}")
                
            return True
        except ValueError as e:
            self.logger.warning(str(e))
            return False
        except Exception as e:
            self.logger.error(f"Failed to load custom metric {file_path}: {str(e)}")
            return False
    
    def _validate_metric_file(self, file_path: Path) -> str:
        """验证自定义指标文件命名规范"""
        stem = file_path.stem[len(CUSTOM_METRIC_PREFIX):]
        parts = stem.split('_')
        if len(parts) < 3:
            raise ValueError(f"Invalid custom metric filename: {file_path.name}, should be metric_<level1>_<level2>_<level3>.py")

        level1, level2, level3 = parts[:3]
        if not self._is_metric_configured(level1, level2, level3):
            raise ValueError(f"Unconfigured metric: {level1}.{level2}.{level3}")
        return f"{level1}.{level2}.{level3}"
    
    def _is_metric_configured(self, level1: str, level2: str, level3: str) -> bool:
        """检查指标是否在配置中注册"""
        custom_config = self.config_manager.get_custom_config()
        try:
            return (level1 in custom_config and 
                    isinstance(custom_config[level1], dict) and
                    level2 in custom_config[level1] and
                    isinstance(custom_config[level1][level2], dict) and
                    level3 in custom_config[level1][level2] and
                    isinstance(custom_config[level1][level2][level3], dict))
        except Exception:
            return False
    
    def get_builtin_metrics(self) -> Dict[str, Type]:
        return self.metric_modules
    
    def get_custom_metrics(self) -> Dict[str, Any]:
        return self.custom_metric_modules

class EvaluationEngine:
    """评估引擎组件"""
    
    def __init__(self, logger: logging.Logger, config_manager: ConfigManager, metric_loader: MetricLoader):
        self.logger = logger
        self.config_manager = config_manager
        self.metric_loader = metric_loader
    
    def evaluate(self, data: Any) -> Dict[str, Any]:
        """执行评估流程"""
        raw_results = self._collect_builtin_metrics(data)
        custom_results = self._collect_custom_metrics(data)
        return self._process_merged_results(raw_results, custom_results)
    
    def _collect_builtin_metrics(self, data: Any) -> Dict[str, Any]:
        """收集内置指标结果"""
        metric_modules = self.metric_loader.get_builtin_metrics()
        raw_results: Dict[str, Any] = {}
        
        with ThreadPoolExecutor(max_workers=len(metric_modules)) as executor:
            futures = {
                executor.submit(self._run_module, module, data, module_name): module_name
                for module_name, module in metric_modules.items()
            }

            for future in futures:
                module_name = futures[future]
                try:
                    result = future.result()
                    raw_results[module_name] = result[module_name]
                except Exception as e:
                    self.logger.error(
                        f"{module_name} evaluation failed: {str(e)}",
                        exc_info=True,
                    )
                    raw_results[module_name] = {
                        "status": "error",
                        "message": str(e),
                        "timestamp": datetime.now().isoformat(),
                    }
        
        return raw_results
    
    def _collect_custom_metrics(self, data: Any) -> Dict[str, Dict]:
        """收集自定义指标结果"""
        custom_metrics = self.metric_loader.get_custom_metrics()
        if not custom_metrics:
            return {}
            
        custom_results = {}
        
        for metric_key, metric_info in custom_metrics.items():
            try:
                level1, level2, level3 = metric_key.split('.')
                
                if metric_info['type'] == 'class':
                    metric_class = metric_info['class']
                    metric_instance = metric_class(data)
                    metric_result = metric_instance.calculate()
                else:
                    module = metric_info['module']
                    metric_result = module.evaluate(data)
                
                if level1 not in custom_results:
                    custom_results[level1] = {}
                custom_results[level1] = metric_result
                
                self.logger.info(f"Calculated custom metric: {level1}.{level2}.{level3}")
                
            except Exception as e:
                self.logger.error(f"Custom metric {metric_key} failed: {str(e)}")
                
                try:
                    level1, level2, level3 = metric_key.split('.')
                    
                    if level1 not in custom_results:
                        custom_results[level1] = {}
                        
                    custom_results[level1] = {
                        "status": "error",
                        "message": str(e),
                        "timestamp": datetime.now().isoformat(),
                    }
                except Exception:
                    pass
        
        return custom_results
    
    def _process_merged_results(self, raw_results: Dict, custom_results: Dict) -> Dict:
        """处理合并后的评估结果"""
        from modules.lib.score import Score
        final_results = {}
        merged_config = self.config_manager.get_config()

        for level1, level1_data in raw_results.items():
            if level1 in custom_results:
                level1_data.update(custom_results[level1])

            try:
                evaluator = Score(merged_config, level1)
                final_results.update(evaluator.evaluate(level1_data))
            except Exception as e:
                final_results[level1] = self._format_error(e)

        for level1, level1_data in custom_results.items():
            if level1 not in raw_results:
                try:
                    evaluator = Score(merged_config, level1)
                    final_results.update(evaluator.evaluate(level1_data))
                except Exception as e:
                    final_results[level1] = self._format_error(e)

        return final_results
        
    def _format_error(self, e: Exception) -> Dict:
        return {
            "status": "error",
            "message": str(e),
            "timestamp": datetime.now().isoformat()
        }
                
    def _run_module(self, module_class: Any, data: Any, module_name: str) -> Dict[str, Any]:
        """执行单个评估模块"""
        try:
            instance = module_class(data)
            return {module_name: instance.report_statistic()}
        except Exception as e:
            self.logger.error(f"{module_name} execution error: {str(e)}", exc_info=True)
            return {module_name: {"error": str(e)}}

class LoggingManager:
    """日志管理组件"""
    
    def __init__(self, log_path: Path):
        self.log_path = log_path
        self.logger = self._init_logger()
    
    def _init_logger(self) -> logging.Logger:
        """初始化日志系统"""
        try:
            from modules.lib.log_manager import LogManager
            log_manager = LogManager(self.log_path)
            return log_manager.get_logger()
        except (ImportError, PermissionError, IOError) as e:
            logger = logging.getLogger("evaluator")
            logger.setLevel(logging.INFO)
            console_handler = logging.StreamHandler()
            console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
            logger.addHandler(console_handler)
            logger.warning(f"Failed to init standard logger: {str(e)}, using fallback logger")
            return logger
    
    def get_logger(self) -> logging.Logger:
        return self.logger

class DataProcessor:
    """数据处理组件"""
    
    def __init__(self, logger: logging.Logger, data_path: Path, config_path: Optional[Path] = None):
        self.logger = logger
        self.data_path = data_path
        self.config_path = config_path
        self.processor = self._load_processor()
        self.case_name = self.data_path.name
    
    def _load_processor(self) -> Any:
        """加载数据处理器"""
        try:
            from modules.lib import data_process
            return data_process.DataPreprocessing(self.data_path, self.config_path)
        except ImportError as e:
            self.logger.error(f"Failed to load data processor: {str(e)}")
            raise RuntimeError(f"Failed to load data processor: {str(e)}") from e
    
    def validate(self) -> None:
        """验证数据路径"""
        if not self.data_path.exists():
            raise FileNotFoundError(f"Data path not exists: {self.data_path}")
        if not self.data_path.is_dir():
            raise NotADirectoryError(f"Invalid data directory: {self.data_path}")

class EvaluationPipeline:
    """评估流水线控制器"""
    
    def __init__(self, all_config_path: str, base_config_path: str, log_path: str, data_path: str, report_path: str, 
                 custom_metrics_path: Optional[str] = None, custom_config_path: Optional[str] = None):
        # 路径初始化
        self.all_config_path = Path(all_config_path) if all_config_path else None
        self.base_config_path = Path(base_config_path) if base_config_path else None
        self.custom_config_path = Path(custom_config_path) if custom_config_path else None
        self.data_path = Path(data_path)
        self.report_path = Path(report_path)
        self.custom_metrics_path = Path(custom_metrics_path) if custom_metrics_path else None
        
        # 日志
        self.logging_manager = LoggingManager(Path(log_path))
        self.logger = self.logging_manager.get_logger()
        # 配置
        self.config_manager = ConfigManager(self.logger)
        self.config = self.config_manager.load_configs(
            self.all_config_path, self.base_config_path, self.custom_config_path
        )
        # 指标加载
        self.metric_loader = MetricLoader(self.logger, self.config_manager)
        self.metric_loader.load_builtin_metrics()
        self.metric_loader.load_custom_metrics(self.custom_metrics_path)
        # 数据处理
        self.data_processor = DataProcessor(self.logger, self.data_path, self.all_config_path)
        self.evaluation_engine = EvaluationEngine(self.logger, self.config_manager, self.metric_loader)
    
    def execute(self) -> Dict[str, Any]:
        """执行评估流水线"""
        try:
            self.data_processor.validate()
            
            self.logger.info(f"Start evaluation: {self.data_path.name}")
            start_time = time.perf_counter()
            results = self.evaluation_engine.evaluate(self.data_processor.processor)
            elapsed_time = time.perf_counter() - start_time
            self.logger.info(f"Evaluation completed, time: {elapsed_time:.2f}s")
            
            report = self._generate_report(self.data_processor.case_name, results)
            return report
            
        except Exception as e:
            self.logger.critical(f"Evaluation failed: {str(e)}", exc_info=True)
            return {"error": str(e), "traceback": traceback.format_exc()}
    
    def _add_overall_result(self, report: Dict[str, Any]) -> Dict[str, Any]:
        """处理评测报告并添加总体结果字段"""
        # 加载阈值参数
        thresholds = {
            "T0": self.config['T_threshold']['T0_threshold'],
            "T1": self.config['T_threshold']['T1_threshold'],
            "T2": self.config['T_threshold']['T2_threshold']
        }
        
        # 初始化计数器
        counters = {'p0': 0, 'p1': 0, 'p2': 0}
        
        # 遍历报告中的所有键,包括内置和自定义一级指标
        for category, category_data in report.items():
            # 跳过非指标键(如metadata等)
            if not isinstance(category_data, dict) or category == "metadata":
                continue
                
            # 如果该维度的结果为False,根据其priority增加对应计数
            if not category_data.get('result', True):
                priority = category_data.get('priority')
                if priority == 0:
                    counters['p0'] += 1
                elif priority == 1:
                    counters['p1'] += 1
                elif priority == 2:
                    counters['p2'] += 1
        
        # 阈值判断逻辑
        thresholds_exceeded = (
            counters['p0'] > thresholds['T0'],
            counters['p1'] > thresholds['T1'],
            counters['p2'] > thresholds['T2']
        )
        
        # 生成处理后的报告
        processed_report = report.copy()
        processed_report['overall_result'] = not any(thresholds_exceeded)
        
        # 添加统计信息
        processed_report['threshold_checks'] = {
            'T0_threshold': thresholds['T0'],
            'T1_threshold': thresholds['T1'],
            'T2_threshold': thresholds['T2'],
            'actual_counts': counters
        }
        
        self.logger.info(f"Added overall result: {processed_report['overall_result']}")
        return processed_report
        
    def _generate_report(self, case_name: str, results: Dict[str, Any]) -> Dict[str, Any]:
        """生成评估报告"""
        from modules.lib.common import dict2json
        
        self.report_path.mkdir(parents=True, exist_ok=True)
        
        results["metadata"] = {
            "case_name": case_name,
            "timestamp": datetime.now().isoformat(),
            "version": "3.1.0",
        }
        
        # 添加总体结果评估
        results = self._add_overall_result(results)
        
        report_file = self.report_path / f"{case_name}_report.json"
        dict2json(results, report_file)
        self.logger.info(f"Report generated: {report_file}")
        
        return results

def main():
    """命令行入口"""
    parser = argparse.ArgumentParser(
        description="Autonomous Driving Evaluation System V3.1",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    
    parser.add_argument(
        "--logPath",
        type=str,
        default="test.log",
        help="Log file path",
    )
    parser.add_argument(
        "--dataPath",
        type=str,
        default=r"D:\Cicv\招远\AD_GBT41798-2022_TrafficSignalRecognitionAndResponse_LST_01",
        help="Input data directory",
    )
    
    parser.add_argument(
        "--allConfigPath",
        type=str,
        default=r"D:\Cicv\招远\zhaoyuan\config\all_metrics_config.yaml",
        help="Full metrics config file path (built-in + custom)",
    )
    
    parser.add_argument(
        "--baseConfigPath",
        type=str,
        default=r"D:\Cicv\招远\zhaoyuan\config\builtin_metrics_config.yaml",
        help="Built-in metrics config file path",
    )
    parser.add_argument(
        "--reportPath",
        type=str,
        default="reports",
        help="Output report directory",
    )
    parser.add_argument(
        "--customMetricsPath",
        type=str,
        default="custom_metrics",
        help="Custom metrics scripts directory (optional)",
    )
    parser.add_argument(
        "--customConfigPath",
        type=str,
        default=r"D:\Cicv\招远\zhaoyuan\config\custom_metrics_config.yaml",
        help="Custom metrics config path (optional)",
    )
    
    args = parser.parse_args()

    try:
        pipeline = EvaluationPipeline(
            all_config_path=args.allConfigPath,
            base_config_path=args.baseConfigPath,
            log_path=args.logPath, 
            data_path=args.dataPath, 
            report_path=args.reportPath, 
            custom_metrics_path=args.customMetricsPath, 
            custom_config_path=args.customConfigPath
        )
        
        start_time = time.perf_counter()
        result = pipeline.execute()
        elapsed_time = time.perf_counter() - start_time

        if "error" in result:
            print(f"Evaluation failed: {result['error']}")
            sys.exit(1)

        print(f"Evaluation completed, total time: {elapsed_time:.2f}s")
        print(f"Report path: {pipeline.report_path}")
        
    except KeyboardInterrupt:
        print("\nUser interrupted")
        sys.exit(130)
    except Exception as e:
        print(f"Execution error: {str(e)}")
        traceback.print_exc()
        sys.exit(1)

if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    main()