Skip to content

Commit

Permalink
[Example] Replace deprecated jax.tree_map function with jax.tree_util…
Browse files Browse the repository at this point in the history
….tree_map (#1188)
  • Loading branch information
KazukiOhta authored Jun 18, 2024
1 parent 51e9055 commit d485956
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/alphazero/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,16 +336,16 @@ def body_fn(val):
samples = jax.tree_util.tree_map(lambda x: x.reshape((-1, *x.shape[3:])), samples)
rng_key, subkey = jax.random.split(rng_key)
ixs = jax.random.permutation(subkey, jnp.arange(samples.obs.shape[0]))
samples = jax.tree_map(lambda x: x[ixs], samples) # shuffle
samples = jax.tree_util.tree_map(lambda x: x[ixs], samples) # shuffle
num_updates = samples.obs.shape[0] // config.training_batch_size
minibatches = jax.tree_map(
minibatches = jax.tree_util.tree_map(
lambda x: x.reshape((num_updates, num_devices, -1) + x.shape[1:]), samples
)

# Training
policy_losses, value_losses = [], []
for i in range(num_updates):
minibatch: Sample = jax.tree_map(lambda x: x[i], minibatches)
minibatch: Sample = jax.tree_util.tree_map(lambda x: x[i], minibatches)
model, opt_state, policy_loss, value_loss = train(model, opt_state, minibatch)
policy_losses.append(policy_loss.mean().item())
value_losses.append(value_loss.mean().item())
Expand Down

0 comments on commit d485956

Please sign in to comment.