diff --git a/metadrive/component/sensors/base_camera.py b/metadrive/component/sensors/base_camera.py index 4c1fec6c4..051bd9026 100644 --- a/metadrive/component/sensors/base_camera.py +++ b/metadrive/component/sensors/base_camera.py @@ -30,7 +30,7 @@ class BaseCamera(ImageBuffer, BaseSensor): CAM_MASK = None attached_object = None - num_channels=3 + num_channels = 3 def __init__(self, engine, need_cuda=False, frame_buffer_property=None): self._enable_cuda = need_cuda @@ -111,7 +111,7 @@ def perceive(self, base_object, clip=True) -> np.ndarray: self.track(base_object) if self.enable_cuda: assert self.cuda_rendered_result is not None - ret = self.cuda_rendered_result[..., :-1][..., ::-1][::-1][...,:self.num_channels] + ret = self.cuda_rendered_result[..., :-1][..., ::-1][::-1][..., :self.num_channels] else: ret = self.get_rgb_array_cpu() if self.engine.global_config["rgb_to_grayscale"]: diff --git a/metadrive/engine/core/image_buffer.py b/metadrive/engine/core/image_buffer.py index ed3a343b2..5a3fd12de 100644 --- a/metadrive/engine/core/image_buffer.py +++ b/metadrive/engine/core/image_buffer.py @@ -29,14 +29,14 @@ class ImageBuffer: num_channels = 3 def __init__( - self, - width: float, - height: float, - pos: Vec3, - bkg_color: Union[Vec4, Vec3], - parent_node: NodePath = None, - frame_buffer_property=None, - engine=None + self, + width: float, + height: float, + pos: Vec3, + bkg_color: Union[Vec4, Vec3], + parent_node: NodePath = None, + frame_buffer_property=None, + engine=None ): self.logger = get_logger() self._node_path_list = [] diff --git a/metadrive/engine/core/main_camera.py b/metadrive/engine/core/main_camera.py index 5d4b5f244..f6e022301 100644 --- a/metadrive/engine/core/main_camera.py +++ b/metadrive/engine/core/main_camera.py @@ -41,7 +41,7 @@ class MainCamera(BaseSensor): MOUSE_MOVE_INTO_LATENCY = 2 MOUSE_SPEED_MULTIPLIER = 1 - num_channels=3 + num_channels = 3 def __init__(self, engine, camera_height: float, camera_dist: float): self._origin_height = camera_height diff --git a/metadrive/obs/image_obs.py b/metadrive/obs/image_obs.py index c4f23e36d..bd611359c 100644 --- a/metadrive/obs/image_obs.py +++ b/metadrive/obs/image_obs.py @@ -63,7 +63,7 @@ def observation_space(self): assert sensor_cls == "MainCamera" or issubclass(sensor_cls, BaseCamera), "Sensor should be BaseCamera" channel = sensor_cls.num_channels if sensor_cls != "MainCamera" else 3 shape = (self.config["sensors"][self.image_source][2], self.config["sensors"][self.image_source][1] - ) + ((self.STACK_SIZE,) if self.config["rgb_to_grayscale"] else (channel, self.STACK_SIZE)) + ) + ((self.STACK_SIZE, ) if self.config["rgb_to_grayscale"] else (channel, self.STACK_SIZE)) if self.rgb_clip: return gym.spaces.Box(-0.0, 1.0, shape=shape, dtype=np.float32) else: diff --git a/metadrive/tests/test_sensors/test_depth_cam.py b/metadrive/tests/test_sensors/test_depth_cam.py index 8e2261311..ccd5a1759 100644 --- a/metadrive/tests/test_sensors/test_depth_cam.py +++ b/metadrive/tests/test_sensors/test_depth_cam.py @@ -48,14 +48,13 @@ def test_depth_cam(config, render=False): o, r, tm, tc, info = env.step([0, 1]) assert env.observation_space.contains(o) # Reverse - assert o["image"].shape == (config["height"], - config["width"], - DepthCamera.num_channels, - config["stack_size"]) + assert o["image"].shape == ( + config["height"], config["width"], DepthCamera.num_channels, config["stack_size"] + ) if render: cv2.imshow('img', o["image"][..., -1]) cv2.waitKey(1) - print("FPS:", 10/(time.time() - start)) + print("FPS:", 10 / (time.time() - start)) finally: env.close() diff --git a/metadrive/tests/test_sensors/test_semantic_cam.py b/metadrive/tests/test_sensors/test_semantic_cam.py index b645e1b28..4bfc4794c 100644 --- a/metadrive/tests/test_sensors/test_semantic_cam.py +++ b/metadrive/tests/test_sensors/test_semantic_cam.py @@ -48,8 +48,9 @@ def test_semantic_cam(config, render=False): o, r, tm, tc, info = env.step([0, 1]) assert env.observation_space.contains(o) # Reverse - assert o["image"].shape == (config["height"], config["width"], - SemanticCamera.num_channels, config["stack_size"]) + assert o["image"].shape == ( + config["height"], config["width"], SemanticCamera.num_channels, config["stack_size"] + ) if render: cv2.imshow('img', o["image"][..., -1]) cv2.waitKey(1) diff --git a/metadrive/tests/vis_functionality/vis_depth_cam.py b/metadrive/tests/vis_functionality/vis_depth_cam.py index aba875aef..960f628ea 100644 --- a/metadrive/tests/vis_functionality/vis_depth_cam.py +++ b/metadrive/tests/vis_functionality/vis_depth_cam.py @@ -26,7 +26,6 @@ ) env.reset() - def get_image(env): depth_cam = env.vehicle.get_camera(env.vehicle.config["image_source"]) rgb_cam = env.vehicle.get_camera("rgb_camera") @@ -38,7 +37,6 @@ def get_image(env): rgb_cam.save_image(env.vehicle, "rgb_{}.jpg".format(h)) env.engine.screenshot() - env.engine.accept("m", get_image, extraArgs=[env]) import cv2 diff --git a/metadrive/tests/vis_functionality/vis_semantic_cam.py b/metadrive/tests/vis_functionality/vis_semantic_cam.py index e544e32ed..7487a81f9 100644 --- a/metadrive/tests/vis_functionality/vis_semantic_cam.py +++ b/metadrive/tests/vis_functionality/vis_semantic_cam.py @@ -41,9 +41,9 @@ def get_image(env): cv2.waitKey(1) # if env.config["use_render"]: - # for i in range(ImageObservation.STACK_SIZE): - # ObservationType.show_gray_scale_array(o["image"][:, :, i]) - # env.render() + # for i in range(ImageObservation.STACK_SIZE): + # ObservationType.show_gray_scale_array(o["image"][:, :, i]) + # env.render() # if tm or tc: # # print("Reset") # env.reset()