diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index 3fd8eb13e..81c5d2256 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -351,7 +351,7 @@ def legal_labels(label): a1 = jax.vmap(legal_normal_moves)(possible_piece_positions).flatten() a2 = legal_en_passants() actions = jnp.hstack((a1, a2)) # include -1 - ixs = jnp.nonzero(actions >= 0, size=128, fill_value=0)[0] # 128: random large number + ixs = jnp.nonzero(actions >= 0, size=200, fill_value=0)[0] # 200: random large number actions = actions[ixs] actions = jnp.where(jax.vmap(is_not_checked)(actions), actions, -1) mask = jnp.zeros(64 * 73 + 1, dtype=jnp.bool_) # +1 for sentinel