123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- 安全指标计算模块(支持多目标物) - 场景感知版本
- 优化要点:
- 1. 重构代码结构,提高可读性
- 2. 去除冗余代码,优化性能
- 3. 保持原始计算逻辑不变
- 4. 增加全面注释
- 5. 优化数据处理流程
- """
- import os
- import numpy as np
- import pandas as pd
- import math
- import scipy.integrate as spi
- from collections import defaultdict
- from typing import Dict, Any, List, Optional, Tuple
- from pathlib import Path
- from modules.lib.score import Score
- from modules.lib.log_manager import LogManager
- from modules.lib.chart_generator import generate_safety_chart_data
- # ==================== 安全指标计算接口函数 ====================
- # 每个函数对应一个安全指标的计算入口
- # ------------------------------------------------------------
- def calculate_ttc(data_processed, plot_path) -> dict:
- """计算碰撞时间(TTC)"""
- return _calculate_metric('TTC', data_processed, plot_path)
- def calculate_mttc(data_processed, plot_path) -> dict:
- """计算修正碰撞时间(MTTC)"""
- return _calculate_metric('MTTC', data_processed, plot_path)
- def calculate_thw(data_processed, plot_path) -> dict:
- """计算车头时距(THW)"""
- return _calculate_metric('THW', data_processed, plot_path)
- def calculate_ttb(data_processed, plot_path) -> dict:
- """计算制动时间(TTB)"""
- return _calculate_metric('TTB', data_processed, plot_path)
- def calculate_tm(data_processed, plot_path) -> dict:
- """计算时间裕度(TM)"""
- return _calculate_metric('TM', data_processed, plot_path)
- def calculate_dtc(data_processed, plot_path) -> dict:
- """计算碰撞距离(DTC)"""
- return _calculate_metric('DTC', data_processed, plot_path)
- def calculate_psd(data_processed, plot_path) -> dict:
- """计算预测安全距离比(PSD)"""
- return _calculate_metric('PSD', data_processed, plot_path)
- def calculate_collisionrisk(data_processed, plot_path) -> dict:
- """计算碰撞风险(collisionRisk)"""
- return _calculate_metric('collisionRisk', data_processed, plot_path)
- def calculate_lonsd(data_processed, plot_path) -> dict:
- """计算纵向安全距离(LonSD)"""
- return _calculate_metric('LonSD', data_processed, plot_path)
- def calculate_latsd(data_processed, plot_path) -> dict:
- """计算横向安全距离(LatSD)"""
- return _calculate_metric('LatSD', data_processed, plot_path)
- def calculate_btn(data_processed, plot_path) -> dict:
- """计算制动威胁数(BTN)"""
- return _calculate_metric('BTN', data_processed, plot_path)
- def calculate_collisionseverity(data_processed, plot_path) -> dict:
- """计算碰撞严重性(collisionSeverity)"""
- return _calculate_metric('collisionSeverity', data_processed, plot_path)
- def _calculate_metric(metric_name, data_processed, plot_path) -> dict:
- """安全指标计算的通用处理函数"""
- key_name = metric_name
- result_key = {metric_name: None}
- # 检查输入数据有效性
- if data_processed is None or not hasattr(data_processed, 'object_df'):
- return result_key
- try:
- safety = SafetyCalculator(data_processed)
- # 特殊处理 collisionRisk 和 collisionSeverity 指标
- if metric_name.lower() == 'collisionrisk':
- metric_value = safety.get_collision_risk_value()
- elif metric_name.lower() == 'collisionseverity':
- metric_value = safety.get_collision_severity_value()
- else:
- # 其他指标使用原有逻辑
- metric_value = getattr(safety, f'get_{metric_name.lower()}_value')()
- # 生成图表数据
- metric_data = getattr(safety, f'{metric_name.lower()}_data', None)
- if metric_data:
- safety.generate_metric_chart(metric_name, plot_path)
- LogManager().get_logger().info(f"安全指标[{metric_name}]计算结果: {metric_value}")
- return {metric_name: metric_value}
- except Exception as e:
- LogManager().get_logger().error(f"{metric_name}计算异常: {str(e)}", exc_info=True)
- return result_key
- # ==================== 安全指标注册与批处理类 ====================
- # 用于根据配置批量执行安全指标计算
- # ------------------------------------------------------------
- class SafetyRegistry:
- """安全指标注册器,根据配置动态注册和执行指标计算"""
- def __init__(self, data_processed, plot_path):
- self.logger = LogManager().get_logger()
- self.data = data_processed
- self.plot_path = plot_path
- # 检查安全配置
- if not hasattr(data_processed, 'safety_config') or not data_processed.safety_config:
- self.logger.warning("安全配置为空,跳过安全指标计算")
- self.safety_config = {}
- self.metrics = []
- self._registry = {}
- return
- self.safety_config = data_processed.safety_config.get("safety", {})
- self.metrics = self._extract_metrics(self.safety_config)
- self._registry = self._build_registry()
- def _extract_metrics(self, config_node: dict) -> list:
- """从配置中递归提取所有指标名称"""
- metrics = []
- def _recurse(node):
- if isinstance(node, dict):
- if 'name' in node and not any(isinstance(v, dict) for v in node.values()):
- metrics.append(node['name'])
- for v in node.values():
- _recurse(v)
- _recurse(config_node)
- self.logger.info(f'评比的安全指标列表: {metrics}')
- return metrics
- def _build_registry(self) -> dict:
- """构建指标名称到计算函数的映射"""
- registry = {}
- for metric_name in self.metrics:
- func_name = f"calculate_{metric_name.lower()}"
- if func_name in globals():
- registry[metric_name] = globals()[func_name]
- else:
- self.logger.warning(f"未实现安全指标函数: {func_name}")
- return registry
- def batch_execute(self) -> dict:
- """批量执行所有注册的安全指标计算"""
- results = {}
- if not hasattr(self, 'safety_config') or not self.safety_config or not self._registry:
- self.logger.info("安全配置为空或无注册指标,返回空结果")
- return results
- for name, func in self._registry.items():
- try:
- result = func(self.data, self.plot_path)
- results.update(result)
- except Exception as e:
- self.logger.error(f"{name} 执行失败: {str(e)}", exc_info=True)
- results[name] = None
- self.logger.info(f'安全指标计算结果: {results}')
- return results
- class SafeManager:
- """安全管理器,封装安全指标计算流程"""
- def __init__(self, data_processed, plot_path):
- self.data = data_processed
- self.logger = LogManager().get_logger()
- self.plot_path = plot_path
- # 初始化安全指标注册器
- if not hasattr(data_processed, 'safety_config') or not data_processed.safety_config:
- self.logger.warning("安全配置为空,跳过安全指标计算初始化")
- self.registry = None
- else:
- self.registry = SafetyRegistry(self.data, self.plot_path)
- def report_statistic(self):
- """生成安全指标统计报告"""
- if self.registry is None:
- self.logger.info("安全指标管理器未初始化,返回空结果")
- return {}
- return self.registry.batch_execute()
- # ==================== 安全指标计算核心类 ====================
- # 包含所有安全指标的具体计算实现
- # ------------------------------------------------------------
- class SafetyCalculator:
- """安全指标计算器,实现各种安全指标的具体计算"""
- DEFAULT_VALUES = {
- 'TTC': 10.0, 'MTTC': 3.3, 'THW': 2.5,
- 'TTB': 10.0, 'TM': 10.0, 'DTC': 10.0, 'PSD': 10.0,
- 'LonSD': 100.0, 'LatSD': 2.0, 'BTN': 0.0,
- 'collisionRisk': 0.0, 'collisionSeverity': 0.0
- }
- def __init__(self, data_processed):
- self.logger = LogManager().get_logger()
- self.data_processed = data_processed
- self.df = data_processed.object_df.copy()
- self.ego_df = data_processed.ego_data.copy()
- self.obj_id_list = data_processed.obj_id_list
- # 初始化配置参数
- self._init_config_params()
- # 初始化数据结构和指标存储
- self._init_data_structures()
- # 执行安全参数计算
- if self.obj_id_list:
- self.logger.info("开始执行安全参数计算,目标物数量: %d", len(self.obj_id_list) - 1)
- self._calculate_safety_parameters()
- self.logger.info("安全参数计算完成")
- else:
- self.logger.info("没有目标物,跳过安全参数计算")
- self.empty_flag = True
- def _init_config_params(self):
- """从配置中初始化车辆参数"""
- config = self.data_processed.vehicle_config
- self.rho = config.get("RHO", 1.5) # 默认反应时间1.5秒
- self.ego_accel_max = config.get("EGO_ACCEL_MAX", 3.0) # 默认最大加速度3.0 m/s²
- self.obj_decel_max = config.get("OBJ_DECEL_MAX", 3.0) # 默认目标最大减速度3.0 m/s²
- self.ego_decel_min = config.get("EGO_DECEL_MIN", 4.0) # 默认自车最小减速度4.0 m/s²
- self.ego_decel_lon_max = config.get("EGO_DECEL_LON_MAX", 4.0) # 默认纵向最大减速度4.0 m/s²
- self.ego_decel_lat_max = config.get("EGO_DECEL_LAT_MAX", 3.0) # 默认横向最大减速度3.0 m/s²
- self.ego_width = config.get("CAR_WIDTH", 2.0) # 默认车辆宽度2.0米
- self.ego_length = config.get("CAR_LENGTH", 4.5) # 默认车辆长度4.5米
- self.vehicle_length = config.get("VEHICLE_LENGTH", 4.5) # 默认车辆长度4.5米
- self.vehicle_width = config.get("VEHICLE_WIDTH", 2.0) # 默认车辆宽度2.0米
- self.pedestrian_length = config.get("PEDESTRIAN_LENGTH", 0.5) # 默认行人长度0.5米
- self.pedestrian_width = config.get("PEDESTRIAN_WIDTH", 0.5) # 默认行人宽度0.5米
- # 计算派生参数
- self.max_deceleration = self.ego_decel_lon_max
- self.ego_decel_max = np.sqrt(self.ego_decel_lon_max ** 2 + self.ego_decel_lat_max ** 2)
- self.ped_width = self.pedestrian_width
- self.ped_length = self.pedestrian_length
- def _init_data_structures(self):
- """初始化数据存储结构"""
- self.time_list = self.ego_df['simTime'].values.tolist()
- self.frame_list = self.ego_df['simFrame'].values.tolist()
- self.empty_flag = True # 标记是否有有效目标
- # 初始化指标数据存储
- self._init_metric_storage()
- # 创建输出目录
- self.output_dir = os.path.join(os.getcwd(), 'data')
- os.makedirs(self.output_dir, exist_ok=True)
- def _init_metric_storage(self):
- """初始化所有指标的存储列表"""
- metrics = [
- 'ttc', 'mttc', 'thw', 'ttb', 'tm', 'lonsd', 'latsd', 'btn',
- 'collision_risk', 'collision_severity'
- ]
- for metric in metrics:
- setattr(self, f'{metric}_data', [])
- def _calculate_safety_parameters(self):
- """核心安全参数计算方法"""
- # 预处理:构建按帧组织的目标字典
- obj_dict = self._preprocess_object_data()
- df_list = []
- EGO_PLAYER_ID = 1
- found_valid_target = False
- # 逐帧处理数据
- for frame_num in self.frame_list:
- try:
- ego_data = obj_dict[frame_num][EGO_PLAYER_ID]
- except KeyError:
- continue
- frame_targets = self._process_frame_targets(frame_num, ego_data, obj_dict, EGO_PLAYER_ID)
- if frame_targets:
- found_valid_target = True
- df_fnum = pd.DataFrame(frame_targets)
- df_list.append(df_fnum)
- # 合并处理结果
- self._postprocess_results(df_list, found_valid_target)
- def _preprocess_object_data(self):
- """预处理目标数据为按帧组织的字典"""
- obj_dict = defaultdict(dict)
- obj_data_dict = self.df.to_dict('records')
- for item in obj_data_dict:
- obj_dict[item['simFrame']][item['playerId']] = item
- return obj_dict
- def _process_frame_targets(self, frame_num, ego_data, obj_dict, ego_id):
- """处理单帧内的所有目标物"""
- frame_targets = []
- vx_ego = ego_data.get('lon_v_vehicle', 0)
- vy_ego = ego_data.get('lat_v_vehicle', 0)
- self.lane_width = ego_data.get('lane_width', 3.75)
- for player_id in self.obj_id_list:
- if player_id == ego_id:
- continue
- try:
- obj_data = obj_dict[frame_num][player_id]
- except KeyError:
- continue
- # 计算目标物安全指标
- result = self._calculate_target_metrics(ego_data, obj_data, vx_ego, vy_ego)
- obj_data.update(result)
- frame_targets.append(obj_data)
- return frame_targets
- def _calculate_target_metrics(self, ego_data, obj_data, vx_ego, vy_ego):
- """计算单个目标物的安全指标"""
- # 提取基本参数
- obj_type = obj_data.get('type', 0)
- lon_d = obj_data.get('x_relative_start_dist', 0)
- lat_d = obj_data.get('y_relative_start_dist', 0)
- vx_obj = obj_data.get('lon_v_vehicle', 0)
- vy_obj = obj_data.get('lat_v_vehicle', 0)
- # 计算相对运动参数
- vx_rel = vx_obj - vx_ego
- vy_rel = vy_obj - vy_ego
- dist = math.sqrt(lon_d ** 2 + lat_d ** 2)
- # 计算航向角差异
- h1 = ego_data['posH']
- h2 = obj_data.get('posH', h1)
- heading_diff = abs(h1 - h2)
- if heading_diff > 180:
- heading_diff = 360 - heading_diff
- # 计算相对加速度
- ax_ego = ego_data.get('lon_acc_vehicle', 0)
- ay_ego = ego_data.get('lat_acc_vehicle', 0)
- ax_obj = obj_data.get('lon_acc_vehicle', 0)
- ay_obj = obj_data.get('lat_acc_vehicle', 0)
- ax_rel_ego = ax_obj - ax_ego
- ay_rel_ego = ay_obj - ay_ego
- # 计算投影
- vrel_projection = self._cal_v_projection_using_relative(lon_d, lat_d, vx_rel, vy_rel)
- accl_projection = self._cal_a_projection_using_relative(lon_d, lat_d, ax_rel_ego, ay_rel_ego)
- # 检查目标相关性
- is_relevant = self._is_relevant_target(lat_d, lon_d, vy_rel, heading_diff, obj_type)
- # 检查过滤条件
- passed_conflict = self._has_passed_conflict_point(lon_d, lat_d, obj_type, heading_diff)
- moving_away = self._is_moving_away(
- (vx_ego, vy_ego), (vx_obj, vy_obj), lon_d, lat_d, obj_type, heading_diff
- )
- skip_risk_calculation = passed_conflict and (moving_away or obj_data.get('v', 0) < 0.05)
- # 计算指标
- if not is_relevant or skip_risk_calculation:
- return self._get_default_metrics(obj_type)
- else:
- return self._calculate_relevant_metrics(
- dist, vrel_projection, accl_projection,
- lon_d, lat_d, vx_ego, vy_ego, vx_rel, ax_rel_ego,
- heading_diff, obj_type, ego_data, obj_data
- )
- def _get_default_metrics(self, obj_type):
- """获取默认指标值"""
- return {
- 'TTC': self.DEFAULT_VALUES['TTC'],
- 'MTTC': self.DEFAULT_VALUES['MTTC'],
- 'THW': self.DEFAULT_VALUES['THW'] if obj_type != 5 else None,
- 'TTB': self.DEFAULT_VALUES['TTB'] if obj_type != 5 else None,
- 'TM': self.DEFAULT_VALUES['TM'] if obj_type != 5 else None,
- 'PSD': self.DEFAULT_VALUES['PSD'],
- 'DTC': self.DEFAULT_VALUES['DTC'],
- 'LonSD': self.DEFAULT_VALUES['LonSD'] if obj_type != 5 else None,
- 'LatSD': self.DEFAULT_VALUES['LatSD'] if obj_type != 5 else None,
- 'BTN': self.DEFAULT_VALUES['BTN'] if obj_type != 5 else None,
- 'collisionSeverity': self.DEFAULT_VALUES['collisionSeverity'],
- 'pr_death': 0,
- 'collisionRisk': self.DEFAULT_VALUES['collisionRisk']
- }
- def _calculate_relevant_metrics(self, dist, vrel_projection, accl_projection,
- lon_d, lat_d, vx_ego, vy_ego, vx_rel, ax_rel_ego,
- heading_diff, obj_type, ego_data, obj_data):
- """计算相关目标的指标"""
- Tc = 0.3 # 反应时间
- # 核心指标计算
- metrics = {
- 'TTC': self._cal_TTC(dist, vrel_projection),
- 'MTTC': self._cal_MTTC(dist, vrel_projection, ego_data.get('lon_acc_vehicle', 0)),
- 'THW': self._cal_THW(lon_d, vx_ego),
- 'TTB': self._cal_TTB(dist, vrel_projection, self.ego_decel_max) if obj_type != 5 else None,
- 'TM': self._cal_TM(vrel_projection, obj_data.get('v', 0),
- obj_data.get('lon_acc_vehicle', 0),
- ego_data.get('v', 0),
- ego_data.get('lon_acc_vehicle', 0)) if obj_type != 5 else None,
- 'PSD': self._cal_PSD(dist, vrel_projection, self.max_deceleration,
- heading_diff, vx_ego, vy_ego, obj_type, lon_d, lat_d),
- 'DTC': self._cal_DTC(vx_rel, ax_rel_ego, self.rho) if abs(heading_diff) < 30 else self.DEFAULT_VALUES[
- 'DTC'],
- 'LonSD': self._cal_longitudinal_safe_dist(vx_ego, vx_rel, self.rho,
- self.ego_decel_min,
- self.obj_decel_max) if obj_type != 5 else None,
- 'LatSD': self._cal_lateral_safe_dist(abs(lat_d), vy_ego, self.ego_width,
- self.lane_width, self.ego_decel_lat_max) if obj_type != 5 else None,
- 'BTN': self._cal_BTN_new(ego_data.get('lon_acc_vehicle', 0),
- ego_data.get('lon_acc_vehicle', 0),
- abs(lon_d), vx_ego, self.ego_decel_lon_max) if obj_type != 5 else None
- }
- # 处理碰撞风险相关指标
- if metrics['TTC'] is None or metrics['TTC'] > 4000:
- metrics.update({
- 'collisionSeverity': 0,
- 'pr_death': 0,
- 'collisionRisk': 0
- })
- else:
- try:
- result, _ = spi.quad(self._normal_distribution, 0, metrics['TTC'] - Tc)
- collisionSeverity = 1 - result
- pr_death = self._death_pr(obj_type, vrel_projection)
- collisionRisk = self._cal_collisionRisk_level(obj_type, vrel_projection, collisionSeverity)
- metrics.update({
- 'collisionSeverity': collisionSeverity * 100,
- 'pr_death': pr_death * 100,
- 'collisionRisk': collisionRisk * 100
- })
- except Exception as e:
- self.logger.error(f"碰撞风险计算错误: {e}")
- metrics.update({
- 'collisionSeverity': 0,
- 'pr_death': 0,
- 'collisionRisk': 0
- })
- return metrics
- def _postprocess_results(self, df_list, found_valid_target):
- """后处理计算结果"""
- self.empty_flag = not found_valid_target
- # 合并所有帧数据
- if df_list:
- self.df_safe = pd.concat(df_list)
- col_list = ['simTime', 'simFrame', 'playerId',
- 'TTC', 'MTTC', 'THW', 'TTB', 'TM', 'DTC', 'PSD', 'LonSD', 'LatSD', 'BTN',
- 'collisionSeverity', 'pr_death', 'collisionRisk']
- self.df_safe = self.df_safe[col_list].reset_index(drop=True)
- else:
- self.df_safe = pd.DataFrame()
- self.empty_flag = True
- self.logger.info(f"处理完成,找到有效目标: {not self.empty_flag}")
- # ==================== 核心计算方法 ====================
- def _cal_v_projection_using_relative(self, lon_d, lat_d, vx_rel, vy_rel):
- """计算速度投影(自车坐标系)"""
- dist = math.sqrt(lon_d ** 2 + lat_d ** 2)
- if dist < 1e-6:
- return 0.0
- U_ABx = lon_d / dist
- U_ABy = lat_d / dist
- return vx_rel * U_ABx + vy_rel * U_ABy
- def _cal_a_projection_using_relative(self, lon_d, lat_d, ax_rel, ay_rel):
- """计算加速度投影(自车坐标系)"""
- dist = math.sqrt(lon_d ** 2 + lat_d ** 2)
- if dist < 1e-6:
- return 0.0
- U_ABx = lon_d / dist
- U_ABy = lat_d / dist
- return ax_rel * U_ABx + ay_rel * U_ABy
- def _cal_TTC(self, dist, vrel_projection):
- """计算碰撞时间(TTC)"""
- if vrel_projection >= 0: # 不接近
- return None
- return dist / abs(vrel_projection)
- def _cal_MTTC(self, dist, vrel_projection, a_ego):
- """计算修正的碰撞时间(MTTC)"""
- if vrel_projection >= 0 or dist <= 0 or a_ego >= 0:
- return None
- try:
- if abs(a_ego) < 1e-6:
- return dist / abs(vrel_projection) if abs(vrel_projection) > 1e-6 else None
- discriminant = vrel_projection ** 2 + 2 * a_ego * dist
- if discriminant < 0:
- return None
- t1 = (-vrel_projection + math.sqrt(discriminant)) / a_ego
- t2 = (-vrel_projection - math.sqrt(discriminant)) / a_ego
- valid_times = [t for t in (t1, t2) if t > 0]
- return min(valid_times) if valid_times else None
- except Exception as e:
- self.logger.warning(f"MTTC计算错误: {e}")
- return None
- def _cal_THW(self, lon_d, vx_ego):
- """计算车头时距(THW)"""
- if vx_ego is None or vx_ego <= 0 or lon_d is None or lon_d <= 0:
- return None
- return lon_d / vx_ego
- def _cal_TTB(self, dist, vrel_projection, ego_decel_max):
- """计算制动时间(TTB)"""
- if vrel_projection is None or ego_decel_max is None or vrel_projection >= 0:
- return None
- try:
- TTB = (dist + vrel_projection ** 2 / (2 * ego_decel_max) / vrel_projection)
- return TTB if TTB > 0 and not (np.isinf(TTB) or np.isnan(TTB)) else None
- except Exception as e:
- self.logger.warning(f"计算TTB时出错: {e}")
- return None
- def _cal_TM(self, vrel_projection, v2, a2, v1, a1):
- """计算时间裕度(TM)"""
- if (vrel_projection is None or v2 is None or a2 is None or a2 <= 0 or
- v1 is None or a1 is None or a1 <= 0 or vrel_projection >= 0):
- return None
- try:
- TM = (v2 ** 2 / (2 * a2) - v1 ** 2 / (2 * a1)) / v1
- return TM if TM > 0 and not (np.isinf(TM) or np.isnan(TM)) else None
- except Exception as e:
- self.logger.warning(f"计算TM时出错: {e}")
- return None
- def _cal_PSD(self, dist, vrel_projection, max_decel, heading_diff, vx_ego, vy_ego, obj_type, relative_x,
- relative_y):
- """计算预测安全距离比(PSD)"""
- abs_heading_diff = abs(heading_diff)
- # 根据航向差确定场景类型
- if abs_heading_diff < 45:
- # 纵向场景
- D = abs(relative_x)
- v_dir = abs(vx_ego)
- min_stopping_dist = (v_dir ** 2) / (2 * self.ego_decel_lon_max)
- elif 45 <= abs_heading_diff <= 135:
- # 横向场景
- D = abs(relative_y)
- v_dir = abs(vy_ego)
- if obj_type == 5: # 行人
- reaction_dist = abs(relative_y) * self.rho
- braking_dist = (v_dir ** 2) / (2 * self.ego_decel_lat_max)
- min_stopping_dist = reaction_dist + braking_dist
- else: # 车辆
- min_stopping_dist = (v_dir ** 2) / (2 * self.ego_decel_lat_max)
- else: # 对向场景
- return self.DEFAULT_VALUES['PSD']
- # 处理特殊情况
- if min_stopping_dist < 1e-3:
- return float('inf') if D > 0.1 else 0.0
- return D / min_stopping_dist
- def _cal_DTC(self, vx_rel, ax_rel_ego, t):
- """计算碰撞距离(DTC)"""
- if vx_rel >= 0: # 没有接近风险
- return self.DEFAULT_VALUES['DTC']
- try:
- speed = abs(vx_rel)
- reaction_distance = speed * t
- braking_distance = (speed ** 2) / (2 * self.ego_decel_lon_max)
- dtc = reaction_distance + braking_distance
- return max(0, dtc)
- except Exception as e:
- self.logger.warning(f"DTC计算错误: {e}")
- return self.DEFAULT_VALUES['DTC']
- def _cal_longitudinal_safe_dist(self, v_ego, v_lead_rel, rho, decel_ego, decel_lead):
- """计算纵向安全距离(LonSD)"""
- try:
- v_lead = v_ego + v_lead_rel
- reaction_distance = v_ego * rho
- braking_distance_ego = v_ego ** 2 / (2 * decel_ego) if v_ego > 0 else 0
- braking_distance_lead = v_lead ** 2 / (2 * decel_lead) if v_lead > 0 else 0
- safe_dist = reaction_distance + braking_distance_ego - braking_distance_lead
- return max(2.0, safe_dist)
- except Exception as e:
- self.logger.warning(f"计算纵向安全距离错误: {e}")
- return 25.0
- def _cal_lateral_safe_dist(self, lat_dist, v_ego_lat, vehicle_width, lane_width, decel_lat_max):
- """计算横向安全距离(LatSD)"""
- try:
- available_space = lane_width - vehicle_width
- base_margin = 0.5
- speed_margin = 0.1 * abs(v_ego_lat)
- safe_dist = min(base_margin + speed_margin, available_space / 2)
- return max(base_margin, safe_dist)
- except Exception as e:
- self.logger.warning(f"计算横向安全距离错误: {e}")
- return 1.0
- def _cal_BTN_new(self, lon_a1, lon_a, lon_d, lon_v, ego_decel_lon_max):
- """计算制动威胁数(BTN)"""
- if lon_d is None or lon_d <= 0 or ego_decel_lon_max is None or ego_decel_lon_max <= 0:
- return None
- try:
- return (lon_a1 + lon_a - (lon_v ** 2) / (2 * lon_d)) / ego_decel_lon_max
- except:
- return None
- def _death_pr(self, obj_type, v_relative):
- """计算死亡概率"""
- if obj_type == 5: # 行人
- return 1 / (1 + np.exp(7.723 - 0.15 * v_relative))
- else: # 车辆
- return 1 / (1 + np.exp(8.192 - 0.12 * v_relative))
- def _cal_collisionRisk_level(self, obj_type, v_relative, collisionSeverity):
- """计算碰撞风险等级"""
- p_death = self._death_pr(obj_type, v_relative)
- return 0.4 * p_death + 0.6 * collisionSeverity
- def _normal_distribution(self, x, mean=1.32, std_dev=0.26):
- """正态分布函数用于碰撞严重性计算"""
- return (1 / (math.sqrt(std_dev * 2 * math.pi))) * math.exp(-0.5 * (x - mean) ** 2 / std_dev)
- # ==================== 目标物过滤方法 ====================
- def _is_relevant_target(self, lat_d, lon_d, v_lat, heading_diff, obj_type):
- """判断目标物是否相关"""
- abs_lat_d = abs(lat_d)
- abs_heading_diff = abs(heading_diff)
- lane_half_width = self.lane_width / 2
- ego_half_width = self.ego_width / 2
- ped_half_width = self.ped_width / 2
- safe_width = ego_half_width + ped_half_width + 0.2
- # 行人目标处理
- if obj_type == 5:
- if abs_heading_diff >= 60: # 横穿行人
- return (0 < lon_d < 50) and (abs_lat_d < safe_width or abs_lat_d < lane_half_width)
- else: # 沿路行人
- return (0 < lon_d < 30) and (abs_lat_d < safe_width) and abs(v_lat) < 1.0
- # 车辆类目标处理
- if abs_heading_diff < 45: # 纵向目标
- return abs_lat_d < lane_half_width + 0.1 and abs(lon_d) < 80
- if 45 <= abs_heading_diff <= 135: # 横向目标
- return 0 < lon_d < 30 and abs_lat_d < ego_half_width + 0.1
- return False # 反向目标
- def _has_passed_conflict_point(self, relative_x, relative_y, obj_type, heading_diff):
- """判断目标是否已越过冲突点"""
- if obj_type == 5: # 行人
- obj_length = self.pedestrian_length
- obj_width = self.pedestrian_width
- else: # 车辆
- obj_length = self.vehicle_length
- obj_width = self.vehicle_width
- length_safe_dist = (self.ego_length + obj_length) / 2
- width_safe_dist = (self.ego_width + obj_width) / 2
- if obj_type == 5:
- return relative_x < -length_safe_dist or abs(relative_y) > width_safe_dist * 2
- else:
- if abs(heading_diff) < 45: # 同向
- return relative_x < -length_safe_dist or abs(relative_y) > width_safe_dist * 1.5
- elif 45 <= abs(heading_diff) <= 135: # 横向
- return relative_x < -length_safe_dist or relative_x > length_safe_dist
- else: # 对向
- return ((relative_x < -length_safe_dist and relative_y < -width_safe_dist) or
- (relative_x < -length_safe_dist and relative_y > width_safe_dist) or
- (relative_x > length_safe_dist and relative_y < -width_safe_dist) or
- (relative_x > length_safe_dist and relative_y > width_safe_dist))
- def _is_moving_away(self, ego_vel, target_vel, relative_x, relative_y, obj_type, heading_diff):
- """判断目标是否正在远离"""
- vx_rel = target_vel[0] - ego_vel[0]
- vy_rel = target_vel[1] - ego_vel[1]
- dist = math.sqrt(relative_x ** 2 + relative_y ** 2)
- if dist < 1e-6:
- return False
- v_projection = (vx_rel * relative_x + vy_rel * relative_y) / dist
- if obj_type == 5: # 行人
- if relative_x < 0: # 后方
- return v_projection > 0.5
- else: # 前方
- return v_projection < -0.5
- else:
- abs_heading_diff = abs(heading_diff)
- if abs_heading_diff < 45: # 同向
- if relative_x < 0: # 后方
- return v_projection > 0
- else: # 前方
- return v_projection < 0
- elif 45 <= abs_heading_diff <= 135: # 横向
- if relative_x > 0: # 前方
- return v_projection < 0
- else: # 后方
- return v_projection > 0
- else: # 对向
- return False
- # ==================== 指标获取方法 ====================
- def _get_metric_value(self, metric_name, aggregation='min'):
- """通用指标值获取方法"""
- if self.empty_flag or self.df_safe is None or self.df_safe.empty:
- return self.DEFAULT_VALUES.get(metric_name, None)
- values = self.df_safe[metric_name].dropna()
- if values.empty:
- return self.DEFAULT_VALUES.get(metric_name, None)
- # 根据聚合类型获取值
- if aggregation == 'min':
- metric_value = float(values.min())
- elif aggregation == 'max':
- metric_value = float(values.max())
- elif aggregation == 'mean':
- metric_value = float(values.mean())
- else:
- metric_value = float(values.min())
- # 存储指标数据
- metric_data = []
- for _, row in self.df_safe.iterrows():
- if pd.notnull(row[metric_name]):
- metric_data.append({
- 'simTime': row['simTime'],
- 'simFrame': row['simFrame'],
- 'playerId': row['playerId'],
- metric_name: row[metric_name]
- })
- setattr(self, f'{metric_name.lower()}_data', metric_data)
- return metric_value
- def get_ttc_value(self) -> float:
- return self._get_metric_value('TTC')
- def get_mttc_value(self) -> float:
- return self._get_metric_value('MTTC')
- def get_thw_value(self) -> float:
- return self._get_metric_value('THW')
- def get_ttb_value(self) -> float:
- return self._get_metric_value('TTB')
- def get_tm_value(self) -> float:
- return self._get_metric_value('TM')
- def get_dtc_value(self) -> float:
- return self._get_metric_value('DTC', 'min')
- def get_psd_value(self) -> float:
- return self._get_metric_value('PSD', 'min')
- def get_lonsd_value(self) -> float:
- return self._get_metric_value('LonSD', 'mean')
- def get_latsd_value(self) -> float:
- return self._get_metric_value('LatSD', 'min')
- def get_btn_value(self) -> float:
- return self._get_metric_value('BTN', 'max')
- def get_collision_risk_value(self) -> float:
- return self._get_metric_value('collisionRisk', 'max')
- def get_collision_severity_value(self) -> float:
- return self._get_metric_value('collisionSeverity', 'max')
- # ==================== 辅助方法 ====================
- def generate_metric_chart(self, metric_name: str, plot_path: Path) -> None:
- """生成指标图表"""
- try:
- self.output_dir = plot_path if plot_path else os.path.join(os.getcwd(), 'data')
- os.makedirs(self.output_dir, exist_ok=True)
- chart_path = generate_safety_chart_data(self, metric_name, self.output_dir)
- if chart_path:
- self.logger.info(f"{metric_name}图表已生成: {chart_path}")
- else:
- self.logger.warning(f"{metric_name}图表生成失败")
- except Exception as e:
- self.logger.error(f"生成{metric_name}图表失败: {str(e)}", exc_info=True)
|