diff --git a/pgx/shogi.py b/pgx/shogi.py index 31b953ce8..da7039f21 100644 --- a/pgx/shogi.py +++ b/pgx/shogi.py @@ -238,10 +238,9 @@ def _step(state: State, action: Array): state = jax.lax.cond(a.is_drop, _step_drop, _step_move, *(state, a)) # flip state state = _flip(state) - x = state._x._replace(turn=(state._x.turn + 1) % 2) state = state.replace( # type: ignore current_player=(state.current_player + 1) % 2, - _x=x, + _x=state._x._replace(turn=(state._x.turn + 1) % 2), ) legal_action_mask = _legal_action_mask(state) terminated = ~legal_action_mask.any()