-
I see from the documentation that hlo is input to the xla compiler, so is hlo not XLA-optimized? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
May 6, 2023
Replies: 1 comment
-
import jax
def f(x):
return jax.lax.sin(x) + jax.lax.sin(x)
lowered = jax.jit(f).lower(1.0)
print(lowered.as_text())
Notice that there are two duplicate print(lowered.compile().as_text())
Notice that the compiler has de-duplicated the two sine computations. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
songh11
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
jit(f).lower(...)
returns un-optimized HLO. If you calljit(f).lower(...).compile()
, it will be passed through the compiler for optimization. Here's a simple example:Notice that there are two duplicate
sine
computations. Now if you compile it you get …