diff --git a/setup.py b/setup.py index e1cb3e2e38b0..ecb06d9da77c 100644 --- a/setup.py +++ b/setup.py @@ -19,12 +19,12 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.28' +_current_jaxlib_version = '0.4.29' # The following should be updated with each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.28' _default_cuda12_cudnn_version = '91' _available_cuda12_cudnn_versions = [_default_cuda12_cudnn_version] -_libtpu_version = '0.1.dev20240508' +_libtpu_version = '0.1.dev20240609' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location(