Does JAX's XLA compiler optimize away sqrt followed by square? #7109
Answered
by
hawkinsp
NeilGirdhar
asked this question in
General
-
x = jnp.array(1+1j)
y = jnp.square(jnp.abs(x)) |
Beta Was this translation helpful? Give feedback.
Answered by
hawkinsp
Jun 25, 2021
Replies: 1 comment 1 reply
-
Here's a semi-internal way to look at the compiled HLO:
Output on CPU:
So no, it looks like XLA left both operations alone, at least at this level of the compiler. |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
NeilGirdhar
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here's a semi-internal way to look at the compiled HLO:
Output on CPU: