metric_registry.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. """指标注册系统模块
  2. 此模块提供了指标注册和管理的基础设施,包括BaseMetric基类和MetricRegistry类。
  3. 所有自定义指标都应该继承BaseMetric基类,并实现calculate方法。
  4. """
  5. from typing import Dict, Any, List, Type, Optional
  6. import logging
  7. import inspect
  8. import importlib.util
  9. from pathlib import Path
  10. class BaseMetric:
  11. """指标基类
  12. 所有自定义指标都应该继承此类,并实现calculate方法。
  13. """
  14. def __init__(self, data: Any):
  15. """初始化指标
  16. Args:
  17. data: 输入数据,包含场景、轨迹等信息
  18. """
  19. self.data = data
  20. def calculate(self) -> Dict[str, Any]:
  21. """计算指标
  22. Returns:
  23. 计算结果字典,包含指标值、评分和详细信息
  24. """
  25. raise NotImplementedError("子类必须实现calculate方法")
  26. class MetricRegistry:
  27. """指标注册管理器
  28. 负责注册和管理所有可用的指标(内置和自定义)
  29. """
  30. def __init__(self, logger: Optional[logging.Logger] = None):
  31. """初始化注册管理器
  32. Args:
  33. logger: 日志记录器,如果为None则创建一个默认的记录器
  34. """
  35. self.metrics: Dict[str, Type[BaseMetric]] = {}
  36. self.logger = logger or logging.getLogger(__name__)
  37. def register(self, metric_key: str, metric_class: Type[BaseMetric]) -> None:
  38. """注册指标类
  39. Args:
  40. metric_key: 指标键名,通常为'level1.level2.level3'格式
  41. metric_class: 指标类,必须是BaseMetric的子类
  42. """
  43. if not issubclass(metric_class, BaseMetric):
  44. raise TypeError(f"指标类 {metric_class.__name__} 必须继承BaseMetric")
  45. self.metrics[metric_key] = metric_class
  46. self.logger.info(f"已注册指标: {metric_key}")
  47. def get_metric(self, metric_key: str) -> Optional[Type[BaseMetric]]:
  48. """获取指标类
  49. Args:
  50. metric_key: 指标键名
  51. Returns:
  52. 指标类,如果不存在则返回None
  53. """
  54. return self.metrics.get(metric_key)
  55. def get_all_metrics(self) -> Dict[str, Type[BaseMetric]]:
  56. """获取所有注册的指标类
  57. Returns:
  58. 指标类字典
  59. """
  60. return self.metrics
  61. def load_metrics_from_directory(self, directory_path: Path) -> List[str]:
  62. """从目录加载指标类
  63. Args:
  64. directory_path: 指标脚本目录路径
  65. Returns:
  66. 加载成功的指标键名列表
  67. """
  68. if not directory_path.exists() or not directory_path.is_dir():
  69. self.logger.warning(f"指标目录不存在: {directory_path}")
  70. return []
  71. loaded_metrics = []
  72. for py_file in directory_path.glob("*.py"):
  73. try:
  74. # 动态导入模块
  75. module_name = f"custom_metric_{py_file.stem}"
  76. spec = importlib.util.spec_from_file_location(module_name, py_file)
  77. module = importlib.util.module_from_spec(spec)
  78. spec.loader.exec_module(module)
  79. # 查找模块中的BaseMetric子类
  80. for name, obj in inspect.getmembers(module):
  81. if (inspect.isclass(obj) and
  82. issubclass(obj, BaseMetric) and
  83. obj != BaseMetric):
  84. # 获取指标类别
  85. category = getattr(module, 'METRIC_CATEGORY', 'custom')
  86. # 从文件名解析指标键名
  87. if py_file.stem.startswith('metric_'):
  88. parts = py_file.stem[len('metric_'):].split('_')
  89. if len(parts) >= 3:
  90. level1 = parts[0] if category == 'custom' else category
  91. level2 = parts[1]
  92. level3 = parts[2]
  93. metric_key = f"{level1}.{level2}.{level3}"
  94. # 注册指标类
  95. self.register(metric_key, obj)
  96. loaded_metrics.append(metric_key)
  97. # 一个文件只注册一个指标类
  98. break
  99. except Exception as e:
  100. self.logger.error(f"加载指标文件失败 {py_file}: {str(e)}")
  101. return loaded_metrics