From dd6fd559e122d011738c710459ce6bd5f7562cdf Mon Sep 17 00:00:00 2001 From: QuanyiLi Date: Tue, 19 Sep 2023 22:39:04 +0100 Subject: [PATCH] format --- metadrive/component/vehicle/base_vehicle.py | 40 +++++++++---------- metadrive/envs/base_env.py | 23 +++++------ metadrive/envs/scenario_env.py | 4 +- metadrive/manager/scenario_traffic_manager.py | 12 +++--- metadrive/obs/top_down_obs_impl.py | 10 ++++- metadrive/obs/top_down_renderer.py | 5 +-- 6 files changed, 47 insertions(+), 47 deletions(-) diff --git a/metadrive/component/vehicle/base_vehicle.py b/metadrive/component/vehicle/base_vehicle.py index 138cc2b53..f53515d08 100644 --- a/metadrive/component/vehicle/base_vehicle.py +++ b/metadrive/component/vehicle/base_vehicle.py @@ -110,12 +110,12 @@ class BaseVehicle(BaseObject, BaseVehicleState): path = None def __init__( - self, - vehicle_config: Union[dict, Config] = None, - name: str = None, - random_seed=None, - position=None, - heading=None + self, + vehicle_config: Union[dict, Config] = None, + name: str = None, + random_seed=None, + position=None, + heading=None ): """ This Vehicle Config is different from self.get_config(), and it is used to define which modules to use, and @@ -266,14 +266,14 @@ def _update_energy_consumption(self): return step_energy, self.energy_consumption def reset( - self, - vehicle_config=None, - name=None, - random_seed=None, - position: np.ndarray = None, - heading: float = 0.0, - *args, - **kwargs + self, + vehicle_config=None, + name=None, + random_seed=None, + position: np.ndarray = None, + heading: float = 0.0, + *args, + **kwargs ): """ pos is a 2-d array, and heading is a float (unit degree) @@ -529,8 +529,8 @@ def heading_diff(self, target_lane): if not lateral_norm * forward_direction_norm: return 0 cos = ( - (forward_direction[0] * lateral[0] + forward_direction[1] * lateral[1]) / - (lateral_norm * forward_direction_norm) + (forward_direction[0] * lateral[0] + forward_direction[1] * lateral[1]) / + (lateral_norm * forward_direction_norm) ) # return cos # Normalize to 0, 1 @@ -850,7 +850,7 @@ def _update_overtake_stat(self): ckpt_idx = routing._target_checkpoints_index for surrounding_v in surrounding_vs: if surrounding_v.lane_index[:-1] == (routing.checkpoints[ckpt_idx[0]], routing.checkpoints[ckpt_idx[1] - ]): + ]): if self.lane.local_coordinates(self.position)[0] - \ self.lane.local_coordinates(surrounding_v.position)[0] < 0: self.front_vehicles.add(surrounding_v) @@ -885,9 +885,9 @@ def overspeed(self): @property def replay_done(self): return self._replay_done if hasattr(self, "_replay_done") else ( - self.crash_building or self.crash_vehicle or - # self.on_white_continuous_line or - self.on_yellow_continuous_line + self.crash_building or self.crash_vehicle or + # self.on_white_continuous_line or + self.on_yellow_continuous_line ) @property diff --git a/metadrive/envs/base_env.py b/metadrive/envs/base_env.py index 5fd301296..d4ecd4e5e 100644 --- a/metadrive/envs/base_env.py +++ b/metadrive/envs/base_env.py @@ -151,7 +151,7 @@ # ) # These sensors will be constructed automatically and can be accessed in engine.get_sensor("sensor_name") # NOTE: main_camera will be added automatically if you are using offscreen/onscreen mode - sensors=dict(lidar=(Lidar, 50), side_detector=(SideDetector,), lane_line_detector=(LaneLineDetector,)), + sensors=dict(lidar=(Lidar, 50), side_detector=(SideDetector, ), lane_line_detector=(LaneLineDetector, )), # when main_camera is not the image_source for vehicle, reduce the window size to (1,1) for boosting efficiency auto_resize_window=True, @@ -296,7 +296,7 @@ def _post_process_config(self, config): if not config["render_pipeline"]: for panel in config["interface_panel"]: if panel == "dashboard": - config["sensors"]["dashboard"] = (VehiclePanel,) + config["sensors"]["dashboard"] = (VehiclePanel, ) if panel not in config["sensors"]: self.logger.warning( "Fail to add sensor: {} to the interface. Remove it from panel list!".format(panel) @@ -712,20 +712,19 @@ def episode_step(self): return self.engine.episode_step if self.engine is not None else 0 def export_scenarios( - self, - policies: Union[dict, Callable], - scenario_index: Union[list, int], - max_episode_length=None, - verbose=False, - suppress_warning=False, - render_topdown=False, - return_done_info=True, - to_dict=True + self, + policies: Union[dict, Callable], + scenario_index: Union[list, int], + max_episode_length=None, + verbose=False, + suppress_warning=False, + render_topdown=False, + return_done_info=True, + to_dict=True ): """ We export scenarios into a unified format with 10hz sample rate """ - def _act(observation): if isinstance(policies, dict): ret = {} diff --git a/metadrive/envs/scenario_env.py b/metadrive/envs/scenario_env.py index 457e88cbd..84ed9ef50 100644 --- a/metadrive/envs/scenario_env.py +++ b/metadrive/envs/scenario_env.py @@ -243,8 +243,8 @@ def msg(reason): # for compatibility # crash almost equals to crashing with vehicles done_info[TerminationState.CRASH] = ( - done_info[TerminationState.CRASH_VEHICLE] or done_info[TerminationState.CRASH_OBJECT] - or done_info[TerminationState.CRASH_BUILDING] + done_info[TerminationState.CRASH_VEHICLE] or done_info[TerminationState.CRASH_OBJECT] + or done_info[TerminationState.CRASH_BUILDING] ) # log data to curriculum manager diff --git a/metadrive/manager/scenario_traffic_manager.py b/metadrive/manager/scenario_traffic_manager.py index de7dca17a..5977809e0 100644 --- a/metadrive/manager/scenario_traffic_manager.py +++ b/metadrive/manager/scenario_traffic_manager.py @@ -197,14 +197,12 @@ def spawn_vehicle(self, v_id, track): v_cfg["top_down_length"] = track["state"]["length"][self.episode_step] v_cfg["top_down_width"] = track["state"]["width"][self.episode_step] if v_cfg["top_down_length"] < 1 or v_cfg["top_down_width"] < 0.5: - logger.warning("Scenario ID: {}. The top_down size of vehicle {} is weird: " - "{}".format(self.engine.current_seed, v_id, [v_cfg["length"], v_cfg["width"]])) + logger.warning( + "Scenario ID: {}. The top_down size of vehicle {} is weird: " + "{}".format(self.engine.current_seed, v_id, [v_cfg["length"], v_cfg["width"]]) + ) v = self.spawn_object( - vehicle_class, - position=state["position"], - heading=state["heading"], - vehicle_config=v_cfg, - name=obj_name + vehicle_class, position=state["position"], heading=state["heading"], vehicle_config=v_cfg, name=obj_name ) self._scenario_id_to_obj_id[v_id] = v.name self._obj_id_to_scenario_id[v.name] = v_id diff --git a/metadrive/obs/top_down_obs_impl.py b/metadrive/obs/top_down_obs_impl.py index ec956ab69..56239f81d 100644 --- a/metadrive/obs/top_down_obs_impl.py +++ b/metadrive/obs/top_down_obs_impl.py @@ -197,8 +197,14 @@ class ObjectGraphics: @classmethod def display( - cls, object: history_object, surface, color, heading, label: bool = False, draw_countour=False, - contour_width=1 + cls, + object: history_object, + surface, + color, + heading, + label: bool = False, + draw_countour=False, + contour_width=1 ) -> None: """ Display a vehicle on a pygame surface. diff --git a/metadrive/obs/top_down_renderer.py b/metadrive/obs/top_down_renderer.py index b497a0e0b..b3e4c09c6 100644 --- a/metadrive/obs/top_down_renderer.py +++ b/metadrive/obs/top_down_renderer.py @@ -21,7 +21,6 @@ color_white = (255, 255, 255) - def draw_top_down_map( map, resolution: Iterable = (512, 512), @@ -415,9 +414,7 @@ def _draw(self, *args, **kwargs): c = TopDownSemanticColor.get_color(v.type, True) * (1 - alpha_f) + alpha_f * 255 else: c = (c[0] + alpha_f * (255 - c[0]), c[1] + alpha_f * (255 - c[1]), c[2] + alpha_f * (255 - c[2])) - ObjectGraphics.display( - object=v, surface=self._runtime_canvas, heading=h, color=c, draw_countour=False - ) + ObjectGraphics.display(object=v, surface=self._runtime_canvas, heading=h, color=c, draw_countour=False) # Draw the whole trajectory of ego vehicle with no gradient colors: if self.draw_target_vehicle_trajectory: