|
@@ -17,7 +17,6 @@ import traceback
|
|
|
import json
|
|
|
import inspect
|
|
|
|
|
|
-
|
|
|
# 常量定义
|
|
|
DEFAULT_WORKERS = 4
|
|
|
CUSTOM_METRIC_PREFIX = "metric_"
|
|
@@ -31,26 +30,27 @@ else:
|
|
|
|
|
|
sys.path.insert(0, str(_ROOT_PATH))
|
|
|
|
|
|
+
|
|
|
class ConfigManager:
|
|
|
"""配置管理组件"""
|
|
|
-
|
|
|
+
|
|
|
def __init__(self, logger: logging.Logger):
|
|
|
self.logger = logger
|
|
|
self.base_config: Dict[str, Any] = {}
|
|
|
self.custom_config: Dict[str, Any] = {}
|
|
|
self.merged_config: Dict[str, Any] = {}
|
|
|
-
|
|
|
+
|
|
|
def split_configs(self, all_config_path: Path, base_config_path: Path, custom_config_path: Path) -> None:
|
|
|
"""从all_metrics_config.yaml拆分成内置和自定义配置"""
|
|
|
try:
|
|
|
with open(all_config_path, 'r', encoding='utf-8') as f:
|
|
|
all_metrics = yaml.safe_load(f) or {}
|
|
|
-
|
|
|
+
|
|
|
with open(base_config_path, 'r', encoding='utf-8') as f:
|
|
|
builtin_metrics = yaml.safe_load(f) or {}
|
|
|
-
|
|
|
+
|
|
|
custom_metrics = self._find_custom_metrics(all_metrics, builtin_metrics)
|
|
|
-
|
|
|
+
|
|
|
if custom_metrics:
|
|
|
with open(custom_config_path, 'w', encoding='utf-8') as f:
|
|
|
yaml.dump(custom_metrics, f, allow_unicode=True, sort_keys=False, indent=2)
|
|
@@ -58,18 +58,18 @@ class ConfigManager:
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Failed to split configs: {str(e)}")
|
|
|
raise
|
|
|
-
|
|
|
+
|
|
|
def _find_custom_metrics(self, all_metrics, builtin_metrics, current_path=""):
|
|
|
"""递归比较找出自定义指标"""
|
|
|
custom_metrics = {}
|
|
|
-
|
|
|
+
|
|
|
if isinstance(all_metrics, dict) and isinstance(builtin_metrics, dict):
|
|
|
for key in all_metrics:
|
|
|
if key not in builtin_metrics:
|
|
|
custom_metrics[key] = all_metrics[key]
|
|
|
else:
|
|
|
child_custom = self._find_custom_metrics(
|
|
|
- all_metrics[key],
|
|
|
+ all_metrics[key],
|
|
|
builtin_metrics[key],
|
|
|
f"{current_path}.{key}" if current_path else key
|
|
|
)
|
|
@@ -77,34 +77,34 @@ class ConfigManager:
|
|
|
custom_metrics[key] = child_custom
|
|
|
elif all_metrics != builtin_metrics:
|
|
|
return all_metrics
|
|
|
-
|
|
|
+
|
|
|
if custom_metrics:
|
|
|
return self._ensure_structure(custom_metrics, all_metrics, current_path)
|
|
|
return None
|
|
|
-
|
|
|
+
|
|
|
def _ensure_structure(self, metrics_dict, full_dict, path):
|
|
|
"""确保每级包含name和priority"""
|
|
|
if not isinstance(metrics_dict, dict):
|
|
|
return metrics_dict
|
|
|
-
|
|
|
+
|
|
|
current = full_dict
|
|
|
for key in path.split('.'):
|
|
|
if key in current:
|
|
|
current = current[key]
|
|
|
else:
|
|
|
break
|
|
|
-
|
|
|
+
|
|
|
result = {}
|
|
|
if isinstance(current, dict):
|
|
|
if 'name' in current:
|
|
|
result['name'] = current['name']
|
|
|
if 'priority' in current:
|
|
|
result['priority'] = current['priority']
|
|
|
-
|
|
|
+
|
|
|
for key, value in metrics_dict.items():
|
|
|
if key not in ['name', 'priority']:
|
|
|
result[key] = self._ensure_structure(value, full_dict, f"{path}.{key}" if path else key)
|
|
|
-
|
|
|
+
|
|
|
return result
|
|
|
|
|
|
def load_configs(self, base_config_path: Optional[Path], custom_config_path: Optional[Path]) -> Dict[str, Any]:
|
|
@@ -116,19 +116,19 @@ class ConfigManager:
|
|
|
target_custom_path = custom_config_path or (base_config_path.parent / "custom_metrics_config.yaml")
|
|
|
self.split_configs(all_config_path, base_config_path, target_custom_path)
|
|
|
custom_config_path = target_custom_path
|
|
|
-
|
|
|
+
|
|
|
self.base_config = self._safe_load_config(base_config_path) if base_config_path else {}
|
|
|
self.custom_config = self._safe_load_config(custom_config_path) if custom_config_path else {}
|
|
|
self.merged_config = self._merge_configs(self.base_config, self.custom_config)
|
|
|
return self.merged_config
|
|
|
-
|
|
|
+
|
|
|
def _safe_load_config(self, config_path: Path) -> Dict[str, Any]:
|
|
|
"""安全加载YAML配置"""
|
|
|
try:
|
|
|
if not config_path.exists():
|
|
|
self.logger.warning(f"Config file not found: {config_path}")
|
|
|
return {}
|
|
|
-
|
|
|
+
|
|
|
with config_path.open('r', encoding='utf-8') as f:
|
|
|
config = yaml.safe_load(f) or {}
|
|
|
self.logger.info(f"Loaded config: {config_path}")
|
|
@@ -136,24 +136,24 @@ class ConfigManager:
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Failed to load config {config_path}: {str(e)}")
|
|
|
return {}
|
|
|
-
|
|
|
+
|
|
|
def _merge_configs(self, base_config: Dict, custom_config: Dict) -> Dict:
|
|
|
"""智能合并配置"""
|
|
|
merged = base_config.copy()
|
|
|
-
|
|
|
+
|
|
|
for level1_key, level1_value in custom_config.items():
|
|
|
if not isinstance(level1_value, dict) or 'name' not in level1_value:
|
|
|
if level1_key not in merged:
|
|
|
merged[level1_key] = level1_value
|
|
|
continue
|
|
|
-
|
|
|
+
|
|
|
if level1_key not in merged:
|
|
|
merged[level1_key] = level1_value
|
|
|
else:
|
|
|
for level2_key, level2_value in level1_value.items():
|
|
|
if level2_key in ['name', 'priority']:
|
|
|
continue
|
|
|
-
|
|
|
+
|
|
|
if isinstance(level2_value, dict):
|
|
|
if level2_key not in merged[level1_key]:
|
|
|
merged[level1_key][level2_key] = level2_value
|
|
@@ -161,31 +161,32 @@ class ConfigManager:
|
|
|
for level3_key, level3_value in level2_value.items():
|
|
|
if level3_key in ['name', 'priority']:
|
|
|
continue
|
|
|
-
|
|
|
+
|
|
|
if isinstance(level3_value, dict):
|
|
|
if level3_key not in merged[level1_key][level2_key]:
|
|
|
merged[level1_key][level2_key][level3_key] = level3_value
|
|
|
-
|
|
|
+
|
|
|
return merged
|
|
|
-
|
|
|
+
|
|
|
def get_config(self) -> Dict[str, Any]:
|
|
|
return self.merged_config
|
|
|
-
|
|
|
+
|
|
|
def get_base_config(self) -> Dict[str, Any]:
|
|
|
return self.base_config
|
|
|
-
|
|
|
+
|
|
|
def get_custom_config(self) -> Dict[str, Any]:
|
|
|
return self.custom_config
|
|
|
|
|
|
+
|
|
|
class MetricLoader:
|
|
|
"""指标加载器组件"""
|
|
|
-
|
|
|
+
|
|
|
def __init__(self, logger: logging.Logger, config_manager: ConfigManager):
|
|
|
self.logger = logger
|
|
|
self.config_manager = config_manager
|
|
|
self.metric_modules: Dict[str, Type] = {}
|
|
|
self.custom_metric_modules: Dict[str, Any] = {}
|
|
|
-
|
|
|
+
|
|
|
def load_builtin_metrics(self) -> Dict[str, Type]:
|
|
|
"""加载内置指标模块"""
|
|
|
module_mapping = {
|
|
@@ -195,15 +196,15 @@ class MetricLoader:
|
|
|
"efficient": ("modules.metric.efficient", "EfficientManager"),
|
|
|
"function": ("modules.metric.function", "FunctionManager"),
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
self.metric_modules = {
|
|
|
name: self._load_module(*info)
|
|
|
for name, info in module_mapping.items()
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
self.logger.info(f"Loaded builtin metrics: {', '.join(self.metric_modules.keys())}")
|
|
|
return self.metric_modules
|
|
|
-
|
|
|
+
|
|
|
@lru_cache(maxsize=32)
|
|
|
def _load_module(self, module_path: str, class_name: str) -> Type:
|
|
|
"""动态加载Python模块"""
|
|
@@ -213,7 +214,7 @@ class MetricLoader:
|
|
|
except (ImportError, AttributeError) as e:
|
|
|
self.logger.error(f"Failed to load module: {module_path}.{class_name} - {str(e)}")
|
|
|
raise
|
|
|
-
|
|
|
+
|
|
|
def load_custom_metrics(self, custom_metrics_path: Optional[Path]) -> Dict[str, Any]:
|
|
|
"""加载自定义指标模块"""
|
|
|
if not custom_metrics_path or not custom_metrics_path.is_dir():
|
|
@@ -225,30 +226,30 @@ class MetricLoader:
|
|
|
if py_file.name.startswith(CUSTOM_METRIC_PREFIX):
|
|
|
if self._process_custom_metric_file(py_file):
|
|
|
loaded_count += 1
|
|
|
-
|
|
|
+
|
|
|
self.logger.info(f"Loaded {loaded_count} custom metric modules")
|
|
|
return self.custom_metric_modules
|
|
|
-
|
|
|
+
|
|
|
def _process_custom_metric_file(self, file_path: Path) -> bool:
|
|
|
"""处理单个自定义指标文件"""
|
|
|
try:
|
|
|
metric_key = self._validate_metric_file(file_path)
|
|
|
-
|
|
|
+
|
|
|
module_name = f"custom_metric_{file_path.stem}"
|
|
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
|
module = importlib.util.module_from_spec(spec)
|
|
|
spec.loader.exec_module(module)
|
|
|
-
|
|
|
+
|
|
|
from modules.lib.metric_registry import BaseMetric
|
|
|
metric_class = None
|
|
|
-
|
|
|
+
|
|
|
for name, obj in inspect.getmembers(module):
|
|
|
- if (inspect.isclass(obj) and
|
|
|
- issubclass(obj, BaseMetric) and
|
|
|
- obj != BaseMetric):
|
|
|
+ if (inspect.isclass(obj) and
|
|
|
+ issubclass(obj, BaseMetric) and
|
|
|
+ obj != BaseMetric):
|
|
|
metric_class = obj
|
|
|
break
|
|
|
-
|
|
|
+
|
|
|
if metric_class:
|
|
|
self.custom_metric_modules[metric_key] = {
|
|
|
'type': 'class',
|
|
@@ -264,7 +265,7 @@ class MetricLoader:
|
|
|
self.logger.info(f"Loaded function-based custom metric: {metric_key}")
|
|
|
else:
|
|
|
raise AttributeError(f"Missing evaluate() function or BaseMetric subclass: {file_path.name}")
|
|
|
-
|
|
|
+
|
|
|
return True
|
|
|
except ValueError as e:
|
|
|
self.logger.warning(str(e))
|
|
@@ -272,24 +273,25 @@ class MetricLoader:
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Failed to load custom metric {file_path}: {str(e)}")
|
|
|
return False
|
|
|
-
|
|
|
+
|
|
|
def _validate_metric_file(self, file_path: Path) -> str:
|
|
|
"""验证自定义指标文件命名规范"""
|
|
|
stem = file_path.stem[len(CUSTOM_METRIC_PREFIX):]
|
|
|
parts = stem.split('_')
|
|
|
if len(parts) < 3:
|
|
|
- raise ValueError(f"Invalid custom metric filename: {file_path.name}, should be metric_<level1>_<level2>_<level3>.py")
|
|
|
+ raise ValueError(
|
|
|
+ f"Invalid custom metric filename: {file_path.name}, should be metric_<level1>_<level2>_<level3>.py")
|
|
|
|
|
|
level1, level2, level3 = parts[:3]
|
|
|
if not self._is_metric_configured(level1, level2, level3):
|
|
|
raise ValueError(f"Unconfigured metric: {level1}.{level2}.{level3}")
|
|
|
return f"{level1}.{level2}.{level3}"
|
|
|
-
|
|
|
+
|
|
|
def _is_metric_configured(self, level1: str, level2: str, level3: str) -> bool:
|
|
|
"""检查指标是否在配置中注册"""
|
|
|
custom_config = self.config_manager.get_custom_config()
|
|
|
try:
|
|
|
- return (level1 in custom_config and
|
|
|
+ return (level1 in custom_config and
|
|
|
isinstance(custom_config[level1], dict) and
|
|
|
level2 in custom_config[level1] and
|
|
|
isinstance(custom_config[level1][level2], dict) and
|
|
@@ -297,32 +299,33 @@ class MetricLoader:
|
|
|
isinstance(custom_config[level1][level2][level3], dict))
|
|
|
except Exception:
|
|
|
return False
|
|
|
-
|
|
|
+
|
|
|
def get_builtin_metrics(self) -> Dict[str, Type]:
|
|
|
return self.metric_modules
|
|
|
-
|
|
|
+
|
|
|
def get_custom_metrics(self) -> Dict[str, Any]:
|
|
|
return self.custom_metric_modules
|
|
|
|
|
|
+
|
|
|
class EvaluationEngine:
|
|
|
"""评估引擎组件"""
|
|
|
-
|
|
|
+
|
|
|
def __init__(self, logger: logging.Logger, config_manager: ConfigManager, metric_loader: MetricLoader):
|
|
|
self.logger = logger
|
|
|
self.config_manager = config_manager
|
|
|
self.metric_loader = metric_loader
|
|
|
-
|
|
|
+
|
|
|
def evaluate(self, data: Any) -> Dict[str, Any]:
|
|
|
"""执行评估流程"""
|
|
|
raw_results = self._collect_builtin_metrics(data)
|
|
|
custom_results = self._collect_custom_metrics(data)
|
|
|
return self._process_merged_results(raw_results, custom_results)
|
|
|
-
|
|
|
+
|
|
|
def _collect_builtin_metrics(self, data: Any) -> Dict[str, Any]:
|
|
|
"""收集内置指标结果"""
|
|
|
metric_modules = self.metric_loader.get_builtin_metrics()
|
|
|
raw_results: Dict[str, Any] = {}
|
|
|
-
|
|
|
+
|
|
|
with ThreadPoolExecutor(max_workers=len(metric_modules)) as executor:
|
|
|
futures = {
|
|
|
executor.submit(self._run_module, module, data, module_name): module_name
|
|
@@ -344,21 +347,21 @@ class EvaluationEngine:
|
|
|
"message": str(e),
|
|
|
"timestamp": datetime.now().isoformat(),
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
return raw_results
|
|
|
-
|
|
|
+
|
|
|
def _collect_custom_metrics(self, data: Any) -> Dict[str, Dict]:
|
|
|
"""收集自定义指标结果"""
|
|
|
custom_metrics = self.metric_loader.get_custom_metrics()
|
|
|
if not custom_metrics:
|
|
|
return {}
|
|
|
-
|
|
|
+
|
|
|
custom_results = {}
|
|
|
-
|
|
|
+
|
|
|
for metric_key, metric_info in custom_metrics.items():
|
|
|
try:
|
|
|
level1, level2, level3 = metric_key.split('.')
|
|
|
-
|
|
|
+
|
|
|
if metric_info['type'] == 'class':
|
|
|
metric_class = metric_info['class']
|
|
|
metric_instance = metric_class(data)
|
|
@@ -366,22 +369,22 @@ class EvaluationEngine:
|
|
|
else:
|
|
|
module = metric_info['module']
|
|
|
metric_result = module.evaluate(data)
|
|
|
-
|
|
|
+
|
|
|
if level1 not in custom_results:
|
|
|
custom_results[level1] = {}
|
|
|
custom_results[level1] = metric_result
|
|
|
-
|
|
|
+
|
|
|
self.logger.info(f"Calculated custom metric: {level1}.{level2}.{level3}")
|
|
|
-
|
|
|
+
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Custom metric {metric_key} failed: {str(e)}")
|
|
|
-
|
|
|
+
|
|
|
try:
|
|
|
level1, level2, level3 = metric_key.split('.')
|
|
|
-
|
|
|
+
|
|
|
if level1 not in custom_results:
|
|
|
custom_results[level1] = {}
|
|
|
-
|
|
|
+
|
|
|
custom_results[level1] = {
|
|
|
"status": "error",
|
|
|
"message": str(e),
|
|
@@ -389,9 +392,9 @@ class EvaluationEngine:
|
|
|
}
|
|
|
except Exception:
|
|
|
pass
|
|
|
-
|
|
|
+
|
|
|
return custom_results
|
|
|
-
|
|
|
+
|
|
|
def _process_merged_results(self, raw_results: Dict, custom_results: Dict) -> Dict:
|
|
|
"""处理合并后的评估结果"""
|
|
|
from modules.lib.score import Score
|
|
@@ -417,14 +420,14 @@ class EvaluationEngine:
|
|
|
final_results[level1] = self._format_error(e)
|
|
|
|
|
|
return final_results
|
|
|
-
|
|
|
+
|
|
|
def _format_error(self, e: Exception) -> Dict:
|
|
|
return {
|
|
|
"status": "error",
|
|
|
"message": str(e),
|
|
|
"timestamp": datetime.now().isoformat()
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
def _run_module(self, module_class: Any, data: Any, module_name: str) -> Dict[str, Any]:
|
|
|
"""执行单个评估模块"""
|
|
|
try:
|
|
@@ -434,13 +437,14 @@ class EvaluationEngine:
|
|
|
self.logger.error(f"{module_name} execution error: {str(e)}", exc_info=True)
|
|
|
return {module_name: {"error": str(e)}}
|
|
|
|
|
|
+
|
|
|
class LoggingManager:
|
|
|
"""日志管理组件"""
|
|
|
-
|
|
|
+
|
|
|
def __init__(self, log_path: Path):
|
|
|
self.log_path = log_path
|
|
|
self.logger = self._init_logger()
|
|
|
-
|
|
|
+
|
|
|
def _init_logger(self) -> logging.Logger:
|
|
|
"""初始化日志系统"""
|
|
|
try:
|
|
@@ -455,20 +459,21 @@ class LoggingManager:
|
|
|
logger.addHandler(console_handler)
|
|
|
logger.warning(f"Failed to init standard logger: {str(e)}, using fallback logger")
|
|
|
return logger
|
|
|
-
|
|
|
+
|
|
|
def get_logger(self) -> logging.Logger:
|
|
|
return self.logger
|
|
|
|
|
|
+
|
|
|
class DataProcessor:
|
|
|
"""数据处理组件"""
|
|
|
-
|
|
|
+
|
|
|
def __init__(self, logger: logging.Logger, data_path: Path, config_path: Optional[Path] = None):
|
|
|
self.logger = logger
|
|
|
self.data_path = data_path
|
|
|
self.config_path = config_path
|
|
|
self.processor = self._load_processor()
|
|
|
self.case_name = self.data_path.name
|
|
|
-
|
|
|
+
|
|
|
def _load_processor(self) -> Any:
|
|
|
"""加载数据处理器"""
|
|
|
try:
|
|
@@ -477,7 +482,7 @@ class DataProcessor:
|
|
|
except ImportError as e:
|
|
|
self.logger.error(f"Failed to load data processor: {str(e)}")
|
|
|
raise RuntimeError(f"Failed to load data processor: {str(e)}") from e
|
|
|
-
|
|
|
+
|
|
|
def validate(self) -> None:
|
|
|
"""验证数据路径"""
|
|
|
if not self.data_path.exists():
|
|
@@ -485,10 +490,11 @@ class DataProcessor:
|
|
|
if not self.data_path.is_dir():
|
|
|
raise NotADirectoryError(f"Invalid data directory: {self.data_path}")
|
|
|
|
|
|
+
|
|
|
class EvaluationPipeline:
|
|
|
"""评估流水线控制器"""
|
|
|
-
|
|
|
- def __init__(self, config_path: str, log_path: str, data_path: str, report_path: str,
|
|
|
+
|
|
|
+ def __init__(self, config_path: str, log_path: str, data_path: str, report_path: str,
|
|
|
custom_metrics_path: Optional[str] = None, custom_config_path: Optional[str] = None):
|
|
|
# 路径初始化
|
|
|
self.config_path = Path(config_path) if config_path else None
|
|
@@ -496,7 +502,7 @@ class EvaluationPipeline:
|
|
|
self.data_path = Path(data_path)
|
|
|
self.report_path = Path(report_path)
|
|
|
self.custom_metrics_path = Path(custom_metrics_path) if custom_metrics_path else None
|
|
|
-
|
|
|
+
|
|
|
# 组件初始化
|
|
|
self.logging_manager = LoggingManager(Path(log_path))
|
|
|
self.logger = self.logging_manager.get_logger()
|
|
@@ -507,50 +513,51 @@ class EvaluationPipeline:
|
|
|
self.metric_loader.load_custom_metrics(self.custom_metrics_path)
|
|
|
self.evaluation_engine = EvaluationEngine(self.logger, self.config_manager, self.metric_loader)
|
|
|
self.data_processor = DataProcessor(self.logger, self.data_path, self.config_path)
|
|
|
-
|
|
|
+
|
|
|
def execute(self) -> Dict[str, Any]:
|
|
|
"""执行评估流水线"""
|
|
|
try:
|
|
|
self.data_processor.validate()
|
|
|
-
|
|
|
+
|
|
|
self.logger.info(f"Start evaluation: {self.data_path.name}")
|
|
|
start_time = time.perf_counter()
|
|
|
results = self.evaluation_engine.evaluate(self.data_processor.processor)
|
|
|
elapsed_time = time.perf_counter() - start_time
|
|
|
self.logger.info(f"Evaluation completed, time: {elapsed_time:.2f}s")
|
|
|
-
|
|
|
+
|
|
|
report = self._generate_report(self.data_processor.case_name, results)
|
|
|
return report
|
|
|
-
|
|
|
+
|
|
|
except Exception as e:
|
|
|
self.logger.critical(f"Evaluation failed: {str(e)}", exc_info=True)
|
|
|
return {"error": str(e), "traceback": traceback.format_exc()}
|
|
|
-
|
|
|
+
|
|
|
def _generate_report(self, case_name: str, results: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
"""生成评估报告"""
|
|
|
from modules.lib.common import dict2json
|
|
|
-
|
|
|
+
|
|
|
self.report_path.mkdir(parents=True, exist_ok=True)
|
|
|
-
|
|
|
+
|
|
|
results["metadata"] = {
|
|
|
"case_name": case_name,
|
|
|
"timestamp": datetime.now().isoformat(),
|
|
|
"version": "3.1.0",
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
report_file = self.report_path / f"{case_name}_report.json"
|
|
|
dict2json(results, report_file)
|
|
|
self.logger.info(f"Report generated: {report_file}")
|
|
|
-
|
|
|
+
|
|
|
return results
|
|
|
|
|
|
+
|
|
|
def main():
|
|
|
"""命令行入口"""
|
|
|
parser = argparse.ArgumentParser(
|
|
|
description="Autonomous Driving Evaluation System V3.1",
|
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
parser.add_argument(
|
|
|
"--logPath",
|
|
|
type=str,
|
|
@@ -560,7 +567,7 @@ def main():
|
|
|
parser.add_argument(
|
|
|
"--dataPath",
|
|
|
type=str,
|
|
|
- default=r"D:\Kevin\zhaoyuan\data\V2V_CSAE53-2020_ForwardCollisionW_LST_01-01",
|
|
|
+ default=r"D:\Kevin\zhaoyuan\data\V2V_CSAE53-2020_ForwardCollision_LST_01-02",
|
|
|
help="Input data directory",
|
|
|
)
|
|
|
parser.add_argument(
|
|
@@ -587,19 +594,19 @@ def main():
|
|
|
default="config/custom_metrics_config.yaml",
|
|
|
help="Custom metrics config path (optional)",
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
try:
|
|
|
pipeline = EvaluationPipeline(
|
|
|
- args.configPath,
|
|
|
- args.logPath,
|
|
|
- args.dataPath,
|
|
|
- args.reportPath,
|
|
|
- args.customMetricsPath,
|
|
|
+ args.configPath,
|
|
|
+ args.logPath,
|
|
|
+ args.dataPath,
|
|
|
+ args.reportPath,
|
|
|
+ args.customMetricsPath,
|
|
|
args.customConfigPath
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
start_time = time.perf_counter()
|
|
|
result = pipeline.execute()
|
|
|
elapsed_time = time.perf_counter() - start_time
|
|
@@ -610,7 +617,7 @@ def main():
|
|
|
|
|
|
print(f"Evaluation completed, total time: {elapsed_time:.2f}s")
|
|
|
print(f"Report path: {pipeline.report_path}")
|
|
|
-
|
|
|
+
|
|
|
except KeyboardInterrupt:
|
|
|
print("\nUser interrupted")
|
|
|
sys.exit(130)
|
|
@@ -619,6 +626,7 @@ def main():
|
|
|
traceback.print_exc()
|
|
|
sys.exit(1)
|
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
warnings.filterwarnings("ignore")
|
|
|
main()
|