#!/usr/bin/env python # -*- coding: utf-8 -*- """ 安全指标计算模块(支持多目标物) - 场景感知版本 优化要点: 1. 重构代码结构,提高可读性 2. 去除冗余代码,优化性能 3. 保持原始计算逻辑不变 4. 增加全面注释 5. 优化数据处理流程 """ import os import numpy as np import pandas as pd import math import scipy.integrate as spi from collections import defaultdict from typing import Dict, Any, List, Optional, Tuple from pathlib import Path from modules.lib.score import Score from modules.lib.log_manager import LogManager from modules.lib.chart_generator import generate_safety_chart_data # ==================== 安全指标计算接口函数 ==================== # 每个函数对应一个安全指标的计算入口 # ------------------------------------------------------------ def calculate_ttc(data_processed, plot_path) -> dict: """计算碰撞时间(TTC)""" return _calculate_metric('TTC', data_processed, plot_path) def calculate_mttc(data_processed, plot_path) -> dict: """计算修正碰撞时间(MTTC)""" return _calculate_metric('MTTC', data_processed, plot_path) def calculate_thw(data_processed, plot_path) -> dict: """计算车头时距(THW)""" return _calculate_metric('THW', data_processed, plot_path) def calculate_ttb(data_processed, plot_path) -> dict: """计算制动时间(TTB)""" return _calculate_metric('TTB', data_processed, plot_path) def calculate_tm(data_processed, plot_path) -> dict: """计算时间裕度(TM)""" return _calculate_metric('TM', data_processed, plot_path) def calculate_dtc(data_processed, plot_path) -> dict: """计算碰撞距离(DTC)""" return _calculate_metric('DTC', data_processed, plot_path) def calculate_psd(data_processed, plot_path) -> dict: """计算预测安全距离比(PSD)""" return _calculate_metric('PSD', data_processed, plot_path) def calculate_collisionrisk(data_processed, plot_path) -> dict: """计算碰撞风险(collisionRisk)""" return _calculate_metric('collisionRisk', data_processed, plot_path) def calculate_lonsd(data_processed, plot_path) -> dict: """计算纵向安全距离(LonSD)""" return _calculate_metric('LonSD', data_processed, plot_path) def calculate_latsd(data_processed, plot_path) -> dict: """计算横向安全距离(LatSD)""" return _calculate_metric('LatSD', data_processed, plot_path) def calculate_btn(data_processed, plot_path) -> dict: """计算制动威胁数(BTN)""" return _calculate_metric('BTN', data_processed, plot_path) def calculate_collisionseverity(data_processed, plot_path) -> dict: """计算碰撞严重性(collisionSeverity)""" return _calculate_metric('collisionSeverity', data_processed, plot_path) def _calculate_metric(metric_name, data_processed, plot_path) -> dict: """安全指标计算的通用处理函数""" key_name = metric_name result_key = {metric_name: None} # 检查输入数据有效性 if data_processed is None or not hasattr(data_processed, 'object_df'): return result_key try: safety = SafetyCalculator(data_processed) # 特殊处理 collisionRisk 和 collisionSeverity 指标 if metric_name.lower() == 'collisionrisk': metric_value = safety.get_collision_risk_value() elif metric_name.lower() == 'collisionseverity': metric_value = safety.get_collision_severity_value() else: # 其他指标使用原有逻辑 metric_value = getattr(safety, f'get_{metric_name.lower()}_value')() # 生成图表数据 metric_data = getattr(safety, f'{metric_name.lower()}_data', None) if metric_data: safety.generate_metric_chart(metric_name, plot_path) LogManager().get_logger().info(f"安全指标[{metric_name}]计算结果: {metric_value}") return {metric_name: metric_value} except Exception as e: LogManager().get_logger().error(f"{metric_name}计算异常: {str(e)}", exc_info=True) return result_key # ==================== 安全指标注册与批处理类 ==================== # 用于根据配置批量执行安全指标计算 # ------------------------------------------------------------ class SafetyRegistry: """安全指标注册器,根据配置动态注册和执行指标计算""" def __init__(self, data_processed, plot_path): self.logger = LogManager().get_logger() self.data = data_processed self.plot_path = plot_path # 检查安全配置 if not hasattr(data_processed, 'safety_config') or not data_processed.safety_config: self.logger.warning("安全配置为空,跳过安全指标计算") self.safety_config = {} self.metrics = [] self._registry = {} return self.safety_config = data_processed.safety_config.get("safety", {}) self.metrics = self._extract_metrics(self.safety_config) self._registry = self._build_registry() def _extract_metrics(self, config_node: dict) -> list: """从配置中递归提取所有指标名称""" metrics = [] def _recurse(node): if isinstance(node, dict): if 'name' in node and not any(isinstance(v, dict) for v in node.values()): metrics.append(node['name']) for v in node.values(): _recurse(v) _recurse(config_node) self.logger.info(f'评比的安全指标列表: {metrics}') return metrics def _build_registry(self) -> dict: """构建指标名称到计算函数的映射""" registry = {} for metric_name in self.metrics: func_name = f"calculate_{metric_name.lower()}" if func_name in globals(): registry[metric_name] = globals()[func_name] else: self.logger.warning(f"未实现安全指标函数: {func_name}") return registry def batch_execute(self) -> dict: """批量执行所有注册的安全指标计算""" results = {} if not hasattr(self, 'safety_config') or not self.safety_config or not self._registry: self.logger.info("安全配置为空或无注册指标,返回空结果") return results for name, func in self._registry.items(): try: result = func(self.data, self.plot_path) results.update(result) except Exception as e: self.logger.error(f"{name} 执行失败: {str(e)}", exc_info=True) results[name] = None self.logger.info(f'安全指标计算结果: {results}') return results class SafeManager: """安全管理器,封装安全指标计算流程""" def __init__(self, data_processed, plot_path): self.data = data_processed self.logger = LogManager().get_logger() self.plot_path = plot_path # 初始化安全指标注册器 if not hasattr(data_processed, 'safety_config') or not data_processed.safety_config: self.logger.warning("安全配置为空,跳过安全指标计算初始化") self.registry = None else: self.registry = SafetyRegistry(self.data, self.plot_path) def report_statistic(self): """生成安全指标统计报告""" if self.registry is None: self.logger.info("安全指标管理器未初始化,返回空结果") return {} return self.registry.batch_execute() # ==================== 安全指标计算核心类 ==================== # 包含所有安全指标的具体计算实现 # ------------------------------------------------------------ class SafetyCalculator: """安全指标计算器,实现各种安全指标的具体计算""" DEFAULT_VALUES = { 'TTC': 10.0, 'MTTC': 3.3, 'THW': 2.5, 'TTB': 10.0, 'TM': 10.0, 'DTC': 10.0, 'PSD': 10.0, 'LonSD': 100.0, 'LatSD': 2.0, 'BTN': 0.0, 'collisionRisk': 0.0, 'collisionSeverity': 0.0 } def __init__(self, data_processed): self.logger = LogManager().get_logger() self.data_processed = data_processed self.df = data_processed.object_df.copy() self.ego_df = data_processed.ego_data.copy() self.obj_id_list = data_processed.obj_id_list # 初始化配置参数 self._init_config_params() # 初始化数据结构和指标存储 self._init_data_structures() # 执行安全参数计算 if self.obj_id_list: self.logger.info("开始执行安全参数计算,目标物数量: %d", len(self.obj_id_list) - 1) self._calculate_safety_parameters() self.logger.info("安全参数计算完成") else: self.logger.info("没有目标物,跳过安全参数计算") self.empty_flag = True def _init_config_params(self): """从配置中初始化车辆参数""" config = self.data_processed.vehicle_config self.rho = config.get("RHO", 1.5) # 默认反应时间1.5秒 self.ego_accel_max = config.get("EGO_ACCEL_MAX", 3.0) # 默认最大加速度3.0 m/s² self.obj_decel_max = config.get("OBJ_DECEL_MAX", 3.0) # 默认目标最大减速度3.0 m/s² self.ego_decel_min = config.get("EGO_DECEL_MIN", 4.0) # 默认自车最小减速度4.0 m/s² self.ego_decel_lon_max = config.get("EGO_DECEL_LON_MAX", 4.0) # 默认纵向最大减速度4.0 m/s² self.ego_decel_lat_max = config.get("EGO_DECEL_LAT_MAX", 3.0) # 默认横向最大减速度3.0 m/s² self.ego_width = config.get("CAR_WIDTH", 2.0) # 默认车辆宽度2.0米 self.ego_length = config.get("CAR_LENGTH", 4.5) # 默认车辆长度4.5米 self.vehicle_length = config.get("VEHICLE_LENGTH", 4.5) # 默认车辆长度4.5米 self.vehicle_width = config.get("VEHICLE_WIDTH", 2.0) # 默认车辆宽度2.0米 self.pedestrian_length = config.get("PEDESTRIAN_LENGTH", 0.5) # 默认行人长度0.5米 self.pedestrian_width = config.get("PEDESTRIAN_WIDTH", 0.5) # 默认行人宽度0.5米 # 计算派生参数 self.max_deceleration = self.ego_decel_lon_max self.ego_decel_max = np.sqrt(self.ego_decel_lon_max ** 2 + self.ego_decel_lat_max ** 2) self.ped_width = self.pedestrian_width self.ped_length = self.pedestrian_length def _init_data_structures(self): """初始化数据存储结构""" self.time_list = self.ego_df['simTime'].values.tolist() self.frame_list = self.ego_df['simFrame'].values.tolist() self.empty_flag = True # 标记是否有有效目标 # 初始化指标数据存储 self._init_metric_storage() # 创建输出目录 self.output_dir = os.path.join(os.getcwd(), 'data') os.makedirs(self.output_dir, exist_ok=True) def _init_metric_storage(self): """初始化所有指标的存储列表""" metrics = [ 'ttc', 'mttc', 'thw', 'ttb', 'tm', 'lonsd', 'latsd', 'btn', 'collision_risk', 'collision_severity' ] for metric in metrics: setattr(self, f'{metric}_data', []) def _calculate_safety_parameters(self): """核心安全参数计算方法""" # 预处理:构建按帧组织的目标字典 obj_dict = self._preprocess_object_data() df_list = [] EGO_PLAYER_ID = 1 found_valid_target = False # 逐帧处理数据 for frame_num in self.frame_list: try: ego_data = obj_dict[frame_num][EGO_PLAYER_ID] except KeyError: continue frame_targets = self._process_frame_targets(frame_num, ego_data, obj_dict, EGO_PLAYER_ID) if frame_targets: found_valid_target = True df_fnum = pd.DataFrame(frame_targets) df_list.append(df_fnum) # 合并处理结果 self._postprocess_results(df_list, found_valid_target) def _preprocess_object_data(self): """预处理目标数据为按帧组织的字典""" obj_dict = defaultdict(dict) obj_data_dict = self.df.to_dict('records') for item in obj_data_dict: obj_dict[item['simFrame']][item['playerId']] = item return obj_dict def _process_frame_targets(self, frame_num, ego_data, obj_dict, ego_id): """处理单帧内的所有目标物""" frame_targets = [] vx_ego = ego_data.get('lon_v_vehicle', 0) vy_ego = ego_data.get('lat_v_vehicle', 0) self.lane_width = ego_data.get('lane_width', 3.75) for player_id in self.obj_id_list: if player_id == ego_id: continue try: obj_data = obj_dict[frame_num][player_id] except KeyError: continue # 计算目标物安全指标 result = self._calculate_target_metrics(ego_data, obj_data, vx_ego, vy_ego) obj_data.update(result) frame_targets.append(obj_data) return frame_targets def _calculate_target_metrics(self, ego_data, obj_data, vx_ego, vy_ego): """计算单个目标物的安全指标""" # 提取基本参数 obj_type = obj_data.get('type', 0) lon_d = obj_data.get('x_relative_start_dist', 0) lat_d = obj_data.get('y_relative_start_dist', 0) vx_obj = obj_data.get('lon_v_vehicle', 0) vy_obj = obj_data.get('lat_v_vehicle', 0) # 计算相对运动参数 vx_rel = vx_obj - vx_ego vy_rel = vy_obj - vy_ego dist = math.sqrt(lon_d ** 2 + lat_d ** 2) # 计算航向角差异 h1 = ego_data['posH'] h2 = obj_data.get('posH', h1) heading_diff = abs(h1 - h2) if heading_diff > 180: heading_diff = 360 - heading_diff # 计算相对加速度 ax_ego = ego_data.get('lon_acc_vehicle', 0) ay_ego = ego_data.get('lat_acc_vehicle', 0) ax_obj = obj_data.get('lon_acc_vehicle', 0) ay_obj = obj_data.get('lat_acc_vehicle', 0) ax_rel_ego = ax_obj - ax_ego ay_rel_ego = ay_obj - ay_ego # 计算投影 vrel_projection = self._cal_v_projection_using_relative(lon_d, lat_d, vx_rel, vy_rel) accl_projection = self._cal_a_projection_using_relative(lon_d, lat_d, ax_rel_ego, ay_rel_ego) # 检查目标相关性 is_relevant = self._is_relevant_target(lat_d, lon_d, vy_rel, heading_diff, obj_type) # 检查过滤条件 passed_conflict = self._has_passed_conflict_point(lon_d, lat_d, obj_type, heading_diff) moving_away = self._is_moving_away( (vx_ego, vy_ego), (vx_obj, vy_obj), lon_d, lat_d, obj_type, heading_diff ) skip_risk_calculation = passed_conflict and (moving_away or obj_data.get('v', 0) < 0.05) # 计算指标 if not is_relevant or skip_risk_calculation: return self._get_default_metrics(obj_type) else: return self._calculate_relevant_metrics( dist, vrel_projection, accl_projection, lon_d, lat_d, vx_ego, vy_ego, vx_rel, ax_rel_ego, heading_diff, obj_type, ego_data, obj_data ) def _get_default_metrics(self, obj_type): """获取默认指标值""" return { 'TTC': self.DEFAULT_VALUES['TTC'], 'MTTC': self.DEFAULT_VALUES['MTTC'], 'THW': self.DEFAULT_VALUES['THW'] if obj_type != 5 else None, 'TTB': self.DEFAULT_VALUES['TTB'] if obj_type != 5 else None, 'TM': self.DEFAULT_VALUES['TM'] if obj_type != 5 else None, 'PSD': self.DEFAULT_VALUES['PSD'], 'DTC': self.DEFAULT_VALUES['DTC'], 'LonSD': self.DEFAULT_VALUES['LonSD'] if obj_type != 5 else None, 'LatSD': self.DEFAULT_VALUES['LatSD'] if obj_type != 5 else None, 'BTN': self.DEFAULT_VALUES['BTN'] if obj_type != 5 else None, 'collisionSeverity': self.DEFAULT_VALUES['collisionSeverity'], 'pr_death': 0, 'collisionRisk': self.DEFAULT_VALUES['collisionRisk'] } def _calculate_relevant_metrics(self, dist, vrel_projection, accl_projection, lon_d, lat_d, vx_ego, vy_ego, vx_rel, ax_rel_ego, heading_diff, obj_type, ego_data, obj_data): """计算相关目标的指标""" Tc = 0.3 # 反应时间 # 核心指标计算 metrics = { 'TTC': self._cal_TTC(dist, vrel_projection), 'MTTC': self._cal_MTTC(dist, vrel_projection, ego_data.get('lon_acc_vehicle', 0)), 'THW': self._cal_THW(lon_d, vx_ego), 'TTB': self._cal_TTB(dist, vrel_projection, self.ego_decel_max) if obj_type != 5 else None, 'TM': self._cal_TM(vrel_projection, obj_data.get('v', 0), obj_data.get('lon_acc_vehicle', 0), ego_data.get('v', 0), ego_data.get('lon_acc_vehicle', 0)) if obj_type != 5 else None, 'PSD': self._cal_PSD(dist, vrel_projection, self.max_deceleration, heading_diff, vx_ego, vy_ego, obj_type, lon_d, lat_d), 'DTC': self._cal_DTC(vx_rel, ax_rel_ego, self.rho) if abs(heading_diff) < 30 else self.DEFAULT_VALUES[ 'DTC'], 'LonSD': self._cal_longitudinal_safe_dist(vx_ego, vx_rel, self.rho, self.ego_decel_min, self.obj_decel_max) if obj_type != 5 else None, 'LatSD': self._cal_lateral_safe_dist(abs(lat_d), vy_ego, self.ego_width, self.lane_width, self.ego_decel_lat_max) if obj_type != 5 else None, 'BTN': self._cal_BTN_new(ego_data.get('lon_acc_vehicle', 0), ego_data.get('lon_acc_vehicle', 0), abs(lon_d), vx_ego, self.ego_decel_lon_max) if obj_type != 5 else None } # 处理碰撞风险相关指标 if metrics['TTC'] is None or metrics['TTC'] > 4000: metrics.update({ 'collisionSeverity': 0, 'pr_death': 0, 'collisionRisk': 0 }) else: try: result, _ = spi.quad(self._normal_distribution, 0, metrics['TTC'] - Tc) collisionSeverity = 1 - result pr_death = self._death_pr(obj_type, vrel_projection) collisionRisk = self._cal_collisionRisk_level(obj_type, vrel_projection, collisionSeverity) metrics.update({ 'collisionSeverity': collisionSeverity * 100, 'pr_death': pr_death * 100, 'collisionRisk': collisionRisk * 100 }) except Exception as e: self.logger.error(f"碰撞风险计算错误: {e}") metrics.update({ 'collisionSeverity': 0, 'pr_death': 0, 'collisionRisk': 0 }) return metrics def _postprocess_results(self, df_list, found_valid_target): """后处理计算结果""" self.empty_flag = not found_valid_target # 合并所有帧数据 if df_list: self.df_safe = pd.concat(df_list) col_list = ['simTime', 'simFrame', 'playerId', 'TTC', 'MTTC', 'THW', 'TTB', 'TM', 'DTC', 'PSD', 'LonSD', 'LatSD', 'BTN', 'collisionSeverity', 'pr_death', 'collisionRisk'] self.df_safe = self.df_safe[col_list].reset_index(drop=True) else: self.df_safe = pd.DataFrame() self.empty_flag = True self.logger.info(f"处理完成,找到有效目标: {not self.empty_flag}") # ==================== 核心计算方法 ==================== def _cal_v_projection_using_relative(self, lon_d, lat_d, vx_rel, vy_rel): """计算速度投影(自车坐标系)""" dist = math.sqrt(lon_d ** 2 + lat_d ** 2) if dist < 1e-6: return 0.0 U_ABx = lon_d / dist U_ABy = lat_d / dist return vx_rel * U_ABx + vy_rel * U_ABy def _cal_a_projection_using_relative(self, lon_d, lat_d, ax_rel, ay_rel): """计算加速度投影(自车坐标系)""" dist = math.sqrt(lon_d ** 2 + lat_d ** 2) if dist < 1e-6: return 0.0 U_ABx = lon_d / dist U_ABy = lat_d / dist return ax_rel * U_ABx + ay_rel * U_ABy def _cal_TTC(self, dist, vrel_projection): """计算碰撞时间(TTC)""" if vrel_projection >= 0: # 不接近 return None return dist / abs(vrel_projection) def _cal_MTTC(self, dist, vrel_projection, a_ego): """计算修正的碰撞时间(MTTC)""" if vrel_projection >= 0 or dist <= 0 or a_ego >= 0: return None try: if abs(a_ego) < 1e-6: return dist / abs(vrel_projection) if abs(vrel_projection) > 1e-6 else None discriminant = vrel_projection ** 2 + 2 * a_ego * dist if discriminant < 0: return None t1 = (-vrel_projection + math.sqrt(discriminant)) / a_ego t2 = (-vrel_projection - math.sqrt(discriminant)) / a_ego valid_times = [t for t in (t1, t2) if t > 0] return min(valid_times) if valid_times else None except Exception as e: self.logger.warning(f"MTTC计算错误: {e}") return None def _cal_THW(self, lon_d, vx_ego): """计算车头时距(THW)""" if vx_ego is None or vx_ego <= 0 or lon_d is None or lon_d <= 0: return None return lon_d / vx_ego def _cal_TTB(self, dist, vrel_projection, ego_decel_max): """计算制动时间(TTB)""" if vrel_projection is None or ego_decel_max is None or vrel_projection >= 0: return None try: TTB = (dist + vrel_projection ** 2 / (2 * ego_decel_max) / vrel_projection) return TTB if TTB > 0 and not (np.isinf(TTB) or np.isnan(TTB)) else None except Exception as e: self.logger.warning(f"计算TTB时出错: {e}") return None def _cal_TM(self, vrel_projection, v2, a2, v1, a1): """计算时间裕度(TM)""" if (vrel_projection is None or v2 is None or a2 is None or a2 <= 0 or v1 is None or a1 is None or a1 <= 0 or vrel_projection >= 0): return None try: TM = (v2 ** 2 / (2 * a2) - v1 ** 2 / (2 * a1)) / v1 return TM if TM > 0 and not (np.isinf(TM) or np.isnan(TM)) else None except Exception as e: self.logger.warning(f"计算TM时出错: {e}") return None def _cal_PSD(self, dist, vrel_projection, max_decel, heading_diff, vx_ego, vy_ego, obj_type, relative_x, relative_y): """计算预测安全距离比(PSD)""" abs_heading_diff = abs(heading_diff) # 根据航向差确定场景类型 if abs_heading_diff < 45: # 纵向场景 D = abs(relative_x) v_dir = abs(vx_ego) min_stopping_dist = (v_dir ** 2) / (2 * self.ego_decel_lon_max) elif 45 <= abs_heading_diff <= 135: # 横向场景 D = abs(relative_y) v_dir = abs(vy_ego) if obj_type == 5: # 行人 reaction_dist = abs(relative_y) * self.rho braking_dist = (v_dir ** 2) / (2 * self.ego_decel_lat_max) min_stopping_dist = reaction_dist + braking_dist else: # 车辆 min_stopping_dist = (v_dir ** 2) / (2 * self.ego_decel_lat_max) else: # 对向场景 return self.DEFAULT_VALUES['PSD'] # 处理特殊情况 if min_stopping_dist < 1e-3: return float('inf') if D > 0.1 else 0.0 return D / min_stopping_dist def _cal_DTC(self, vx_rel, ax_rel_ego, t): """计算碰撞距离(DTC)""" if vx_rel >= 0: # 没有接近风险 return self.DEFAULT_VALUES['DTC'] try: speed = abs(vx_rel) reaction_distance = speed * t braking_distance = (speed ** 2) / (2 * self.ego_decel_lon_max) dtc = reaction_distance + braking_distance return max(0, dtc) except Exception as e: self.logger.warning(f"DTC计算错误: {e}") return self.DEFAULT_VALUES['DTC'] def _cal_longitudinal_safe_dist(self, v_ego, v_lead_rel, rho, decel_ego, decel_lead): """计算纵向安全距离(LonSD)""" try: v_lead = v_ego + v_lead_rel reaction_distance = v_ego * rho braking_distance_ego = v_ego ** 2 / (2 * decel_ego) if v_ego > 0 else 0 braking_distance_lead = v_lead ** 2 / (2 * decel_lead) if v_lead > 0 else 0 safe_dist = reaction_distance + braking_distance_ego - braking_distance_lead return max(2.0, safe_dist) except Exception as e: self.logger.warning(f"计算纵向安全距离错误: {e}") return 25.0 def _cal_lateral_safe_dist(self, lat_dist, v_ego_lat, vehicle_width, lane_width, decel_lat_max): """计算横向安全距离(LatSD)""" try: available_space = lane_width - vehicle_width base_margin = 0.5 speed_margin = 0.1 * abs(v_ego_lat) safe_dist = min(base_margin + speed_margin, available_space / 2) return max(base_margin, safe_dist) except Exception as e: self.logger.warning(f"计算横向安全距离错误: {e}") return 1.0 def _cal_BTN_new(self, lon_a1, lon_a, lon_d, lon_v, ego_decel_lon_max): """计算制动威胁数(BTN)""" if lon_d is None or lon_d <= 0 or ego_decel_lon_max is None or ego_decel_lon_max <= 0: return None try: return (lon_a1 + lon_a - (lon_v ** 2) / (2 * lon_d)) / ego_decel_lon_max except: return None def _death_pr(self, obj_type, v_relative): """计算死亡概率""" if obj_type == 5: # 行人 return 1 / (1 + np.exp(7.723 - 0.15 * v_relative)) else: # 车辆 return 1 / (1 + np.exp(8.192 - 0.12 * v_relative)) def _cal_collisionRisk_level(self, obj_type, v_relative, collisionSeverity): """计算碰撞风险等级""" p_death = self._death_pr(obj_type, v_relative) return 0.4 * p_death + 0.6 * collisionSeverity def _normal_distribution(self, x, mean=1.32, std_dev=0.26): """正态分布函数用于碰撞严重性计算""" return (1 / (math.sqrt(std_dev * 2 * math.pi))) * math.exp(-0.5 * (x - mean) ** 2 / std_dev) # ==================== 目标物过滤方法 ==================== def _is_relevant_target(self, lat_d, lon_d, v_lat, heading_diff, obj_type): """判断目标物是否相关""" abs_lat_d = abs(lat_d) abs_heading_diff = abs(heading_diff) lane_half_width = self.lane_width / 2 ego_half_width = self.ego_width / 2 ped_half_width = self.ped_width / 2 safe_width = ego_half_width + ped_half_width + 0.2 # 行人目标处理 if obj_type == 5: if abs_heading_diff >= 60: # 横穿行人 return (0 < lon_d < 50) and (abs_lat_d < safe_width or abs_lat_d < lane_half_width) else: # 沿路行人 return (0 < lon_d < 30) and (abs_lat_d < safe_width) and abs(v_lat) < 1.0 # 车辆类目标处理 if abs_heading_diff < 45: # 纵向目标 return abs_lat_d < lane_half_width + 0.1 and abs(lon_d) < 80 if 45 <= abs_heading_diff <= 135: # 横向目标 return 0 < lon_d < 30 and abs_lat_d < ego_half_width + 0.1 return False # 反向目标 def _has_passed_conflict_point(self, relative_x, relative_y, obj_type, heading_diff): """判断目标是否已越过冲突点""" if obj_type == 5: # 行人 obj_length = self.pedestrian_length obj_width = self.pedestrian_width else: # 车辆 obj_length = self.vehicle_length obj_width = self.vehicle_width length_safe_dist = (self.ego_length + obj_length) / 2 width_safe_dist = (self.ego_width + obj_width) / 2 if obj_type == 5: return relative_x < -length_safe_dist or abs(relative_y) > width_safe_dist * 2 else: if abs(heading_diff) < 45: # 同向 return relative_x < -length_safe_dist or abs(relative_y) > width_safe_dist * 1.5 elif 45 <= abs(heading_diff) <= 135: # 横向 return relative_x < -length_safe_dist or relative_x > length_safe_dist else: # 对向 return ((relative_x < -length_safe_dist and relative_y < -width_safe_dist) or (relative_x < -length_safe_dist and relative_y > width_safe_dist) or (relative_x > length_safe_dist and relative_y < -width_safe_dist) or (relative_x > length_safe_dist and relative_y > width_safe_dist)) def _is_moving_away(self, ego_vel, target_vel, relative_x, relative_y, obj_type, heading_diff): """判断目标是否正在远离""" vx_rel = target_vel[0] - ego_vel[0] vy_rel = target_vel[1] - ego_vel[1] dist = math.sqrt(relative_x ** 2 + relative_y ** 2) if dist < 1e-6: return False v_projection = (vx_rel * relative_x + vy_rel * relative_y) / dist if obj_type == 5: # 行人 if relative_x < 0: # 后方 return v_projection > 0.5 else: # 前方 return v_projection < -0.5 else: abs_heading_diff = abs(heading_diff) if abs_heading_diff < 45: # 同向 if relative_x < 0: # 后方 return v_projection > 0 else: # 前方 return v_projection < 0 elif 45 <= abs_heading_diff <= 135: # 横向 if relative_x > 0: # 前方 return v_projection < 0 else: # 后方 return v_projection > 0 else: # 对向 return False # ==================== 指标获取方法 ==================== def _get_metric_value(self, metric_name, aggregation='min'): """通用指标值获取方法""" if self.empty_flag or self.df_safe is None or self.df_safe.empty: return self.DEFAULT_VALUES.get(metric_name, None) values = self.df_safe[metric_name].dropna() if values.empty: return self.DEFAULT_VALUES.get(metric_name, None) # 根据聚合类型获取值 if aggregation == 'min': metric_value = float(values.min()) elif aggregation == 'max': metric_value = float(values.max()) elif aggregation == 'mean': metric_value = float(values.mean()) else: metric_value = float(values.min()) # 存储指标数据 metric_data = [] for _, row in self.df_safe.iterrows(): if pd.notnull(row[metric_name]): metric_data.append({ 'simTime': row['simTime'], 'simFrame': row['simFrame'], 'playerId': row['playerId'], metric_name: row[metric_name] }) setattr(self, f'{metric_name.lower()}_data', metric_data) return metric_value def get_ttc_value(self) -> float: return self._get_metric_value('TTC') def get_mttc_value(self) -> float: return self._get_metric_value('MTTC') def get_thw_value(self) -> float: return self._get_metric_value('THW') def get_ttb_value(self) -> float: return self._get_metric_value('TTB') def get_tm_value(self) -> float: return self._get_metric_value('TM') def get_dtc_value(self) -> float: return self._get_metric_value('DTC', 'min') def get_psd_value(self) -> float: return self._get_metric_value('PSD', 'min') def get_lonsd_value(self) -> float: return self._get_metric_value('LonSD', 'mean') def get_latsd_value(self) -> float: return self._get_metric_value('LatSD', 'min') def get_btn_value(self) -> float: return self._get_metric_value('BTN', 'max') def get_collision_risk_value(self) -> float: return self._get_metric_value('collisionRisk', 'max') def get_collision_severity_value(self) -> float: return self._get_metric_value('collisionSeverity', 'max') # ==================== 辅助方法 ==================== def generate_metric_chart(self, metric_name: str, plot_path: Path) -> None: """生成指标图表""" try: self.output_dir = plot_path if plot_path else os.path.join(os.getcwd(), 'data') os.makedirs(self.output_dir, exist_ok=True) chart_path = generate_safety_chart_data(self, metric_name, self.output_dir) if chart_path: self.logger.info(f"{metric_name}图表已生成: {chart_path}") else: self.logger.warning(f"{metric_name}图表生成失败") except Exception as e: self.logger.error(f"生成{metric_name}图表失败: {str(e)}", exc_info=True)