JAX release v0.4.13
NOTE: This is the last JAX release that will include Python 3.8 support
-
Changes
jax.jit
now allowsNone
to be passed toin_shardings
and
out_shardings
. The semantics are as follows:- For in_shardings, JAX will mark is as replicated but this behavior
can change in the future. - For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
- For in_shardings, JAX will mark is as replicated but this behavior
jax.experimental.pjit.pjit
also allowsNone
to be passed to
in_shardings
andout_shardings
. The semantics are as follows:- If the mesh context manager is not provided, JAX has the freedom to
choose whatever sharding it wants.- For in_shardings, JAX will mark is as replicated but this behavior
can change in the future. - For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
- For in_shardings, JAX will mark is as replicated but this behavior
- If the mesh context manager is provided, None will imply that the value
will be replicated on all devices of the mesh.
- If the mesh context manager is not provided, JAX has the freedom to
- Executable.cost_analysis() works on Cloud TPU
- Added a warning if a non-allowlisted
jaxlib
plugin is in use. - Added
jax.tree_util.tree_leaves_with_path
.
-
Bug fixes
- Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
is namedcudnn89
instead ofcudnn88
.
- Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
-
Deprecations
- The
native_serialization_strict_checks
parameter to
{func}jax.experimental.jax2tf.convert
is deprecated in favor of the
newnative_serializaation_disabled_checks
({jax-issue}#16347
).
- The