Skip to content

JAX v0.4.29

Compare
Choose a tag to compare
@hawkinsp hawkinsp released this 10 Jun 18:31
· 2960 commits to main since this release
  • Changes

    • We anticipate that this will be the last release of JAX and jaxlib
      supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
      plugin jaxlib (e.g. pip install jax[cuda12]).
    • JAX now requires ml_dtypes version 0.4.0 or newer.
    • Removed backwards-compatibility support for old usage of the
      jax.experimental.export API. It is not possible anymore to use
      from jax.experimental.export import export, and instead you should use
      from jax.experimental import export.
      The removed functionality has been deprecated since 0.4.24.
  • Deprecations

    • jax.sharding.XLACompatibleSharding is deprecated. Please use
      jax.sharding.Sharding.
    • jax.experimental.Exported.in_shardings has been renamed as
      jax.experimental.Exported.in_shardings_hlo. Same for out_shardings.
      The old names will be removed after 3 months.
    • Removed a number of previously-deprecated APIs:
      • from {mod}jax.core: non_negative_dim, DimSize, Shape
      • from {mod}jax.lax: tie_in
      • from {mod}jax.nn: normalize
      • from {mod}jax.interpreters.xla: backend_specific_translations,
        translations, register_translation, xla_destructure,
        TranslationRule, TranslationContext, XlaOp.
    • The tol argument of {func}jax.numpy.linalg.matrix_rank is being
      deprecated and will soon be removed. Use rtol instead.
    • The rcond argument of {func}jax.numpy.linalg.pinv is being
      deprecated and will soon be removed. Use rtol instead.
    • The deprecated jax.config submodule has been removed. To configure JAX
      use import jax and then reference the config object via jax.config.
    • {mod}jax.random APIs no longer accept batched keys, where previously
      some did unintentionally. Going forward, we recommend explicit use of
      {func}jax.vmap in such cases.
  • New Functionality

    • Added {func}jax.experimental.Exported.in_shardings_jax to construct
      shardings that can be used with the JAX APIs from the HloShardings
      that are stored in the Exported objects.