Skip to content

Anything faster than cond for splicing together functions? #14860

Answered by jakevdp
dieterichlawson asked this question in Q&A
Discussion options

You must be logged in to vote

This is fine as long as x is a scalar. If x is a vector you can use lax.select, or jnp.where which is a wrapper around that:

def exprel(x):
  return jnp.where(x == 0, 1, jnp.expm1(x) / x))

This is the pattern used throughout the JAX source code for this kind of thing.

Side-note, if you're interested in autodiff over this function, you may also need a custom_jvp to correctly define the derivative at x = 0; there's a tutorial on this here: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@jakevdp
Comment options

@dieterichlawson
Comment options

@dieterichlawson
Comment options

@jakevdp
Comment options

Answer selected by dieterichlawson
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants