123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- ##################################################################
- #
- # Copyright (c) 2025 CICV, Inc. All Rights Reserved
- #
- ##################################################################
- """
- @Authors: zhanghaiwen(zhanghaiwen@china-icv.cn)
- @Data: 2025/01/5
- @Last Modified: 2025/01/5
- @Summary: Function Metrics Calculation
- """
- from modules.lib.score import Score
- from modules.lib.log_manager import LogManager
- import numpy as np
- from typing import Dict, Tuple, Optional, Callable, Any
- import pandas as pd
- # ----------------------
- # 基础工具函数 (Pure functions)
- # ----------------------
- def calculate_distance(ego_pos: np.ndarray, obj_pos: np.ndarray) -> np.ndarray:
- """向量化距离计算"""
- return np.linalg.norm(ego_pos - obj_pos, axis=1)
- def calculate_relative_speed(ego_speed: np.ndarray, obj_speed: np.ndarray) -> np.ndarray:
- """向量化相对速度计算"""
- return np.linalg.norm(ego_speed - obj_speed, axis=1)
- def extract_ego_obj(data: pd.DataFrame) -> Tuple[pd.Series, pd.DataFrame]:
- """数据提取函数"""
- ego = data[data['playerId'] == 1].iloc[0]
- obj = data[data['playerId'] != 1]
- return ego, obj
- def get_first_warning(ego_df: pd.DataFrame, obj_df: pd.DataFrame) -> Optional[pd.DataFrame]:
- """带缓存的预警数据获取"""
- warning_times = ego_df[ego_df['ifwarning'] == 1]['simTime']
- if warning_times.empty:
- return None
- first_time = warning_times.iloc[0]
- return obj_df[obj_df['simTime'] == first_time]
- # ----------------------
- # 核心计算功能函数
- # ----------------------
- def latestWarningDistance(data_processed) -> dict:
- """预警距离计算流水线"""
- ego_df = data_processed.ego_data
- obj_df = data_processed.object_df
- warning_data = get_first_warning(ego_df, obj_df)
- if warning_data is None:
- return {"latestWarningDistance": 0.0}
- ego, obj = extract_ego_obj(warning_data)
- distances = calculate_distance(
- np.array([[ego['posX'], ego['posY']]]),
- obj[['posX', 'posY']].values
- )
- return {"latestWarningDistance": float(np.min(distances))}
- def latestWarningDistance_TTC(data_processed) -> dict:
- """TTC计算流水线"""
- ego_df = data_processed.ego_data
- obj_df = data_processed.object_df
- warning_data = get_first_warning(ego_df, obj_df)
- if warning_data is None:
- return {"latestWarningDistance_TTC": 0.0}
- ego, obj = extract_ego_obj(warning_data)
-
- # 向量化计算
- ego_pos = np.array([[ego['posX'], ego['posY']]])
- ego_speed = np.array([[ego['speedX'], ego['speedY']]])
- obj_pos = obj[['posX', 'posY']].values
- obj_speed = obj[['speedX', 'speedY']].values
- distances = calculate_distance(ego_pos, obj_pos)
- rel_speeds = calculate_relative_speed(ego_speed, obj_speed)
- with np.errstate(divide='ignore', invalid='ignore'):
- ttc = np.where(rel_speeds != 0, distances / rel_speeds, np.inf)
-
- return {"latestWarningDistance_TTC": float(np.nanmin(ttc))}
- class FunctionRegistry:
- """动态函数注册器(支持参数验证)"""
-
- def __init__(self, data_processed):
- self.logger = LogManager().get_logger() # 获取全局日志实例
- self.data = data_processed
- self.fun_config = data_processed.function_config["function"]
- self.level_3_merics = self._extract_level_3_metrics(self.fun_config)
- self._registry: Dict[str, Callable] = {}
- self._registry = self._build_registry()
-
- def _extract_level_3_metrics(self, config_node: dict) -> list:
- """DFS遍历提取第三层指标(时间复杂度O(n))[4](@ref)"""
- metrics = []
- def _recurse(node):
- if isinstance(node, dict):
- if 'name' in node and not any(isinstance(v, dict) for v in node.values()):
- metrics.append(node['name'])
- for v in node.values():
- _recurse(v)
- _recurse(config_node)
- self.logger.info(f'评比的功能指标列表:{metrics}')
- return metrics
- def _build_registry(self) -> dict:
- """自动注册指标函数(防御性编程)"""
- registry = {}
- for func_name in self.level_3_merics:
- try:
- registry[func_name] = globals()[func_name]
- except KeyError:
- print(f"未实现指标函数: {func_name}")
- self.logger.error(f"未实现指标函数: {func_name}")
- return registry
- def batch_execute(self) -> dict:
- """批量执行指标计算(带熔断机制)"""
- results = {}
- for name, func in self._registry.items():
- try:
- result = func(self.data) # 统一传递数据上下文
- results.update(result)
- except Exception as e:
- print(f"{name} 执行失败: {str(e)}")
- self.logger.error(f"{name} 执行失败: {str(e)}", exc_info=True)
- results[name] = None
- self.logger.info(f'功能指标计算结果:{results}')
- return results
- class FunctionManager:
- """管理功能指标计算的类"""
- def __init__(self, data_processed):
- self.data = data_processed
- self.function = FunctionRegistry(self.data)
- def report_statistic(self):
- """
- 计算并报告功能指标结果。
- :return: 评估结果
- """
- function_result = self.function.batch_execute()
- # evaluator = Score(self.data.function_config)
- # result = evaluator.evaluate(function_result)
- # return result
- return function_result
- # self.logger.info(f'Function Result: {function_result}')
- # 使用示例
- if __name__ == "__main__":
- pass
- # print("\n[功能类表现及得分情况]")
-
|