Skip to content

How can I stop gradient when losses include nan #14217

Answered by jakevdp
garam-kim1 asked this question in Q&A
Discussion options

You must be logged in to vote

This looks related to the situation covered in the following FAQ entry: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

But if your goal is to simply have the specified entries in x not contribute to the gradient, you can do so by zeroing them out:

def loss(x):
    nan_mask = random.uniform(key, x.shape) > 0.5
    x = x * 2.0
    x = jnp.where(nan_mask, 0, x)
    return x.mean()

print(jax.grad(loss)(jnp.ones(10,)))
# [0.2 0.  0.2 0.2 0.2 0.2 0.2 0.2 0.  0. ]

Is that the output you're hoping to see?

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by garam-kim1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants