Skip to content

Commit

Permalink
AlphaZero comput_action tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
rystrauss committed Dec 28, 2023
1 parent 2f0dd96 commit 91bb7c2
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ wandb
artifacts
build
dopamax.egg-info
.ipynb_checkpoints
13 changes: 10 additions & 3 deletions dopamax/agents/anakin/alphazero.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Tuple, Dict
from typing import Tuple, Dict, Optional

import haiku as hk
import jax
Expand Down Expand Up @@ -178,7 +178,13 @@ def default_config() -> ConfigDict:
return config

def compute_action(
self, params: hk.Params, key: PRNGKey, observation: Observation, env_state: EnvState, deterministic: bool = True
self,
params: hk.Params,
key: PRNGKey,
observation: Observation,
env_state: EnvState,
deterministic: bool = True,
num_simulations: Optional[int] = None,
) -> Action:
model_key, search_key = jax.random.split(key)
observation = jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), observation)
Expand All @@ -199,13 +205,14 @@ def compute_action(
rng_key=search_key,
root=root,
recurrent_fn=self._recurrent_fn,
num_simulations=self.config.num_simulations,
num_simulations=num_simulations or self.config.num_simulations,
invalid_actions=invalid_actions,
max_depth=self.config.max_depth,
dirichlet_fraction=self.config.root_exploration_fraction,
dirichlet_alpha=self.config.root_dirichlet_alpha,
pb_c_base=self.config.pb_c_base,
pb_c_init=self.config.pb_c_init,
temperature=0.0 if deterministic else 1.0,
)

policy_output, value = jax.tree_map(lambda x: jnp.squeeze(x, axis=0), (policy_output, value))
Expand Down
2 changes: 1 addition & 1 deletion dopamax/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
from .cartpole import CartPole
from .mountain_car import MountainCar
from .mountain_car_continuous import MountainCarContinuous
from .pgx import ConnectFourEnvironment
from .pgx import ConnectFour
from .utils import make_env
2 changes: 1 addition & 1 deletion dopamax/environments/pgx/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .connect_four import ConnectFourEnvironment
from .connect_four import ConnectFour
4 changes: 2 additions & 2 deletions dopamax/environments/pgx/connect_four.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

@register(_NAME)
@dataclass(frozen=True)
class ConnectFourEnvironment(PGXEnvironment):
class ConnectFour(PGXEnvironment):
def __init__(self):
pgx_env = pgx.make("connect_four")
super(ConnectFourEnvironment, self).__init__(_pgx_env=pgx_env)
super(ConnectFour, self).__init__(_pgx_env=pgx_env)

@property
def max_episode_length(self) -> int:
Expand Down

0 comments on commit 91bb7c2

Please sign in to comment.