Difference in reverse and forward mode autodiff JAX #24895
Unanswered
SanderSchomaker
asked this question in
Q&A
Replies: 1 comment
-
Hi - without example code it's hard to give a detailed answer, but depending on the computation it's not unreasonable that two different ways of computing a value might lead to differences of 1 part in 10^4. Here's an numpy example of computing the same sum in two different ways, leading to a difference of one part in 10^4: import numpy as np
rng = np.random.default_rng(0)
x = rng.normal(size=1000000).astype('float32')
x = x[np.argsort(abs(x))]
result1 = x.sum() - 1000
result2 = x[::-1].sum() - 1000
print(result1)
# -1.42938232421875
print(result2)
# -1.42919921875
print(result1 - result2)
# -0.00018310546875 All that to say, things may just be working as expected for floating point math. If it's not, we'll need some kind of reproduction demonstrating the bug. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
For a project, I compute the gradient of a cost function with reverse and forward mode JAX autodiff. It is expected
that both reverse and forward autodiff should give the same gradient (until machine precision). After I introduced
jax.tree_util.register_dataclass with data_fields and meta_fields and computed the gradient there is a difference
between the gradient computed in reverse and forward mode autodiff (difference first element of the gradients:
3.58691570e+03 - 3.58668399e+03 = 0.23171, larger than machine precision). The reverse mode is still working as
expected but the forward mode is different (compared to previous versions of my code).
Unfortunately, I am not allowed to share the code and I am also not able to reproduce this difference with an example.
Is there anyone encountered the same problem? Or are there any suggestions on how this problem can be solved?
What can be the reason causing this difference?
Beta Was this translation helpful? Give feedback.
All reactions