diff --git a/examples/ppo-cartpole/config.yaml b/examples/ppo-cartpole/config.yaml index df8d5ee..bd9afa1 100644 --- a/examples/ppo-cartpole/config.yaml +++ b/examples/ppo-cartpole/config.yaml @@ -1,8 +1,8 @@ -env_name: "CartPole" +env_name: "gymnax:CartPole-v1" agent_name: "PPO" train_iterations: 1500 # Comment out the seed if you want each run to be different. -seed: 1035704761 +seed: 88476718 agent_config: rollout_fragment_length: 128 minibatch_size: 32 diff --git a/src/dopamax/_scripts/cli.py b/src/dopamax/_scripts/cli.py index 606d61d..113c77e 100644 --- a/src/dopamax/_scripts/cli.py +++ b/src/dopamax/_scripts/cli.py @@ -71,7 +71,9 @@ def train(config, offline): with open(params_path, "wb") as f: pickle.dump(params, f) - params_artifact = wandb.Artifact(config.env_name + "-" + config.agent_name + "-agent", type="agent") + params_artifact = wandb.Artifact( + config.env_name.replace(":", "-") + "-" + config.agent_name + "-agent", type="agent" + ) params_artifact.add_file(params_path) run.log_artifact(params_artifact) diff --git a/src/dopamax/environments/__init__.py b/src/dopamax/environments/__init__.py index 8ccb931..17e57cb 100644 --- a/src/dopamax/environments/__init__.py +++ b/src/dopamax/environments/__init__.py @@ -1,8 +1 @@ -from .brax.ant import Ant -from .brax.half_cheetah import HalfCheetah -from .brax.inverted_pendulum import InvertedPendulum -from .cartpole import CartPole -from .mountain_car import MountainCar -from .mountain_car_continuous import MountainCarContinuous -from .pgx import ConnectFour, TicTacToe, Go9x9, Go19x19 from .utils import make_env diff --git a/src/dopamax/environments/cartpole.py b/src/dopamax/environments/cartpole.py deleted file mode 100644 index 50b0dcd..0000000 --- a/src/dopamax/environments/cartpole.py +++ /dev/null @@ -1,223 +0,0 @@ -from typing import Tuple, Optional - -import jax -import jax.numpy as jnp -import numpy as np -from chex import dataclass, PRNGKey -from dm_env import StepType - -from dopamax.environments.environment import EnvState, Environment, TimeStep -from dopamax.environments.utils import register -from dopamax.spaces import Space, Box, Discrete -from dopamax.typing import Action, Observation - -_NAME = "CartPole" - - -@dataclass(frozen=True) -class CartPoleEnvState(EnvState): - episode_reward: float - episode_length: float - x: float - x_dot: float - theta: float - theta_dot: float - time: int - - def to_obs(self) -> Observation: - return jnp.array([self.x, self.x_dot, self.theta, self.theta_dot]) - - -@register(_NAME) -@dataclass(frozen=True) -class CartPole(Environment): - """The CartPole environment, as defined by Barto, Sutton, and Anderson. - - References: - This implementation is adapted from: - https://github.com/RobertTLange/gymnax/blob/main/gymnax/environments/classic_control/cartpole.py - """ - - gravity: float = 9.8 - masscart: float = 1.0 - masspole: float = 0.1 - total_mass: float = 1.0 + 0.1 - length: float = 0.5 - polemass_length: float = 0.05 - force_mag: float = 10.0 - tau: float = 0.02 - theta_threshold_radians: float = 12 * 2 * jnp.pi / 360 - x_threshold: float = 2.4 - - @property - def name(self) -> str: - return _NAME - - @property - def max_episode_length(self) -> int: - return 500 - - @property - def observation_space(self) -> Space: - high = jnp.array( - [ - self.x_threshold * 2, - jnp.finfo(jnp.float32).max, - self.theta_threshold_radians * 2, - jnp.finfo(jnp.float32).max, - ] - ) - return Box(low=-high, high=high, shape=(4,)) - - @property - def action_space(self) -> Space: - return Discrete(2) - - @property - def renderable(self) -> bool: - return True - - @property - def fps(self) -> Optional[int]: - return 30 - - @property - def render_shape(self) -> Optional[Tuple[int, int, int]]: - return 400, 600, 3 - - def _is_terminal(self, state: CartPoleEnvState) -> Tuple[bool, bool]: - done1 = jnp.logical_or( - state.x < -self.x_threshold, - state.x > self.x_threshold, - ) - - done2 = jnp.logical_or( - state.theta < -self.theta_threshold_radians, - state.theta > self.theta_threshold_radians, - ) - - truncate = state.time >= self.max_episode_length - done = jnp.logical_or(jnp.logical_or(done1, done2), truncate) - - return done, truncate - - def reset(self, key: PRNGKey) -> Tuple[TimeStep, CartPoleEnvState]: - x, x_dot, theta, theta_dot = jax.random.uniform(key, minval=-0.05, maxval=0.05, shape=(4,)) - - state = CartPoleEnvState( - episode_reward=0.0, - episode_length=0, - x=x, - x_dot=x_dot, - theta=theta, - theta_dot=theta_dot, - time=0, - ) - time_step = TimeStep.restart(state.to_obs()) - - return time_step, state - - def step(self, key: PRNGKey, state: CartPoleEnvState, action: Action) -> Tuple[TimeStep, CartPoleEnvState]: - prev_terminal, _ = self._is_terminal(state) - - force = self.force_mag * action - self.force_mag * (1 - action) - costheta = jnp.cos(state.theta) - sintheta = jnp.sin(state.theta) - - temp = (force + self.polemass_length * state.theta_dot**2 * sintheta) / self.total_mass - thetaacc = (self.gravity * sintheta - costheta * temp) / ( - self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass) - ) - xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass - - x = state.x + self.tau * state.x_dot - x_dot = state.x_dot + self.tau * xacc - theta = state.theta + self.tau * state.theta_dot - theta_dot = state.theta_dot + self.tau * thetaacc - - reward = 1.0 - prev_terminal - length = 1 - prev_terminal - - state = CartPoleEnvState( - episode_reward=state.episode_reward + reward, - episode_length=state.episode_length + length, - x=x, - x_dot=x_dot, - theta=theta, - theta_dot=theta_dot, - time=state.time + 1, - ) - done, truncate = self._is_terminal(state) - - time_step = TimeStep( - observation=state.to_obs(), - reward=reward, - discount=1.0 - jnp.float32(done & ~truncate), - step_type=jax.lax.select(done, StepType.LAST, StepType.MID), - ) - - return time_step, state - - def render(self, state: CartPoleEnvState) -> np.ndarray: - import pygame - from pygame import gfxdraw - - screen_width, screen_height = 600, 400 - - screen = pygame.Surface((screen_width, screen_height)) - - world_width = self.x_threshold * 2 - scale = screen_width / world_width - polewidth = 10.0 - polelen = scale * (2 * self.length) - cartwidth = 50.0 - cartheight = 30.0 - - surf = pygame.Surface((screen_width, screen_height)) - surf.fill((255, 255, 255)) - - l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 - axleoffset = cartheight / 4.0 - cartx = state.x * scale + screen_width / 2.0 # MIDDLE OF CART - carty = 100 # TOP OF CART - cart_coords = [(l, b), (l, t), (r, t), (r, b)] - cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords] - gfxdraw.aapolygon(surf, cart_coords, (0, 0, 0)) - gfxdraw.filled_polygon(surf, cart_coords, (0, 0, 0)) - - l, r, t, b = ( - -polewidth / 2, - polewidth / 2, - polelen - polewidth / 2, - -polewidth / 2, - ) - - pole_coords = [] - for coord in [(l, b), (l, t), (r, t), (r, b)]: - coord = pygame.math.Vector2(coord).rotate_rad(-state.theta) - coord = (coord[0] + cartx, coord[1] + carty + axleoffset) - pole_coords.append(coord) - gfxdraw.aapolygon(surf, pole_coords, (202, 152, 101)) - gfxdraw.filled_polygon(surf, pole_coords, (202, 152, 101)) - - gfxdraw.aacircle( - surf, - int(cartx), - int(carty + axleoffset), - int(polewidth / 2), - (129, 132, 203), - ) - gfxdraw.filled_circle( - surf, - int(cartx), - int(carty + axleoffset), - int(polewidth / 2), - (129, 132, 203), - ) - - gfxdraw.hline(surf, 0, screen_width, carty, (0, 0, 0)) - - surf = pygame.transform.flip(surf, False, True) - screen.blit(surf, (0, 0)) - - return np.transpose(np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)) diff --git a/src/dopamax/environments/gymnax.py b/src/dopamax/environments/gymnax.py new file mode 100644 index 0000000..b8623a6 --- /dev/null +++ b/src/dopamax/environments/gymnax.py @@ -0,0 +1,76 @@ +from typing import Tuple + +import gymnax +import jax +from chex import PRNGKey, dataclass +from dm_env import StepType +from gymnax.environments.spaces import Space +from dopamax import spaces +from dopamax.environments.environment import Environment, EnvState, TimeStep +from dopamax.typing import Action + + +def _convert_space(space: gymnax.environments.spaces.Space) -> spaces.Space: + if isinstance(space, gymnax.environments.spaces.Box): + return spaces.Box(low=space.low, high=space.high, shape=space.shape, dtype=space.dtype) + + if isinstance(space, gymnax.environments.spaces.Discrete): + return spaces.Discrete(space.n, dtype=space.dtype) + + if isinstance(space, gymnax.environments.spaces.Dict): + return spaces.Dict(spaces=space.spaces) + + raise ValueError(f"Unknown space: {space}") + + +@dataclass(frozen=True) +class GymnaxEnvState(EnvState): + gymnax_state: gymnax.environments.EnvState + + +@dataclass(frozen=True) +class GymnaxEnvironment(Environment): + env: gymnax.environments.environment.Environment + env_params: gymnax.environments.environment.EnvParams + + @property + def name(self) -> str: + return self.env.name + + @property + def max_episode_length(self) -> int: + return self.env_params.max_steps_in_episode + + @property + def observation_space(self) -> Space: + return _convert_space(self.env.observation_space(self.env_params)) + + @property + def action_space(self) -> Space: + return _convert_space(self.env.action_space(self.env_params)) + + def reset(self, key: PRNGKey) -> Tuple[TimeStep, GymnaxEnvState]: + obs, gymnax_state = self.env.reset(key, self.env_params) + state = GymnaxEnvState(episode_length=0, episode_reward=0.0, gymnax_state=gymnax_state) + ts = TimeStep.restart(self.env.get_obs(gymnax_state)) + return ts, state + + def step(self, key: PRNGKey, state: GymnaxEnvState, action: Action) -> Tuple[TimeStep, GymnaxEnvState]: + obs, gymnax_state, reward, done, info = self.env.step(key, state.gymnax_state, action, self.env_params) + + done = jax.numpy.bool_(done) + + ts = TimeStep( + observation=self.env.get_obs(gymnax_state), + reward=reward, + discount=info["discount"], + step_type=jax.lax.select(done, StepType.LAST, StepType.MID), + ) + + new_state = GymnaxEnvState( + episode_reward=state.episode_reward + reward, + episode_length=state.episode_length + 1, + gymnax_state=gymnax_state, + ) + + return ts, new_state diff --git a/src/dopamax/environments/mountain_car.py b/src/dopamax/environments/mountain_car.py deleted file mode 100644 index cad0528..0000000 --- a/src/dopamax/environments/mountain_car.py +++ /dev/null @@ -1,201 +0,0 @@ -import math -from typing import Tuple, Optional - -import jax -import jax.numpy as jnp -import numpy as np -from chex import dataclass, PRNGKey -from dm_env import StepType - -from dopamax.environments.environment import EnvState, Environment, TimeStep -from dopamax.environments.utils import register -from dopamax.spaces import Space, Discrete, Box -from dopamax.typing import Observation, Action - -_NAME = "MountainCar" - - -@dataclass(frozen=True) -class MountainCarEnvState(EnvState): - episode_reward: float - episode_length: float - position: float - velocity: float - time: int - - def to_obs(self) -> Observation: - return jnp.array([self.position, self.velocity]) - - -@register(_NAME) -@dataclass(frozen=True) -class MountainCar(Environment): - """The MountainCar environment. - - References: - This implementation is adapted from: - https://github.com/RobertTLange/gymnax/blob/main/gymnax/environments/classic_control/mountain_car.py - """ - - min_position: float = -1.2 - max_position: float = 0.6 - max_speed: float = 0.07 - goal_position: float = 0.5 - goal_velocity: float = 0.0 - force: float = 0.001 - gravity: float = 0.0025 - - @property - def name(self) -> str: - return _NAME - - @property - def max_episode_length(self) -> int: - return 200 - - @property - def observation_space(self) -> Space: - low = jnp.array([self.min_position, -self.max_speed], dtype=jnp.float32) - high = jnp.array([self.max_position, self.max_speed], dtype=jnp.float32) - return Box(low=low, high=high, shape=(2,)) - - @property - def action_space(self) -> Space: - return Discrete(3) - - @property - def renderable(self) -> bool: - return True - - @property - def fps(self) -> Optional[int]: - return 30 - - @property - def render_shape(self) -> Optional[Tuple[int, int, int]]: - return 400, 600, 3 - - def _is_terminal(self, state: MountainCarEnvState) -> Tuple[bool, bool]: - done = jnp.logical_and(state.position >= self.goal_position, state.velocity >= self.goal_velocity) - truncate = state.time >= self.max_episode_length - done = jnp.logical_or(done, truncate) - - return done, truncate - - def reset(self, key: PRNGKey) -> Tuple[TimeStep, EnvState]: - position = jax.random.uniform(key, shape=(), minval=-0.6, maxval=-0.4) - - state = MountainCarEnvState( - episode_reward=0.0, - episode_length=0, - position=position, - velocity=0.0, - time=0, - ) - time_step = TimeStep.restart(state.to_obs()) - - return time_step, state - - def step(self, key: PRNGKey, state: MountainCarEnvState, action: Action) -> Tuple[TimeStep, MountainCarEnvState]: - prev_terminal, _ = self._is_terminal(state) - - velocity = state.velocity + ((action - 1) * self.force + jnp.cos(3 * state.position) * -self.gravity) - velocity = jnp.clip(velocity, -self.max_speed, self.max_speed) - position = state.position + velocity - position = jnp.clip(position, self.min_position, self.max_position) - velocity = velocity * (1 - (position == self.min_position) * (velocity < 0)) - - reward = -1.0 + prev_terminal - length = 1 - prev_terminal - - state = MountainCarEnvState( - episode_reward=state.episode_reward + reward, - episode_length=state.episode_length + length, - position=position, - velocity=velocity, - time=state.time + 1, - ) - done, truncate = self._is_terminal(state) - - time_step = TimeStep( - observation=state.to_obs(), - reward=reward, - discount=1.0 - jnp.float32(done & ~truncate), - step_type=jax.lax.select(done, StepType.LAST, StepType.MID), - ) - - return time_step, state - - def _height(self, xs): - return np.sin(3 * xs) * 0.45 + 0.55 - - def render(self, state: MountainCarEnvState) -> np.ndarray: - import pygame - from pygame import gfxdraw - - screen_width, screen_height = 600, 400 - - screen = pygame.Surface((screen_width, screen_height)) - - world_width = self.max_position - self.min_position - scale = screen_width / world_width - carwidth = 40 - carheight = 20 - - surf = pygame.Surface((screen_width, screen_height)) - surf.fill((255, 255, 255)) - - pos = state.position - - xs = np.linspace(self.min_position, self.max_position, 100) - ys = self._height(xs) - xys = list(zip((xs - self.min_position) * scale, ys * scale)) - - pygame.draw.aalines(surf, points=xys, closed=False, color=(0, 0, 0)) - - clearance = 10 - - l, r, t, b = -carwidth / 2, carwidth / 2, carheight, 0 - coords = [] - for c in [(l, b), (l, t), (r, t), (r, b)]: - c = pygame.math.Vector2(c).rotate_rad(math.cos(3 * pos)) - coords.append( - ( - c[0] + (pos - self.min_position) * scale, - c[1] + clearance + self._height(pos) * scale, - ) - ) - - gfxdraw.aapolygon(surf, coords, (0, 0, 0)) - gfxdraw.filled_polygon(surf, coords, (0, 0, 0)) - - for c in [(carwidth / 4, 0), (-carwidth / 4, 0)]: - c = pygame.math.Vector2(c).rotate_rad(math.cos(3 * pos)) - wheel = ( - int(c[0] + (pos - self.min_position) * scale), - int(c[1] + clearance + self._height(pos) * scale), - ) - - gfxdraw.aacircle(surf, wheel[0], wheel[1], int(carheight / 2.5), (128, 128, 128)) - gfxdraw.filled_circle(surf, wheel[0], wheel[1], int(carheight / 2.5), (128, 128, 128)) - - flagx = int((self.goal_position - self.min_position) * scale) - flagy1 = int(self._height(self.goal_position) * scale) - flagy2 = flagy1 + 50 - gfxdraw.vline(surf, flagx, flagy1, flagy2, (0, 0, 0)) - - gfxdraw.aapolygon( - surf, - [(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)], - (204, 204, 0), - ) - gfxdraw.filled_polygon( - surf, - [(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)], - (204, 204, 0), - ) - - surf = pygame.transform.flip(surf, False, True) - screen.blit(surf, (0, 0)) - - return np.transpose(np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)) diff --git a/src/dopamax/environments/mountain_car_continuous.py b/src/dopamax/environments/mountain_car_continuous.py deleted file mode 100644 index e8764ec..0000000 --- a/src/dopamax/environments/mountain_car_continuous.py +++ /dev/null @@ -1,201 +0,0 @@ -import math -from typing import Tuple, Optional - -import jax -import jax.numpy as jnp -import numpy as np -from chex import dataclass, PRNGKey -from dm_env import StepType - -from dopamax.environments.environment import Environment, TimeStep, EnvState -from dopamax.environments.mountain_car import MountainCarEnvState -from dopamax.environments.utils import register -from dopamax.spaces import Space, Box -from dopamax.typing import Action - -_NAME = "MountainCarContinuous" - -MountainCarContinuousEnvState = MountainCarEnvState - - -@register(_NAME) -@dataclass(frozen=True) -class MountainCarContinuous(Environment): - """The continuous version of the MountainCar environment. - - References: - This implementation is adapted from: - https://github.com/RobertTLange/gymnax/blob/main/gymnax/environments/classic_control/continuous_mountain_car.py - """ - - min_action: float = -1.0 - max_action: float = 1.0 - min_position: float = -1.2 - max_position: float = 0.6 - max_speed: float = 0.07 - goal_position: float = 0.45 - goal_velocity: float = 0.0 - power: float = 0.0015 - gravity: float = 0.0025 - - @property - def name(self) -> str: - return _NAME - - @property - def max_episode_length(self) -> int: - return 999 - - @property - def observation_space(self) -> Space: - low = jnp.array([self.min_position, -self.max_speed], dtype=jnp.float32) - high = jnp.array([self.max_position, self.max_speed], dtype=jnp.float32) - return Box(low=low, high=high, shape=(2,)) - - @property - def action_space(self) -> Space: - return Box(low=self.min_action, high=self.max_action, shape=(1,)) - - @property - def renderable(self) -> bool: - return True - - @property - def fps(self) -> Optional[int]: - return 30 - - @property - def render_shape(self) -> Optional[Tuple[int, int, int]]: - return 400, 600, 3 - - def _is_terminal(self, state: MountainCarContinuousEnvState) -> Tuple[bool, bool]: - done = jnp.logical_and(state.position >= self.goal_position, state.velocity >= self.goal_velocity) - truncate = state.time >= self.max_episode_length - done = jnp.logical_or(done, truncate) - - return done, truncate - - def reset(self, key: PRNGKey) -> Tuple[TimeStep, EnvState]: - position = jax.random.uniform(key, shape=(), minval=-0.6, maxval=-0.4) - - state = MountainCarEnvState( - episode_reward=0.0, - episode_length=0, - position=position, - velocity=0.0, - time=0, - ) - time_step = TimeStep.restart(state.to_obs()) - - return time_step, state - - def step( - self, key: PRNGKey, state: MountainCarContinuousEnvState, action: Action - ) -> Tuple[TimeStep, MountainCarContinuousEnvState]: - prev_terminal, _ = self._is_terminal(state) - - action = jnp.squeeze(action, 0) - - force = jnp.clip(action, self.min_action, self.max_action) - velocity = state.velocity + (force * self.power - jnp.cos(3 * state.position) * self.gravity) - velocity = jnp.clip(velocity, -self.max_speed, self.max_speed) - position = state.position + velocity - position = jnp.clip(position, self.min_position, self.max_position) - velocity = velocity * (1 - (position >= self.goal_position) * (velocity < 0)) - - reward = -0.1 * action**2 + 100 * ((position >= self.goal_position) * (velocity >= self.goal_velocity)) - reward *= ~prev_terminal - - length = 1 - prev_terminal - - state = MountainCarEnvState( - episode_reward=state.episode_reward + reward, - episode_length=state.episode_length + length, - position=position, - velocity=velocity, - time=state.time + 1, - ) - done, truncate = self._is_terminal(state) - - time_step = TimeStep( - observation=state.to_obs(), - reward=reward, - discount=1.0 - jnp.float32(done & ~truncate), - step_type=jax.lax.select(done, StepType.LAST, StepType.MID), - ) - - return time_step, state - - def _height(self, xs): - return np.sin(3 * xs) * 0.45 + 0.55 - - def render(self, state: MountainCarContinuousEnvState) -> np.ndarray: - import pygame - from pygame import gfxdraw - - screen_width, screen_height = 600, 400 - - screen = pygame.Surface((screen_width, screen_height)) - - world_width = self.max_position - self.min_position - scale = screen_width / world_width - carwidth = 40 - carheight = 20 - - surf = pygame.Surface((screen_width, screen_height)) - surf.fill((255, 255, 255)) - - pos = state.position - - xs = np.linspace(self.min_position, self.max_position, 100) - ys = self._height(xs) - xys = list(zip((xs - self.min_position) * scale, ys * scale)) - - pygame.draw.aalines(surf, points=xys, closed=False, color=(0, 0, 0)) - - clearance = 10 - - l, r, t, b = -carwidth / 2, carwidth / 2, carheight, 0 - coords = [] - for c in [(l, b), (l, t), (r, t), (r, b)]: - c = pygame.math.Vector2(c).rotate_rad(math.cos(3 * pos)) - coords.append( - ( - c[0] + (pos - self.min_position) * scale, - c[1] + clearance + self._height(pos) * scale, - ) - ) - - gfxdraw.aapolygon(surf, coords, (0, 0, 0)) - gfxdraw.filled_polygon(surf, coords, (0, 0, 0)) - - for c in [(carwidth / 4, 0), (-carwidth / 4, 0)]: - c = pygame.math.Vector2(c).rotate_rad(math.cos(3 * pos)) - wheel = ( - int(c[0] + (pos - self.min_position) * scale), - int(c[1] + clearance + self._height(pos) * scale), - ) - - gfxdraw.aacircle(surf, wheel[0], wheel[1], int(carheight / 2.5), (128, 128, 128)) - gfxdraw.filled_circle(surf, wheel[0], wheel[1], int(carheight / 2.5), (128, 128, 128)) - - flagx = int((self.goal_position - self.min_position) * scale) - flagy1 = int(self._height(self.goal_position) * scale) - flagy2 = flagy1 + 50 - gfxdraw.vline(surf, flagx, flagy1, flagy2, (0, 0, 0)) - - gfxdraw.aapolygon( - surf, - [(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)], - (204, 204, 0), - ) - gfxdraw.filled_polygon( - surf, - [(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)], - (204, 204, 0), - ) - - surf = pygame.transform.flip(surf, False, True) - screen.blit(surf, (0, 0)) - - return np.transpose(np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)) diff --git a/src/dopamax/environments/utils.py b/src/dopamax/environments/utils.py index efe340c..6da806e 100644 --- a/src/dopamax/environments/utils.py +++ b/src/dopamax/environments/utils.py @@ -1,4 +1,5 @@ from dopamax.environments.environment import Environment +from dopamax.environments.gymnax import GymnaxEnvironment _registry = {} @@ -23,5 +24,14 @@ def make_env(env_name: str, **kwargs) -> Environment: Returns: The environment. """ + if env_name.startswith("gymnax:"): + try: + import gymnax + except ImportError: + raise ImportError("Unable to import gymnax. Please install gymnax (e.g. via 'pip install gymnax').") + + env, env_params = gymnax.make(env_name[7:], **kwargs) + return GymnaxEnvironment(env=env, env_params=env_params) + env_cls = _registry[env_name] return env_cls(**kwargs) diff --git a/tests/environments/__init__.py b/tests/environments/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/environments/test_cartpole.py b/tests/environments/test_cartpole.py deleted file mode 100644 index 2c4ddfc..0000000 --- a/tests/environments/test_cartpole.py +++ /dev/null @@ -1,53 +0,0 @@ -import chex -import gymnasium as gym -import jax -import numpy as np -import pytest -from dm_env import StepType - -from dopamax.environments.cartpole import CartPole - - -def test_cartpole(): - key = jax.random.PRNGKey(0) - - jax_env = CartPole() - gym_env = gym.make("CartPole-v1", render_mode="rgb_array") - - time_step, state = jax_env.reset(key) - chex.assert_trees_all_equal(time_step.observation, state.to_obs()) - - gym_env.reset() - gym_env.unwrapped.state = (state.x, state.x_dot, state.theta, state.theta_dot) - - assert gym_env.unwrapped.state[0] == pytest.approx(state.x, rel=0.001) - assert gym_env.unwrapped.state[1] == pytest.approx(state.x_dot, rel=0.001) - assert gym_env.unwrapped.state[2] == pytest.approx(state.theta, rel=0.001) - assert gym_env.unwrapped.state[3] == pytest.approx(state.theta_dot, rel=0.001) - - for _ in range(500): - gym_render = gym_env.render() - jax_render = jax_env.render(state) - - chex.assert_trees_all_equal(gym_render, jax_render) - chex.assert_shape((gym_render, jax_render), jax_env.render_shape) - - action = jax_env.action_space.sample(key) - time_step, state = jax_env.step(key, state, action) - - gym_obs, gym_reward, gym_terminated, gym_truncated, _ = gym_env.step(np.asarray(action)) - - assert gym_reward == time_step.reward - assert (gym_terminated or gym_truncated) == bool(time_step.step_type == StepType.LAST) - assert gym_env.unwrapped.state[0] == pytest.approx(state.x, rel=0.001) - assert gym_env.unwrapped.state[1] == pytest.approx(state.x_dot, rel=0.001) - assert gym_env.unwrapped.state[2] == pytest.approx(state.theta, rel=0.001) - assert gym_env.unwrapped.state[3] == pytest.approx(state.theta_dot, rel=0.001) - chex.assert_trees_all_close(gym_obs, time_step.observation, rtol=0.001) - - key, _ = jax.random.split(key) - - if gym_terminated or gym_truncated: - time_step, state = jax_env.reset(key) - gym_env.reset() - gym_env.unwrapped.state = (state.x, state.x_dot, state.theta, state.theta_dot) diff --git a/tests/environments/test_mountain_car.py b/tests/environments/test_mountain_car.py deleted file mode 100644 index 068cbbc..0000000 --- a/tests/environments/test_mountain_car.py +++ /dev/null @@ -1,49 +0,0 @@ -import chex -import gymnasium as gym -import jax -import numpy as np -import pytest -from dm_env import StepType - -from dopamax.environments.mountain_car import MountainCar - - -def test_mountain_car(): - key = jax.random.PRNGKey(0) - - jax_env = MountainCar() - gym_env = gym.make("MountainCar-v0", render_mode="rgb_array") - - time_step, state = jax_env.reset(key) - chex.assert_trees_all_equal(time_step.observation, state.to_obs()) - - gym_env.reset() - gym_env.unwrapped.state = (state.position, state.velocity) - - assert gym_env.unwrapped.state[0] == pytest.approx(state.position, rel=0.001) - assert gym_env.unwrapped.state[1] == pytest.approx(state.velocity, rel=0.001) - - for _ in range(300): - gym_render = gym_env.render() - jax_render = jax_env.render(state) - - chex.assert_trees_all_equal(gym_render, jax_render) - chex.assert_shape((gym_render, jax_render), jax_env.render_shape) - - action = jax_env.action_space.sample(key) - time_step, state = jax_env.step(key, state, action) - - gym_obs, gym_reward, gym_terminated, gym_truncated, _ = gym_env.step(np.asarray(action)) - - assert gym_reward == time_step.reward - assert (gym_terminated or gym_truncated) == bool(time_step.step_type == StepType.LAST) - assert gym_env.unwrapped.state[0] == pytest.approx(state.position, rel=0.001) - assert gym_env.unwrapped.state[1] == pytest.approx(state.velocity, rel=0.001) - chex.assert_trees_all_close(gym_obs, time_step.observation, rtol=0.001) - - key, _ = jax.random.split(key) - - if gym_terminated or gym_truncated: - time_step, state = jax_env.reset(key) - gym_env.reset() - gym_env.unwrapped.state = (state.position, state.velocity) diff --git a/tests/environments/test_mountain_car_continuous.py b/tests/environments/test_mountain_car_continuous.py deleted file mode 100644 index 4a3475b..0000000 --- a/tests/environments/test_mountain_car_continuous.py +++ /dev/null @@ -1,49 +0,0 @@ -import chex -import gymnasium as gym -import jax -import numpy as np -import pytest -from dm_env import StepType - -from dopamax.environments.mountain_car_continuous import MountainCarContinuous - - -def test_mountain_car_continuous(): - key = jax.random.PRNGKey(0) - - jax_env = MountainCarContinuous() - gym_env = gym.make("MountainCarContinuous-v0", render_mode="rgb_array") - - time_step, state = jax_env.reset(key) - chex.assert_trees_all_equal(time_step.observation, state.to_obs()) - - gym_env.reset() - gym_env.unwrapped.state = (state.position, state.velocity) - - assert gym_env.unwrapped.state[0] == pytest.approx(state.position, rel=0.001) - assert gym_env.unwrapped.state[1] == pytest.approx(state.velocity, rel=0.001) - - for _ in range(300): - gym_render = gym_env.render() - jax_render = jax_env.render(state) - - chex.assert_trees_all_equal(gym_render, jax_render) - chex.assert_shape((gym_render, jax_render), jax_env.render_shape) - - action = jax_env.action_space.sample(key) - time_step, state = jax_env.step(key, state, action) - - gym_obs, gym_reward, gym_terminated, gym_truncated, _ = gym_env.step(np.array(action)) - - assert gym_reward == pytest.approx(time_step.reward, rel=0.0001) - assert (gym_terminated or gym_truncated) == bool(time_step.step_type == StepType.LAST) - assert gym_env.unwrapped.state[0] == pytest.approx(state.position, rel=0.001) - assert gym_env.unwrapped.state[1] == pytest.approx(state.velocity, rel=0.001) - chex.assert_trees_all_close(gym_obs, time_step.observation, rtol=0.001) - - key, _ = jax.random.split(key) - - if gym_terminated or gym_truncated: - time_step, state = jax_env.reset(key) - gym_env.reset() - gym_env.unwrapped.state = (state.position, state.velocity)