Skip to content

Commit

Permalink
Merge pull request #23670 from hawkinsp:postrelease
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675252501
  • Loading branch information
Google-ML-Automation committed Sep 16, 2024
2 parents 8804be0 + ae0e403 commit 8c84e16
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 7 deletions.
20 changes: 18 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Remember to align the itemized text with the first line of an item within a list
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->

## jax 0.4.33
## jax 0.4.34

* Deletion:
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
Expand All @@ -26,11 +26,24 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
The argument was only used by `xmap` which was removed in 0.4.31.

## jaxlib 0.4.33

## jax 0.4.33 (September 16, 2024)

This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.

A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of `libtpu`.

This release fixes an inaccurate result for F64 tanh on CPU (#23590).

## jax 0.4.32 (September 11, 2024)

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

* New Functionality
* Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering`
to support the use of the new {ref}`ffi-tutorial` to interface with custom
Expand Down Expand Up @@ -94,6 +107,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.

## jaxlib 0.4.32 (September 11, 2024)

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

* Breaking changes
* Hermetic CUDA support is added.
Hermetic CUDA uses a specific downloadable version of CUDA instead of the
Expand Down
4 changes: 2 additions & 2 deletions jax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pathlib
import subprocess

_version = "0.4.33"
_version = "0.4.34"
# The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_release_version: str | None = None
Expand Down Expand Up @@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files):


__version__ = _get_version_string()
_minimum_jaxlib_version = "0.4.31"
_minimum_jaxlib_version = "0.4.33"

def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

project_name = 'jax'

_current_jaxlib_version = '0.4.32'
_current_jaxlib_version = '0.4.33'
# The following should be updated after each new jaxlib release.
_latest_jaxlib_version_on_pypi = '0.4.32'
_libtpu_version = '0.1.dev20240911'
_latest_jaxlib_version_on_pypi = '0.4.33'
_libtpu_version = '0.1.dev20240916'

def load_version_module(pkg_path):
spec = importlib.util.spec_from_file_location(
Expand Down
14 changes: 14 additions & 0 deletions third_party/xla/tanh.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
diff --git a/xla/service/cpu/llvm_ir_runtime.cc b/xla/service/cpu/llvm_ir_runtime.cc
index 89b40b915caa3..25541c16bfd61 100644
--- a/xla/service/cpu/llvm_ir_runtime.cc
+++ b/xla/service/cpu/llvm_ir_runtime.cc
@@ -410,7 +410,8 @@ void RewriteIRRuntimeFunctions(llvm::Module* module,
rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8);
rewrite_calls(kTanhV16F32SymbolName, GenerateVF32Tanh, /*vector_width=*/16);

- rewrite_calls("tanh", GenerateVF64Tanh, /*vector_width=*/1);
+ // TODO(penporn): Re-enable after fixing JAX issue #23590.
+ // rewrite_calls("tanh", GenerateVF64Tanh, /*vector_width=*/1);

rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1);
rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1);
3 changes: 3 additions & 0 deletions third_party/xla/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def repo():
sha256 = XLA_SHA256,
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
patch_file = [
"//third_party/xla:tanh.patch",
],
)

# For development, one often wants to make changes to the TF repository as well
Expand Down

0 comments on commit 8c84e16

Please sign in to comment.