From 28265e2c42fddacda42474f0f8f79e4b8c1a10a7 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Sep 2024 13:59:24 -0700 Subject: [PATCH] Remove code that existed to support jaxlib < 0.4.32. New minimum versions: * jaxlib 0.4.32 * xla_extension_version 283 * mlir_api_version 57 PiperOrigin-RevId: 673526233 --- jax/_src/compiler.py | 4 +- jax/_src/interpreters/pxla.py | 9 +--- jax/_src/lax/linalg.py | 29 ++++------- tests/export_back_compat_test.py | 85 ++++++++++++-------------------- tests/layout_test.py | 3 -- tests/linalg_test.py | 4 -- tests/pjit_test.py | 11 ----- tests/shape_poly_test.py | 5 -- tests/shard_map_test.py | 7 --- 9 files changed, 45 insertions(+), 112 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 81457f1cbd07..108741b5f8fd 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -33,7 +33,6 @@ from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir import numpy as np @@ -157,8 +156,7 @@ def get_compile_options( build_options = compile_options.executable_build_options build_options.use_spmd_partitioning = use_spmd_partitioning build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning - if xla_extension_version >= 280: - build_options.use_shardy_partitioner = use_shardy_partitioner + build_options.use_shardy_partitioner = use_shardy_partitioner if fdo_profile is not None: build_options.fdo_profile = fdo_profile if use_auto_spmd_partitioning: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 882f71d58671..9c3eed0253ad 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -61,7 +61,6 @@ from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -3022,12 +3021,8 @@ def aot_cache_miss(*args, **kwargs): self.unsafe_call.name, None, aot_cache_miss, [], [], [], tree_util.dispatch_registry, cc_shard_arg) -if xla_extension_version < 282: - def cc_shard_arg(x, sharding): - return shard_args([sharding], [None], [x])[0] -else: - def cc_shard_arg(x, sharding, layout): # type: ignore - return shard_args([sharding], [layout], [x])[0] +def cc_shard_arg(x, sharding, layout): # type: ignore + return shard_args([sharding], [layout], [x])[0] def check_arg_avals_for_call(ref_avals, arg_avals, diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 0cc0e774af53..8752e0b6d1de 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -514,11 +514,7 @@ def _cholesky_cpu_lowering(ctx, operand): out_aval, = ctx.avals_out batch_dims = operand_aval.shape[:-2] op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - # TODO(b/344892332): Remove the check after the compatibility period. - if jaxlib_version < (0, 4, 31): - ctx_arg = () - else: - ctx_arg = (ctx,) + ctx_arg = (ctx,) result, info = lapack.potrf_hlo(*ctx_arg, operand_aval.dtype, operand, lower=True, a_shape_vals=op_shape_vals) @@ -556,7 +552,7 @@ def _cholesky_update_abstract_eval(r_matrix, w_vector): def _cholesky_update_gpu_lowering_rule(target_name_prefix, ctx, r_matrix, w_vector): # TODO(b/360781533): Remove guard after 3 week forward compatibility period. - if ctx.is_forward_compat() or jaxlib_version < (0, 4, 32): + if ctx.is_forward_compat(): r_matrix_aval, _ = ctx.avals_in try: [platform] = ctx.module_context.platforms @@ -726,8 +722,7 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - # TODO(b/344892332): Remove the conditional after the compatibility period. - ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else () + ctx_args = (ctx,) w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand, input_shape_vals=op_shape_vals, jobvl=compute_left_eigenvectors, @@ -937,8 +932,7 @@ def _eigh_cpu_gpu_lowering( op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) cpu_args = [] if platform == "cpu": - # TODO(b/344892332): Remove the conditional after the compatibility period. - ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else () + ctx_args = (ctx,) cpu_args.extend(ctx_args) v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand, a_shape_vals=op_shape_vals, lower=lower) @@ -1511,9 +1505,9 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, platform: str, info_aval = ShapedArray(batch_dims, np.dtype(np.int32)) m = operand_aval.shape[-2] - # TODO(b/357034884): Remove version gate once jaxlib 0.4.32 is the minimum - # version and the forward compat flag after the 3 week compatibility window. - if jaxlib_version < (0, 4, 32) or ctx.is_forward_compat(): + # TODO(b/357034884): Remove version gate on the forward compat flag after the + # 3 week compatibility window. + if ctx.is_forward_compat(): if not is_constant_shape(operand_aval.shape[-2:]): raise NotImplementedError( "Shape polymorphism for native lowering for lu on CPU and GPU is " @@ -1757,9 +1751,8 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *, a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a) else: a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) - # TODO(b/344892332): Remove the conditional after the compatibility period ctx_args = ( - (ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else () + (ctx,) if platform == "cpu" else () ) a_out, taus, *maybe_info_geqrf = geqrf_impl( *ctx_args, a_aval.dtype, a, a_shape_vals=a_shape_vals @@ -1881,9 +1874,8 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *, f"on GPU is not implemented; b/261671778; {a_aval.shape}") a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus) else: - # TODO(b/344892332): Remove the conditional after the compatibility period ctx_args = ( - (ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else () + (ctx,) if platform == "cpu" else () ) a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape) @@ -2152,8 +2144,7 @@ def _svd_cpu_gpu_lowering( compute_uv=compute_uv) else: a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - # TODO(b/344892332): Remove the conditional after the compatibility period. - ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else () + ctx_args = (ctx,) s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand, full_matrices=full_matrices, compute_uv=compute_uv, diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 4e7898d57fe0..103357ac18ac 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -66,7 +66,6 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.lib import cuda_versions -from jax._src.lib import version as jaxlib_version config.parse_flags_with_absl() @@ -190,14 +189,11 @@ def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"): atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name]) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 31) self.run_one_test(func, data, rtol=rtol, atol=atol) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) @@ -258,14 +254,11 @@ def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_eig_results) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 32) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_eig_results) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=check_eig_results) @staticmethod def eigh_input(shape, dtype): @@ -316,14 +309,11 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 32) - if has_xla_ffi_support: - # FFI Kernel test - with config.export_ignore_forward_compatibility(True): - data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_eigh_results, operand)) + # FFI Kernel test + with config.export_ignore_forward_compatibility(True): + data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_eigh_results, operand)) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{variant}", @@ -385,8 +375,6 @@ def test_cuda_lu_pivots_to_permutation(self): def test_cuda_lu_lapack_getrf(self, dtype_name:str): if not config.enable_x64.value and dtype_name in ["f64", "c128"]: self.skipTest("Test disabled for x32 mode") - if jaxlib_version < (0, 4, 32): - self.skipTest("Not implemented in older versions of jaxlib") dtype = dict(f32=np.float32, f64=np.float64, c64=np.complex64, c128=np.complex128)[dtype_name] shape = (3, 4) @@ -416,15 +404,12 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"): data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name]) rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] self.run_one_test(func, data, rtol=rtol) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 32) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata( - cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{batched}", @@ -502,14 +487,11 @@ def test_cpu_lu_lapack_getrf(self, dtype_name:str): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_lu_results, operand, dtype=dtype)) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 32) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_lu_results, operand, + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_lu_results, operand, dtype=dtype)) def check_svd_results(self, input, res_run, res_exp, @@ -629,16 +611,13 @@ def func(input): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_svd_results, input)) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 32) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata( - cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_svd_results, input)) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_svd_results, input)) @jtu.parameterized_filterable( kwargs=[ diff --git a/tests/layout_test.py b/tests/layout_test.py index f14120e46116..1d18179ccfee 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -15,7 +15,6 @@ import contextlib import math from functools import partial -import unittest from absl.testing import absltest import numpy as np @@ -511,8 +510,6 @@ def g(x): 'Layout passed to jit does not match the layout on the respective arg'): g(arr) - @unittest.skipIf(xla_extension_version < 282, - "Requires xla_extension_version >= 282") def test_in_layouts_jit_jnp_input(self): major_last_layout = DLL(major_to_minor=(1, 0)) sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 4dcdeb19e1ef..446e10abd097 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -16,7 +16,6 @@ from functools import partial import itertools -import unittest import numpy as np import scipy @@ -2194,9 +2193,6 @@ def testHilbert(self, n): symmetrize_output=[True, False], ) @jtu.skip_on_devices("tpu") - @unittest.skipIf( - jax._src.lib.version < (0, 4, 32), "requires jaxlib >= 0.4.32" - ) def testSymmetricProduct(self, shape, dtype, symmetrize_output): rng = jtu.rand_default(self.rng()) batch_size = 10 diff --git a/tests/pjit_test.py b/tests/pjit_test.py index dbb867ab9a39..6c022653581d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -56,7 +56,6 @@ from jax._src.lib.mlir import dialects from jax._src import xla_bridge from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib import xla_extension from jax._src.util import curry, unzip2 @@ -4433,8 +4432,6 @@ def f(x): "Compiled object called with input sharding.*does not match"): compiled(cpu_arr) - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_different_devices_wsc_abstract_mesh_cache_hit(self): if jax.device_count() < 4: self.skipTest('Requires >=4 devices') @@ -4463,8 +4460,6 @@ def f(x): self.assertEqual(lowering_count[0], 1) self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ. - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_wsc_abstract_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) @@ -4484,8 +4479,6 @@ def f(x): self.assertArraysEqual(out_eager, np_inp * 2) self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x'))) - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_wsc_sds_abstract_mesh(self): mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P()) @@ -4499,8 +4492,6 @@ def f(x): sds = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s) f.eval_shape(sds) # doesn't crash - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_wsc_vmap_abstract_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) @@ -4517,8 +4508,6 @@ def f(x): out2 = jax.jit(jax.vmap(f, spmd_axis_name='y'))(arr) self.assertEqual(out2.sharding, NamedSharding(mesh, P('y', 'x'))) - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_wsc_abstract_mesh_errors(self): mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8) diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 323d44b542d6..27199c874332 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -2843,11 +2843,6 @@ def test_vmap_error(self): ((2, 3, 8, 4), "b1, b2, ..."), ((2, 3, 4, 5), "b1, b2, m, n"), ] - # TODO(danfm): Remove once jaxlib v0.4.32 is the minimum version. - # jaxlib versions before 0.4.32 require a static shape for the non-batch - # dimensions because these are used for computing the "permuation_size" - # which is passed to lu_pivots_to_permutation. - if jaxlib_version >= (0, 4, 32) or not poly.endswith("m, n") ], [ # The random primitive tests, with threefry (both partitionable and diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 0c6848f94c03..3d9b567e2ef4 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -45,7 +45,6 @@ from jax._src import linear_util as lu from jax._src import tree_util import jax.numpy as jnp -from jax._src.lib import xla_extension_version from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.shard_map import shard_map @@ -777,8 +776,6 @@ def with_capture(y_slice): # is over an axis of size 2. This is a problem at the moment. jax.make_jaxpr(mapped)(x, y).jaxpr - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_shard_map_abstract_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) @@ -803,8 +800,6 @@ def f(x): self.assertArraysEqual(out2, np_inp) self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_different_devices_shmap_abstract_mesh_cache_hit(self): if jax.device_count() < 4: self.skipTest('Requires >=4 devices') @@ -835,8 +830,6 @@ def f(x): self.assertEqual(lowering_count[0], 1) self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ. - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_shmap_abstract_mesh_errors(self): mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8)