# evaluation_engine.py
import sys
import warnings
import time
import importlib
import yaml  # 添加yaml模块导入
from pathlib import Path
import argparse
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import Dict, Any, List, Optional
from datetime import datetime

# 强制导入所有可能动态加载的模块



# 安全设置根目录路径(动态路径管理)
# 判断是否处于编译模式
if hasattr(sys, "_MEIPASS"):
    # 编译模式下使用临时资源目录
    _ROOT_PATH = Path(sys._MEIPASS)
else:
    # 开发模式下使用原工程路径
    _ROOT_PATH = Path(__file__).resolve().parent.parent

sys.path.insert(0, str(_ROOT_PATH))
print(f"当前根目录:{_ROOT_PATH}")
print(f'当前系统路径:{sys.path}')


class EvaluationCore:
    """评估引擎核心类(单例模式)"""

    _instance = None

    def __new__(cls, logPath: str, configPath: str = None, customConfigPath: str = None, customMetricsPath: str = None):
        if not cls._instance:
            cls._instance = super().__new__(cls)
            cls._instance._init(logPath, configPath, customConfigPath, customMetricsPath)
        return cls._instance

    def _init(self, logPath: str = None, configPath: str = None, customConfigPath: str = None, customMetricsPath: str = None) -> None:
        """初始化引擎组件"""
        self.log_path = logPath
        self.config_path = configPath
        self.custom_config_path = customConfigPath
        self.custom_metrics_path = customMetricsPath
        
        # 加载配置
        self.metrics_config = {}
        self.custom_metrics_config = {}
        self.merged_config = {}  # 添加合并后的配置
        
        # 自定义指标脚本模块
        self.custom_metrics_modules = {}
        
        self._init_log_system()
        self._load_configs()  # 加载并合并配置
        self._init_metrics()
        self._load_custom_metrics()

    def _init_log_system(self) -> None:
        """集中式日志管理"""
        try:
            from modules.lib.log_manager import LogManager

            log_manager = LogManager(self.log_path)
            self.logger = log_manager.get_logger()
        except (PermissionError, IOError) as e:
            sys.stderr.write(f"日志系统初始化失败: {str(e)}\n")
            sys.exit(1)

    def _init_metrics(self) -> None:
        """初始化评估模块(策略模式)"""
        # from modules.metric import safety, comfort, traffic, efficient, function
        self.metric_modules = {
            "safety": self._load_module("modules.metric.safety", "SafeManager"),
            "comfort": self._load_module("modules.metric.comfort", "ComfortManager"),
            "traffic": self._load_module("modules.metric.traffic", "TrafficManager"),
            "efficient": self._load_module("modules.metric.efficient", "EfficientManager"),
            "function": self._load_module("modules.metric.function", "FunctionManager"),
        }

    @lru_cache(maxsize=32)
    def _load_module(self, module_path: str, class_name: str) -> Any:
        """动态加载评估模块(带缓存)"""
        try:
            __import__(module_path)
            return getattr(sys.modules[module_path], class_name)
        except (ImportError, AttributeError) as e:
            self.logger.error(f"模块加载失败: {module_path}.{class_name} - {str(e)}")
            raise

    def _load_configs(self) -> None:
        """加载并合并内置指标和自定义指标配置"""
        # 加载内置指标配置
        if self.config_path and Path(self.config_path).exists():
            try:
                with open(self.config_path, 'r', encoding='utf-8') as f:
                    self.metrics_config = yaml.safe_load(f)
                self.logger.info(f"成功加载内置指标配置: {self.config_path}")
            except Exception as e:
                self.logger.error(f"加载内置指标配置失败: {str(e)}")
                self.metrics_config = {}
        
        # 加载自定义指标配置
        if self.custom_config_path and Path(self.custom_config_path).exists():
            try:
                with open(self.custom_config_path, 'r', encoding='utf-8') as f:
                    self.custom_metrics_config = yaml.safe_load(f)
                self.logger.info(f"成功加载自定义指标配置: {self.custom_config_path}")
            except Exception as e:
                self.logger.error(f"加载自定义指标配置失败: {str(e)}")
                self.custom_metrics_config = {}
        
        # 合并配置
        self.merged_config = self._merge_configs(self.metrics_config, self.custom_metrics_config)

    def _merge_configs(self, base_config: Dict, custom_config: Dict) -> Dict:
        """
        合并内置指标和自定义指标配置
        
        策略:
        1. 如果自定义指标与内置指标有相同的一级指标,则合并其下的二级指标
        2. 如果自定义指标与内置指标有相同的二级指标,则合并其下的三级指标
        3. 如果是全新的指标,则直接添加
        """
        merged = base_config.copy()
        
        for level1_key, level1_value in custom_config.items():
            # 跳过非指标配置项(如vehicle等)
            if not isinstance(level1_value, dict) or 'name' not in level1_value:
                if level1_key not in merged:
                    merged[level1_key] = level1_value
                continue
                
            if level1_key not in merged:
                # 全新的一级指标
                merged[level1_key] = level1_value
            else:
                # 合并已存在的一级指标下的内容
                for level2_key, level2_value in level1_value.items():
                    if level2_key == 'name' or level2_key == 'priority':
                        continue
                        
                    if isinstance(level2_value, dict):
                        if level2_key not in merged[level1_key]:
                            # 新的二级指标
                            merged[level1_key][level2_key] = level2_value
                        else:
                            # 合并已存在的二级指标下的内容
                            for level3_key, level3_value in level2_value.items():
                                if level3_key == 'name' or level3_key == 'priority':
                                    continue
                                    
                                if isinstance(level3_value, dict):
                                    if level3_key not in merged[level1_key][level2_key]:
                                        # 新的三级指标
                                        merged[level1_key][level2_key][level3_key] = level3_value
        
        return merged

    def _load_custom_metrics(self) -> None:
        """加载自定义指标脚本"""
        if not self.custom_metrics_path or not Path(self.custom_metrics_path).exists():
            return
            
        custom_metrics_dir = Path(self.custom_metrics_path)
        if not custom_metrics_dir.is_dir():
            self.logger.warning(f"自定义指标路径不是目录: {custom_metrics_dir}")
            return
            
        # 遍历自定义指标脚本目录
        for file_path in custom_metrics_dir.glob("*.py"):
            if file_path.name.startswith("metric_") and file_path.name.endswith(".py"):
                try:
                    # 解析脚本名称,获取指标层级信息
                    parts = file_path.stem[7:].split('_')  # 去掉'metric_'前缀
                    if len(parts) < 3:
                        self.logger.warning(f"自定义指标脚本 {file_path.name} 命名不符合规范,应为 metric_<level1>_<level2>_<level3>.py")
                        continue
                    
                    level1, level2, level3 = parts[0], parts[1], parts[2]
                    
                    # 检查指标是否在配置中
                    if not self._check_metric_in_config(level1, level2, level3, self.custom_metrics_config):
                        self.logger.warning(f"自定义指标 {level1}.{level2}.{level3} 在配置中不存在,跳过加载")
                        continue
                    
                    # 加载脚本模块
                    module_name = f"custom_metric_{level1}_{level2}_{level3}"
                    spec = importlib.util.spec_from_file_location(module_name, file_path)
                    module = importlib.util.module_from_spec(spec)
                    spec.loader.exec_module(module)
                    
                    # 检查模块是否包含必要的函数
                    if not hasattr(module, 'evaluate'):
                        self.logger.warning(f"自定义指标脚本 {file_path.name} 缺少 evaluate 函数")
                        continue
                    
                    # 存储模块引用
                    key = f"{level1}.{level2}.{level3}"
                    self.custom_metrics_modules[key] = module
                    self.logger.info(f"成功加载自定义指标脚本: {file_path.name}")
                    
                except Exception as e:
                    self.logger.error(f"加载自定义指标脚本 {file_path.name} 失败: {str(e)}")

    def _check_metric_in_config(self, level1: str, level2: str, level3: str, config: Dict) -> bool:
        """检查指标是否在配置中存在"""
        try:
            return (level1 in config and 
                    isinstance(config[level1], dict) and
                    level2 in config[level1] and
                    isinstance(config[level1][level2], dict) and
                    level3 in config[level1][level2] and
                    isinstance(config[level1][level2][level3], dict))
        except Exception:
            return False

    def parallel_evaluate(self, data: Any) -> Dict[str, Any]:
        """并行化评估引擎(动态线程池)"""
        # 存储所有评估结果
        results = {}
        
        # 1. 先评估内置指标
        self._evaluate_built_in_metrics(data, results)
        
        # 2. 再评估自定义指标并合并结果
        self._evaluate_and_merge_custom_metrics(data, results)
        
        return results
    
    def _evaluate_built_in_metrics(self, data: Any, results: Dict[str, Any]) -> None:
        """评估内置指标"""
        # 关键修改点1:线程数=模块数
        with ThreadPoolExecutor(max_workers=len(self.metric_modules)) as executor:
            # 关键修改点2:按模块名创建future映射
            futures = {
                module_name: executor.submit(
                    self._run_module, module, data, module_name
                )
                for module_name, module in self.metric_modules.items()
            }

            # 关键修改点3:按模块顺序处理结果
            for module_name, future in futures.items():
                try:
                    from modules.lib.score import Score
                    evaluator = Score(self.merged_config, module_name)
                    result_module = future.result()
                    result = evaluator.evaluate(result_module)
                    # results.update(result[module_name])
                    results.update(result)
                except Exception as e:
                    self.logger.error(
                        f"{module_name} 评估失败: {str(e)}",
                        exc_info=True,
                        extra={"stack": True},  # 记录完整堆栈
                    )
                    # 错误信息结构化存储
                    results[module_name] = {
                        "status": "error",
                        "message": str(e),
                        "timestamp": datetime.now().isoformat(),
                    }
    
    def _evaluate_and_merge_custom_metrics(self, data: Any, results: Dict[str, Any]) -> None:
        """评估自定义指标并合并结果"""
        if not self.custom_metrics_modules:
            return
            
        # 按一级指标分组自定义指标
        grouped_metrics = {}
        for metric_key in self.custom_metrics_modules:
            level1 = metric_key.split('.')[0]
            if level1 not in grouped_metrics:
                grouped_metrics[level1] = []
            grouped_metrics[level1].append(metric_key)
        
        # 处理每个一级指标组
        for level1, metric_keys in grouped_metrics.items():
            # 检查是否为内置一级指标
            is_built_in = level1 in self.metrics_config and 'name' in self.metrics_config[level1]
            level1_name = self.merged_config[level1].get('name', level1) if level1 in self.merged_config else level1
            
            # 如果是内置一级指标,将结果合并到已有结果中
            if is_built_in and level1_name in results:
                for metric_key in metric_keys:
                    self._evaluate_and_merge_single_metric(data, results, metric_key, level1_name)
            else:
                # 如果是新的一级指标,创建新的结果结构
                if level1_name not in results:
                    results[level1_name] = {}
                
                # 评估该一级指标下的所有自定义指标
                for metric_key in metric_keys:
                    self._evaluate_and_merge_single_metric(data, results, metric_key, level1_name)
    
    def _evaluate_and_merge_single_metric(self, data: Any, results: Dict[str, Any], metric_key: str, level1_name: str) -> None:
        """评估单个自定义指标并合并结果"""
        try:
            level1, level2, level3 = metric_key.split('.')
            module = self.custom_metrics_modules[metric_key]
            
            # 获取指标配置
            metric_config = self.custom_metrics_config[level1][level2][level3]
            
            # 获取指标名称
            level2_name = self.custom_metrics_config[level1][level2].get('name', level2)
            level3_name = metric_config.get('name', level3)
            
            # 确保结果字典结构存在
            if level2_name not in results[level1_name]:
                results[level1_name][level2_name] = {}
            
            # 调用自定义指标评测函数
            metric_result = module.evaluate(data)
            from modules.lib.score import Score
            evaluator = Score(self.merged_config, level1_name)
            
            result = evaluator.evaluate(metric_result)
           
            results.update(result)
            
            
            self.logger.info(f"评测自定义指标: {level1_name}.{level2_name}.{level3_name}")
            
        except Exception as e:
            self.logger.error(f"评测自定义指标 {metric_key} 失败: {str(e)}")
            
            # 尝试添加错误信息到结果中
            try:
                level1, level2, level3 = metric_key.split('.')
                level2_name = self.custom_metrics_config[level1][level2].get('name', level2)
                level3_name = self.custom_metrics_config[level1][level2][level3].get('name', level3)
                
                if level2_name not in results[level1_name]:
                    results[level1_name][level2_name] = {}
                    
                results[level1_name][level2_name][level3_name] = {
                    "status": "error",
                    "message": str(e),
                    "timestamp": datetime.now().isoformat(),
                }
            except Exception:
                pass

    def _run_module(
        self, module_class: Any, data: Any, module_name: str
    ) -> Dict[str, Any]:
        """执行单个评估模块(带熔断机制)"""
        try:
            instance = module_class(data)
            return {module_name: instance.report_statistic()}
        except Exception as e:
            self.logger.error(f"{module_name} 执行异常: {str(e)}", stack_info=True)
            return {module_name: {"error": str(e)}}




class EvaluationPipeline:
    """评估流水线控制器"""

    def __init__(self, configPath: str, logPath: str, dataPath: str, resultPath: str, customMetricsPath: Optional[str] = None, customConfigPath: Optional[str] = None):
        self.configPath = Path(configPath)
        self.custom_config_path = Path(customConfigPath) if customConfigPath else None
        self.data_path = Path(dataPath)
        self.report_path = Path(resultPath)
        self.custom_metrics_path = Path(customMetricsPath) if customMetricsPath else None
        
        # 创建评估引擎实例,传入所有必要参数
        self.engine = EvaluationCore(
            logPath, 
            configPath=str(self.configPath), 
            customConfigPath=str(self.custom_config_path) if self.custom_config_path else None,
            customMetricsPath=str(self.custom_metrics_path) if self.custom_metrics_path else None
        )
        
        self.data_processor = self._load_data_processor()

    def _load_data_processor(self) -> Any:
        """动态加载数据预处理模块"""
        try:
            from modules.lib import data_process

            return data_process.DataPreprocessing(self.data_path, self.configPath)
        except ImportError as e:
            raise RuntimeError(f"数据处理器加载失败: {str(e)}") from e

    def execute_pipeline(self) -> Dict[str, Any]:
        """端到端执行评估流程"""
        self._validate_case()

        try:
            metric_results = self.engine.parallel_evaluate(self.data_processor)
            report = self._generate_report(
                self.data_processor.case_name, metric_results
            )
            return report
        except Exception as e:
            self.engine.logger.critical(f"流程执行失败: {str(e)}", exc_info=True)
            return {"error": str(e)}

    def _validate_case(self) -> None:
        """用例路径验证"""
        case_path = self.data_path
        if not case_path.exists():
            raise FileNotFoundError(f"用例路径不存在: {case_path}")
        if not case_path.is_dir():
            raise NotADirectoryError(f"无效的用例目录: {case_path}")

    def _generate_report(self, case_name: str, results: Dict) -> Dict:
        """生成评估报告(模板方法模式)"""
        from modules.lib.common import dict2json

        report_path = self.report_path
        report_path.mkdir(parents=True, exist_ok=True, mode=0o777)

        report_file = report_path / f"{case_name}_report.json"
        dict2json(results, report_file)
        self.engine.logger.info(f"评估报告已生成: {report_file}")
        return results


def main():
    """命令行入口(工厂模式)"""
    parser = argparse.ArgumentParser(
        description="自动驾驶评估系统 V3.0 - 支持动态指标选择和自定义指标",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # 带帮助说明的参数定义,增加默认值
    parser.add_argument(
        "--logPath",
        type=str,
        default=r"D:\Cicv\招远\zhaoyuan0410\logs\test.log",
        help="日志文件存储路径",
    )
    parser.add_argument(
        "--dataPath",
        type=str,
        default=r"D:\Cicv\招远\V2V_CSAE53-2020_ForwardCollision_LST_02-03",
        help="预处理后的输入数据目录",
    )
    parser.add_argument(
        "--configPath",
        type=str,
        default=r"D:\Cicv\招远\zhaoyuan0410\config\metrics_config.yaml",
        help="评估指标配置文件路径",
    )
    parser.add_argument(
        "--reportPath",
        type=str,
        default=r"D:\Cicv\招远\zhaoyuan0410\result",
        help="评估报告输出目录",
    )
    # 新增自定义指标路径参数(可选)
    parser.add_argument(
        "--customMetricsPath",
        type=str,
        default=r"D:\Cicv\招远\zhaoyuan0410\custom_metrics",
        help="自定义指标脚本目录(可选)",
    )
    # 新增自定义指标路径参数(可选)
    parser.add_argument(
        "--customConfigPath",
        type=str,
        default=r"D:\Cicv\招远\zhaoyuan0410\test\custom_metrics_config.yaml",
        help="自定义指标脚本目录(可选)",
    )
    args = parser.parse_args()

    try:
        pipeline = EvaluationPipeline(
            args.configPath, args.logPath, args.dataPath, args.reportPath, args.customMetricsPath, args.customConfigPath
        )
        start_time = time.perf_counter()

        result = pipeline.execute_pipeline()

        if "error" in result:
            sys.exit(1)

        print(f"评估完成,耗时: {time.perf_counter()-start_time:.2f}s")
        print(f"报告路径: {pipeline.report_path}")
    except KeyboardInterrupt:
        print("\n用户中断操作")
        sys.exit(130)
    except Exception as e:
        print(f"程序执行异常: {str(e)}")
        sys.exit(1)


if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    main()