-
I'm a little surprised by the result of grad. I've got a piecewise defined function and when evaluating the grad transformation I'm getting nans at a position where the function is perfectly well defined. Here's a snippet to reproduce this issue. import jax
def f(x):
return jax.numpy.where(
x < 1.0,
0.0,
jax.numpy.log(x)
)
df = jax.grad(f)
print(df(-0.1))
print(df(0.0))
print(df(0.1)) gives
Is this expected behaviour? I would assume that autograd should be perfectly fine handling points around 0. |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Feb 22, 2021
Replies: 1 comment 1 reply
-
This section of the FAQ might be helpful: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where Please let me know if that does not answer your question! |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
wulu473
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This section of the FAQ might be helpful: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
Please let me know if that does not answer your question!