diff --git a/pgx/chess.py b/pgx/chess.py index 3c5f16ab6..be66db569 100644 --- a/pgx/chess.py +++ b/pgx/chess.py @@ -121,7 +121,7 @@ class State(core.State): terminated: Array = FALSE truncated: Array = FALSE legal_action_mask: Array = INIT_LEGAL_ACTION_MASK # 64 * 73 = 4672 - observation: Array = jnp.zeros((8, 8, 19), dtype=jnp.float32) + observation: Array = jnp.zeros((8, 8, 119), dtype=jnp.float32) _step_count: Array = jnp.int32(0) # --- Chess specific --- _turn: Array = jnp.int32(0) diff --git a/pgx/gardner_chess.py b/pgx/gardner_chess.py index b19d30e0b..62ca70933 100644 --- a/pgx/gardner_chess.py +++ b/pgx/gardner_chess.py @@ -84,7 +84,7 @@ class State(core.State): terminated: Array = FALSE truncated: Array = FALSE legal_action_mask: Array = INIT_LEGAL_ACTION_MASK - observation: Array = jnp.zeros((5, 5, 19), dtype=jnp.float32) + observation: Array = jnp.zeros((5, 5, 115), dtype=jnp.float32) _step_count: Array = jnp.int32(0) # --- Chess specific --- _turn: Array = jnp.int32(0)