diff --git a/CHANGELOG.md b/CHANGELOG.md index a683ee83b7af..1ebbf585da00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,7 @@ Remember to align the itemized text with the first line of an item within a list * Changes * JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has been dropped. + * JAX now supports NumPy 2.0. ## jax 0.4.25 (Feb 26, 2024) diff --git a/setup.py b/setup.py index 3b465c0ef37b..bcf687dfdf57 100644 --- a/setup.py +++ b/setup.py @@ -22,12 +22,12 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.25' +_current_jaxlib_version = '0.4.26' # The following should be updated with each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.25' _default_cuda12_cudnn_version = '89' _available_cuda12_cudnn_versions = [_default_cuda12_cudnn_version] -_libtpu_version = '0.1.dev20240224' +_libtpu_version = '0.1.dev20240403' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location(