From be09383b0b1902983b783161babf9af46d85a166 Mon Sep 17 00:00:00 2001 From: bkorpan <43765130+bkorpan@users.noreply.github.com> Date: Fri, 17 May 2024 01:51:44 -0700 Subject: [PATCH] [Chess] fix dummy observation (#1186) --- pgx/chess.py | 2 +- pgx/gardner_chess.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pgx/chess.py b/pgx/chess.py index 3c5f16ab6..be66db569 100644 --- a/pgx/chess.py +++ b/pgx/chess.py @@ -121,7 +121,7 @@ class State(core.State): terminated: Array = FALSE truncated: Array = FALSE legal_action_mask: Array = INIT_LEGAL_ACTION_MASK # 64 * 73 = 4672 - observation: Array = jnp.zeros((8, 8, 19), dtype=jnp.float32) + observation: Array = jnp.zeros((8, 8, 119), dtype=jnp.float32) _step_count: Array = jnp.int32(0) # --- Chess specific --- _turn: Array = jnp.int32(0) diff --git a/pgx/gardner_chess.py b/pgx/gardner_chess.py index b19d30e0b..62ca70933 100644 --- a/pgx/gardner_chess.py +++ b/pgx/gardner_chess.py @@ -84,7 +84,7 @@ class State(core.State): terminated: Array = FALSE truncated: Array = FALSE legal_action_mask: Array = INIT_LEGAL_ACTION_MASK - observation: Array = jnp.zeros((5, 5, 19), dtype=jnp.float32) + observation: Array = jnp.zeros((5, 5, 115), dtype=jnp.float32) _step_count: Array = jnp.int32(0) # --- Chess specific --- _turn: Array = jnp.int32(0)