Skip to content

Commit

Permalink
Adopt pgx for two player environments
Browse files Browse the repository at this point in the history
  • Loading branch information
rystrauss committed Dec 27, 2023
1 parent 30251dc commit 3e7dc59
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 168 deletions.
2 changes: 1 addition & 1 deletion dopamax/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
from .cartpole import CartPole
from .mountain_car import MountainCar
from .mountain_car_continuous import MountainCarContinuous
from .two_player.connect_four import ConnectFour
from .pgx import ConnectFourEnvironment
from .utils import make_env
1 change: 1 addition & 0 deletions dopamax/environments/pgx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .connect_four import ConnectFourEnvironment
80 changes: 80 additions & 0 deletions dopamax/environments/pgx/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from abc import ABC
from typing import Tuple, Optional

import jax
import jax.numpy as jnp
import pgx
from chex import dataclass, PRNGKey
from dm_env import StepType

from dopamax.environments.environment import Environment, EnvState, TimeStep
from dopamax.spaces import Box, Space, Discrete
from dopamax.typing import Action


@dataclass(frozen=True)
class PGXEnvState(EnvState):
episode_reward: float
episode_length: float
pgx_state: pgx.State


@dataclass(frozen=True)
class PGXEnvironment(Environment, ABC):
_pgx_env: pgx.Env

def reset(self, key: PRNGKey) -> Tuple[TimeStep, PGXEnvState]:
pgx_state = self._pgx_env.init(key)

time_step = TimeStep.restart(
pgx_state.observation.astype(jnp.float32),
{
"current_player": pgx_state.current_player,
"legal_action_mask": pgx_state.legal_action_mask,
},
)
env_state = PGXEnvState(
episode_reward=0.0,
episode_length=0,
pgx_state=pgx_state,
)

return time_step, env_state

@property
def observation_space(self) -> Space:
return Box(low=-jnp.inf, high=jnp.inf, shape=self._pgx_env.observation_shape, dtype=jnp.float32)

@property
def action_space(self) -> Space:
return Discrete(self._pgx_env.num_actions)

def step(self, key: PRNGKey, state: PGXEnvState, action: Action) -> Tuple[TimeStep, PGXEnvState]:
prev_terminal = jnp.squeeze(jnp.bool_(state.pgx_state.terminated | state.pgx_state.truncated))

new_pgx_state = self._pgx_env.step(state.pgx_state, action, key)

reward = jnp.squeeze(new_pgx_state.rewards[new_pgx_state.current_player])
length = 1 - prev_terminal

state = PGXEnvState(
episode_reward=state.episode_reward + reward,
episode_length=state.episode_length + length,
pgx_state=new_pgx_state,
)

done = jnp.squeeze(jnp.bool_(new_pgx_state.terminated | new_pgx_state.truncated))
truncate = jnp.squeeze(jnp.bool_(new_pgx_state.truncated))

time_step = TimeStep(
observation=new_pgx_state.observation.astype(jnp.float32),
reward=reward,
discount=1.0 - jnp.float32(done & ~truncate),
step_type=jax.lax.select(done, StepType.LAST, StepType.MID),
info={
"current_player": new_pgx_state.current_player,
"legal_action_mask": new_pgx_state.legal_action_mask,
},
)

return time_step, state
23 changes: 23 additions & 0 deletions dopamax/environments/pgx/connect_four.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pgx
from chex import dataclass

from dopamax.environments.utils import register
from dopamax.environments.pgx.base import PGXEnvironment

_NAME = "ConnectFour"


@register(_NAME)
@dataclass(frozen=True)
class ConnectFourEnvironment(PGXEnvironment):
def __init__(self):
pgx_env = pgx.make("connect_four")
super(ConnectFourEnvironment, self).__init__(_pgx_env=pgx_env)

@property
def max_episode_length(self) -> int:
return 7 * 6

@property
def name(self) -> str:
return _NAME
Empty file.
17 changes: 0 additions & 17 deletions dopamax/environments/two_player/base.py

This file was deleted.

150 changes: 0 additions & 150 deletions dopamax/environments/two_player/connect_four.py

This file was deleted.

1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"distrax>=0.1.3",
"dm-env>=1.6",
"dm-haiku>=0.0.9",
"pgx>=2.0.1",
"einops>=0.6.0",
"ffmpeg>=1.4",
"imageio>=2.25.1",
Expand Down

0 comments on commit 3e7dc59

Please sign in to comment.