Skip to content

Commit

Permalink
Remove code that existed to support jaxlib < 0.4.32.
Browse files Browse the repository at this point in the history
New minimum versions:
* jaxlib 0.4.32
* xla_extension_version 283
* mlir_api_version 57

PiperOrigin-RevId: 673526233
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Sep 16, 2024
1 parent 29163fc commit 28265e2
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 112 deletions.
4 changes: 1 addition & 3 deletions jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
import numpy as np

Expand Down Expand Up @@ -157,8 +156,7 @@ def get_compile_options(
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
if xla_extension_version >= 280:
build_options.use_shardy_partitioner = use_shardy_partitioner
build_options.use_shardy_partitioner = use_shardy_partitioner
if fdo_profile is not None:
build_options.fdo_profile = fdo_profile
if use_auto_spmd_partitioning:
Expand Down
9 changes: 2 additions & 7 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
Expand Down Expand Up @@ -3022,12 +3021,8 @@ def aot_cache_miss(*args, **kwargs):
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, cc_shard_arg)

if xla_extension_version < 282:
def cc_shard_arg(x, sharding):
return shard_args([sharding], [None], [x])[0]
else:
def cc_shard_arg(x, sharding, layout): # type: ignore
return shard_args([sharding], [layout], [x])[0]
def cc_shard_arg(x, sharding, layout): # type: ignore
return shard_args([sharding], [layout], [x])[0]


def check_arg_avals_for_call(ref_avals, arg_avals,
Expand Down
29 changes: 10 additions & 19 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,7 @@ def _cholesky_cpu_lowering(ctx, operand):
out_aval, = ctx.avals_out
batch_dims = operand_aval.shape[:-2]
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
# TODO(b/344892332): Remove the check after the compatibility period.
if jaxlib_version < (0, 4, 31):
ctx_arg = ()
else:
ctx_arg = (ctx,)
ctx_arg = (ctx,)
result, info = lapack.potrf_hlo(*ctx_arg, operand_aval.dtype, operand,
lower=True, a_shape_vals=op_shape_vals)

Expand Down Expand Up @@ -556,7 +552,7 @@ def _cholesky_update_abstract_eval(r_matrix, w_vector):

def _cholesky_update_gpu_lowering_rule(target_name_prefix, ctx, r_matrix, w_vector):
# TODO(b/360781533): Remove guard after 3 week forward compatibility period.
if ctx.is_forward_compat() or jaxlib_version < (0, 4, 32):
if ctx.is_forward_compat():
r_matrix_aval, _ = ctx.avals_in
try:
[platform] = ctx.module_context.platforms
Expand Down Expand Up @@ -726,8 +722,7 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
ctx_args = (ctx,)
w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand,
input_shape_vals=op_shape_vals,
jobvl=compute_left_eigenvectors,
Expand Down Expand Up @@ -937,8 +932,7 @@ def _eigh_cpu_gpu_lowering(
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
cpu_args = []
if platform == "cpu":
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
ctx_args = (ctx,)
cpu_args.extend(ctx_args)
v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand,
a_shape_vals=op_shape_vals, lower=lower)
Expand Down Expand Up @@ -1511,9 +1505,9 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, platform: str,
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
m = operand_aval.shape[-2]

# TODO(b/357034884): Remove version gate once jaxlib 0.4.32 is the minimum
# version and the forward compat flag after the 3 week compatibility window.
if jaxlib_version < (0, 4, 32) or ctx.is_forward_compat():
# TODO(b/357034884): Remove version gate on the forward compat flag after the
# 3 week compatibility window.
if ctx.is_forward_compat():
if not is_constant_shape(operand_aval.shape[-2:]):
raise NotImplementedError(
"Shape polymorphism for native lowering for lu on CPU and GPU is "
Expand Down Expand Up @@ -1757,9 +1751,8 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *,
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period
ctx_args = (
(ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else ()
(ctx,) if platform == "cpu" else ()
)
a_out, taus, *maybe_info_geqrf = geqrf_impl(
*ctx_args, a_aval.dtype, a, a_shape_vals=a_shape_vals
Expand Down Expand Up @@ -1881,9 +1874,8 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *,
f"on GPU is not implemented; b/261671778; {a_aval.shape}")
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
else:
# TODO(b/344892332): Remove the conditional after the compatibility period
ctx_args = (
(ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else ()
(ctx,) if platform == "cpu" else ()
)
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape)
Expand Down Expand Up @@ -2152,8 +2144,7 @@ def _svd_cpu_gpu_lowering(
compute_uv=compute_uv)
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
ctx_args = (ctx,)
s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
Expand Down
85 changes: 32 additions & 53 deletions tests/export_back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
from jax._src.lib import version as jaxlib_version

config.parse_flags_with_absl()

Expand Down Expand Up @@ -190,14 +189,11 @@ def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"):
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]

data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name])
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 31)
self.run_one_test(func, data, rtol=rtol, atol=atol)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol)
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol)

@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
Expand Down Expand Up @@ -258,14 +254,11 @@ def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):

self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_eig_results)
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_eig_results)
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_eig_results)

@staticmethod
def eigh_input(shape, dtype):
Expand Down Expand Up @@ -316,14 +309,11 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"):
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_eigh_results, operand))
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
# FFI Kernel test
with config.export_ignore_forward_compatibility(True):
data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_eigh_results, operand))
# FFI Kernel test
with config.export_ignore_forward_compatibility(True):
data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_eigh_results, operand))

@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}_{variant}",
Expand Down Expand Up @@ -385,8 +375,6 @@ def test_cuda_lu_pivots_to_permutation(self):
def test_cuda_lu_lapack_getrf(self, dtype_name:str):
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
if jaxlib_version < (0, 4, 32):
self.skipTest("Not implemented in older versions of jaxlib")
dtype = dict(f32=np.float32, f64=np.float64,
c64=np.complex64, c128=np.complex128)[dtype_name]
shape = (3, 4)
Expand Down Expand Up @@ -416,15 +404,12 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"):
data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name])
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
self.run_one_test(func, data, rtol=rtol)
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name]
)
self.run_one_test(func, data, rtol=rtol)
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name]
)
self.run_one_test(func, data, rtol=rtol)

@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}_{batched}",
Expand Down Expand Up @@ -502,14 +487,11 @@ def test_cpu_lu_lapack_getrf(self, dtype_name:str):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_lu_results, operand,
dtype=dtype))
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_lu_results, operand,
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_lu_results, operand,
dtype=dtype))

def check_svd_results(self, input, res_run, res_exp,
Expand Down Expand Up @@ -629,16 +611,13 @@ def func(input):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results,
input))
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
)
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results, input))
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
)
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results, input))

@jtu.parameterized_filterable(
kwargs=[
Expand Down
3 changes: 0 additions & 3 deletions tests/layout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import contextlib
import math
from functools import partial
import unittest
from absl.testing import absltest
import numpy as np

Expand Down Expand Up @@ -511,8 +510,6 @@ def g(x):
'Layout passed to jit does not match the layout on the respective arg'):
g(arr)

@unittest.skipIf(xla_extension_version < 282,
"Requires xla_extension_version >= 282")
def test_in_layouts_jit_jnp_input(self):
major_last_layout = DLL(major_to_minor=(1, 0))
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
Expand Down
4 changes: 0 additions & 4 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from functools import partial
import itertools
import unittest

import numpy as np
import scipy
Expand Down Expand Up @@ -2194,9 +2193,6 @@ def testHilbert(self, n):
symmetrize_output=[True, False],
)
@jtu.skip_on_devices("tpu")
@unittest.skipIf(
jax._src.lib.version < (0, 4, 32), "requires jaxlib >= 0.4.32"
)
def testSymmetricProduct(self, shape, dtype, symmetrize_output):
rng = jtu.rand_default(self.rng())
batch_size = 10
Expand Down
11 changes: 0 additions & 11 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from jax._src.lib.mlir import dialects
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import xla_extension
from jax._src.util import curry, unzip2

Expand Down Expand Up @@ -4433,8 +4432,6 @@ def f(x):
"Compiled object called with input sharding.*does not match"):
compiled(cpu_arr)

@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_different_devices_wsc_abstract_mesh_cache_hit(self):
if jax.device_count() < 4:
self.skipTest('Requires >=4 devices')
Expand Down Expand Up @@ -4463,8 +4460,6 @@ def f(x):
self.assertEqual(lowering_count[0], 1)
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.

@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_wsc_abstract_mesh(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
Expand All @@ -4484,8 +4479,6 @@ def f(x):
self.assertArraysEqual(out_eager, np_inp * 2)
self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x')))

@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_wsc_sds_abstract_mesh(self):
mesh = jtu.create_mesh((2,), 'x')
s = NamedSharding(mesh, P())
Expand All @@ -4499,8 +4492,6 @@ def f(x):
sds = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s)
f.eval_shape(sds) # doesn't crash

@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_wsc_vmap_abstract_mesh(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
Expand All @@ -4517,8 +4508,6 @@ def f(x):
out2 = jax.jit(jax.vmap(f, spmd_axis_name='y'))(arr)
self.assertEqual(out2.sharding, NamedSharding(mesh, P('y', 'x')))

@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_wsc_abstract_mesh_errors(self):
mesh = jtu.create_mesh((2,), ('x',))
np_inp = np.arange(8)
Expand Down
5 changes: 0 additions & 5 deletions tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2843,11 +2843,6 @@ def test_vmap_error(self):
((2, 3, 8, 4), "b1, b2, ..."),
((2, 3, 4, 5), "b1, b2, m, n"),
]
# TODO(danfm): Remove once jaxlib v0.4.32 is the minimum version.
# jaxlib versions before 0.4.32 require a static shape for the non-batch
# dimensions because these are used for computing the "permuation_size"
# which is passed to lu_pivots_to_permutation.
if jaxlib_version >= (0, 4, 32) or not poly.endswith("m, n")
],
[
# The random primitive tests, with threefry (both partitionable and
Expand Down
Loading

0 comments on commit 28265e2

Please sign in to comment.