123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734 |
- #!/usr/bin/env python3
- # evaluator_enhanced.py
- import sys
- import warnings
- import time
- import importlib
- import importlib.util
- import yaml
- from pathlib import Path
- import argparse
- from concurrent.futures import ThreadPoolExecutor
- from functools import lru_cache
- from typing import Dict, Any, List, Optional, Type, Tuple, Callable, Union
- from datetime import datetime
- import logging
- import traceback
- import json
- import inspect
- # 常量定义
- DEFAULT_WORKERS = 4
- CUSTOM_METRIC_PREFIX = "metric_"
- CUSTOM_METRIC_FILE_PATTERN = "*.py"
- # 安全设置根目录路径
- if hasattr(sys, "_MEIPASS"):
- _ROOT_PATH = Path(sys._MEIPASS)
- else:
- _ROOT_PATH = Path(__file__).resolve().parent.parent
- sys.path.insert(0, str(_ROOT_PATH))
- class ConfigManager:
- """配置管理组件"""
-
- def __init__(self, logger: logging.Logger):
- self.logger = logger
- self.base_config: Dict[str, Any] = {}
- self.custom_config: Dict[str, Any] = {}
- self.merged_config: Dict[str, Any] = {}
- self._config_cache = {}
-
- def split_configs(self, all_metrics_path: Path, builtin_metrics_path: Path, custom_metrics_path: Path) -> None:
- """从all_metrics_config.yaml拆分成内置和自定义配置"""
- # 检查是否已经存在提取的配置文件,如果存在则跳过拆分过程
- extracted_builtin_path = builtin_metrics_path.parent / f"{builtin_metrics_path.stem}_extracted{builtin_metrics_path.suffix}"
- if extracted_builtin_path.exists() and custom_metrics_path.exists():
- self.logger.info(f"使用已存在的拆分配置文件: {extracted_builtin_path}")
- return
-
- try:
- # 使用缓存加载配置文件,避免重复读取
- all_metrics_dict = self._safe_load_config(all_metrics_path)
- builtin_metrics_dict = self._safe_load_config(builtin_metrics_path)
-
- # 递归提取内置和自定义指标
- extracted_builtin_metrics, custom_metrics_dict = self._split_metrics_recursive(
- all_metrics_dict, builtin_metrics_dict
- )
-
- # 保存提取的内置指标到新文件
- with open(extracted_builtin_path, 'w', encoding='utf-8') as f:
- yaml.dump(extracted_builtin_metrics, f, allow_unicode=True, sort_keys=False, indent=2)
- self.logger.info(f"拆分配置: 提取的内置指标已保存到 {extracted_builtin_path}")
-
- if custom_metrics_dict:
- with open(custom_metrics_path, 'w', encoding='utf-8') as f:
- yaml.dump(custom_metrics_dict, f, allow_unicode=True, sort_keys=False, indent=2)
- self.logger.info(f"拆分配置: 自定义指标已保存到 {custom_metrics_path}")
-
- except Exception as err:
- self.logger.error(f"拆分配置失败: {str(err)}")
- raise
-
- def _split_metrics_recursive(self, all_dict: Dict, builtin_dict: Dict) -> Tuple[Dict, Dict]:
- """递归拆分内置和自定义指标配置"""
- extracted_builtin = {}
- custom_metrics = {}
-
- for key, value in all_dict.items():
- if key in builtin_dict:
- # 如果是字典类型,继续递归
- if isinstance(value, dict) and isinstance(builtin_dict[key], dict):
- sub_builtin, sub_custom = self._split_metrics_recursive(value, builtin_dict[key])
- if sub_builtin:
- extracted_builtin[key] = sub_builtin
- if sub_custom:
- custom_metrics[key] = sub_custom
- else:
- # 如果不是字典类型,直接复制
- extracted_builtin[key] = value
- else:
- # 如果键不在内置配置中,归类为自定义指标
- custom_metrics[key] = value
-
- return extracted_builtin, custom_metrics
-
- def load_configs(self, all_config_path: Optional[Path], builtin_metrics_path: Optional[Path], custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
- """加载并合并配置"""
- # 如果已经加载过配置,直接返回缓存的结果
- cache_key = f"{all_config_path}_{builtin_metrics_path}_{custom_metrics_path}"
- if cache_key in self._config_cache:
- self.logger.info("使用缓存的配置数据")
- return self._config_cache[cache_key]
-
- # 自动拆分配置
- extracted_builtin_path = None
-
- if all_config_path and all_config_path.exists():
- # 生成提取的内置指标配置文件路径
- extracted_builtin_path = builtin_metrics_path.parent / f"{builtin_metrics_path.stem}_extracted{builtin_metrics_path.suffix}"
- self.split_configs(all_config_path, builtin_metrics_path, custom_metrics_path)
-
- # 优先使用提取的内置指标配置
- if extracted_builtin_path and extracted_builtin_path.exists():
- self.base_config = self._safe_load_config(extracted_builtin_path)
- else:
- self.base_config = self._safe_load_config(builtin_metrics_path) if builtin_metrics_path else {}
-
- self.custom_config = self._safe_load_config(custom_metrics_path) if custom_metrics_path else {}
- if all_config_path and all_config_path.exists():
- self.merged_config = self._safe_load_config(all_config_path)
- # 缓存配置结果
- self._config_cache[cache_key] = self.merged_config
- return self.merged_config
- return {}
-
- @lru_cache(maxsize=16)
- def _safe_load_config(self, config_path: Path) -> Dict[str, Any]:
- """安全加载YAML配置,使用lru_cache减少重复读取"""
- try:
- if not config_path or not config_path.exists():
- self.logger.warning(f"Config file not found: {config_path}")
- return {}
- with config_path.open('r', encoding='utf-8') as f:
- config_dict = yaml.safe_load(f) or {}
- self.logger.info(f"Loaded config: {config_path}")
- return config_dict
- except Exception as err:
- self.logger.error(f"Failed to load config {config_path}: {str(err)}")
- return {}
-
- def get_config(self) -> Dict[str, Any]:
- return self.merged_config
-
- def get_base_config(self) -> Dict[str, Any]:
- return self.base_config
-
- def get_custom_config(self) -> Dict[str, Any]:
- return self.custom_config
- class MetricLoader:
- """指标加载器组件"""
-
- def __init__(self, logger: logging.Logger, config_manager: ConfigManager):
- self.logger = logger
- self.config_manager = config_manager
- self.metric_modules: Dict[str, Type] = {}
- self.custom_metric_modules: Dict[str, Any] = {}
-
- def load_builtin_metrics(self) -> Dict[str, Type]:
- """加载内置指标模块"""
- module_mapping = {
- "safety": ("modules.metric.safety", "SafeManager"),
- "comfort": ("modules.metric.comfort", "ComfortManager"),
- "traffic": ("modules.metric.traffic", "TrafficManager"),
- "efficient": ("modules.metric.efficient", "EfficientManager"),
- "function": ("modules.metric.function", "FunctionManager"),
- }
-
- self.metric_modules = {
- name: self._load_module(*info)
- for name, info in module_mapping.items()
- }
-
- self.logger.info(f"Loaded builtin metrics: {', '.join(self.metric_modules.keys())}")
- return self.metric_modules
-
- @lru_cache(maxsize=32)
- def _load_module(self, module_path: str, class_name: str) -> Type:
- """动态加载Python模块"""
- try:
- module = __import__(module_path, fromlist=[class_name])
- return getattr(module, class_name)
- except (ImportError, AttributeError) as e:
- self.logger.error(f"Failed to load module: {module_path}.{class_name} - {str(e)}")
- raise
-
- def load_custom_metrics(self, custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
- """加载自定义指标模块"""
- if not custom_metrics_path or not custom_metrics_path.is_dir():
- self.logger.info("No custom metrics path or path not exists")
- return {}
- # 检查是否有新的自定义指标文件
- current_files = set(f.name for f in custom_metrics_path.glob(CUSTOM_METRIC_FILE_PATTERN)
- if f.name.startswith(CUSTOM_METRIC_PREFIX))
- loaded_files = set(self.custom_metric_modules.keys())
-
- # 如果没有新文件且已有加载的模块,直接返回
- if self.custom_metric_modules and not (current_files - loaded_files):
- self.logger.info(f"No new custom metrics to load, using {len(self.custom_metric_modules)} cached modules")
- return self.custom_metric_modules
- loaded_count = 0
- for py_file in custom_metrics_path.glob(CUSTOM_METRIC_FILE_PATTERN):
- if py_file.name.startswith(CUSTOM_METRIC_PREFIX):
- if self._process_custom_metric_file(py_file):
- loaded_count += 1
-
- self.logger.info(f"Loaded {loaded_count} custom metric modules")
- return self.custom_metric_modules
-
- def _process_custom_metric_file(self, file_path: Path) -> bool:
- """处理单个自定义指标文件"""
- try:
- metric_key = self._validate_metric_file(file_path)
-
- module_name = f"custom_metric_{file_path.stem}"
- spec = importlib.util.spec_from_file_location(module_name, file_path)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
-
- from modules.lib.metric_registry import BaseMetric
- metric_class = None
-
- for name, obj in inspect.getmembers(module):
- if (inspect.isclass(obj) and
- issubclass(obj, BaseMetric) and
- obj != BaseMetric):
- metric_class = obj
- break
-
- if metric_class:
- self.custom_metric_modules[metric_key] = {
- 'type': 'class',
- 'module': module,
- 'class': metric_class
- }
- self.logger.info(f"Loaded class-based custom metric: {metric_key}")
- elif hasattr(module, 'evaluate'):
- self.custom_metric_modules[metric_key] = {
- 'type': 'function',
- 'module': module
- }
- self.logger.info(f"Loaded function-based custom metric: {metric_key}")
- else:
- raise AttributeError(f"Missing evaluate() function or BaseMetric subclass: {file_path.name}")
-
- return True
- except ValueError as e:
- self.logger.warning(str(e))
- return False
- except Exception as e:
- self.logger.error(f"Failed to load custom metric {file_path}: {str(e)}")
- return False
-
- def _validate_metric_file(self, file_path: Path) -> str:
- """验证自定义指标文件命名规范"""
- stem = file_path.stem[len(CUSTOM_METRIC_PREFIX):]
- parts = stem.split('_')
- if len(parts) < 3:
- raise ValueError(f"Invalid custom metric filename: {file_path.name}, should be metric_<level1>_<level2>_<level3>.py")
- level1, level2, level3 = parts[:3]
- if not self._is_metric_configured(level1, level2, level3):
- raise ValueError(f"Unconfigured metric: {level1}.{level2}.{level3}")
- return f"{level1}.{level2}.{level3}"
-
- def _is_metric_configured(self, level1: str, level2: str, level3: str) -> bool:
- """检查指标是否在配置中注册"""
- custom_config = self.config_manager.get_custom_config()
- try:
- return (level1 in custom_config and
- isinstance(custom_config[level1], dict) and
- level2 in custom_config[level1] and
- isinstance(custom_config[level1][level2], dict) and
- level3 in custom_config[level1][level2] and
- isinstance(custom_config[level1][level2][level3], dict))
- except Exception:
- return False
-
- def get_builtin_metrics(self) -> Dict[str, Type]:
- return self.metric_modules
-
- def get_custom_metrics(self) -> Dict[str, Any]:
- return self.custom_metric_modules
- class EvaluationEngine:
- """评估引擎组件"""
-
- def __init__(self, logger: logging.Logger, config_manager: ConfigManager, metric_loader: MetricLoader):
- self.logger = logger
- self.config_manager = config_manager
- self.metric_loader = metric_loader
-
- def evaluate(self, data: Any) -> Dict[str, Any]:
- """执行评估流程"""
- raw_results = self._collect_builtin_metrics(data)
- custom_results = self._collect_custom_metrics(data)
- return self._process_merged_results(raw_results, custom_results)
-
- def _collect_builtin_metrics(self, data: Any) -> Dict[str, Any]:
- """收集内置指标结果"""
- metric_modules = self.metric_loader.get_builtin_metrics()
- raw_results: Dict[str, Any] = {}
-
- # 获取配置中实际存在的指标
- config = self.config_manager.get_config()
- available_metrics = {
- metric_name for metric_name in metric_modules.keys()
- if metric_name in config and isinstance(config[metric_name], dict)
- }
-
- # 只处理配置中存在的指标
- filtered_modules = {
- name: module for name, module in metric_modules.items()
- if name in available_metrics
- }
-
- # 优化线程池大小,避免创建过多线程
- max_workers = min(len(filtered_modules), DEFAULT_WORKERS)
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- futures = {
- executor.submit(self._run_module, module, data, module_name): module_name
- for module_name, module in filtered_modules.items()
- }
-
- for future in futures:
- module_name = futures[future]
- try:
- result = future.result()
- raw_results[module_name] = result[module_name]
- except Exception as e:
- self.logger.error(
- f"{module_name} evaluation failed: {str(e)}",
- exc_info=True,
- )
- raw_results[module_name] = {
- "status": "error",
- "message": str(e),
- "timestamp": datetime.now().isoformat(),
- }
-
- return raw_results
-
- def _collect_custom_metrics(self, data: Any) -> Dict[str, Dict]:
- """收集自定义指标结果"""
- custom_metrics = self.metric_loader.get_custom_metrics()
- if not custom_metrics:
- return {}
-
- custom_results = {}
-
- # 使用线程池并行处理自定义指标
- max_workers = min(len(custom_metrics), DEFAULT_WORKERS)
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- futures = {}
-
- # 提交所有自定义指标任务
- for metric_key, metric_info in custom_metrics.items():
- futures[executor.submit(self._run_custom_metric, metric_key, metric_info, data)] = metric_key
-
- # 收集结果
- for future in futures:
- metric_key = futures[future]
- try:
- level1, result = future.result()
- if level1:
- custom_results[level1] = result
- except Exception as e:
- self.logger.error(f"Custom metric {metric_key} execution failed: {str(e)}")
-
- return custom_results
-
- def _run_custom_metric(self, metric_key: str, metric_info: Dict, data: Any) -> Tuple[str, Dict]:
- """执行单个自定义指标"""
- try:
- level1, level2, level3 = metric_key.split('.')
-
- if metric_info['type'] == 'class':
- metric_class = metric_info['class']
- metric_instance = metric_class(data)
- metric_result = metric_instance.calculate()
- else:
- module = metric_info['module']
- metric_result = module.evaluate(data)
-
- self.logger.info(f"Calculated custom metric: {level1}.{level2}.{level3}")
- return level1, metric_result
-
- except Exception as e:
- self.logger.error(f"Custom metric {metric_key} failed: {str(e)}")
- try:
- level1 = metric_key.split('.')[0]
- return level1, {
- "status": "error",
- "message": str(e),
- "timestamp": datetime.now().isoformat(),
- }
- except Exception:
- return "", {}
-
- def _process_merged_results(self, raw_results: Dict, custom_results: Dict) -> Dict:
- """处理合并后的评估结果"""
- from modules.lib.score import Score
- final_results = {}
- merged_config = self.config_manager.get_config()
- for level1, level1_data in raw_results.items():
- if level1 in custom_results:
- level1_data.update(custom_results[level1])
- try:
- evaluator = Score(merged_config, level1)
- final_results.update(evaluator.evaluate(level1_data))
- except Exception as e:
- final_results[level1] = self._format_error(e)
- for level1, level1_data in custom_results.items():
- if level1 not in raw_results:
- try:
- evaluator = Score(merged_config, level1)
- final_results.update(evaluator.evaluate(level1_data))
- except Exception as e:
- final_results[level1] = self._format_error(e)
- return final_results
-
- def _format_error(self, e: Exception) -> Dict:
- return {
- "status": "error",
- "message": str(e),
- "timestamp": datetime.now().isoformat()
- }
-
- def _run_module(self, module_class: Any, data: Any, module_name: str) -> Dict[str, Any]:
- """执行单个评估模块"""
- try:
- instance = module_class(data)
- return {module_name: instance.report_statistic()}
- except Exception as e:
- self.logger.error(f"{module_name} execution error: {str(e)}", exc_info=True)
- return {module_name: {"error": str(e)}}
- class LoggingManager:
- """日志管理组件"""
-
- def __init__(self, log_path: Path):
- self.log_path = log_path
- self.logger = self._init_logger()
-
- def _init_logger(self) -> logging.Logger:
- """初始化日志系统"""
- try:
- from modules.lib.log_manager import LogManager
- log_manager = LogManager(self.log_path)
- return log_manager.get_logger()
- except (ImportError, PermissionError, IOError) as e:
- logger = logging.getLogger("evaluator")
- logger.setLevel(logging.INFO)
- console_handler = logging.StreamHandler()
- console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
- logger.addHandler(console_handler)
- logger.warning(f"Failed to init standard logger: {str(e)}, using fallback logger")
- return logger
-
- def get_logger(self) -> logging.Logger:
- return self.logger
- class DataProcessor:
- """数据处理组件"""
-
- def __init__(self, logger: logging.Logger, data_path: Path, config_path: Optional[Path] = None):
- self.logger = logger
- self.data_path = data_path
- self.config_path = config_path
- self.case_name = self.data_path.name
- self._processor = None
-
- @property
- def processor(self) -> Any:
- """懒加载数据处理器,只在首次访问时创建"""
- if self._processor is None:
- self._processor = self._load_processor()
- return self._processor
-
- def _load_processor(self) -> Any:
- """加载数据处理器"""
- try:
- start_time = time.perf_counter()
- from modules.lib import data_process
- processor = data_process.DataPreprocessing(self.data_path, self.config_path)
- elapsed_time = time.perf_counter() - start_time
- self.logger.info(f"Data processor loaded in {elapsed_time:.2f}s")
- return processor
- except ImportError as e:
- self.logger.error(f"Failed to load data processor: {str(e)}")
- raise RuntimeError(f"Failed to load data processor: {str(e)}") from e
-
- def validate(self) -> None:
- """验证数据路径"""
- if not self.data_path.exists():
- raise FileNotFoundError(f"Data path not exists: {self.data_path}")
- if not self.data_path.is_dir():
- raise NotADirectoryError(f"Invalid data directory: {self.data_path}")
- class EvaluationPipeline:
- """评估流水线控制器"""
-
- def __init__(self, all_config_path: str, base_config_path: str, log_path: str, data_path: str, report_path: str,
- custom_metrics_path: Optional[str] = None, custom_config_path: Optional[str] = None):
- # 路径初始化
- self.all_config_path = Path(all_config_path) if all_config_path else None
- self.base_config_path = Path(base_config_path) if base_config_path else None
- self.custom_config_path = Path(custom_config_path) if custom_config_path else None
- self.data_path = Path(data_path)
- self.report_path = Path(report_path)
- self.custom_metrics_path = Path(custom_metrics_path) if custom_metrics_path else None
-
- # 日志
- self.logging_manager = LoggingManager(Path(log_path))
- self.logger = self.logging_manager.get_logger()
- # 配置
- self.config_manager = ConfigManager(self.logger)
- self.config = self.config_manager.load_configs(
- self.all_config_path, self.base_config_path, self.custom_config_path
- )
- # 指标加载
- self.metric_loader = MetricLoader(self.logger, self.config_manager)
- self.metric_loader.load_builtin_metrics()
- self.metric_loader.load_custom_metrics(self.custom_metrics_path)
- # 数据处理
- self.data_processor = DataProcessor(self.logger, self.data_path, self.all_config_path)
- self.evaluation_engine = EvaluationEngine(self.logger, self.config_manager, self.metric_loader)
-
- def execute(self) -> Dict[str, Any]:
- """执行评估流水线"""
- try:
- # 只在首次运行时验证数据路径
- self.data_processor.validate()
-
- self.logger.info(f"Start evaluation: {self.data_path.name}")
- start_time = time.perf_counter()
-
- # 性能分析日志
- config_start = time.perf_counter()
- results = self.evaluation_engine.evaluate(self.data_processor.processor)
- eval_time = time.perf_counter() - config_start
-
- # 生成报告
- report_start = time.perf_counter()
- report = self._generate_report(self.data_processor.case_name, results)
- report_time = time.perf_counter() - report_start
-
- # 总耗时
- elapsed_time = time.perf_counter() - start_time
- self.logger.info(f"Evaluation completed, time: {elapsed_time:.2f}s (评估: {eval_time:.2f}s, 报告: {report_time:.2f}s)")
-
- return report
-
- except Exception as e:
- self.logger.critical(f"Evaluation failed: {str(e)}", exc_info=True)
- return {"error": str(e), "traceback": traceback.format_exc()}
-
- def _add_overall_result(self, report: Dict[str, Any]) -> Dict[str, Any]:
- """处理评测报告并添加总体结果字段"""
- # 加载阈值参数
- thresholds = {
- "T0": self.config['T_threshold']['T0_threshold'],
- "T1": self.config['T_threshold']['T1_threshold'],
- "T2": self.config['T_threshold']['T2_threshold']
- }
-
- # 初始化计数器
- counters = {'p0': 0, 'p1': 0, 'p2': 0}
-
- # 优化:一次性收集所有失败的指标
- failed_categories = [
- (category, category_data.get('priority'))
- for category, category_data in report.items()
- if isinstance(category_data, dict) and category != "metadata" and not category_data.get('result', True)
- ]
-
- # 计数
- for _, priority in failed_categories:
- if priority == 0:
- counters['p0'] += 1
- elif priority == 1:
- counters['p1'] += 1
- elif priority == 2:
- counters['p2'] += 1
-
- # 阈值判断逻辑
- overall_result = not (
- counters['p0'] > thresholds['T0'] or
- counters['p1'] > thresholds['T1'] or
- counters['p2'] > thresholds['T2']
- )
-
- # 生成处理后的报告
- processed_report = report.copy()
- processed_report['overall_result'] = overall_result
-
- # 添加统计信息
- processed_report['threshold_checks'] = {
- 'T0_threshold': thresholds['T0'],
- 'T1_threshold': thresholds['T1'],
- 'T2_threshold': thresholds['T2'],
- 'actual_counts': counters
- }
-
- self.logger.info(f"Added overall result: {overall_result}")
- return processed_report
-
- def _generate_report(self, case_name: str, results: Dict[str, Any]) -> Dict[str, Any]:
- """生成评估报告"""
- from modules.lib.common import dict2json
-
- self.report_path.mkdir(parents=True, exist_ok=True)
-
- results["metadata"] = {
- "case_name": case_name,
- "timestamp": datetime.now().isoformat(),
- "version": "1.0",
- }
-
- # 添加总体结果评估
- results = self._add_overall_result(results)
-
- report_file = self.report_path / f"{case_name}_report.json"
- dict2json(results, report_file)
- self.logger.info(f"Report generated: {report_file}")
-
- return results
- def main():
- """命令行入口"""
- parser = argparse.ArgumentParser(
- description="Autonomous Driving Evaluation System V3.1",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # 必要参数
- parser.add_argument(
- "--dataPath",
- type=str,
- default=r"D:\Kevin\zhaoyuan\data\V2V_CSAE53-2020_ForwardCollision_LST_01-02",
- help="Input data directory",
- )
-
- # 配置参数
- config_group = parser.add_argument_group('Configuration')
- config_group.add_argument(
- "--allConfigPath",
- type=str,
- default=r"config/all_metrics_config.yaml",
- help="Full metrics config file path (built-in + custom)",
- )
- config_group.add_argument(
- "--baseConfigPath",
- type=str,
- default=r"config/builtin_metrics_config.yaml",
- help="Built-in metrics config file path",
- )
- config_group.add_argument(
- "--customConfigPath",
- type=str,
- default=r"config/custom_metrics_config.yaml",
- help="Custom metrics config path (optional)",
- )
-
- # 输出参数
- output_group = parser.add_argument_group('Output')
- output_group.add_argument(
- "--logPath",
- type=str,
- default="test.log",
- help="Log file path",
- )
- output_group.add_argument(
- "--reportPath",
- type=str,
- default="reports",
- help="Output report directory",
- )
-
- # 扩展参数
- ext_group = parser.add_argument_group('Extensions')
- ext_group.add_argument(
- "--customMetricsPath",
- type=str,
- default="custom_metrics",
- help="Custom metrics scripts directory (optional)",
- )
-
- args = parser.parse_args()
- try:
- pipeline = EvaluationPipeline(
- all_config_path=args.allConfigPath,
- base_config_path=args.baseConfigPath,
- log_path=args.logPath,
- data_path=args.dataPath,
- report_path=args.reportPath,
- custom_metrics_path=args.customMetricsPath,
- custom_config_path=args.customConfigPath
- )
-
- start_time = time.perf_counter()
- result = pipeline.execute()
- elapsed_time = time.perf_counter() - start_time
- if "error" in result:
- print(f"Evaluation failed: {result['error']}")
- sys.exit(1)
- print(f"Evaluation completed, total time: {elapsed_time:.2f}s")
- print(f"Report path: {pipeline.report_path}")
-
- except KeyboardInterrupt:
- print("\nUser interrupted")
- sys.exit(130)
- except Exception as e:
- print(f"Execution error: {str(e)}")
- traceback.print_exc()
- sys.exit(1)
- if __name__ == "__main__":
- warnings.filterwarnings("ignore")
- main()
|