diff --git a/dopamax/agents/anakin/alphazero.py b/dopamax/agents/anakin/alphazero.py index 89cad63..f882328 100644 --- a/dopamax/agents/anakin/alphazero.py +++ b/dopamax/agents/anakin/alphazero.py @@ -14,7 +14,7 @@ from dopamax.agents.anakin.base import AnakinAgent, AnakinTrainState, AnakinTrainStateWithReplayBuffer from dopamax.agents.utils import register from dopamax.environments.environment import Environment, EnvState -from dopamax.environments.two_player.base import TwoPlayerZeroSumEnvironment +from dopamax.environments.pgx.base import PGXEnvironment from dopamax.networks import get_network_build_fn, get_actor_critic_model_fn from dopamax.rollouts import SampleBatch, rollout_truncated from dopamax.typing import Metrics, Observation, Action @@ -54,7 +54,7 @@ class AlphaZero(AnakinAgent): def __init__(self, env: Environment, config: ConfigDict): super().__init__(env, config) - assert isinstance(env, TwoPlayerZeroSumEnvironment), "AlphaZero only supports `TwoPlayerZeroSumEnvironment`s." + assert isinstance(env, PGXEnvironment), "AlphaZero only supports `PGXEnvironment`s." network_build_fn = get_network_build_fn(self.config.network, **self.config.network_config) model_fn = get_actor_critic_model_fn( @@ -80,13 +80,17 @@ def recurrent_fn( ) pi, value = self._model.apply(params, model_key, next_time_step.observation["observation"]) - prior_logits = pi.logits - 1e10 * next_time_step.observation["invalid_actions"] + + prior_logits = pi.logits - jnp.max(pi.logits, axis=-1, keepdims=True) + prior_logits = jnp.where( + next_time_step.observation["invalid_actions"], jnp.finfo(prior_logits.dtype).min, prior_logits + ) value = jnp.where(next_time_step.discount == 0.0, 0.0, value) recurrent_fn_output = mctx.RecurrentFnOutput( reward=next_time_step.reward, - discount=next_time_step.discount, + discount=-next_time_step.discount, prior_logits=prior_logits, value=value, ) @@ -293,21 +297,15 @@ def update_scan_fn(carry, _): metrics = jax.tree_map(jnp.mean, metrics) - incremental_episodes, episode_metrics = self._get_episode_metrics(rollout_data) - incremental_episodes = self._maybe_all_reduce("psum", incremental_episodes) - incremental_timesteps = ( - self.config.rollout_fragment_length * self.config.num_envs_per_device * self.config.num_devices - ) - next_train_state = train_state.update( new_key=next_train_state_key, - incremental_timesteps=incremental_timesteps, - incremental_episodes=incremental_episodes, + rollout_data=rollout_data, new_params=new_params, new_opt_state=new_opt_state, new_time_step=new_time_step, new_env_state=new_env_state, new_buffer_state=new_buffer_state, + maybe_all_reduce_fn=self._maybe_all_reduce, ) - return next_train_state, (metrics, episode_metrics) + return next_train_state, metrics diff --git a/dopamax/environments/pgx/base.py b/dopamax/environments/pgx/base.py index d44ff93..cbe9852 100644 --- a/dopamax/environments/pgx/base.py +++ b/dopamax/environments/pgx/base.py @@ -8,7 +8,7 @@ from dm_env import StepType from dopamax.environments.environment import Environment, EnvState, TimeStep -from dopamax.spaces import Box, Space, Discrete +from dopamax.spaces import Box, Space, Discrete, Dict from dopamax.typing import Action @@ -27,10 +27,12 @@ def reset(self, key: PRNGKey) -> Tuple[TimeStep, PGXEnvState]: pgx_state = self._pgx_env.init(key) time_step = TimeStep.restart( - pgx_state.observation.astype(jnp.float32), + { + "observation": pgx_state.observation.astype(jnp.float32), + "invalid_actions": (~pgx_state.legal_action_mask).astype(jnp.float32), + }, { "current_player": pgx_state.current_player, - "legal_action_mask": pgx_state.legal_action_mask, }, ) env_state = PGXEnvState( @@ -43,7 +45,12 @@ def reset(self, key: PRNGKey) -> Tuple[TimeStep, PGXEnvState]: @property def observation_space(self) -> Space: - return Box(low=-jnp.inf, high=jnp.inf, shape=self._pgx_env.observation_shape, dtype=jnp.float32) + return Dict( + { + "observation": Box(low=-jnp.inf, high=jnp.inf, shape=self._pgx_env.observation_shape), + "invalid_actions": Box(low=0, high=1, shape=(self._pgx_env.num_actions,)), + } + ) @property def action_space(self) -> Space: @@ -67,13 +74,15 @@ def step(self, key: PRNGKey, state: PGXEnvState, action: Action) -> Tuple[TimeSt truncate = jnp.squeeze(jnp.bool_(new_pgx_state.truncated)) time_step = TimeStep( - observation=new_pgx_state.observation.astype(jnp.float32), + observation={ + "observation": new_pgx_state.observation.astype(jnp.float32), + "invalid_actions": (~new_pgx_state.legal_action_mask).astype(jnp.float32), + }, reward=reward, discount=1.0 - jnp.float32(done & ~truncate), step_type=jax.lax.select(done, StepType.LAST, StepType.MID), info={ "current_player": new_pgx_state.current_player, - "legal_action_mask": new_pgx_state.legal_action_mask, }, )