JAX v0.4.34
-
New Functionality
- This release includes wheels for Python 3.13. Free-threading mode is not yet
supported. jax.errors.JaxRuntimeError
has been added as a public alias for the
formerly privateXlaRuntimeError
type.
- This release includes wheels for Python 3.13. Free-threading mode is not yet
-
Breaking changes
jax_pmap_no_rank_reduction
flag is set toTrue
by default.array[0]
on a pmap result now introduces a reshape (usearray[0:1]
instead).- The per-shard shape (accessable via
jax_array.addressable_shards
or
jax_array.addressable_data(0))
now has a leading(1, ...)
. Update code
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
jax.experimental.host_callback
has been deprecated since March 2024, with
JAX version 0.4.26. Now we set the default value of the
--jax_host_callback_legacy
configuration value toTrue
, which means that
if your code usesjax.experimental.host_callback
APIs, those API calls
will be implemented in terms of the newjax.experimental.io_callback
API.
If this breaks your code, for a very limited time, you can set the
--jax_host_callback_legacy
toTrue
. Soon we will remove that
configuration option, so you should instead transition to using the
new JAX callback APIs. See #20385 for a discussion.
-
Deprecations
- In
jax.numpy.trim_zeros
, non-arraylike arguments or arraylike
arguments withndim != 1
are now deprecated, and in the future will result
in an error. - Internal pretty-printing tools
jax.core.pp_*
have been removed, after
being deprecated in JAX v0.4.30. jax.lib.xla_client.Device
is deprecated; usejax.Device
instead.jax.lib.xla_client.XlaRuntimeError
has been deprecated. Use
jax.errors.JaxRuntimeError
instead.
- In
-
Deletion:
jax.xla_computation
is deleted. It has been 3 months since its deprecation
in 0.4.30 JAX release.
Please use the AOT APIs to get the same functionality asjax.xla_computation
.jax.xla_computation(fn)(*args, **kwargs)
can be replaced with
jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
.- You can also use
.out_info
property ofjax.stages.Lowered
to get the
output information (like tree structure, shape and dtype). - For cross-backend lowering, you can replace
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
with
jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.
jax.ShapeDtypeStruct
no longer accepts thenamed_shape
argument.
The argument was only used byxmap
which was removed in 0.4.31.jax.tree.map(f, None, non-None)
, which previously emitted a
DeprecationWarning
, now raises an error.None
is only a tree-prefix of itself. To preserve the current behavior, you can
askjax.tree.map
to treatNone
as a leaf value by writing:
jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.jax.sharding.XLACompatibleSharding
has been removed. Please use
jax.sharding.Sharding
.
-
Bug fixes
- Fixed a bug where
jax.numpy.cumsum
would produce incorrect outputs
if a non-boolean input was provided anddtype=bool
was specified. - Edit implementation of
jax.numpy.ldexp
to get correct gradient.
- Fixed a bug where