#!/usr/bin/env python # -*- coding: utf-8 -*- ################################################################## # # Copyright (c) 2025 CICV, Inc. All Rights Reserved # ################################################################## """ @Authors: zhanghaiwen(zhanghaiwen@china-icv.cn) @Data: 2025/01/5 @Last Modified: 2025/01/5 @Summary: Function Metrics Calculation """ from modules.lib.score import Score from modules.lib.log_manager import LogManager import numpy as np from typing import Dict, Tuple, Optional, Callable, Any import pandas as pd # ---------------------- # 基础工具函数 (Pure functions) # ---------------------- def calculate_distance(ego_pos: np.ndarray, obj_pos: np.ndarray) -> np.ndarray: """向量化距离计算""" return np.linalg.norm(ego_pos - obj_pos, axis=1) def calculate_relative_speed(ego_speed: np.ndarray, obj_speed: np.ndarray) -> np.ndarray: """向量化相对速度计算""" return np.linalg.norm(ego_speed - obj_speed, axis=1) def extract_ego_obj(data: pd.DataFrame) -> Tuple[pd.Series, pd.DataFrame]: """数据提取函数""" ego = data[data['playerId'] == 1].iloc[0] obj = data[data['playerId'] != 1] return ego, obj def get_first_warning(ego_df: pd.DataFrame, obj_df: pd.DataFrame) -> Optional[pd.DataFrame]: """带缓存的预警数据获取""" warning_times = ego_df[ego_df['ifwarning'] == 1]['simTime'] if warning_times.empty: return None first_time = warning_times.iloc[0] return obj_df[obj_df['simTime'] == first_time] # ---------------------- # 核心计算功能函数 # ---------------------- def latestWarningDistance(data_processed) -> dict: """预警距离计算流水线""" ego_df = data_processed.ego_data obj_df = data_processed.object_df warning_data = get_first_warning(ego_df, obj_df) if warning_data is None: return {"latestWarningDistance": 0.0} ego, obj = extract_ego_obj(warning_data) distances = calculate_distance( np.array([[ego['posX'], ego['posY']]]), obj[['posX', 'posY']].values ) return {"latestWarningDistance": float(np.min(distances))} def latestWarningDistance_TTC(data_processed) -> dict: """TTC计算流水线""" ego_df = data_processed.ego_data obj_df = data_processed.object_df warning_data = get_first_warning(ego_df, obj_df) if warning_data is None: return {"latestWarningDistance_TTC": 0.0} ego, obj = extract_ego_obj(warning_data) # 向量化计算 ego_pos = np.array([[ego['posX'], ego['posY']]]) ego_speed = np.array([[ego['speedX'], ego['speedY']]]) obj_pos = obj[['posX', 'posY']].values obj_speed = obj[['speedX', 'speedY']].values distances = calculate_distance(ego_pos, obj_pos) rel_speeds = calculate_relative_speed(ego_speed, obj_speed) with np.errstate(divide='ignore', invalid='ignore'): ttc = np.where(rel_speeds != 0, distances / rel_speeds, np.inf) return {"latestWarningDistance_TTC": float(np.nanmin(ttc))} class FunctionRegistry: """动态函数注册器(支持参数验证)""" def __init__(self, data_processed): self.logger = LogManager().get_logger() # 获取全局日志实例 self.data = data_processed self.fun_config = data_processed.function_config["function"] self.level_3_merics = self._extract_level_3_metrics(self.fun_config) self._registry: Dict[str, Callable] = {} self._registry = self._build_registry() def _extract_level_3_metrics(self, config_node: dict) -> list: """DFS遍历提取第三层指标(时间复杂度O(n))[4](@ref)""" metrics = [] def _recurse(node): if isinstance(node, dict): if 'name' in node and not any(isinstance(v, dict) for v in node.values()): metrics.append(node['name']) for v in node.values(): _recurse(v) _recurse(config_node) self.logger.info(f'评比的功能指标列表:{metrics}') return metrics def _build_registry(self) -> dict: """自动注册指标函数(防御性编程)""" registry = {} for func_name in self.level_3_merics: try: registry[func_name] = globals()[func_name] except KeyError: print(f"未实现指标函数: {func_name}") self.logger.error(f"未实现指标函数: {func_name}") return registry def batch_execute(self) -> dict: """批量执行指标计算(带熔断机制)""" results = {} for name, func in self._registry.items(): try: result = func(self.data) # 统一传递数据上下文 results.update(result) except Exception as e: print(f"{name} 执行失败: {str(e)}") self.logger.error(f"{name} 执行失败: {str(e)}", exc_info=True) results[name] = None self.logger.info(f'功能指标计算结果:{results}') return results class FunctionManager: """管理功能指标计算的类""" def __init__(self, data_processed): self.data = data_processed self.function = FunctionRegistry(self.data) def report_statistic(self): """ 计算并报告功能指标结果。 :return: 评估结果 """ function_result = self.function.batch_execute() # evaluator = Score(self.data.function_config) # result = evaluator.evaluate(function_result) # return result return function_result # self.logger.info(f'Function Result: {function_result}') # 使用示例 if __name__ == "__main__": pass # print("\n[功能类表现及得分情况]")