Skip to content

Commit

Permalink
Update AlphaZero for pgx
Browse files Browse the repository at this point in the history
  • Loading branch information
rystrauss committed Dec 27, 2023
1 parent 3e7dc59 commit 98af28d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 19 deletions.
24 changes: 11 additions & 13 deletions dopamax/agents/anakin/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dopamax.agents.anakin.base import AnakinAgent, AnakinTrainState, AnakinTrainStateWithReplayBuffer
from dopamax.agents.utils import register
from dopamax.environments.environment import Environment, EnvState
from dopamax.environments.two_player.base import TwoPlayerZeroSumEnvironment
from dopamax.environments.pgx.base import PGXEnvironment
from dopamax.networks import get_network_build_fn, get_actor_critic_model_fn
from dopamax.rollouts import SampleBatch, rollout_truncated
from dopamax.typing import Metrics, Observation, Action
Expand Down Expand Up @@ -54,7 +54,7 @@ class AlphaZero(AnakinAgent):
def __init__(self, env: Environment, config: ConfigDict):
super().__init__(env, config)

assert isinstance(env, TwoPlayerZeroSumEnvironment), "AlphaZero only supports `TwoPlayerZeroSumEnvironment`s."
assert isinstance(env, PGXEnvironment), "AlphaZero only supports `PGXEnvironment`s."

network_build_fn = get_network_build_fn(self.config.network, **self.config.network_config)
model_fn = get_actor_critic_model_fn(
Expand All @@ -80,13 +80,17 @@ def recurrent_fn(
)

pi, value = self._model.apply(params, model_key, next_time_step.observation["observation"])
prior_logits = pi.logits - 1e10 * next_time_step.observation["invalid_actions"]

prior_logits = pi.logits - jnp.max(pi.logits, axis=-1, keepdims=True)
prior_logits = jnp.where(
next_time_step.observation["invalid_actions"], jnp.finfo(prior_logits.dtype).min, prior_logits
)

value = jnp.where(next_time_step.discount == 0.0, 0.0, value)

recurrent_fn_output = mctx.RecurrentFnOutput(
reward=next_time_step.reward,
discount=next_time_step.discount,
discount=-next_time_step.discount,
prior_logits=prior_logits,
value=value,
)
Expand Down Expand Up @@ -293,21 +297,15 @@ def update_scan_fn(carry, _):

metrics = jax.tree_map(jnp.mean, metrics)

incremental_episodes, episode_metrics = self._get_episode_metrics(rollout_data)
incremental_episodes = self._maybe_all_reduce("psum", incremental_episodes)
incremental_timesteps = (
self.config.rollout_fragment_length * self.config.num_envs_per_device * self.config.num_devices
)

next_train_state = train_state.update(
new_key=next_train_state_key,
incremental_timesteps=incremental_timesteps,
incremental_episodes=incremental_episodes,
rollout_data=rollout_data,
new_params=new_params,
new_opt_state=new_opt_state,
new_time_step=new_time_step,
new_env_state=new_env_state,
new_buffer_state=new_buffer_state,
maybe_all_reduce_fn=self._maybe_all_reduce,
)

return next_train_state, (metrics, episode_metrics)
return next_train_state, metrics
21 changes: 15 additions & 6 deletions dopamax/environments/pgx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dm_env import StepType

from dopamax.environments.environment import Environment, EnvState, TimeStep
from dopamax.spaces import Box, Space, Discrete
from dopamax.spaces import Box, Space, Discrete, Dict
from dopamax.typing import Action


Expand All @@ -27,10 +27,12 @@ def reset(self, key: PRNGKey) -> Tuple[TimeStep, PGXEnvState]:
pgx_state = self._pgx_env.init(key)

time_step = TimeStep.restart(
pgx_state.observation.astype(jnp.float32),
{
"observation": pgx_state.observation.astype(jnp.float32),
"invalid_actions": (~pgx_state.legal_action_mask).astype(jnp.float32),
},
{
"current_player": pgx_state.current_player,
"legal_action_mask": pgx_state.legal_action_mask,
},
)
env_state = PGXEnvState(
Expand All @@ -43,7 +45,12 @@ def reset(self, key: PRNGKey) -> Tuple[TimeStep, PGXEnvState]:

@property
def observation_space(self) -> Space:
return Box(low=-jnp.inf, high=jnp.inf, shape=self._pgx_env.observation_shape, dtype=jnp.float32)
return Dict(
{
"observation": Box(low=-jnp.inf, high=jnp.inf, shape=self._pgx_env.observation_shape),
"invalid_actions": Box(low=0, high=1, shape=(self._pgx_env.num_actions,)),
}
)

@property
def action_space(self) -> Space:
Expand All @@ -67,13 +74,15 @@ def step(self, key: PRNGKey, state: PGXEnvState, action: Action) -> Tuple[TimeSt
truncate = jnp.squeeze(jnp.bool_(new_pgx_state.truncated))

time_step = TimeStep(
observation=new_pgx_state.observation.astype(jnp.float32),
observation={
"observation": new_pgx_state.observation.astype(jnp.float32),
"invalid_actions": (~new_pgx_state.legal_action_mask).astype(jnp.float32),
},
reward=reward,
discount=1.0 - jnp.float32(done & ~truncate),
step_type=jax.lax.select(done, StepType.LAST, StepType.MID),
info={
"current_player": new_pgx_state.current_player,
"legal_action_mask": new_pgx_state.legal_action_mask,
},
)

Expand Down

0 comments on commit 98af28d

Please sign in to comment.