Skip to content

jax.numpy.arcsin(2+0j) shows different value of numpy.arcsin(2+0j) #15927

Closed Answered by jakevdp
sofomryu asked this question in Q&A
Discussion options

You must be logged in to vote

No, JAX doesn't offer any fine-grained control over the branch that is used for particular inputs to jnp.arcsin. You could always define your own version of arcsin with a different branch convention, e.g.

def my_arcsin(z):
  result = jnp.arcsin(z)
  return jnp.where(jnp.imag(z) == 0, result.conj(), result)

But I would not recommend this as a fix. It would be fixing the symptom rather than fixing the root cause of the problem, which is that your simulation is too sensitive to particular branch cutoff conventions of complex functions.

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@sofomryu
Comment options

@jakevdp
Comment options

Answer selected by sofomryu
@hawkinsp
Comment options

@sofomryu
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants