diff --git a/bax/trainer.py b/bax/trainer.py index bb9f24d..eef55b0 100644 --- a/bax/trainer.py +++ b/bax/trainer.py @@ -226,7 +226,7 @@ def _grad_step( new_params, ) new_ema_params = jmp.select_tree( - should_skip, new_ema_params, train_state.ema_params + should_skip, train_state.ema_params, new_ema_params ) else: new_ema_params = train_state.ema_params