From 3cef65374f0742c0fdd396dfa10da536464df60f Mon Sep 17 00:00:00 2001 From: Ryan Strauss Date: Wed, 9 Mar 2022 07:49:04 -0500 Subject: [PATCH] EMA params bug fix --- bax/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bax/trainer.py b/bax/trainer.py index bb9f24d..eef55b0 100644 --- a/bax/trainer.py +++ b/bax/trainer.py @@ -226,7 +226,7 @@ def _grad_step( new_params, ) new_ema_params = jmp.select_tree( - should_skip, new_ema_params, train_state.ema_params + should_skip, train_state.ema_params, new_ema_params ) else: new_ema_params = train_state.ema_params