Releases: jax-ml/jax
Releases · jax-ml/jax
Jax release v0.3.1
- Changes:
jax.test_util.JaxTestCase
andjax.test_util.JaxTestLoader
are now deprecated.
The suggested replacement is to useparametrized.TestCase
directly. For tests that
rely on custom asserts such asJaxTestCase.assertAllClose()
, the suggested replacement
is to use standard numpy testing utilities such asnumpy.testing.assert_allclose()
,
which work directly with JAX arrays (#9620 ).jax.test_util.JaxTestCase
now setsjax_numpy_rank_promotion='raise'
by default
(#9562 ). To recover the previous behavior, use the new
jax.test_util.with_config
decorator:@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
- Added
jax.scipy.linalg.schur
,jax.scipy.linalg.sqrtm
,
jax.scipy.signal.csd
,jax.scipy.signal.stft
,
jax.scipy.signal.welch
.
Jaxlib release v0.3.0
- Changes
- Bazel 5.0.0 is now required to build jaxlib.
- jaxlib version has been bumped to 0.3.0. Please see the design doc
for the explanation.
Jax release v0.3.0
- Changes
- jax version has been bumped to 0.3.0. Please see the design doc
for the explanation.
- jax version has been bumped to 0.3.0. Please see the design doc
JAX release v0.2.28
- GitHub commits.
jax.jit(f).lower(...).compiler_ir()
now defaults to the MHLO dialect if no
dialect=
is passed.- The
jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
now returns an MLIR
ir.Module
object instead of its string representation.
Jaxlib v0.1.76
- New features
- Includes precompiled SASS for NVidia compute capability 8.0 GPUS
(e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not
to increase the number of compute capabilities: GPUs with compute capability
6.1 can use the 6.0 SASS. - With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR
by default.
- Includes precompiled SASS for NVidia compute capability 8.0 GPUS
- Breaking changes
- Support for NumPy 1.18 has been dropped, per the deprecation policy.
Please upgrade to a supported NumPy version.
- Support for NumPy 1.18 has been dropped, per the deprecation policy.
- Bug fixes
JAX release v0.2.27
-
Breaking changes:
- Support for NumPy 1.18 has been dropped, per the [deprecation policy](https://jax.readthedocs.io/en/latest/ deprecation.html). Please upgrade to a supported NumPy version.
- The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the
JAX_HOST_CALLBACK_AD_TRANSFORMS
environment variable, or the--flax_host_callback_ad_transforms
flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs ({jax-issue}#8678
). - Sorting now matches the behavior of NumPy for
0.0
andNaN
regardless of the bit representation. In particular,0.0
and-0.0
are now treated as equivalent, where previously-0.0
was treated as less than0.0
. Additionally allNaN
representations are now treated as equivalent and sorted to the end of the array. Previously negativeNaN
values were sorted to the front of the array, andNaN
values with different internal bit representations were not treated as equivalent, and were sorted according to those bit patterns ({jax- issue}#9178
). - {func}
jax.numpy.unique
now treatsNaN
values in the same way asnp.unique
in NumPy versions 1.21 and newer: at most oneNaN
value will appear in the uniquified output ({jax-issue}9184
).
-
Bug fixes:
- host_callback now supports ad_checkpoint.checkpoint ({jax-issue}
#8907
).
- host_callback now supports ad_checkpoint.checkpoint ({jax-issue}
-
New features:
- add
jax.block_until_ready
({jax-issue}`#8941) - Added a new debugging flag/environment variable
JAX_DUMP_IR_TO=/path
. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path. - Added
jax.ensure_compile_time_eval
to the public api ({jax-issue}#7987
). - jax2tf now supports a flag jax2tf_associative_scan_reductions to change the lowering for associative reductions, e.g., jnp.cumsum, to behave like JAX on CPU and GPU (to use an associative scan). See the jax2tf README for more details ({jax-issue}
#9189
).
- add
JAX release v0.2.26
-
Bug fixes:
-
Out-of-bounds indices to jax.ops.segment_sum will now be handled with FILL_OR_DROP semantics, as documented. This primarily afects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634).
-
jax2tf will force the converted code to use XLA for the code fragments under jax.jit, e.g., most jax.numpy functions (#7839).
Jaxlib release v0.1.75
- New features:
- Support for python 3.10.
Jaxlib release v0.1.74
jaxlib-v0.1.74 Jaxlib v0.1.74
JAX release v0.2.25
-
New features:
- (Experimental)
jax.distributed.initialize
exposes multi-host GPU backend. jax.random.permutation
supports newindependent
keyword argument
({jax-issue}#8430
)
- (Experimental)
-
Breaking changes
- Moved
jax.experimental.stax
tojax.example_libraries.stax
- Moved
jax.experimental.optimizers
tojax.example_libraries.optimizers
- Moved
-
New features:
- Added
jax.lax.linalg.qdwh
.
- Added