"""指标注册系统模块 此模块提供了指标注册和管理的基础设施,包括BaseMetric基类和MetricRegistry类。 所有自定义指标都应该继承BaseMetric基类,并实现calculate方法。 """ from typing import Dict, Any, List, Type, Optional import logging import inspect import importlib.util from pathlib import Path class BaseMetric: """指标基类 所有自定义指标都应该继承此类,并实现calculate方法。 """ def __init__(self, data: Any): """初始化指标 Args: data: 输入数据,包含场景、轨迹等信息 """ self.data = data def calculate(self) -> Dict[str, Any]: """计算指标 Returns: 计算结果字典,包含指标值、评分和详细信息 """ raise NotImplementedError("子类必须实现calculate方法") class MetricRegistry: """指标注册管理器 负责注册和管理所有可用的指标(内置和自定义) """ def __init__(self, logger: Optional[logging.Logger] = None): """初始化注册管理器 Args: logger: 日志记录器,如果为None则创建一个默认的记录器 """ self.metrics: Dict[str, Type[BaseMetric]] = {} self.logger = logger or logging.getLogger(__name__) def register(self, metric_key: str, metric_class: Type[BaseMetric]) -> None: """注册指标类 Args: metric_key: 指标键名,通常为'level1.level2.level3'格式 metric_class: 指标类,必须是BaseMetric的子类 """ if not issubclass(metric_class, BaseMetric): raise TypeError(f"指标类 {metric_class.__name__} 必须继承BaseMetric") self.metrics[metric_key] = metric_class self.logger.info(f"已注册指标: {metric_key}") def get_metric(self, metric_key: str) -> Optional[Type[BaseMetric]]: """获取指标类 Args: metric_key: 指标键名 Returns: 指标类,如果不存在则返回None """ return self.metrics.get(metric_key) def get_all_metrics(self) -> Dict[str, Type[BaseMetric]]: """获取所有注册的指标类 Returns: 指标类字典 """ return self.metrics def load_metrics_from_directory(self, directory_path: Path) -> List[str]: """从目录加载指标类 Args: directory_path: 指标脚本目录路径 Returns: 加载成功的指标键名列表 """ if not directory_path.exists() or not directory_path.is_dir(): self.logger.warning(f"指标目录不存在: {directory_path}") return [] loaded_metrics = [] for py_file in directory_path.glob("*.py"): try: # 动态导入模块 module_name = f"custom_metric_{py_file.stem}" spec = importlib.util.spec_from_file_location(module_name, py_file) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # 查找模块中的BaseMetric子类 for name, obj in inspect.getmembers(module): if (inspect.isclass(obj) and issubclass(obj, BaseMetric) and obj != BaseMetric): # 获取指标类别 category = getattr(module, 'METRIC_CATEGORY', 'custom') # 从文件名解析指标键名 if py_file.stem.startswith('metric_'): parts = py_file.stem[len('metric_'):].split('_') if len(parts) >= 3: level1 = parts[0] if category == 'custom' else category level2 = parts[1] level3 = parts[2] metric_key = f"{level1}.{level2}.{level3}" # 注册指标类 self.register(metric_key, obj) loaded_metrics.append(metric_key) # 一个文件只注册一个指标类 break except Exception as e: self.logger.error(f"加载指标文件失败 {py_file}: {str(e)}") return loaded_metrics