Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge release/0.4.33 into main and update versions. #23670

Merged
merged 2 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Remember to align the itemized text with the first line of an item within a list
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->

## jax 0.4.33
## jax 0.4.34

* Deletion:
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
Expand All @@ -26,11 +26,24 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
The argument was only used by `xmap` which was removed in 0.4.31.

## jaxlib 0.4.33

## jax 0.4.33 (September 16, 2024)

This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.

A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of `libtpu`.

This release fixes an inaccurate result for F64 tanh on CPU (#23590).

## jax 0.4.32 (September 11, 2024)

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

* New Functionality
* Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering`
to support the use of the new {ref}`ffi-tutorial` to interface with custom
Expand Down Expand Up @@ -94,6 +107,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.

## jaxlib 0.4.32 (September 11, 2024)

Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.

* Breaking changes
* Hermetic CUDA support is added.
Hermetic CUDA uses a specific downloadable version of CUDA instead of the
Expand Down
4 changes: 2 additions & 2 deletions jax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pathlib
import subprocess

_version = "0.4.33"
_version = "0.4.34"
# The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_release_version: str | None = None
Expand Down Expand Up @@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files):


__version__ = _get_version_string()
_minimum_jaxlib_version = "0.4.31"
_minimum_jaxlib_version = "0.4.33"

def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

project_name = 'jax'

_current_jaxlib_version = '0.4.32'
_current_jaxlib_version = '0.4.33'
# The following should be updated after each new jaxlib release.
_latest_jaxlib_version_on_pypi = '0.4.32'
_libtpu_version = '0.1.dev20240911'
_latest_jaxlib_version_on_pypi = '0.4.33'
_libtpu_version = '0.1.dev20240916'

def load_version_module(pkg_path):
spec = importlib.util.spec_from_file_location(
Expand Down
14 changes: 14 additions & 0 deletions third_party/xla/tanh.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
diff --git a/xla/service/cpu/llvm_ir_runtime.cc b/xla/service/cpu/llvm_ir_runtime.cc
index 89b40b915caa3..25541c16bfd61 100644
--- a/xla/service/cpu/llvm_ir_runtime.cc
+++ b/xla/service/cpu/llvm_ir_runtime.cc
@@ -410,7 +410,8 @@ void RewriteIRRuntimeFunctions(llvm::Module* module,
rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8);
rewrite_calls(kTanhV16F32SymbolName, GenerateVF32Tanh, /*vector_width=*/16);

- rewrite_calls("tanh", GenerateVF64Tanh, /*vector_width=*/1);
+ // TODO(penporn): Re-enable after fixing JAX issue #23590.
+ // rewrite_calls("tanh", GenerateVF64Tanh, /*vector_width=*/1);

rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1);
rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1);
3 changes: 3 additions & 0 deletions third_party/xla/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def repo():
sha256 = XLA_SHA256,
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
patch_file = [
"//third_party/xla:tanh.patch",
],
)

# For development, one often wants to make changes to the TF repository as well
Expand Down