function.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. ##################################################################
  4. #
  5. # Copyright (c) 2025 CICV, Inc. All Rights Reserved
  6. #
  7. ##################################################################
  8. """
  9. @Authors: zhanghaiwen(zhanghaiwen@china-icv.cn)
  10. @Data: 2025/01/5
  11. @Last Modified: 2025/01/5
  12. @Summary: Function Metrics Calculation
  13. """
  14. from modules.lib.score import Score
  15. from modules.lib.log_manager import LogManager
  16. import numpy as np
  17. from typing import Dict, Tuple, Optional, Callable, Any
  18. import pandas as pd
  19. # ----------------------
  20. # 基础工具函数 (Pure functions)
  21. # ----------------------
  22. def calculate_distance(ego_pos: np.ndarray, obj_pos: np.ndarray) -> np.ndarray:
  23. """向量化距离计算"""
  24. return np.linalg.norm(ego_pos - obj_pos, axis=1)
  25. def calculate_relative_speed(ego_speed: np.ndarray, obj_speed: np.ndarray) -> np.ndarray:
  26. """向量化相对速度计算"""
  27. return np.linalg.norm(ego_speed - obj_speed, axis=1)
  28. def extract_ego_obj(data: pd.DataFrame) -> Tuple[pd.Series, pd.DataFrame]:
  29. """数据提取函数"""
  30. ego = data[data['playerId'] == 1].iloc[0]
  31. obj = data[data['playerId'] != 1]
  32. return ego, obj
  33. def get_first_warning(ego_df: pd.DataFrame, obj_df: pd.DataFrame) -> Optional[pd.DataFrame]:
  34. """带缓存的预警数据获取"""
  35. warning_times = ego_df[ego_df['ifwarning'] == 1]['simTime']
  36. if warning_times.empty:
  37. return None
  38. first_time = warning_times.iloc[0]
  39. return obj_df[obj_df['simTime'] == first_time]
  40. # ----------------------
  41. # 核心计算功能函数
  42. # ----------------------
  43. def latestWarningDistance(data_processed) -> dict:
  44. """预警距离计算流水线"""
  45. ego_df = data_processed.ego_data
  46. obj_df = data_processed.object_df
  47. warning_data = get_first_warning(ego_df, obj_df)
  48. if warning_data is None:
  49. return {"latestWarningDistance": 0.0}
  50. ego, obj = extract_ego_obj(warning_data)
  51. distances = calculate_distance(
  52. np.array([[ego['posX'], ego['posY']]]),
  53. obj[['posX', 'posY']].values
  54. )
  55. return {"latestWarningDistance": float(np.min(distances))}
  56. def latestWarningDistance_TTC(data_processed) -> dict:
  57. """TTC计算流水线"""
  58. ego_df = data_processed.ego_data
  59. obj_df = data_processed.object_df
  60. warning_data = get_first_warning(ego_df, obj_df)
  61. if warning_data is None:
  62. return {"latestWarningDistance_TTC": 0.0}
  63. ego, obj = extract_ego_obj(warning_data)
  64. # 向量化计算
  65. ego_pos = np.array([[ego['posX'], ego['posY']]])
  66. ego_speed = np.array([[ego['speedX'], ego['speedY']]])
  67. obj_pos = obj[['posX', 'posY']].values
  68. obj_speed = obj[['speedX', 'speedY']].values
  69. distances = calculate_distance(ego_pos, obj_pos)
  70. rel_speeds = calculate_relative_speed(ego_speed, obj_speed)
  71. with np.errstate(divide='ignore', invalid='ignore'):
  72. ttc = np.where(rel_speeds != 0, distances / rel_speeds, np.inf)
  73. return {"latestWarningDistance_TTC": float(np.nanmin(ttc))}
  74. class FunctionRegistry:
  75. """动态函数注册器(支持参数验证)"""
  76. def __init__(self, data_processed):
  77. self.logger = LogManager().get_logger() # 获取全局日志实例
  78. self.data = data_processed
  79. self.fun_config = data_processed.function_config["function"]
  80. self.level_3_merics = self._extract_level_3_metrics(self.fun_config)
  81. self._registry: Dict[str, Callable] = {}
  82. self._registry = self._build_registry()
  83. def _extract_level_3_metrics(self, config_node: dict) -> list:
  84. """DFS遍历提取第三层指标(时间复杂度O(n))[4](@ref)"""
  85. metrics = []
  86. def _recurse(node):
  87. if isinstance(node, dict):
  88. if 'name' in node and not any(isinstance(v, dict) for v in node.values()):
  89. metrics.append(node['name'])
  90. for v in node.values():
  91. _recurse(v)
  92. _recurse(config_node)
  93. self.logger.info(f'评比的功能指标列表:{metrics}')
  94. return metrics
  95. def _build_registry(self) -> dict:
  96. """自动注册指标函数(防御性编程)"""
  97. registry = {}
  98. for func_name in self.level_3_merics:
  99. try:
  100. registry[func_name] = globals()[func_name]
  101. except KeyError:
  102. print(f"未实现指标函数: {func_name}")
  103. self.logger.error(f"未实现指标函数: {func_name}")
  104. return registry
  105. def batch_execute(self) -> dict:
  106. """批量执行指标计算(带熔断机制)"""
  107. results = {}
  108. for name, func in self._registry.items():
  109. try:
  110. result = func(self.data) # 统一传递数据上下文
  111. results.update(result)
  112. except Exception as e:
  113. print(f"{name} 执行失败: {str(e)}")
  114. self.logger.error(f"{name} 执行失败: {str(e)}", exc_info=True)
  115. results[name] = None
  116. self.logger.info(f'功能指标计算结果:{results}')
  117. return results
  118. class FunctionManager:
  119. """管理功能指标计算的类"""
  120. def __init__(self, data_processed):
  121. self.data = data_processed
  122. self.function = FunctionRegistry(self.data)
  123. def report_statistic(self):
  124. """
  125. 计算并报告功能指标结果。
  126. :return: 评估结果
  127. """
  128. function_result = self.function.batch_execute()
  129. # evaluator = Score(self.data.function_config)
  130. # result = evaluator.evaluate(function_result)
  131. # return result
  132. return function_result
  133. # self.logger.info(f'Function Result: {function_result}')
  134. # 使用示例
  135. if __name__ == "__main__":
  136. pass
  137. # print("\n[功能类表现及得分情况]")