Skip to content

Commit

Permalink
Merge pull request #34 from Farama-Foundation/feature/pickle-all-envs
Browse files Browse the repository at this point in the history
Add EzPickle to all envs
  • Loading branch information
ffelten authored Feb 10, 2023
2 parents 9070a74 + 6d1b876 commit 6dc669c
Show file tree
Hide file tree
Showing 18 changed files with 97 additions and 23 deletions.
5 changes: 4 additions & 1 deletion mo_gymnasium/envs/breakable_bottles/breakable_bottles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import numpy as np
from gymnasium import Env
from gymnasium.spaces import Box, Dict, Discrete, MultiBinary
from gymnasium.utils import EzPickle


class BreakableBottles(Env):
class BreakableBottles(Env, EzPickle):
"""
## Description
This environment implements the problems UnbreakableBottles and BreakableBottles defined in Section 4.1.2 of the paper
Expand Down Expand Up @@ -64,6 +65,8 @@ def __init__(
bottle_reward=25,
unbreakable_bottles=False,
):
EzPickle.__init__(self, render_mode, size, prob_drop, time_penalty, bottle_reward, unbreakable_bottles)

self.render_mode = render_mode

# settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from gymnasium.envs.classic_control.continuous_mountain_car import (
Continuous_MountainCarEnv,
)
from gymnasium.utils import EzPickle


class MOContinuousMountainCar(Continuous_MountainCarEnv):
class MOContinuousMountainCar(Continuous_MountainCarEnv, EzPickle):
"""
A continuous version of the MountainCar environment, where the goal is to reach the top of the mountain.
Expand All @@ -22,6 +23,7 @@ class MOContinuousMountainCar(Continuous_MountainCarEnv):

def __init__(self, render_mode: Optional[str] = None, goal_velocity=0):
super().__init__(render_mode, goal_velocity)
EzPickle.__init__(self, render_mode, goal_velocity)

self.reward_space = spaces.Box(low=np.array([-1.0, -1.0]), high=np.array([0.0, 0.0]), shape=(2,), dtype=np.float32)

Expand Down
5 changes: 4 additions & 1 deletion mo_gymnasium/envs/deep_sea_treasure/deep_sea_treasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pygame
from gymnasium.spaces import Box, Discrete
from gymnasium.utils import EzPickle


# As in Yang et al. (2019):
Expand Down Expand Up @@ -42,7 +43,7 @@
)


class DeepSeaTreasure(gym.Env):
class DeepSeaTreasure(gym.Env, EzPickle):
"""
## Description
The Deep Sea Treasure environment is classic MORL problem in which the agent controls a submarine in a 2D grid world.
Expand Down Expand Up @@ -79,6 +80,8 @@ class DeepSeaTreasure(gym.Env):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float_state=False):
EzPickle.__init__(self, render_mode, dst_map, float_state)

self.render_mode = render_mode
self.size = 11
self.window_size = 512
Expand Down
5 changes: 4 additions & 1 deletion mo_gymnasium/envs/fishwood/fishwood.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from gymnasium.utils import EzPickle


class FishWood(gym.Env):
class FishWood(gym.Env, EzPickle):
"""
## Description
The FishWood environment is a simple MORL problem in which the agent controls a fisherman which can either fish or go collect wood.
Expand Down Expand Up @@ -46,6 +47,8 @@ class FishWood(gym.Env):
MAX_TS = 200

def __init__(self, render_mode: Optional[str] = None, fishproba=0.1, woodproba=0.9):
EzPickle.__init__(self, render_mode, fishproba, woodproba)

self.render_mode = render_mode
self._fishproba = fishproba
self._woodproba = woodproba
Expand Down
5 changes: 4 additions & 1 deletion mo_gymnasium/envs/four_room/four_room.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pygame
from gymnasium.spaces import Box, Discrete
from gymnasium.utils import EzPickle


MAZE = np.array(
Expand All @@ -30,7 +31,7 @@
BLACK = (0, 0, 0)


class FourRoom(gym.Env):
class FourRoom(gym.Env, EzPickle):
"""
## Description
A discretized version of the gridworld environment introduced in [1]. Here, an agent learns to
Expand Down Expand Up @@ -85,6 +86,8 @@ def __init__(self, render_mode: Optional[str] = None, maze=MAZE):
0, 1, .... 9 indicates the type of shape to be placed in the corresponding cell
entries containing other characters are treated as regular empty cells
"""
EzPickle.__init__(self, render_mode, maze)

self.render_mode = render_mode
self.window_size = 512
self.window = None
Expand Down
5 changes: 4 additions & 1 deletion mo_gymnasium/envs/fruit_tree/fruit_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from gymnasium.utils import EzPickle


FRUITS = {
Expand Down Expand Up @@ -238,7 +239,7 @@
}


class FruitTreeEnv(gym.Env):
class FruitTreeEnv(gym.Env, EzPickle):
"""
## Description
Expand All @@ -263,6 +264,8 @@ class FruitTreeEnv(gym.Env):

def __init__(self, depth=6):
assert depth in [5, 6, 7], "Depth must be 5, 6 or 7."
EzPickle.__init__(self, depth)

self.reward_dim = 6
self.tree_depth = depth # zero based depth
branches = np.zeros((int(2**self.tree_depth - 1), self.reward_dim))
Expand Down
5 changes: 4 additions & 1 deletion mo_gymnasium/envs/highway/highway.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
Text,
Tuple,
)
from gymnasium.utils import EzPickle
from highway_env.envs import HighwayEnv, HighwayEnvFast


class MOHighwayEnv(HighwayEnv):
class MOHighwayEnv(HighwayEnv, EzPickle):
"""
## Description
Multi-objective version of the HighwayEnv environment.
Expand All @@ -30,6 +31,8 @@ class MOHighwayEnv(HighwayEnv):
"""

def __init__(self, *args, **kwargs):
EzPickle.__init__(self, *args, **kwargs)

super().__init__(*args, **kwargs)
self.reward_space = Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)
self.observation_space = _convert_space(self.observation_space)
Expand Down
2 changes: 1 addition & 1 deletion mo_gymnasium/envs/lunar_lander/lunar_lander.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


class MOLunarLander(LunarLander):
class MOLunarLander(LunarLander): # no need for EzPickle, it's already in LunarLander
"""
## Description
Multi-objective version of the LunarLander environment.
Expand Down
5 changes: 3 additions & 2 deletions mo_gymnasium/envs/mario/mario.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from gym_super_mario_bros import SuperMarioBrosEnv
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gymnasium.utils import seeding
from gymnasium.utils import EzPickle, seeding

# from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
from gymnasium.wrappers import GrayScaleObservation, ResizeObservation
Expand All @@ -16,7 +16,7 @@
from mo_gymnasium.envs.mario.joypad_space import JoypadSpace


class MOSuperMarioBros(SuperMarioBrosEnv):
class MOSuperMarioBros(SuperMarioBrosEnv, EzPickle):
"""
## Description
Multi-objective version of the SuperMarioBro environment.
Expand Down Expand Up @@ -45,6 +45,7 @@ def __init__(
objectives=["x_pos", "time", "death", "coin", "enemy"],
render_mode: Optional[str] = None,
):
EzPickle.__init__(self, rom_mode, lost_levels, target, objectives, render_mode)
super().__init__(rom_mode, lost_levels, target)

self.render_mode = render_mode
Expand Down
5 changes: 4 additions & 1 deletion mo_gymnasium/envs/minecart/minecart.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pygame
import scipy.stats
from gymnasium.spaces import Box, Discrete
from gymnasium.utils import EzPickle
from scipy.spatial import ConvexHull


Expand Down Expand Up @@ -86,7 +87,7 @@
MINE_IMG = str(Path(__file__).parent.absolute()) + "/assets/mine.png"


class Minecart(gym.Env):
class Minecart(gym.Env, EzPickle):
"""
## Description
Agent must collect two types of ores and minimize fuel consumption.
Expand Down Expand Up @@ -133,6 +134,8 @@ def __init__(
image_observation=False,
config=str(Path(__file__).parent.absolute()) + "/mine_config.json",
):
EzPickle.__init__(self, render_mode, image_observation, config)

self.render_mode = render_mode
self.screen = None
self.last_render_mode_used = None
Expand Down
4 changes: 3 additions & 1 deletion mo_gymnasium/envs/mountain_car/mountain_car.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import numpy as np
from gymnasium import spaces
from gymnasium.envs.classic_control.mountain_car import MountainCarEnv
from gymnasium.utils import EzPickle


class MOMountainCar(MountainCarEnv):
class MOMountainCar(MountainCarEnv, EzPickle):
"""
A multi-objective version of the MountainCar environment, where the goal is to reach the top of the mountain.
Expand All @@ -21,6 +22,7 @@ class MOMountainCar(MountainCarEnv):

def __init__(self, render_mode: Optional[str] = None, goal_velocity=0):
super().__init__(render_mode, goal_velocity)
EzPickle.__init__(self, render_mode, goal_velocity)

self.reward_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)

Expand Down
4 changes: 3 additions & 1 deletion mo_gymnasium/envs/mujoco/half_cheetah.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
from gymnasium.envs.mujoco.half_cheetah_v4 import HalfCheetahEnv
from gymnasium.spaces import Box
from gymnasium.utils import EzPickle


class MOHalfCheehtahEnv(HalfCheetahEnv):
class MOHalfCheehtahEnv(HalfCheetahEnv, EzPickle):
"""
## Description
Multi-objective version of the HalfCheetahEnv environment.
Expand All @@ -18,6 +19,7 @@ class MOHalfCheehtahEnv(HalfCheetahEnv):

def __init__(self, **kwargs):
super().__init__(**kwargs)
EzPickle.__init__(self, **kwargs)
self.reward_space = Box(low=-np.inf, high=np.inf, shape=(2,))

def step(self, action):
Expand Down
4 changes: 3 additions & 1 deletion mo_gymnasium/envs/mujoco/hopper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
from gymnasium.envs.mujoco.hopper_v4 import HopperEnv
from gymnasium.spaces import Box
from gymnasium.utils import EzPickle


class MOHopperEnv(HopperEnv):
class MOHopperEnv(HopperEnv, EzPickle):
"""
## Description
Multi-objective version of the HopperEnv environment.
Expand All @@ -20,6 +21,7 @@ class MOHopperEnv(HopperEnv):

def __init__(self, cost_objective=True, **kwargs):
super().__init__(**kwargs)
EzPickle.__init__(self, cost_objective, **kwargs)
self.cost_objetive = cost_objective
self.rew_dim = 3 if cost_objective else 2
self.reward_space = Box(low=-np.inf, high=np.inf, shape=(self.rew_dim,))
Expand Down
4 changes: 3 additions & 1 deletion mo_gymnasium/envs/reacher/reacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from gymnasium import spaces
from gymnasium.utils import EzPickle
from pybulletgym.envs.roboschool.envs.env_bases import BaseBulletEnv
from pybulletgym.envs.roboschool.robots.robot_bases import MJCFBasedRobot
from pybulletgym.envs.roboschool.scenes.scene_bases import SingleRobotEmptyScene
Expand All @@ -10,7 +11,7 @@
target_positions = list(map(lambda l: np.array(l), [(0.14, 0.0), (-0.14, 0.0), (0.0, 0.14), (0.0, -0.14)]))


class ReacherBulletEnv(BaseBulletEnv):
class ReacherBulletEnv(BaseBulletEnv, EzPickle):

metadata = {"render_modes": ["human", "rgb_array"]}

Expand All @@ -20,6 +21,7 @@ def __init__(
target=(0.14, 0.0),
fixed_initial_state: Optional[tuple] = (3.14, 0),
):
EzPickle.__init__(self, render_mode, target, fixed_initial_state)
self.robot = ReacherRobot(target, fixed_initial_state=fixed_initial_state)
self.render_mode = render_mode
BaseBulletEnv.__init__(self, self.robot, render=render_mode == "human")
Expand Down
5 changes: 4 additions & 1 deletion mo_gymnasium/envs/resource_gathering/resource_gathering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import numpy as np
import pygame
from gymnasium.spaces import Box, Discrete
from gymnasium.utils import EzPickle


class ResourceGathering(gym.Env):
class ResourceGathering(gym.Env, EzPickle):
"""
## Description
From "Barrett, Leon & Narayanan, Srini. (2008). Learning all optimal policies with multiple criteria.
Expand Down Expand Up @@ -43,6 +44,8 @@ class ResourceGathering(gym.Env):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

def __init__(self, render_mode: Optional[str] = None):
EzPickle.__init__(self, render_mode)

self.render_mode = render_mode
self.size = 5
self.window_size = 512
Expand Down
5 changes: 4 additions & 1 deletion mo_gymnasium/envs/water_reservoir/dam_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import gymnasium as gym
import numpy as np
from gymnasium.spaces.box import Box
from gymnasium.utils import EzPickle


class DamEnv(gym.Env):
class DamEnv(gym.Env, EzPickle):
"""
## Description
A Water reservoir environment.
Expand Down Expand Up @@ -73,6 +74,8 @@ def __init__(
nO=2,
penalize: bool = False,
):
EzPickle.__init__(self, render_mode, time_limit, nO, penalize)

self.observation_space = Box(low=0.0, high=np.inf, shape=(1,), dtype=np.float32)
self.action_space = Box(low=0, high=np.inf, shape=(1,), dtype=np.float32)

Expand Down
Loading

0 comments on commit 6dc669c

Please sign in to comment.