diff --git a/benchmarks/mosaic/BUILD b/benchmarks/mosaic/BUILD index 727e347e5a64..39c7aa5f3395 100644 --- a/benchmarks/mosaic/BUILD +++ b/benchmarks/mosaic/BUILD @@ -28,25 +28,11 @@ package( jax_generate_backend_suites() -DISABLED_BACKENDS = [ - "cpu", - "tpu", -] - -DISABLED_CONFIGS = [ - "gpu_v100", - "gpu_a100", - "gpu_p100", - "gpu_p100_x32", - "gpu_x32", - "gpu_pjrt_c_api", -] - jax_multiplatform_test( name = "matmul_bench", srcs = ["matmul_bench.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], + enable_configs = ["gpu_h100"], tags = ["notap"], deps = [ "//jax:mosaic_gpu", diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD index 0089b6b9fb0d..4954ce3db4fa 100644 --- a/docs/cuda_custom_call/BUILD +++ b/docs/cuda_custom_call/BUILD @@ -32,10 +32,7 @@ jax_multiplatform_test( name = "cuda_custom_call_test", srcs = ["cuda_custom_call_test.py"], data = [":foo"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], tags = ["notap"], deps = [ "//jax:extend", diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD index 57f78cb2c5c8..fe1a7e9180ac 100644 --- a/jax/experimental/mosaic/gpu/examples/BUILD +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -13,7 +13,7 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_library") -load("//jaxlib:jax.bzl", "jax_py_test", "py_deps") +load("//jaxlib:jax.bzl", "jax_multiplatform_test", "py_deps") licenses(["notice"]) @@ -48,18 +48,17 @@ py_library( ], ) -jax_py_test( +jax_multiplatform_test( name = "run_matmul", srcs = ["matmul.py"], + enable_backends = [], + enable_configs = ["gpu_h100"], main = "matmul.py", tags = [ "manual", "notap", - "requires-gpu-sm90-only", ], deps = [ - "//jax", "//jax:mosaic_gpu", - "//learning/brain/research/jax:gpu_support", ] + py_deps("numpy"), ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 65ec572c7ee2..ece917a4dd4e 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -231,7 +231,7 @@ def jax_multiplatform_test( shard_count = None, deps = [], data = [], - disable_backends = None, # buildifier: disable=unused-variable + enable_backends = None, backend_variant_args = {}, # buildifier: disable=unused-variable backend_tags = {}, # buildifier: disable=unused-variable disable_configs = None, # buildifier: disable=unused-variable @@ -240,6 +240,13 @@ def jax_multiplatform_test( tags = [], main = None, pjrt_c_api_bypass = False): # buildifier: disable=unused-variable + # enable_configs and disable_configs do not do anything in OSS, only in Google's CI. + # The order in which `enable_backends`, `enable_configs`, and `disable_configs` are applied is + # as follows: + # 1. `enable_backends` is applied first, enabling all test configs for the given backends.. + # 2. `disable_configs` is applied second, disabling the named test configs. + # 3. `enable_configs` is applied last, enabling the named test configs. + if main == None: if len(srcs) == 1: main = srcs[0] @@ -256,7 +263,7 @@ def jax_multiplatform_test( "--jax_platform_name=" + backend, ] test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, []) - if disable_backends and backend in disable_backends: + if enable_backends != None and backend not in enable_backends: test_tags += ["manual"] if backend == "gpu": test_tags += tf_cuda_tests_tags() diff --git a/tests/BUILD b/tests/BUILD index c93f18dbb815..df9a28236e6a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -66,7 +66,10 @@ jax_py_test( jax_multiplatform_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], - disable_backends = ["tpu"], + enable_backends = [ + "cpu", + "gpu", + ], tags = ["multiaccelerator"], deps = py_deps("tensorflow_core"), ) @@ -160,10 +163,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "gpu_memory_flags_test_no_preallocation", srcs = ["gpu_memory_flags_test.py"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], env = { "XLA_PYTHON_CLIENT_PREALLOCATE": "0", }, @@ -173,10 +173,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "gpu_memory_flags_test", srcs = ["gpu_memory_flags_test.py"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], env = { "XLA_PYTHON_CLIENT_PREALLOCATE": "1", }, @@ -273,10 +270,7 @@ jax_multiplatform_test( backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], env = {"XLA_FLAGS": "--xla_dump_to=sponge --xla_gpu_enable_latency_hiding_scheduler=true"}, tags = [ "config-cuda-only", @@ -290,10 +284,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "mock_gpu_test", srcs = ["mock_gpu_test.py"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], tags = [ "config-cuda-only", ], @@ -556,11 +547,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "lax_metal_test", srcs = ["lax_metal_test.py"], - disable_backends = [ - "cpu", - "gpu", - "tpu", - ], + enable_backends = ["metal"], tags = ["notap"], deps = [ "//jax:internal_test_util", @@ -649,10 +636,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "metadata_test", srcs = ["metadata_test.py"], - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], ) jax_py_test( @@ -672,10 +656,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "multi_device_test", srcs = ["multi_device_test.py"], - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], ) jax_multiplatform_test( @@ -734,10 +715,7 @@ jax_multiplatform_test( name = "polynomial_test", srcs = ["polynomial_test.py"], # No implementation of nonsymmetric Eigendecomposition. - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], shard_count = { "cpu": 10, }, @@ -753,25 +731,18 @@ jax_multiplatform_test( jax_multiplatform_test( name = "heap_profiler_test", srcs = ["heap_profiler_test.py"], - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], ) jax_multiplatform_test( name = "profiler_test", srcs = ["profiler_test.py"], - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], ) jax_multiplatform_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], - disable_backends = ["tpu"], # The following cases are disabled because they time out in Google's CI, mostly because the # CUDA kernels in Torch take a very long time to compile. disable_configs = [ @@ -779,6 +750,10 @@ jax_multiplatform_test( "gpu_a100", # Pytorch A100 build times out in Google's CI. "gpu_h100", # Pytorch H100 build times out in Google's CI. ], + enable_backends = [ + "cpu", + "gpu", + ], tags = [ "not_build:arm", # TODO(b/355237462): Re-enable once MSAN issue is addressed. @@ -1019,16 +994,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "sparse_nm_test", srcs = ["sparse_nm_test.py"], - config_tags_overrides = { - "gpu_a100": { - "ondemand": False, # Include in presubmit. - }, - }, - disable_backends = [ - "cpu", - "gpu", - "tpu", - ], + enable_backends = [], enable_configs = [ "gpu_a100", "gpu_h100", @@ -1386,13 +1352,10 @@ jax_multiplatform_test( jax_multiplatform_test( name = "experimental_rnn_test", srcs = ["experimental_rnn_test.py"], - disable_backends = [ - "tpu", - "cpu", - ], disable_configs = [ "gpu_a100", # Numerical precision problems. ], + enable_backends = ["gpu"], shard_count = 15, deps = [ "//jax:rnn", @@ -1505,10 +1468,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "fused_attention_stablehlo_test", srcs = ["fused_attention_stablehlo_test.py"], - disable_backends = [ - "tpu", - "cpu", - ], + enable_backends = ["gpu"], shard_count = { "gpu": 4, }, @@ -1542,10 +1502,7 @@ jax_py_test( jax_multiplatform_test( name = "cudnn_fusion_test", srcs = ["cudnn_fusion_test.py"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], enable_configs = [ "gpu_a100", "gpu_h100", diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 4d33e228b906..3d1348371f07 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -28,31 +28,16 @@ package( jax_generate_backend_suites() -DISABLED_BACKENDS = [ - "cpu", - "tpu", -] - -DISABLED_CONFIGS = [ - "gpu_a100", - "gpu_a100_x32", - "gpu_p100", - "gpu_p100_x32", - "gpu_pjrt_c_api", - "gpu_v100", - "gpu_x32", -] - jax_multiplatform_test( name = "gpu_test", srcs = ["gpu_test.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], enable_configs = [ "gpu_h100", "gpu_h100_2gpu", ], shard_count = 4, + tags = ["multiaccelerator"], deps = [ "//jax:mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), @@ -61,8 +46,8 @@ jax_multiplatform_test( jax_multiplatform_test( name = "matmul_test", srcs = ["matmul_test.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], + enable_configs = ["gpu_h100"], shard_count = 5, deps = [ "//jax:mosaic_gpu", @@ -73,8 +58,8 @@ jax_multiplatform_test( jax_multiplatform_test( name = "flash_attention", srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], + enable_configs = ["gpu_h100"], main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py", tags = ["notap"], deps = [ @@ -85,8 +70,8 @@ jax_multiplatform_test( jax_multiplatform_test( name = "flash_attention_test", srcs = ["flash_attention_test.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], + enable_configs = ["gpu_h100"], deps = [ "//jax:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:flash_attention", diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index ba82b8c4223c..1bb0e889dd64 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -38,11 +38,9 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_p100", - "gpu_p100_x32", + enable_backends = [ + "cpu", + "tpu", ], enable_configs = [ "gpu_a100_x32", @@ -75,9 +73,6 @@ jax_multiplatform_test( "gpu_p100_x32", "gpu_h100", ], - shard_count = { - "tpu": 1, - }, deps = [ "//jax:pallas", "//jax:pallas_tpu", @@ -130,8 +125,9 @@ jax_multiplatform_test( srcs = [ "indexing_test.py", ], - disable_backends = [ - "gpu", + enable_backends = [ + "cpu", + "tpu", ], tags = [ "noasan", # Times out. @@ -149,19 +145,7 @@ jax_multiplatform_test( srcs = [ "pallas_vmap_test.py", ], - config_tags_overrides = { - "gpu_a100_x32": { - "ondemand": False, # Include in presubmit. - }, - }, - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_a100", - "gpu_h100", - "gpu_p100", - "gpu_p100_x32", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", @@ -186,26 +170,13 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_backends = [ - "cpu", - "tpu", - ], - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_a100", - "gpu_a100_x32", - "gpu_p100", - "gpu_p100_x32", - "gpu_h100", - ], + enable_backends = [], enable_configs = [ "gpu_h100_x32", ], env = { "JAX_PALLAS_USE_MOSAIC_GPU": "1", }, - tags = ["notap"], deps = [ "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep @@ -221,15 +192,7 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_a100", - "gpu_h100", - "gpu_p100", - "gpu_p100_x32", - "gpu_pjrt_c_api", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", @@ -252,15 +215,7 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_a100", - "gpu_h100", - "gpu_p100", - "gpu_p100_x32", - "gpu_pjrt_c_api", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", ], @@ -304,10 +259,7 @@ jax_multiplatform_test( srcs = [ "pallas_error_handling_test.py", ], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:pallas", "//jax:pallas_tpu", @@ -322,10 +274,7 @@ jax_multiplatform_test( srcs = [ "tpu_all_gather_test.py", ], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:pallas_tpu_ops", ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), @@ -336,10 +285,7 @@ jax_multiplatform_test( srcs = [ "tpu_gmm_test.py", ], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], shard_count = 50, tags = [ "noasan", # Times out. @@ -361,10 +307,7 @@ jax_multiplatform_test( srcs = ["tpu_pallas_test.py"], # The flag is necessary for ``pl.debug_print`` tests to work on TPU. args = ["--logtostderr"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:extend", "//jax:pallas_tpu", @@ -377,8 +320,9 @@ jax_multiplatform_test( srcs = [ "tpu_ops_test.py", ], - disable_backends = [ - "gpu", + enable_backends = [ + "cpu", + "tpu", ], deps = [ "//jax:pallas", @@ -391,10 +335,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_pallas_distributed_test", srcs = ["tpu_pallas_distributed_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:extend", "//jax:pallas_tpu", @@ -405,10 +346,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_pallas_pipeline_test", srcs = ["tpu_pallas_pipeline_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], shard_count = 5, tags = [ "noasan", # Times out. @@ -425,10 +363,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_pallas_async_test", srcs = ["tpu_pallas_async_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], tags = [ ], deps = [ @@ -439,10 +374,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_pallas_mesh_test", srcs = ["tpu_pallas_mesh_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], tags = [ "noasan", "nomsan", @@ -459,10 +391,7 @@ jax_multiplatform_test( srcs = [ "tpu_pallas_random_test.py", ], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:pallas", "//jax:pallas_tpu", @@ -475,10 +404,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_paged_attention_kernel_test", srcs = ["tpu_paged_attention_kernel_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], shard_count = 5, tags = [ "noasan", # Times out. @@ -495,10 +421,7 @@ jax_multiplatform_test( srcs = [ "tpu_splash_attention_kernel_test.py", ], - disable_backends = [ - "gpu", - "cpu", - ], + enable_backends = ["tpu"], shard_count = 24, tags = [ "noasan", # Times out. @@ -515,8 +438,9 @@ jax_multiplatform_test( srcs = [ "tpu_splash_attention_mask_test.py", ], - disable_backends = [ - "gpu", + enable_backends = [ + "cpu", + "tpu", ], deps = [ "//jax:pallas_tpu_ops", @@ -533,17 +457,7 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_backends = [ - "tpu", - ], - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_p100", - "gpu_p100_x32", - "gpu_a100", - "gpu_h100", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", @@ -566,17 +480,7 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_backends = [ - "tpu", - ], - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_a100", - "gpu_h100", - "gpu_p100", - "gpu_p100_x32", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32",