diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index 50e92c37b..3fd8eb13e 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -351,7 +351,7 @@ def legal_labels(label): a1 = jax.vmap(legal_normal_moves)(possible_piece_positions).flatten() a2 = legal_en_passants() actions = jnp.hstack((a1, a2)) # include -1 - ixs = jnp.nonzero(actions >= 0, size=16 * 19, fill_value=0)[0] # 16 * 27 = 432 -> 16 * 19 = 304 + ixs = jnp.nonzero(actions >= 0, size=128, fill_value=0)[0] # 128: random large number actions = actions[ixs] actions = jnp.where(jax.vmap(is_not_checked)(actions), actions, -1) mask = jnp.zeros(64 * 73 + 1, dtype=jnp.bool_) # +1 for sentinel