-
jax.numpy.arcsin(2+0j) = 1.57-1.31j The two library shows complex conjugate value for arcsin(2+0j). In my simulation, the sequential computation of sin function and arcsin function is required. If I do this for x=2+0j with jax.numpy, arcsin(2+0j)=1.57-1.31j, then sine of this is 2+0.00...0001j, then arcsin of this 1.57+1.31j. If I do same thing with numpy, arcsin(2+0j)=1.57+1.31j, then sine of this is 2+0.0..001j, then arcsin of this is 1.57+1.31j. If I do the same thing with Matlab, arcsin(2+0j)=1.57-1.31j, then sine of this is 2-0.0...0001j, then arcsin of this is 1.57-1.31j. My simulation result is very sensitive for this problem. For my simulation, the exact way is the way matlab does. Why are they different? and how can I implement the matlab way using jax.numpy? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Hi, thanks for the question! The If your simulation is sensitive to which branch a particular arcsin implementation returns, then you might want to think about rewriting it to make the computation more robust. |
Beta Was this translation helpful? Give feedback.
No, JAX doesn't offer any fine-grained control over the branch that is used for particular inputs to
jnp.arcsin
. You could always define your own version ofarcsin
with a different branch convention, e.g.But I would not recommend this as a fix. It would be fixing the symptom rather than fixing the root cause of the problem, which is that your simulation is too sensitive to particular branch cutoff conventions of complex functions.