Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanyiLi committed Sep 19, 2023
1 parent 7d4ad31 commit dd6fd55
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 47 deletions.
40 changes: 20 additions & 20 deletions metadrive/component/vehicle/base_vehicle.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ class BaseVehicle(BaseObject, BaseVehicleState):
path = None

def __init__(
self,
vehicle_config: Union[dict, Config] = None,
name: str = None,
random_seed=None,
position=None,
heading=None
self,
vehicle_config: Union[dict, Config] = None,
name: str = None,
random_seed=None,
position=None,
heading=None
):
"""
This Vehicle Config is different from self.get_config(), and it is used to define which modules to use, and
Expand Down Expand Up @@ -266,14 +266,14 @@ def _update_energy_consumption(self):
return step_energy, self.energy_consumption

def reset(
self,
vehicle_config=None,
name=None,
random_seed=None,
position: np.ndarray = None,
heading: float = 0.0,
*args,
**kwargs
self,
vehicle_config=None,
name=None,
random_seed=None,
position: np.ndarray = None,
heading: float = 0.0,
*args,
**kwargs
):
"""
pos is a 2-d array, and heading is a float (unit degree)
Expand Down Expand Up @@ -529,8 +529,8 @@ def heading_diff(self, target_lane):
if not lateral_norm * forward_direction_norm:
return 0
cos = (
(forward_direction[0] * lateral[0] + forward_direction[1] * lateral[1]) /
(lateral_norm * forward_direction_norm)
(forward_direction[0] * lateral[0] + forward_direction[1] * lateral[1]) /
(lateral_norm * forward_direction_norm)
)
# return cos
# Normalize to 0, 1
Expand Down Expand Up @@ -850,7 +850,7 @@ def _update_overtake_stat(self):
ckpt_idx = routing._target_checkpoints_index
for surrounding_v in surrounding_vs:
if surrounding_v.lane_index[:-1] == (routing.checkpoints[ckpt_idx[0]], routing.checkpoints[ckpt_idx[1]
]):
]):
if self.lane.local_coordinates(self.position)[0] - \
self.lane.local_coordinates(surrounding_v.position)[0] < 0:
self.front_vehicles.add(surrounding_v)
Expand Down Expand Up @@ -885,9 +885,9 @@ def overspeed(self):
@property
def replay_done(self):
return self._replay_done if hasattr(self, "_replay_done") else (
self.crash_building or self.crash_vehicle or
# self.on_white_continuous_line or
self.on_yellow_continuous_line
self.crash_building or self.crash_vehicle or
# self.on_white_continuous_line or
self.on_yellow_continuous_line
)

@property
Expand Down
23 changes: 11 additions & 12 deletions metadrive/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
# )
# These sensors will be constructed automatically and can be accessed in engine.get_sensor("sensor_name")
# NOTE: main_camera will be added automatically if you are using offscreen/onscreen mode
sensors=dict(lidar=(Lidar, 50), side_detector=(SideDetector,), lane_line_detector=(LaneLineDetector,)),
sensors=dict(lidar=(Lidar, 50), side_detector=(SideDetector, ), lane_line_detector=(LaneLineDetector, )),

# when main_camera is not the image_source for vehicle, reduce the window size to (1,1) for boosting efficiency
auto_resize_window=True,
Expand Down Expand Up @@ -296,7 +296,7 @@ def _post_process_config(self, config):
if not config["render_pipeline"]:
for panel in config["interface_panel"]:
if panel == "dashboard":
config["sensors"]["dashboard"] = (VehiclePanel,)
config["sensors"]["dashboard"] = (VehiclePanel, )
if panel not in config["sensors"]:
self.logger.warning(
"Fail to add sensor: {} to the interface. Remove it from panel list!".format(panel)
Expand Down Expand Up @@ -712,20 +712,19 @@ def episode_step(self):
return self.engine.episode_step if self.engine is not None else 0

def export_scenarios(
self,
policies: Union[dict, Callable],
scenario_index: Union[list, int],
max_episode_length=None,
verbose=False,
suppress_warning=False,
render_topdown=False,
return_done_info=True,
to_dict=True
self,
policies: Union[dict, Callable],
scenario_index: Union[list, int],
max_episode_length=None,
verbose=False,
suppress_warning=False,
render_topdown=False,
return_done_info=True,
to_dict=True
):
"""
We export scenarios into a unified format with 10hz sample rate
"""

def _act(observation):
if isinstance(policies, dict):
ret = {}
Expand Down
4 changes: 2 additions & 2 deletions metadrive/envs/scenario_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def msg(reason):
# for compatibility
# crash almost equals to crashing with vehicles
done_info[TerminationState.CRASH] = (
done_info[TerminationState.CRASH_VEHICLE] or done_info[TerminationState.CRASH_OBJECT]
or done_info[TerminationState.CRASH_BUILDING]
done_info[TerminationState.CRASH_VEHICLE] or done_info[TerminationState.CRASH_OBJECT]
or done_info[TerminationState.CRASH_BUILDING]
)

# log data to curriculum manager
Expand Down
12 changes: 5 additions & 7 deletions metadrive/manager/scenario_traffic_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,12 @@ def spawn_vehicle(self, v_id, track):
v_cfg["top_down_length"] = track["state"]["length"][self.episode_step]
v_cfg["top_down_width"] = track["state"]["width"][self.episode_step]
if v_cfg["top_down_length"] < 1 or v_cfg["top_down_width"] < 0.5:
logger.warning("Scenario ID: {}. The top_down size of vehicle {} is weird: "
"{}".format(self.engine.current_seed, v_id, [v_cfg["length"], v_cfg["width"]]))
logger.warning(
"Scenario ID: {}. The top_down size of vehicle {} is weird: "
"{}".format(self.engine.current_seed, v_id, [v_cfg["length"], v_cfg["width"]])
)
v = self.spawn_object(
vehicle_class,
position=state["position"],
heading=state["heading"],
vehicle_config=v_cfg,
name=obj_name
vehicle_class, position=state["position"], heading=state["heading"], vehicle_config=v_cfg, name=obj_name
)
self._scenario_id_to_obj_id[v_id] = v.name
self._obj_id_to_scenario_id[v.name] = v_id
Expand Down
10 changes: 8 additions & 2 deletions metadrive/obs/top_down_obs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,14 @@ class ObjectGraphics:

@classmethod
def display(
cls, object: history_object, surface, color, heading, label: bool = False, draw_countour=False,
contour_width=1
cls,
object: history_object,
surface,
color,
heading,
label: bool = False,
draw_countour=False,
contour_width=1
) -> None:
"""
Display a vehicle on a pygame surface.
Expand Down
5 changes: 1 addition & 4 deletions metadrive/obs/top_down_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
color_white = (255, 255, 255)



def draw_top_down_map(
map,
resolution: Iterable = (512, 512),
Expand Down Expand Up @@ -415,9 +414,7 @@ def _draw(self, *args, **kwargs):
c = TopDownSemanticColor.get_color(v.type, True) * (1 - alpha_f) + alpha_f * 255
else:
c = (c[0] + alpha_f * (255 - c[0]), c[1] + alpha_f * (255 - c[1]), c[2] + alpha_f * (255 - c[2]))
ObjectGraphics.display(
object=v, surface=self._runtime_canvas, heading=h, color=c, draw_countour=False
)
ObjectGraphics.display(object=v, surface=self._runtime_canvas, heading=h, color=c, draw_countour=False)

# Draw the whole trajectory of ego vehicle with no gradient colors:
if self.draw_target_vehicle_trajectory:
Expand Down

0 comments on commit dd6fd55

Please sign in to comment.