JAX release v0.2.28
- GitHub commits.
jax.jit(f).lower(...).compiler_ir()
now defaults to the MHLO dialect if no
dialect=
is passed.- The
jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
now returns an MLIR
ir.Module
object instead of its string representation.