فهرست منبع

修改合规性部分代码,增加超车时对目标车进行筛选

XGJ_zhaoyuan 1 ماه پیش
والد
کامیت
752895bf6c
2فایلهای تغییر یافته به همراه329 افزوده شده و 250 حذف شده
  1. 1 0
      modules/config/config.py
  2. 328 250
      modules/metric/traffic.py

+ 1 - 0
modules/config/config.py

@@ -48,6 +48,7 @@ OVERTAKE_INFO = [
     "road_type",
     "interid",
     "crossid",
+    "lane_width"
 ]
 SLOWDOWN_INFO = [
     "simTime",

+ 328 - 250
modules/metric/traffic.py

@@ -1,5 +1,4 @@
 
-
 import math
 import numpy as np
 import pandas as pd
@@ -20,26 +19,12 @@ class OvertakingViolation(object):
         self.traffic_violations_type = "超车违规类"
 
         # self.logger = log.get_logger()  # 使用时再初始化
-
-        self.data = df_data.ego_data
+        self.data = df_data
+        self.data_ego = df_data.ego_data
         self.ego_data = (
-            self.data[config.OVERTAKE_INFO].copy().reset_index(drop=True)
+            self.data_ego[config.OVERTAKE_INFO].copy().reset_index(drop=True)
         )  # Copy to avoid modifying the original DataFrame
         header = self.ego_data.columns
-        if 2 in df_data.obj_id_list:
-            self.data_obj = df_data.obj_data[2]
-            self.obj_data = (
-                self.data_obj[config.OVERTAKE_INFO].copy().reset_index(drop=True)
-            )  # Copy to avoid modifying the original DataFrame
-        else:
-            self.obj_data = pd.DataFrame(columns=header)
-        if 3 in df_data.obj_id_list:
-            self.other_obj_data1 = df_data.obj_data[3]
-            self.other_obj_data = (
-                self.other_obj_data1[config.OVERTAKE_INFO].copy().reset_index(drop=True)
-            )
-        else:
-            self.other_obj_data = pd.DataFrame(columns=header)
         self.overtake_on_right_count = 0
         self.overtake_when_turn_around_count = 0
         self.overtake_when_passing_car_count = 0
@@ -86,6 +71,19 @@ class OvertakingViolation(object):
         return car_dx, car_dy
 
         # 在前车右侧超车、会车时超车、前车掉头时超车
+    def _is_objectcar_of_overtake(self, obj_df):
+        if ((self.ego_data['posX'].values - obj_df['posX'].values).any() <= self.ego_data['lane_width'].values.any()) and (
+                (self.ego_data['posY'].values - obj_df['posY'].values).any() <= 50) and ((self.ego_data['speedX' ].values *obj_df['speedX'].values).any() >= 0):
+            return obj_df
+        else:
+            return None
+
+    def _is_passingcar(self, obj_df):
+        if ((self.ego_data['posX'].values - obj_df['posX'].values).any() <= self.ego_data['lane_width'].values.any()) and (
+                (self.ego_data['posY'].values - obj_df['posY'].values).any() <= 50) and ((self.ego_data['speedX'].values * obj_df['speedX'].values).any() < 0):
+            return obj_df
+        else:
+            return None
 
     def illegal_overtake_with_car(self, window_width=250):
 
@@ -93,91 +91,117 @@ class OvertakingViolation(object):
         frame_id_length = len(self.ego_data["simFrame"])
         start_frame_id = self.ego_data["simFrame"].iloc[0]  # 获取起始点的帧数
 
-        while (start_frame_id + window_width) < frame_id_length:
-            # if start_frame_id == 828:
-            #     print("end")
-            simframe_window1 = list(
-                np.arange(start_frame_id, start_frame_id + window_width)
-            )
-            simframe_window = list(map(int, simframe_window1))
-            # 读取滑动窗口的dataframe数据
-            ego_data_frames = self.ego_data[
-                self.ego_data["simFrame"].isin(simframe_window)
-            ]
-            obj_data_frames = self.obj_data[
-                self.obj_data["simFrame"].isin(simframe_window)
-            ]
-            other_data_frames = self.other_obj_data[
-                self.other_obj_data["simFrame"].isin(simframe_window)
-            ]
-            # 读取前后的laneId
-            lane_id = ego_data_frames["lane_id"].tolist()
-            # 读取前后方向盘转角steeringWheel,
-            driverctrl_start_state = ego_data_frames["posH"].iloc[0]
-            driverctrl_end_state = ego_data_frames["posH"].iloc[-1]
-            # 读取车辆前后的位置信息
-            dx, dy = self._is_dxy_of_car(ego_data_frames, obj_data_frames)
-            ego_speedx = ego_data_frames["speedX"].tolist()
-            ego_speedy = ego_data_frames["speedY"].tolist()
-
-            obj_speedx = obj_data_frames[obj_data_frames["playerId"] == 2][
-                "speedX"
-            ].tolist()
-            obj_speedy = obj_data_frames[obj_data_frames["playerId"] == 2][
-                "speedY"
-            ].tolist()
-            if len(other_data_frames) > 0:
-                other_start_speedx = other_data_frames["speedX"].iloc[0]
-                other_start_speedy = other_data_frames["speedY"].iloc[0]
-                if (
-                    ego_speedx[0] * other_start_speedx
-                    + ego_speedy[0] * other_start_speedy
-                    < 0
-                ):
-                    self.overtake_when_passing_car_count += self._is_overtake(
-                        lane_id, dx, dy, ego_speedx, ego_speedy
-                    )
-                    start_frame_id += window_width
-            """
-            如果滑动窗口开始和最后的laneid一致;
-            方向盘转角前后方向相反(开始方向盘转角向右后来方向盘转角向左);
-            自车和前车的位置发生的交换;
-            则认为右超车
-            """
-            if driverctrl_start_state > 0 and driverctrl_end_state < 0:
-                self.overtake_on_right_count += self._is_overtake(
-                    lane_id, dx, dy, ego_speedx, ego_speedy
-                )
-                start_frame_id += window_width
-            elif (len(ego_speedx)*len(obj_speedx) > 0) and (ego_speedx[0] * obj_speedx[0] + ego_speedy[0] * obj_speedy[0] < 0):
-                self.overtake_when_turn_around_count += self._is_overtake(
-                    lane_id, dx, dy, ego_speedx, ego_speedy
-                )
-                start_frame_id += window_width
+        for i in self.data.obj_id_list:
+            if i != 1:
+                data_obj = self.data.obj_data[i]
             else:
-                start_frame_id += 1
+                continue
+            obj_datas = self._is_objectcar_of_overtake(data_obj)
+            if obj_datas is not None:
+                obj_data = (
+                    obj_datas[config.OVERTAKE_INFO].copy().reset_index(drop=True)
+                )
+                while (start_frame_id + window_width) < frame_id_length:
+
+                    simframe_window1 = list(
+                        np.arange(start_frame_id, start_frame_id + window_width)
+                    )
+                    simframe_window = list(map(int, simframe_window1))
+                    # 读取滑动窗口的dataframe数据
+                    ego_data_frames = self.ego_data[
+                        self.ego_data["simFrame"].isin(simframe_window)
+                    ]
+                    obj_data_frames = obj_data[
+                        obj_data["simFrame"].isin(simframe_window)
+                    ]
+                    # 读取前后的laneId
+                    lane_id = ego_data_frames["lane_id"].tolist()
+                    # 读取前后方向盘转角steeringWheel,
+                    driverctrl_start_state = ego_data_frames["posH"].iloc[0]
+                    driverctrl_end_state = ego_data_frames["posH"].iloc[-1]
+                    # 读取车辆前后的位置信息
+                    dx, dy = self._is_dxy_of_car(ego_data_frames, obj_data_frames)
+                    ego_speedx = ego_data_frames["speedX"].tolist()
+                    ego_speedy = ego_data_frames["speedY"].tolist()
+
+                    obj_speedx = obj_data_frames[obj_data_frames["playerId"] == i][
+                        "speedX"
+                    ].tolist()
+                    obj_speedy = obj_data_frames[obj_data_frames["playerId"] == i][
+                        "speedY"
+                    ].tolist()
+                    for j in self.data.obj_id_list:
+                        data_pass_obj = self.data.obj_data[j]
+                        obj_pass_datas = self._is_passingcar(data_pass_obj)
+                        if obj_pass_datas:
+                            other_data_frames = obj_pass_datas[
+                                obj_pass_datas["simFrame"].isin(simframe_window)
+                            ]
+                            if len(other_data_frames) > 0:
+                                other_start_speedx = other_data_frames["speedX"].iloc[0]
+                                other_start_speedy = other_data_frames["speedY"].iloc[0]
+                                if (
+                                        ego_speedx[0] * other_start_speedx
+                                        + ego_speedy[0] * other_start_speedy
+                                        < 0
+                                ):
+                                    self.overtake_when_passing_car_count += self._is_overtake(
+                                        lane_id, dx, dy, ego_speedx, ego_speedy
+                                    )
+                                    start_frame_id += window_width
+
+
+                    """
+                    如果滑动窗口开始和最后的laneid一致;
+                    方向盘转角前后方向相反(开始方向盘转角向右后来方向盘转角向左);
+                    自车和前车的位置发生的交换;
+                    则认为右超车
+                    """
+                    if driverctrl_start_state > 0 and driverctrl_end_state < 0:
+                        self.overtake_on_right_count += self._is_overtake(
+                            lane_id, dx, dy, ego_speedx, ego_speedy
+                        )
+                        start_frame_id += window_width
+                    elif (len(ego_speedx ) *len(obj_speedx) > 0) and \
+                            (ego_speedx[0] * obj_speedx[0] + ego_speedy[0] * obj_speedy[0] < 0):
+                        self.overtake_when_turn_around_count += self._is_overtake(
+                            lane_id, dx, dy, ego_speedx, ego_speedy
+                        )
+                        start_frame_id += window_width
+                    else:
+                        start_frame_id += 1
         # print(
         #     f"在会车时超车{self.overtake_when_passing_car_count}次, 右侧超车{self.overtake_on_right_count}次, 在前车掉头时超车{self.overtake_when_turn_around_count}次")
 
     # 借道超车场景
     def overtake_in_forbid_lane(self):
-        simTime = self.obj_data["simTime"].tolist()
-        simtime_devide = self.different_road_area_simtime(simTime)
-        if len(simtime_devide) == 0:
-            self.overtake_in_forbid_lane_count += 0
-            return
-        else:
-            for simtime in simtime_devide:
-                lane_overtake = self.ego_data[self.ego_data["simTime"].isin(simtime)]
-                try:
-                    lane_type = lane_overtake["lane_type"].tolist()
-                    if (50002 in lane_type and len(set(lane_type)) > 2) or (
-                        50002 not in lane_type and len(set(lane_type)) > 1
-                    ):
-                        self.overtake_in_forbid_lane_count += 1
-                except Exception as e:
-                    print("数据缺少lane_type信息")
-            # print(f"在不该占用车道超车{self.overtake_in_forbid_lane_count}次")
+        for i in self.data.obj_id_list:
+            if i != 1:
+                data_obj = self.data.obj_data[i]
+            else:
+                continue
+            obj_datas = self._is_objectcar_of_overtake(data_obj)
+            if obj_datas is not None:
+                obj_data = (
+                    obj_datas[config.OVERTAKE_INFO].copy().reset_index(drop=True)
+                )
+                simTime = obj_data["simTime"].tolist()
+                simtime_devide = self.different_road_area_simtime(simTime)
+                if len(simtime_devide) == 0:
+                    self.overtake_in_forbid_lane_count += 0
+                    return
+                else:
+                    for simtime in simtime_devide:
+                        lane_overtake = self.ego_data[self.ego_data["simTime"].isin(simtime)]
+                        try:
+                            lane_type = lane_overtake["lane_type"].tolist()
+                            if (50002 in lane_type and len(set(lane_type)) > 2) or (
+                                    50002 not in lane_type and len(set(lane_type)) > 1
+                            ):
+                                self.overtake_in_forbid_lane_count += 1
+                        except Exception as e:
+                            print("数据缺少lane_type信息")
+                    # print(f"在不该占用车道超车{self.overtake_in_forbid_lane_count}次")
 
     # 在匝道超车
     def overtake_in_ramp_area(self):
@@ -190,20 +214,30 @@ class OvertakingViolation(object):
         else:
             ramp_simTime_list = self.different_road_area_simtime(ramp_simtime_list)
             for ramp_simtime in ramp_simTime_list:
-                lane_id = self.ego_data["lane_id"].tolist()
-                ego_in_ramp = self.ego_data[self.ego_data["simTime"].isin(ramp_simtime)]
-                objstate_in_ramp = self.obj_data[
-                    self.obj_data["simTime"].isin(ramp_simtime)
-                ]
-                dx, dy = self._is_dxy_of_car(ego_in_ramp, objstate_in_ramp)
-                ego_speedx = ego_in_ramp["speedX"].tolist()
-                ego_speedy = ego_in_ramp["speedY"].tolist()
-                if len(lane_id) > 0:
-                    self.overtake_in_ramp_count += self._is_overtake(
-                        lane_id, dx, dy, ego_speedx, ego_speedy
-                    )
-                else:
-                    continue
+                for i in self.data.obj_id_list:
+                    if i != 1:
+                        data_obj = self.data.obj_data[i]
+                    else:
+                        continue
+                    obj_datas = self._is_objectcar_of_overtake(data_obj)
+                    if obj_datas is not None:
+                        obj_data = (
+                            obj_datas[config.OVERTAKE_INFO].copy().reset_index(drop=True)
+                        )
+                        lane_id = self.ego_data["lane_id"].tolist()
+                        ego_in_ramp = self.ego_data[self.ego_data["simTime"].isin(ramp_simtime)]
+                        objstate_in_ramp = obj_data[
+                            obj_data["simTime"].isin(ramp_simtime)
+                        ]
+                        dx, dy = self._is_dxy_of_car(ego_in_ramp, objstate_in_ramp)
+                        ego_speedx = ego_in_ramp["speedX"].tolist()
+                        ego_speedy = ego_in_ramp["speedY"].tolist()
+                        if len(lane_id) > 0:
+                            self.overtake_in_ramp_count += self._is_overtake(
+                                lane_id, dx, dy, ego_speedx, ego_speedy
+                            )
+                        else:
+                            continue
         # print(f"在匝道超车{self.overtake_in_ramp_count}次")
 
     def overtake_in_tunnel_area(self):
@@ -218,18 +252,28 @@ class OvertakingViolation(object):
             for tunnel_simtime in tunnel_simTime_list:
                 lane_id = self.ego_data["lane_id"].tolist()
                 ego_in_tunnel = self.ego_data[self.ego_data["simTime"].isin(tunnel_simtime)]
-                objstate_in_tunnel = self.obj_data[
-                    self.obj_data["simTime"].isin(tunnel_simtime)
-                ]
-                dx, dy = self._is_dxy_of_car(ego_in_tunnel, objstate_in_tunnel)
-                ego_speedx = ego_in_tunnel["speedX"].tolist()
-                ego_speedy = ego_in_tunnel["speedY"].tolist()
-                if len(lane_id) > 0:
-                    self.overtake_in_tunnel_count += self._is_overtake(
-                        lane_id, dx, dy, ego_speedx, ego_speedy
-                    )
-                else:
-                    continue
+                for i in self.data.obj_id_list:
+                    if i != 1:
+                        data_obj = self.data.obj_data[i]
+                    else:
+                        continue
+                    obj_datas = self._is_objectcar_of_overtake(data_obj)
+                    if obj_datas is not None:
+                        obj_data = (
+                            obj_datas[config.OVERTAKE_INFO].copy().reset_index(drop=True)
+                        )
+                        objstate_in_tunnel = obj_data[
+                            obj_data["simTime"].isin(tunnel_simtime)
+                        ]
+                        dx, dy = self._is_dxy_of_car(ego_in_tunnel, objstate_in_tunnel)
+                        ego_speedx = ego_in_tunnel["speedX"].tolist()
+                        ego_speedy = ego_in_tunnel["speedY"].tolist()
+                        if len(lane_id) > 0:
+                            self.overtake_in_tunnel_count += self._is_overtake(
+                                lane_id, dx, dy, ego_speedx, ego_speedy
+                            )
+                        else:
+                            continue
             # print(f"在隧道超车{self.overtake_in_tunnel_count}次")
 
     # 加速车道超车
@@ -249,16 +293,26 @@ class OvertakingViolation(object):
                 ego_in_accelerate = self.ego_data[
                     self.ego_data["simTime"].isin(accelerate_simtime)
                 ]
-                objstate_in_accelerate = self.obj_data[
-                    self.obj_data["simTime"].isin(accelerate_simtime)
-                ]
-                dx, dy = self._is_dxy_of_car(ego_in_accelerate, objstate_in_accelerate)
-                ego_speedx = ego_in_accelerate["speedX"].tolist()
-                ego_speedy = ego_in_accelerate["speedY"].tolist()
-
-                self.overtake_on_accelerate_lane_count += self._is_overtake(
-                    lane_id, dx, dy, ego_speedx, ego_speedy
-                )
+                for i in self.data.obj_id_list:
+                    if i != 1:
+                        data_obj = self.data.obj_data[i]
+                    else:
+                        continue
+                    obj_datas = self._is_objectcar_of_overtake(data_obj)
+                    if obj_datas is not None:
+                        obj_data = (
+                            obj_datas[config.OVERTAKE_INFO].copy().reset_index(drop=True)
+                        )
+                        objstate_in_accelerate = obj_data[
+                            obj_data["simTime"].isin(accelerate_simtime)
+                        ]
+                        dx, dy = self._is_dxy_of_car(ego_in_accelerate, objstate_in_accelerate)
+                        ego_speedx = ego_in_accelerate["speedX"].tolist()
+                        ego_speedy = ego_in_accelerate["speedY"].tolist()
+
+                        self.overtake_on_accelerate_lane_count += self._is_overtake(
+                            lane_id, dx, dy, ego_speedx, ego_speedy
+                        )
             # print(f"在加速车道超车{self.overtake_on_accelerate_lane_count}次")
 
     # 减速车道超车
@@ -278,17 +332,27 @@ class OvertakingViolation(object):
                 ego_in_decelerate = self.ego_data[
                     self.ego_data["simTime"].isin(decelerate_simtime)
                 ]
-                objstate_in_decelerate = self.obj_data[
-                    self.obj_data["simTime"].isin(decelerate_simtime)
-                ]
-                dx, dy = self._is_dxy_of_car(ego_in_decelerate, objstate_in_decelerate)
-                ego_speedx = ego_in_decelerate["speedX"].tolist()
-                ego_speedy = ego_in_decelerate["speedY"].tolist()
-
-                self.overtake_on_decelerate_lane_count += self._is_overtake(
-                    lane_id, dx, dy, ego_speedx, ego_speedy
-                )
-            # print(f"在减速车道超车{self.overtake_on_decelerate_lane_count}次")
+                for i in self.data.obj_id_list:
+                    if i != 1:
+                        data_obj = self.data.obj_data[i]
+                    else:
+                        continue
+                    obj_datas = self._is_objectcar_of_overtake(data_obj)
+                    if obj_datas is not None:
+                        obj_data = (
+                            obj_datas[config.OVERTAKE_INFO].copy().reset_index(drop=True)
+                        )
+                        objstate_in_decelerate = obj_data[
+                            obj_data["simTime"].isin(decelerate_simtime)
+                        ]
+                        dx, dy = self._is_dxy_of_car(ego_in_decelerate, objstate_in_decelerate)
+                        ego_speedx = ego_in_decelerate["speedX"].tolist()
+                        ego_speedy = ego_in_decelerate["speedY"].tolist()
+
+                        self.overtake_on_decelerate_lane_count += self._is_overtake(
+                            lane_id, dx, dy, ego_speedx, ego_speedy
+                        )
+                        # print(f"在减速车道超车{self.overtake_on_decelerate_lane_count}次")
 
     # 在交叉路口
     def overtake_in_different_senerios(self):
@@ -301,42 +365,59 @@ class OvertakingViolation(object):
         else:
             # 筛选在路口或者隧道区域的objectstate、driverctrl、laneinfo数据
             crossroad_ego = self.ego_data[self.ego_data["simTime"].isin(crossroad_simTime)]
-            crossroad_objstate = self.obj_data[
-                self.obj_data["simTime"].isin(crossroad_simTime)
-            ]
-            # crossroad_laneinfo = self.laneinfo_new_data[self.laneinfo_new_data['simTime'].isin(crossroad_simTime)]
-
-            # 读取前后的laneId
-            lane_id = crossroad_ego["lane_id"].tolist()
-
-            # 读取车辆前后的位置信息
-            dx, dy = self._is_dxy_of_car(crossroad_ego, crossroad_objstate)
-            ego_speedx = crossroad_ego["speedX"].tolist()
-            ego_speedy = crossroad_ego["speedY"].tolist()
-            """
-            如果滑动窗口开始和最后的laneid一致;
-            自车和前车的位置发生的交换;
-            则认为发生超车
-            """
-            if len(lane_id) > 0:
-                self.overtake_in_different_senerios_count += self._is_overtake(
-                    lane_id, dx, dy, ego_speedx, ego_speedy
-                )
-            else:
-                pass
-            # print(f"在路口超车{self.overtake_in_different_senerios_count}次")
+            for i in self.data.obj_id_list:
+                if i != 1:
+                    data_obj = self.data.obj_data[i]
+                else:
+                    continue
+                obj_datas = self._is_objectcar_of_overtake(data_obj)
+                if obj_datas is not None:
+                    obj_data = (
+                        obj_datas[config.OVERTAKE_INFO].copy().reset_index(drop=True)
+                    )
+                    crossroad_objstate = obj_data[
+                        obj_data["simTime"].isin(crossroad_simTime)
+                    ]
+                    # crossroad_laneinfo = self.laneinfo_new_data[self.laneinfo_new_data['simTime'].isin(crossroad_simTime)]
+
+                    # 读取前后的laneId
+                    lane_id = crossroad_ego["lane_id"].tolist()
+
+                    # 读取车辆前后的位置信息
+                    dx, dy = self._is_dxy_of_car(crossroad_ego, crossroad_objstate)
+                    ego_speedx = crossroad_ego["speedX"].tolist()
+                    ego_speedy = crossroad_ego["speedY"].tolist()
+                    """
+                    如果滑动窗口开始和最后的laneid一致;
+                    自车和前车的位置发生的交换;
+                    则认为发生超车
+                    """
+                    if len(lane_id) > 0:
+                        self.overtake_in_different_senerios_count += self._is_overtake(
+                            lane_id, dx, dy, ego_speedx, ego_speedy
+                        )
+                    else:
+                        pass
+                    # print(f"在路口超车{self.overtake_in_different_senerios_count}次")
 
     def statistic(self):
-        if len(self.obj_data) == 0:
-            pass
-        else:
-            self.overtake_in_forbid_lane()
-            self.overtake_on_decelerate_lane()
-            self.overtake_on_accelerate_lane()
-            self.overtake_in_ramp_area()
-            self.overtake_in_tunnel_area()
-            self.overtake_in_different_senerios()
-            self.illegal_overtake_with_car()
+        for i in self.data.obj_id_list:
+            if i != 1:
+                data_obj = self.data.obj_data[i]
+            else:
+                continue
+            obj_datas = self._is_objectcar_of_overtake(data_obj)
+            if obj_datas is not None:
+
+                self.overtake_in_forbid_lane()
+                self.overtake_on_decelerate_lane()
+                self.overtake_on_accelerate_lane()
+                self.overtake_in_ramp_area()
+                self.overtake_in_tunnel_area()
+                self.overtake_in_different_senerios()
+                self.illegal_overtake_with_car()
+            else:
+                pass
 
         self.calculated_value = {
             "overtake_on_right": self.overtake_on_right_count,
@@ -389,12 +470,12 @@ class SlowdownViolation(object):
             )
 
             self.ego_data["rela_pos"] = (
-                self.ego_data["dx"] * self.ego_data["speedX"]
-                + self.ego_data["dy"] * self.ego_data["speedY"]
+                    self.ego_data["dx"] * self.ego_data["speedX"]
+                    + self.ego_data["dy"] * self.ego_data["speedY"]
             )
             simtime = self.ego_data[
                 (self.ego_data["rela_pos"] > 0) & (self.ego_data["dist"] < 50)
-            ]["simTime"].tolist()
+                ]["simTime"].tolist()
             return simtime
 
     def different_road_area_simtime(self, df, threshold=0.6):
@@ -433,12 +514,12 @@ class SlowdownViolation(object):
                 crosswalk_objstate = self.ego_data[
                     (self.ego_data["simTime"] >= start_time)
                     & (self.ego_data["simTime"] <= end_time)
-                ]
+                    ]
 
                 # 计算车辆速度
                 ego_speedx = np.array(crosswalk_objstate["speedX"].tolist())
                 ego_speedy = np.array(crosswalk_objstate["speedY"].tolist())
-                ego_speed = np.sqrt(ego_speedx**2 + ego_speedy**2)
+                ego_speed = np.sqrt(ego_speedx ** 2 + ego_speedy ** 2)
 
                 # 判断是否超速
                 if max(ego_speed) > 15 / 3.6:  # 15 km/h 转换为 m/s
@@ -498,8 +579,8 @@ class SlowdownViolation(object):
                         (ego_car["posX"].values - sub_pedestrian_on_the_road["posX"].values)
                         ** 2
                         + (
-                            ego_car["posY"].values
-                            - sub_pedestrian_on_the_road["posY"].values
+                                ego_car["posY"].values
+                                - sub_pedestrian_on_the_road["posY"].values
                         )
                         ** 2
                     )
@@ -520,7 +601,7 @@ class SlowdownViolation(object):
             simtime_list = self.ego_data[
                 (self.ego_data["simTime"].isin(pedestrian_simtime_list))
                 & (self.ego_data["lane_type"] == 20)
-            ]["simTime"].tolist()
+                ]["simTime"].tolist()
             simTime_list = self.different_road_area_simtime(simtime_list)
             pedestrian_on_the_road = self.pedestrian_data[
                 self.pedestrian_data["simTime"].isin(simtime_list)
@@ -538,8 +619,8 @@ class SlowdownViolation(object):
                             (ego_car["posX"].values - sub_pedestrian_on_the_road["posX"].values)
                             ** 2
                             + (
-                                ego_car["posY"].values
-                                - sub_pedestrian_on_the_road["posY"].values
+                                    ego_car["posY"].values
+                                    - sub_pedestrian_on_the_road["posY"].values
                             )
                             ** 2
                         )
@@ -599,12 +680,12 @@ class TurnaroundViolation(object):
             )
 
             self.ego_data["rela_pos"] = (
-                self.ego_data["dx"] * self.ego_data["speedX"]
-                + self.ego_data["dy"] * self.ego_data["speedY"]
+                    self.ego_data["dx"] * self.ego_data["speedX"]
+                    + self.ego_data["dy"] * self.ego_data["speedY"]
             )
             simtime = self.ego_data[
                 (self.ego_data["rela_pos"] > 0) & (self.ego_data["dist"] < 50)
-            ]["simTime"].tolist()
+                ]["simTime"].tolist()
             return simtime
 
     def different_road_area_simtime(self, df, threshold=0.5):
@@ -653,9 +734,9 @@ class TurnaroundViolation(object):
                 ego_end_speedy1 = ego_car1["speedY"].iloc[-1]
 
                 if (
-                    ego_end_speedx1 * ego_start_speedx1
-                    + ego_end_speedy1 * ego_start_speedy1
-                    < 0
+                        ego_end_speedx1 * ego_start_speedx1
+                        + ego_end_speedy1 * ego_start_speedy1
+                        < 0
                 ):
                     self.turning_in_forbiden_turn_back_sign_count += 1
         if len(forbiden_turn_left_simtime_devide) == 0:
@@ -671,9 +752,9 @@ class TurnaroundViolation(object):
                 ego_end_speedy2 = ego_car2["speedY"].iloc[-1]
 
                 if (
-                    ego_end_speedx2 * ego_start_speedx2
-                    + ego_end_speedy2 * ego_start_speedy2
-                    < 0
+                        ego_end_speedx2 * ego_start_speedx2
+                        + ego_end_speedy2 * ego_start_speedy2
+                        < 0
                 ):
                     self.turning_in_forbiden_turn_left_sign_count += 1
 
@@ -681,7 +762,7 @@ class TurnaroundViolation(object):
         sensor_on_intersection = self.pedestrian_in_front_of_car()
         avoid_pedestrian_when_turn_back_simTime_list = self.ego_data[
             self.ego_data["lane_type"] == 20
-        ]["simTime"].tolist()
+            ]["simTime"].tolist()
         avoid_pedestrian_when_turn_back_simTime_devide = (
             self.different_road_area_simtime(
                 avoid_pedestrian_when_turn_back_simTime_list
@@ -689,7 +770,7 @@ class TurnaroundViolation(object):
         )
         if (len(sensor_on_intersection) > 0) and (len(avoid_pedestrian_when_turn_back_simTime_devide) > 0):
             for (
-                avoid_pedestrian_when_turn_back_simtime
+                    avoid_pedestrian_when_turn_back_simtime
             ) in avoid_pedestrian_when_turn_back_simTime_devide:
                 pedestrian_in_intersection_simtime = self.pedestrian_data[
                     self.pedestrian_data["simTime"].isin(
@@ -750,19 +831,19 @@ class WrongWayViolation:
         # 使用向量化和条件判断进行违规判定
         conditions = [
             (
-                self.data["road_fc"].isin(urban_expressway_or_highway)
-                & self.data["lane_type"].isin(driving_lane)
-                & (self.data["v"] == 0)
+                    self.data["road_fc"].isin(urban_expressway_or_highway)
+                    & self.data["lane_type"].isin(driving_lane)
+                    & (self.data["v"] == 0)
             ),
             (
-                self.data["road_fc"].isin(urban_expressway_or_highway)
-                & self.data["lane_type"].isin(emergency_lane)
-                & (self.data["v"] == 0)
+                    self.data["road_fc"].isin(urban_expressway_or_highway)
+                    & self.data["lane_type"].isin(emergency_lane)
+                    & (self.data["v"] == 0)
             ),
             (
-                self.data["road_fc"].isin(urban_expressway_or_highway)
-                & self.data["lane_type"].isin(emergency_lane)
-                & (self.data["v"] != 0)
+                    self.data["road_fc"].isin(urban_expressway_or_highway)
+                    & self.data["lane_type"].isin(emergency_lane)
+                    & (self.data["v"] != 0)
             ),
         ]
 
@@ -786,11 +867,11 @@ class WrongWayViolation:
         )
 
     def statistic(self) -> str:
-
         self.process_violations()
         # self.logger.info(f"停车违规类指标统计完成,统计结果:{self.violation_count}")
         return self.violation_count
 
+
 class SpeedingViolation(object):
     """超速违规类"""
 
@@ -815,38 +896,37 @@ class SpeedingViolation(object):
         # 提取有效道路类型
         urban_expressway_or_highway = {1, 2}  # 使用大括号直接创建集合
         general_road = {3}  # 直接创建包含一个元素的集合
-        
-        
+
         # 转换速度
         self.data["v"] *= 3.6  # 转换速度
-        
+
         conditions = [
             (
-                self.data["road_fc"].isin(urban_expressway_or_highway)
-                & (self.data["v"] > self.data["road_speed_max"] * 1.5)
+                    self.data["road_fc"].isin(urban_expressway_or_highway)
+                    & (self.data["v"] > self.data["road_speed_max"] * 1.5)
             ),
             (
-                self.data["road_fc"].isin(urban_expressway_or_highway)
-                & (self.data["v"] > self.data["road_speed_max"] * 1.2)
-                & (self.data["v"] <= self.data["road_speed_max"] * 1.5)
+                    self.data["road_fc"].isin(urban_expressway_or_highway)
+                    & (self.data["v"] > self.data["road_speed_max"] * 1.2)
+                    & (self.data["v"] <= self.data["road_speed_max"] * 1.5)
             ),
             (
-                self.data["road_fc"].isin(urban_expressway_or_highway)
-                & (self.data["v"] > self.data["road_speed_max"])
-                & (self.data["v"] <= self.data["road_speed_max"] * 1.2)
+                    self.data["road_fc"].isin(urban_expressway_or_highway)
+                    & (self.data["v"] > self.data["road_speed_max"])
+                    & (self.data["v"] <= self.data["road_speed_max"] * 1.2)
             ),
             (
-                self.data["road_fc"].isin(urban_expressway_or_highway)
-                & (self.data["v"] < self.data["road_speed_min"])
+                    self.data["road_fc"].isin(urban_expressway_or_highway)
+                    & (self.data["v"] < self.data["road_speed_min"])
             ),
             (
-                self.data["road_fc"].isin(general_road)
-                & (self.data["v"] > self.data["road_speed_max"] * 1.5)
+                    self.data["road_fc"].isin(general_road)
+                    & (self.data["v"] > self.data["road_speed_max"] * 1.5)
             ),
             (
-                self.data["road_fc"].isin(general_road)
-                & (self.data["v"] > self.data["road_speed_max"] * 1.2)
-                & (self.data["v"] <= self.data["road_speed_max"] * 1.5)
+                    self.data["road_fc"].isin(general_road)
+                    & (self.data["v"] > self.data["road_speed_max"] * 1.2)
+                    & (self.data["v"] <= self.data["road_speed_max"] * 1.5)
             ),
         ]
 
@@ -866,8 +946,6 @@ class SpeedingViolation(object):
 
         # 统计各类违规情况
         self.violation_counts = self.data["violation_type"].value_counts().to_dict()
-        
-       
 
     # 添加statistic方法
     def statistic(self):
@@ -877,6 +955,7 @@ class SpeedingViolation(object):
         print(f"超速违规类指标统计完成,统计结果:{self.violation_counts}")
         return self.violation_counts
 
+
 class TrafficLightViolation(object):
     """违反交通灯类"""
 
@@ -921,8 +1000,8 @@ class TrafficLightViolation(object):
             return False
 
         mid_point = (
-            np.array([stop_line_points[0][0], stop_line_points[0][1]])
-            + 0.5 * line_vector
+                np.array([stop_line_points[0][0], stop_line_points[0][1]])
+                + 0.5 * line_vector
         )
         axletree_to_mid_vector = np.array(
             [point[0] - mid_point[0], point[1] - mid_point[1]]
@@ -936,7 +1015,7 @@ class TrafficLightViolation(object):
             return False
 
         cos_theta = np.dot(axletree_to_mid_vector, direction_vector) / (
-            norm_axletree_to_mid * norm_direction
+                norm_axletree_to_mid * norm_direction
         )
         angle_theta = math.degrees(math.acos(cos_theta))
 
@@ -948,7 +1027,7 @@ class TrafficLightViolation(object):
             (self.data_ego["stopline_id"] != -1)
             & (self.data_ego["stopline_type"] == 1)
             & (self.data_ego["trafficlight_id"] != -1)
-        ]
+            ]
 
     def _group_data(self, filtered_data):
         """按时间差对数据进行分组"""
@@ -983,13 +1062,13 @@ class TrafficLightViolation(object):
 
             if abs(row["speedH"]) > 0.01 or abs(row["speedH"]) < 0.01:
                 has_crossed_line_front = (
-                    self.is_point_cross_line(front_wheel_pos, stop_line_points)
-                    and stateMask == 1
+                        self.is_point_cross_line(front_wheel_pos, stop_line_points)
+                        and stateMask == 1
                 )
                 has_crossed_line_rear = (
-                    self.is_point_cross_line(rear_wheel_pos, stop_line_points)
-                    and row["v"] > 0
-                    and stateMask == 1
+                        self.is_point_cross_line(rear_wheel_pos, stop_line_points)
+                        and row["v"] > 0
+                        and stateMask == 1
                 )
                 has_stop_in_intersection = has_crossed_line_front and row["v"] == 0
                 has_passed_intersection = has_crossed_line_front and dist < 1.0
@@ -1118,6 +1197,7 @@ class WarningViolation(object):
         # self.logger.info(f"警告性违规类指标统计完成,统计结果:{self.violation_counts}")
         return self.violation_counts
 
+
 class TrafficSignViolation(object):
     """交通标志违规类"""
 
@@ -1133,9 +1213,10 @@ class TrafficSignViolation(object):
             "NoStraightThrough": 0,  # 禁止直行标志地方直行
             "SpeedLimitViolation": 0,  # 违反限速规定
             "MinimumSpeedLimitViolation": 0,  # 违反最低限速规定
-        }   
+        }
+
+        # def checkForProhibitionViolation(self):
 
-    # def checkForProhibitionViolation(self):
     #     """禁令标志判断违规:7 禁止直行,12:限制速度"""
     #     # 筛选出sign_type1为7(禁止直行)
     #     violation_straight_df = self.data_ego[self.data_ego["sign_type1"] == 7]
@@ -1145,24 +1226,23 @@ class TrafficSignViolation(object):
         """禁令标志判断违规:7 禁止直行,12:限制速度"""
         # 筛选出 sign_type1 为7(禁止直行)的数据
         violation_straight_df = self.data_ego[self.data_ego["sign_type1"] == 7].copy()
-        
+
         # 判断车辆是否在禁止直行路段直行
         if not violation_straight_df.empty:
             # 按时间戳排序(假设数据按时间顺序处理)
             violation_straight_df = violation_straight_df.sort_values('simTime')
-            
+
             # 计算航向角变化(前后时间点的差值绝对值)
             violation_straight_df['posH_diff'] = violation_straight_df['posH'].diff().abs()
-            
+
             # 筛选条件:航向角变化小于阈值(例如5度)且速度不为0
             threshold = 5  # 单位:度(根据场景调整)
             mask = (violation_straight_df['posH_diff'] <= threshold) & (violation_straight_df['v'] > 0)
             straight_violations = violation_straight_df[mask]
-            
+
             # 统计违规次数或记录违规数据
             self.violation_counts["prohibition_straight"] = len(straight_violations)
-            
-        
+
         # 限制速度判断(原代码)
         violation_speed_limit_df = self.data_ego[self.data_ego["sign_type1"] == 12]
         if violation_speed_limit_df.empty:
@@ -1175,6 +1255,7 @@ class TrafficSignViolation(object):
         if violation_minimum_speed_limit_df.empty:
             mask = self.data_ego["v"] < self.data_ego["sign_speed"]
             self.violation_counts["MinimumSpeedLimitViolation"] = len(self.data_ego[mask])
+
     def statistic(self):
         self.checkForProhibitionViolation()
         self.checkForInstructionViolation()
@@ -1186,7 +1267,6 @@ class ViolationManager:
     """违规管理类,用于管理所有违规行为"""
 
     def __init__(self, data_processed):
-
         self.violations = []
         self.data = data_processed
         self.config = data_processed.traffic_config
@@ -1201,7 +1281,6 @@ class ViolationManager:
         # self.report_statistic()
 
     def report_statistic(self):
-
         traffic_result = self.over_take_violation.statistic()
         traffic_result.update(self.slow_down_violation.statistic())
         traffic_result.update(self.traffic_light_violation.statistic())
@@ -1209,7 +1288,6 @@ class ViolationManager:
         traffic_result.update(self.speeding_violation.statistic())
         traffic_result.update(self.warning_violation.statistic())
 
-
         evaluator = Score(self.config)
         result = evaluator.evaluate(traffic_result)
 
@@ -1220,4 +1298,4 @@ class ViolationManager:
 
 # 示例使用
 if __name__ == "__main__":
-    pass
+    pass