Skip to content

Does JAX's XLA compiler optimize away sqrt followed by square? #7109

Answered by hawkinsp
NeilGirdhar asked this question in General
Discussion options

You must be logged in to vote

Here's a semi-internal way to look at the compiled HLO:

import jax, jax.numpy as jnp

def f(x):
 return jnp.square(jnp.abs(x))

x = jnp.array(1+1j)
c = jax.xla_computation(f)(x)

backend = jax.lib.xla_bridge.get_backend()
e = backend.compile(c)
print(e.hlo_modules()[0].to_string())

Output on CPU:

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
HloModule xla_computation_f.6

%fused_computation (param_0.1: c64[]) -> f32[] {
  %param_0.1 = c64[] parameter(0)
  %abs.0 = f32[] abs(c64[] %param_0.1), metadata={op_type="abs" op_name="xla_computation(f)/abs" source_file="/Users/phawkins/p/jax/t.py" source_line=4}
  ROOT %multiply.0 = f32[…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@NeilGirdhar
Comment options

Answer selected by NeilGirdhar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants