Anything faster than cond
for splicing together functions?
#14860
-
I am writing a version of Using
My main question is: are there any better or faster ways to accomplish this? I will be calling this function millions of times, so I'd like to make sure it's fast. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
This is fine as long as def exprel(x):
return jnp.where(x == 0, 1, jnp.expm1(x) / x)) This is the pattern used throughout the JAX source code for this kind of thing. Side-note, if you're interested in autodiff over this function, you may also need a |
Beta Was this translation helpful? Give feedback.
This is fine as long as
x
is a scalar. Ifx
is a vector you can uselax.select
, orjnp.where
which is a wrapper around that:This is the pattern used throughout the JAX source code for this kind of thing.
Side-note, if you're interested in autodiff over this function, you may also need a
custom_jvp
to correctly define the derivative at x = 0; there's a tutorial on this here: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html