How can I stop gradient when losses include nan #14217
Answered
by
jakevdp
garam-kim1
asked this question in
Q&A
-
When I run below code, the result is key = random.PRNGKey(0)
def loss(x):
nan_mask = random.uniform(key, x.shape) > 0.5
x = x * 2.0
x = x / (~nan_mask)
x = jnp.nan_to_num(x)
x = jnp.where(nan_mask, 0, x)
x = (1 - nan_mask) * x
x = jnp.where(nan_mask, lax.stop_gradient(x), x)
return x.mean()
print(jax.grad(loss)(jnp.ones(10,))) What I expected is Is there any dynamic way to stop(or ignore) gradient when losses include nan? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Jan 31, 2023
Replies: 1 comment
-
This looks related to the situation covered in the following FAQ entry: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where But if your goal is to simply have the specified entries in def loss(x):
nan_mask = random.uniform(key, x.shape) > 0.5
x = x * 2.0
x = jnp.where(nan_mask, 0, x)
return x.mean()
print(jax.grad(loss)(jnp.ones(10,)))
# [0.2 0. 0.2 0.2 0.2 0.2 0.2 0.2 0. 0. ] Is that the output you're hoping to see? |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
garam-kim1
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This looks related to the situation covered in the following FAQ entry: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
But if your goal is to simply have the specified entries in
x
not contribute to the gradient, you can do so by zeroing them out:Is that the output you're hoping to see?