-
I'm working with optimizers in jax.example_libraries I try with two different learning rates (i.e., 0.01 and 0.005)
Both of them update for the same gradient (i.e., grad_1). Then, I calculate the sum of absolute difference of all parameters. However, the absolute difference when using learning rate of 0.01 is not twice of the absolute difference of learning rate 0.005
Thank you so much for your help. Appreciate it! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Thanks for the question! Could you share a fully runnable repro? As you can see from the code, the sgd update isn't doing much. It's hard to guess without having a repro, but it could just be standard precision loss from floating point accumulation. If so, you may be able to reorder the computations to be more stable, but an easy way to check would be to set |
Beta Was this translation helpful? Give feedback.
Thanks for the question! Could you share a fully runnable repro?
As you can see from the code, the sgd update isn't doing much.
It's hard to guess without having a repro, but it could just be standard precision loss from floating point accumulation. If so, you may be able to reorder the computations to be more stable, but an easy way to check would be to set
jax.config.update('jax_enable_x64', True)
at the top of your file and see if the answers change.