Skip to content

Whether the hlo generated by jax through jax.fit.lower() is optimized by xla #15899

Answered by jakevdp
songh11 asked this question in Q&A
Discussion options

You must be logged in to vote

jit(f).lower(...) returns un-optimized HLO. If you call jit(f).lower(...).compile(), it will be passed through the compiler for optimization. Here's a simple example:

import jax

def f(x):
  return jax.lax.sin(x) + jax.lax.sin(x)

lowered = jax.jit(f).lower(1.0)

print(lowered.as_text())
module @jit_f {
  func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
    %0 = stablehlo.sine %arg0 : tensor<f32>
    %1 = stablehlo.sine %arg0 : tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    return %2 : tensor<f32>
  }
}

Notice that there are two duplicate sine computations. Now if you compile it you get …

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by songh11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants