You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently we have been unable to reproduce the schedule free adamw results with JAX.
There seem to be differences between the optax implementation of schedule-free adamw and the pytorch submission.
The text was updated successfully, but these errors were encountered:
I can help debug any issues here. Do you have any code you can share? If there are issues with the optax jax implementation I want to get it fixed asap.
There are many small differences between the behavior of schedule-free jax wrapper and the original algoperf submission. Some differences I'm aware of:
The bias correction in the submission scales the weight decay at early steps. This is slightly faster for fastMRI but doesn't appear to affect any other workloads in my experiments.
Weight decay is applied at y in the Jax version. This decay-at-y version is very similar in my experiments, if not slightly better (when testing in PyTorch). The experiments in the schedule-free paper use this decay-at-y version.
There is a r=0.5 weighting in the submission version - this seems to make little if any difference in practice (hard to tell due to noise).
So overall I expect the jax wrapper version to give as good results on all problems (maybe slightly slower on fastmrI), so if there is a difference it would be from some sort of bug.
Description
Currently we have been unable to reproduce the schedule free adamw results with JAX.
There seem to be differences between the optax implementation of schedule-free adamw and the pytorch submission.
The text was updated successfully, but these errors were encountered: