Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove code that existed to support jaxlib < 0.4.32. #23582

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
import numpy as np

Expand Down Expand Up @@ -157,8 +156,7 @@ def get_compile_options(
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
if xla_extension_version >= 280:
build_options.use_shardy_partitioner = use_shardy_partitioner
build_options.use_shardy_partitioner = use_shardy_partitioner
if fdo_profile is not None:
build_options.fdo_profile = fdo_profile
if use_auto_spmd_partitioning:
Expand Down
9 changes: 2 additions & 7 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
Expand Down Expand Up @@ -3022,12 +3021,8 @@ def aot_cache_miss(*args, **kwargs):
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, cc_shard_arg)

if xla_extension_version < 282:
def cc_shard_arg(x, sharding):
return shard_args([sharding], [None], [x])[0]
else:
def cc_shard_arg(x, sharding, layout): # type: ignore
return shard_args([sharding], [layout], [x])[0]
def cc_shard_arg(x, sharding, layout):
return shard_args([sharding], [layout], [x])[0]


def check_arg_avals_for_call(ref_avals, arg_avals,
Expand Down
29 changes: 10 additions & 19 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,7 @@ def _cholesky_cpu_lowering(ctx, operand):
out_aval, = ctx.avals_out
batch_dims = operand_aval.shape[:-2]
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
# TODO(b/344892332): Remove the check after the compatibility period.
if jaxlib_version < (0, 4, 31):
ctx_arg = ()
else:
ctx_arg = (ctx,)
ctx_arg = (ctx,)
result, info = lapack.potrf_hlo(*ctx_arg, operand_aval.dtype, operand,
lower=True, a_shape_vals=op_shape_vals)

Expand Down Expand Up @@ -556,7 +552,7 @@ def _cholesky_update_abstract_eval(r_matrix, w_vector):

def _cholesky_update_gpu_lowering_rule(target_name_prefix, ctx, r_matrix, w_vector):
# TODO(b/360781533): Remove guard after 3 week forward compatibility period.
if ctx.is_forward_compat() or jaxlib_version < (0, 4, 32):
if ctx.is_forward_compat():
r_matrix_aval, _ = ctx.avals_in
try:
[platform] = ctx.module_context.platforms
Expand Down Expand Up @@ -726,8 +722,7 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
ctx_args = (ctx,)
w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand,
input_shape_vals=op_shape_vals,
jobvl=compute_left_eigenvectors,
Expand Down Expand Up @@ -937,8 +932,7 @@ def _eigh_cpu_gpu_lowering(
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
cpu_args = []
if platform == "cpu":
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
ctx_args = (ctx,)
cpu_args.extend(ctx_args)
v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand,
a_shape_vals=op_shape_vals, lower=lower)
Expand Down Expand Up @@ -1511,9 +1505,9 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, platform: str,
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
m = operand_aval.shape[-2]

# TODO(b/357034884): Remove version gate once jaxlib 0.4.32 is the minimum
# version and the forward compat flag after the 3 week compatibility window.
if jaxlib_version < (0, 4, 32) or ctx.is_forward_compat():
# TODO(b/357034884): Remove version gate on the forward compat flag after the
# 3 week compatibility window.
if ctx.is_forward_compat():
if not is_constant_shape(operand_aval.shape[-2:]):
raise NotImplementedError(
"Shape polymorphism for native lowering for lu on CPU and GPU is "
Expand Down Expand Up @@ -1757,9 +1751,8 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *,
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period
ctx_args = (
(ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else ()
(ctx,) if platform == "cpu" else ()
)
a_out, taus, *maybe_info_geqrf = geqrf_impl(
*ctx_args, a_aval.dtype, a, a_shape_vals=a_shape_vals
Expand Down Expand Up @@ -1881,9 +1874,8 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *,
f"on GPU is not implemented; b/261671778; {a_aval.shape}")
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
else:
# TODO(b/344892332): Remove the conditional after the compatibility period
ctx_args = (
(ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else ()
(ctx,) if platform == "cpu" else ()
)
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape)
Expand Down Expand Up @@ -2152,8 +2144,7 @@ def _svd_cpu_gpu_lowering(
compute_uv=compute_uv)
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
ctx_args = (ctx,)
s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
Expand Down
85 changes: 32 additions & 53 deletions tests/export_back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
from jax._src.lib import version as jaxlib_version

config.parse_flags_with_absl()

Expand Down Expand Up @@ -190,14 +189,11 @@ def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"):
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]

data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name])
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 31)
self.run_one_test(func, data, rtol=rtol, atol=atol)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol)
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol)

@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
Expand Down Expand Up @@ -258,14 +254,11 @@ def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):

self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_eig_results)
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_eig_results)
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_eig_results)

@staticmethod
def eigh_input(shape, dtype):
Expand Down Expand Up @@ -316,14 +309,11 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"):
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_eigh_results, operand))
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
# FFI Kernel test
with config.export_ignore_forward_compatibility(True):
data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_eigh_results, operand))
# FFI Kernel test
with config.export_ignore_forward_compatibility(True):
data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_eigh_results, operand))

@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}_{variant}",
Expand Down Expand Up @@ -385,8 +375,6 @@ def test_cuda_lu_pivots_to_permutation(self):
def test_cuda_lu_lapack_getrf(self, dtype_name:str):
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
if jaxlib_version < (0, 4, 32):
self.skipTest("Not implemented in older versions of jaxlib")
dtype = dict(f32=np.float32, f64=np.float64,
c64=np.complex64, c128=np.complex128)[dtype_name]
shape = (3, 4)
Expand Down Expand Up @@ -416,15 +404,12 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"):
data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name])
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
self.run_one_test(func, data, rtol=rtol)
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name]
)
self.run_one_test(func, data, rtol=rtol)
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name]
)
self.run_one_test(func, data, rtol=rtol)

@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}_{batched}",
Expand Down Expand Up @@ -502,14 +487,11 @@ def test_cpu_lu_lapack_getrf(self, dtype_name:str):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_lu_results, operand,
dtype=dtype))
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_lu_results, operand,
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_lu_results, operand,
dtype=dtype))

def check_svd_results(self, input, res_run, res_exp,
Expand Down Expand Up @@ -629,16 +611,13 @@ def func(input):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results,
input))
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
)
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results, input))
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
)
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results, input))

@jtu.parameterized_filterable(
kwargs=[
Expand Down
3 changes: 0 additions & 3 deletions tests/layout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import contextlib
import math
from functools import partial
import unittest
from absl.testing import absltest
import numpy as np

Expand Down Expand Up @@ -511,8 +510,6 @@ def g(x):
'Layout passed to jit does not match the layout on the respective arg'):
g(arr)

@unittest.skipIf(xla_extension_version < 282,
"Requires xla_extension_version >= 282")
def test_in_layouts_jit_jnp_input(self):
major_last_layout = DLL(major_to_minor=(1, 0))
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
Expand Down
4 changes: 0 additions & 4 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from functools import partial
import itertools
import unittest

import numpy as np
import scipy
Expand Down Expand Up @@ -2194,9 +2193,6 @@ def testHilbert(self, n):
symmetrize_output=[True, False],
)
@jtu.skip_on_devices("tpu")
@unittest.skipIf(
jax._src.lib.version < (0, 4, 32), "requires jaxlib >= 0.4.32"
)
def testSymmetricProduct(self, shape, dtype, symmetrize_output):
rng = jtu.rand_default(self.rng())
batch_size = 10
Expand Down
11 changes: 0 additions & 11 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from jax._src.lib.mlir import dialects
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import xla_extension
from jax._src.util import curry, unzip2

Expand Down Expand Up @@ -4433,8 +4432,6 @@ def f(x):
"Compiled object called with input sharding.*does not match"):
compiled(cpu_arr)

@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_different_devices_wsc_abstract_mesh_cache_hit(self):
if jax.device_count() < 4:
self.skipTest('Requires >=4 devices')
Expand Down Expand Up @@ -4463,8 +4460,6 @@ def f(x):
self.assertEqual(lowering_count[0], 1)
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.

@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_wsc_abstract_mesh(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
Expand All @@ -4484,8 +4479,6 @@ def f(x):
self.assertArraysEqual(out_eager, np_inp * 2)
self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x')))

@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_wsc_sds_abstract_mesh(self):
mesh = jtu.create_mesh((2,), 'x')
s = NamedSharding(mesh, P())
Expand All @@ -4499,8 +4492,6 @@ def f(x):
sds = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s)
f.eval_shape(sds) # doesn't crash

@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_wsc_vmap_abstract_mesh(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
Expand All @@ -4517,8 +4508,6 @@ def f(x):
out2 = jax.jit(jax.vmap(f, spmd_axis_name='y'))(arr)
self.assertEqual(out2.sharding, NamedSharding(mesh, P('y', 'x')))

@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_wsc_abstract_mesh_errors(self):
mesh = jtu.create_mesh((2,), ('x',))
np_inp = np.arange(8)
Expand Down
5 changes: 0 additions & 5 deletions tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2843,11 +2843,6 @@ def test_vmap_error(self):
((2, 3, 8, 4), "b1, b2, ..."),
((2, 3, 4, 5), "b1, b2, m, n"),
]
# TODO(danfm): Remove once jaxlib v0.4.32 is the minimum version.
# jaxlib versions before 0.4.32 require a static shape for the non-batch
# dimensions because these are used for computing the "permuation_size"
# which is passed to lu_pivots_to_permutation.
if jaxlib_version >= (0, 4, 32) or not poly.endswith("m, n")
],
[
# The random primitive tests, with threefry (both partitionable and
Expand Down
Loading
Loading