Skip to content

Commit

Permalink
EMA params bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rystrauss committed Mar 9, 2022
1 parent 818e187 commit 3cef653
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion bax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _grad_step(
new_params,
)
new_ema_params = jmp.select_tree(
should_skip, new_ema_params, train_state.ema_params
should_skip, train_state.ema_params, new_ema_params
)
else:
new_ema_params = train_state.ema_params
Expand Down

0 comments on commit 3cef653

Please sign in to comment.