From 91bb7c2d4f4903a9b60ac484d1838944dd7f2e20 Mon Sep 17 00:00:00 2001 From: Ryan Strauss Date: Wed, 27 Dec 2023 20:49:02 -0500 Subject: [PATCH] AlphaZero comput_action tweaks --- .gitignore | 1 + dopamax/agents/anakin/alphazero.py | 13 ++++++++++--- dopamax/environments/__init__.py | 2 +- dopamax/environments/pgx/__init__.py | 2 +- dopamax/environments/pgx/connect_four.py | 4 ++-- 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index e7fee4b..1f38ce6 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ wandb artifacts build dopamax.egg-info +.ipynb_checkpoints \ No newline at end of file diff --git a/dopamax/agents/anakin/alphazero.py b/dopamax/agents/anakin/alphazero.py index f882328..0c205cd 100644 --- a/dopamax/agents/anakin/alphazero.py +++ b/dopamax/agents/anakin/alphazero.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Tuple, Dict +from typing import Tuple, Dict, Optional import haiku as hk import jax @@ -178,7 +178,13 @@ def default_config() -> ConfigDict: return config def compute_action( - self, params: hk.Params, key: PRNGKey, observation: Observation, env_state: EnvState, deterministic: bool = True + self, + params: hk.Params, + key: PRNGKey, + observation: Observation, + env_state: EnvState, + deterministic: bool = True, + num_simulations: Optional[int] = None, ) -> Action: model_key, search_key = jax.random.split(key) observation = jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), observation) @@ -199,13 +205,14 @@ def compute_action( rng_key=search_key, root=root, recurrent_fn=self._recurrent_fn, - num_simulations=self.config.num_simulations, + num_simulations=num_simulations or self.config.num_simulations, invalid_actions=invalid_actions, max_depth=self.config.max_depth, dirichlet_fraction=self.config.root_exploration_fraction, dirichlet_alpha=self.config.root_dirichlet_alpha, pb_c_base=self.config.pb_c_base, pb_c_init=self.config.pb_c_init, + temperature=0.0 if deterministic else 1.0, ) policy_output, value = jax.tree_map(lambda x: jnp.squeeze(x, axis=0), (policy_output, value)) diff --git a/dopamax/environments/__init__.py b/dopamax/environments/__init__.py index a5a3864..204d74b 100644 --- a/dopamax/environments/__init__.py +++ b/dopamax/environments/__init__.py @@ -4,5 +4,5 @@ from .cartpole import CartPole from .mountain_car import MountainCar from .mountain_car_continuous import MountainCarContinuous -from .pgx import ConnectFourEnvironment +from .pgx import ConnectFour from .utils import make_env diff --git a/dopamax/environments/pgx/__init__.py b/dopamax/environments/pgx/__init__.py index 8b1a509..3fe1a75 100644 --- a/dopamax/environments/pgx/__init__.py +++ b/dopamax/environments/pgx/__init__.py @@ -1 +1 @@ -from .connect_four import ConnectFourEnvironment +from .connect_four import ConnectFour diff --git a/dopamax/environments/pgx/connect_four.py b/dopamax/environments/pgx/connect_four.py index 0c91b09..7eb0175 100644 --- a/dopamax/environments/pgx/connect_four.py +++ b/dopamax/environments/pgx/connect_four.py @@ -9,10 +9,10 @@ @register(_NAME) @dataclass(frozen=True) -class ConnectFourEnvironment(PGXEnvironment): +class ConnectFour(PGXEnvironment): def __init__(self): pgx_env = pgx.make("connect_four") - super(ConnectFourEnvironment, self).__init__(_pgx_env=pgx_env) + super(ConnectFour, self).__init__(_pgx_env=pgx_env) @property def max_episode_length(self) -> int: