diff --git a/WORKSPACE b/WORKSPACE index 9ebf45b89fcc..57e590ac5e0b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # and update the sha256 with the result. http_archive( name = "xla", - sha256 = "8188f528ce2c81b62f6cc01b7b30bee21c276dd79c881c06813fc408d487c48f", - strip_prefix = "xla-d10fb016209af268e86e3b0f0083fd3c283c4592", + sha256 = "617a968b2d4154ef4368e2676c72e2bc9a019be3b6a1941c8dc741d3e5ea3d8e", + strip_prefix = "xla-7a7cee6e31a01d0103c41b753c7e7fe6e0eeece8", urls = [ - "https://github.com/openxla/xla/archive/d10fb016209af268e86e3b0f0083fd3c283c4592.tar.gz", + "https://github.com/openxla/xla/archive/7a7cee6e31a01d0103c41b753c7e7fe6e0eeece8.tar.gz", ], ) diff --git a/setup.py b/setup.py index 0bcb23b78bf7..92ce13a2cba1 100644 --- a/setup.py +++ b/setup.py @@ -19,13 +19,13 @@ from setuptools import setup, find_packages -_current_jaxlib_version = '0.4.10' +_current_jaxlib_version = '0.4.11' # The following should be updated with each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.10' _available_cuda11_cudnn_versions = ['82', '86'] _default_cuda11_cudnn_version = '86' _default_cuda12_cudnn_version = '88' -_libtpu_version = '0.1.dev20230511' +_libtpu_version = '0.1.dev20230531' _dct = {} with open('jax/version.py', encoding='utf-8') as f: