Skip to content

Commit

Permalink
Add auto wrapping of gymnax environments
Browse files Browse the repository at this point in the history
  • Loading branch information
rystrauss committed Dec 15, 2024
1 parent cc41541 commit a7a2b73
Show file tree
Hide file tree
Showing 12 changed files with 91 additions and 786 deletions.
4 changes: 2 additions & 2 deletions examples/ppo-cartpole/config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
env_name: "CartPole"
env_name: "gymnax:CartPole-v1"
agent_name: "PPO"
train_iterations: 1500
# Comment out the seed if you want each run to be different.
seed: 1035704761
seed: 88476718
agent_config:
rollout_fragment_length: 128
minibatch_size: 32
Expand Down
4 changes: 3 additions & 1 deletion src/dopamax/_scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def train(config, offline):
with open(params_path, "wb") as f:
pickle.dump(params, f)

params_artifact = wandb.Artifact(config.env_name + "-" + config.agent_name + "-agent", type="agent")
params_artifact = wandb.Artifact(
config.env_name.replace(":", "-") + "-" + config.agent_name + "-agent", type="agent"
)
params_artifact.add_file(params_path)
run.log_artifact(params_artifact)

Expand Down
7 changes: 0 additions & 7 deletions src/dopamax/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1 @@
from .brax.ant import Ant
from .brax.half_cheetah import HalfCheetah
from .brax.inverted_pendulum import InvertedPendulum
from .cartpole import CartPole
from .mountain_car import MountainCar
from .mountain_car_continuous import MountainCarContinuous
from .pgx import ConnectFour, TicTacToe, Go9x9, Go19x19
from .utils import make_env
223 changes: 0 additions & 223 deletions src/dopamax/environments/cartpole.py

This file was deleted.

76 changes: 76 additions & 0 deletions src/dopamax/environments/gymnax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Tuple

import gymnax
import jax
from chex import PRNGKey, dataclass
from dm_env import StepType
from gymnax.environments.spaces import Space
from dopamax import spaces
from dopamax.environments.environment import Environment, EnvState, TimeStep
from dopamax.typing import Action


def _convert_space(space: gymnax.environments.spaces.Space) -> spaces.Space:
if isinstance(space, gymnax.environments.spaces.Box):
return spaces.Box(low=space.low, high=space.high, shape=space.shape, dtype=space.dtype)

if isinstance(space, gymnax.environments.spaces.Discrete):
return spaces.Discrete(space.n, dtype=space.dtype)

if isinstance(space, gymnax.environments.spaces.Dict):
return spaces.Dict(spaces=space.spaces)

raise ValueError(f"Unknown space: {space}")


@dataclass(frozen=True)
class GymnaxEnvState(EnvState):
gymnax_state: gymnax.environments.EnvState


@dataclass(frozen=True)
class GymnaxEnvironment(Environment):
env: gymnax.environments.environment.Environment
env_params: gymnax.environments.environment.EnvParams

@property
def name(self) -> str:
return self.env.name

@property
def max_episode_length(self) -> int:
return self.env_params.max_steps_in_episode

@property
def observation_space(self) -> Space:
return _convert_space(self.env.observation_space(self.env_params))

@property
def action_space(self) -> Space:
return _convert_space(self.env.action_space(self.env_params))

def reset(self, key: PRNGKey) -> Tuple[TimeStep, GymnaxEnvState]:
obs, gymnax_state = self.env.reset(key, self.env_params)
state = GymnaxEnvState(episode_length=0, episode_reward=0.0, gymnax_state=gymnax_state)
ts = TimeStep.restart(self.env.get_obs(gymnax_state))
return ts, state

def step(self, key: PRNGKey, state: GymnaxEnvState, action: Action) -> Tuple[TimeStep, GymnaxEnvState]:
obs, gymnax_state, reward, done, info = self.env.step(key, state.gymnax_state, action, self.env_params)

done = jax.numpy.bool_(done)

ts = TimeStep(
observation=self.env.get_obs(gymnax_state),
reward=reward,
discount=info["discount"],
step_type=jax.lax.select(done, StepType.LAST, StepType.MID),
)

new_state = GymnaxEnvState(
episode_reward=state.episode_reward + reward,
episode_length=state.episode_length + 1,
gymnax_state=gymnax_state,
)

return ts, new_state
Loading

0 comments on commit a7a2b73

Please sign in to comment.