Skip to content

JAX release v0.4.13

Compare
Choose a tag to compare
@skye skye released this 23 Jun 00:53
· 7629 commits to main since this release

NOTE: This is the last JAX release that will include Python 3.8 support

  • Changes

    • jax.jit now allows None to be passed to in_shardings and
      out_shardings. The semantics are as follows:
      • For in_shardings, JAX will mark is as replicated but this behavior
        can change in the future.
      • For out_shardings, we will rely on the XLA GSPMD partitioner to
        determine the output shardings.
    • jax.experimental.pjit.pjit also allows None to be passed to
      in_shardings and out_shardings. The semantics are as follows:
      • If the mesh context manager is not provided, JAX has the freedom to
        choose whatever sharding it wants.
        • For in_shardings, JAX will mark is as replicated but this behavior
          can change in the future.
        • For out_shardings, we will rely on the XLA GSPMD partitioner to
          determine the output shardings.
      • If the mesh context manager is provided, None will imply that the value
        will be replicated on all devices of the mesh.
    • Executable.cost_analysis() works on Cloud TPU
    • Added a warning if a non-allowlisted jaxlib plugin is in use.
    • Added jax.tree_util.tree_leaves_with_path.
  • Bug fixes

    • Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
      is named cudnn89 instead of cudnn88.
  • Deprecations

    • The native_serialization_strict_checks parameter to
      {func}jax.experimental.jax2tf.convert is deprecated in favor of the
      new native_serializaation_disabled_checks ({jax-issue}#16347).