123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- """指标注册系统模块
- 此模块提供了指标注册和管理的基础设施,包括BaseMetric基类和MetricRegistry类。
- 所有自定义指标都应该继承BaseMetric基类,并实现calculate方法。
- """
- from typing import Dict, Any, List, Type, Optional
- import logging
- import inspect
- import importlib.util
- from pathlib import Path
- class BaseMetric:
- """指标基类
-
- 所有自定义指标都应该继承此类,并实现calculate方法。
- """
-
- def __init__(self, data: Any):
- """初始化指标
-
- Args:
- data: 输入数据,包含场景、轨迹等信息
- """
- self.data = data
-
- def calculate(self) -> Dict[str, Any]:
- """计算指标
-
- Returns:
- 计算结果字典,包含指标值、评分和详细信息
- """
- raise NotImplementedError("子类必须实现calculate方法")
- class MetricRegistry:
- """指标注册管理器
-
- 负责注册和管理所有可用的指标(内置和自定义)
- """
-
- def __init__(self, logger: Optional[logging.Logger] = None):
- """初始化注册管理器
-
- Args:
- logger: 日志记录器,如果为None则创建一个默认的记录器
- """
- self.metrics: Dict[str, Type[BaseMetric]] = {}
- self.logger = logger or logging.getLogger(__name__)
-
- def register(self, metric_key: str, metric_class: Type[BaseMetric]) -> None:
- """注册指标类
-
- Args:
- metric_key: 指标键名,通常为'level1.level2.level3'格式
- metric_class: 指标类,必须是BaseMetric的子类
- """
- if not issubclass(metric_class, BaseMetric):
- raise TypeError(f"指标类 {metric_class.__name__} 必须继承BaseMetric")
-
- self.metrics[metric_key] = metric_class
- self.logger.info(f"已注册指标: {metric_key}")
-
- def get_metric(self, metric_key: str) -> Optional[Type[BaseMetric]]:
- """获取指标类
-
- Args:
- metric_key: 指标键名
-
- Returns:
- 指标类,如果不存在则返回None
- """
- return self.metrics.get(metric_key)
-
- def get_all_metrics(self) -> Dict[str, Type[BaseMetric]]:
- """获取所有注册的指标类
-
- Returns:
- 指标类字典
- """
- return self.metrics
-
- def load_metrics_from_directory(self, directory_path: Path) -> List[str]:
- """从目录加载指标类
-
- Args:
- directory_path: 指标脚本目录路径
-
- Returns:
- 加载成功的指标键名列表
- """
- if not directory_path.exists() or not directory_path.is_dir():
- self.logger.warning(f"指标目录不存在: {directory_path}")
- return []
-
- loaded_metrics = []
- for py_file in directory_path.glob("*.py"):
- try:
- # 动态导入模块
- module_name = f"custom_metric_{py_file.stem}"
- spec = importlib.util.spec_from_file_location(module_name, py_file)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
-
- # 查找模块中的BaseMetric子类
- for name, obj in inspect.getmembers(module):
- if (inspect.isclass(obj) and
- issubclass(obj, BaseMetric) and
- obj != BaseMetric):
-
- # 获取指标类别
- category = getattr(module, 'METRIC_CATEGORY', 'custom')
-
- # 从文件名解析指标键名
- if py_file.stem.startswith('metric_'):
- parts = py_file.stem[len('metric_'):].split('_')
- if len(parts) >= 3:
- level1 = parts[0] if category == 'custom' else category
- level2 = parts[1]
- level3 = parts[2]
- metric_key = f"{level1}.{level2}.{level3}"
-
- # 注册指标类
- self.register(metric_key, obj)
- loaded_metrics.append(metric_key)
-
- # 一个文件只注册一个指标类
- break
- except Exception as e:
- self.logger.error(f"加载指标文件失败 {py_file}: {str(e)}")
-
- return loaded_metrics
|