Replies: 3 comments 3 replies
-
import jax
import jax.numpy as jnp
@jax.custom_jvp
def absolute(x):
return jnp.where(x >= 0, x, -x)
def assert_at0(x):
if x == 0.:
print("x is not derivable at 0")
raise AssertionError("x is not derivable at 0")
return x
@absolute.defjvp
def absolute_jvp(primals, tangents):
x, = primals
shape_dtype = jax.ShapeDtypeStruct(x.shape, x.dtype)
x = jax.pure_callback(assert_at0, shape_dtype, x)
x_dot, = tangents
return absolute(x), x_dot * jnp.sign(x)
jax.grad(jax.jit(absolute))(jnp.array(0.)) |
Beta Was this translation helpful? Give feedback.
-
The best way to do these kinds of runtime checks is with the experimental import jax
import jax.numpy as jnp
from jax.experimental import checkify
@checkify.checkify
def grad_abs(x):
checkify.check(x!=0, "x must be nonzero", x=x)
return jax.grad(jnp.abs)(x)
err, out = jax.jit(grad_abs)(0.)
err.throw() ---------------------------------------------------------------------------
FailedCheckError Traceback (most recent call last)
FailedCheckError: x must be nonzero (check failed at <ipython-input-2-239f912d24b1>:7 (grad_abs))
The above exception was the direct cause of the following exception:
...
JaxRuntimeError: x must be nonzero (check failed at <ipython-input-2-239f912d24b1>:7 (grad_abs)) The API is still experimental and subject to change, but this is the main mechanism available for this kind of check. |
Beta Was this translation helpful? Give feedback.
-
The workaround I settled on in the end is to make an "identity" function that is not everywhere differentiable, i.e., something like
and use that somewhere in my custom function to transform a variable In this way the code above ends up in the automatically generated gradient code of my function and I don not need to provide a manual vjp for my complicated function. Still, it would be nice if JAX offered a way to "attach" some custom code to the gradient computation of a function without having to define a custom vjp. |
Beta Was this translation helpful? Give feedback.
-
I know that for functions that are mathematically not everywhere differentiable (for example
np.abs()
) JAX tries to return a "reasonable" gradient anyway, i.e.jax.grad(np.abs)(0.)
returns1.
(and I think this is useful).But now I have a (jit compiled) function (much more complicated and thus expensive to evaluate and automatically differentiate than abs) that I also know does not have a well defined gradient in certain points. I would like to raise an exception and handle such cases in the calling code rather than having the function return a useless (because it is unstable) automatic derivative after a lot of expensive computation.
In autograd I could detect when my code is traced for differentiation with an ArrayBox and raise an exception.
Is there a way to do the same in JAX? Can I somehow put code into my function that doesn't get in the way of (and actually survives) jit compilation, allows the function to be evaluated everywhere normally, but makes it raise an exception when its derivative is calculated at certain points? Thanks!
Beta Was this translation helpful? Give feedback.
All reactions