Skip to content

Commit

Permalink
Replace disable_backends with enable_backends on jax_multiplatform_test.
Browse files Browse the repository at this point in the history
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Sep 27, 2024
1 parent 5740ab3 commit 26632fd
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 235 deletions.
18 changes: 2 additions & 16 deletions benchmarks/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,11 @@ package(

jax_generate_backend_suites()

DISABLED_BACKENDS = [
"cpu",
"tpu",
]

DISABLED_CONFIGS = [
"gpu_v100",
"gpu_a100",
"gpu_p100",
"gpu_p100_x32",
"gpu_x32",
"gpu_pjrt_c_api",
]

jax_multiplatform_test(
name = "matmul_bench",
srcs = ["matmul_bench.py"],
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
enable_backends = [],
enable_configs = ["gpu_h100"],
tags = ["notap"],
deps = [
"//jax:mosaic_gpu",
Expand Down
5 changes: 1 addition & 4 deletions docs/cuda_custom_call/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ jax_multiplatform_test(
name = "cuda_custom_call_test",
srcs = ["cuda_custom_call_test.py"],
data = [":foo"],
disable_backends = [
"cpu",
"tpu",
],
enable_backends = ["gpu"],
tags = ["notap"],
deps = [
"//jax:extend",
Expand Down
9 changes: 4 additions & 5 deletions jax/experimental/mosaic/gpu/examples/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_library")
load("//jaxlib:jax.bzl", "jax_py_test", "py_deps")
load("//jaxlib:jax.bzl", "jax_multiplatform_test", "py_deps")

licenses(["notice"])

Expand Down Expand Up @@ -48,18 +48,17 @@ py_library(
],
)

jax_py_test(
jax_multiplatform_test(
name = "run_matmul",
srcs = ["matmul.py"],
enable_backends = [],
enable_configs = ["gpu_h100"],
main = "matmul.py",
tags = [
"manual",
"notap",
"requires-gpu-sm90-only",
],
deps = [
"//jax",
"//jax:mosaic_gpu",
"//learning/brain/research/jax:gpu_support",
] + py_deps("numpy"),
)
13 changes: 10 additions & 3 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,22 @@ def jax_multiplatform_test(
shard_count = None,
deps = [],
data = [],
disable_backends = None, # buildifier: disable=unused-variable
enable_backends = None,
backend_variant_args = {}, # buildifier: disable=unused-variable
backend_tags = {}, # buildifier: disable=unused-variable
disable_configs = None, # buildifier: disable=unused-variable
enable_configs = None, # buildifier: disable=unused-variable
enable_configs = [],
config_tags_overrides = None, # buildifier: disable=unused-variable
tags = [],
main = None,
pjrt_c_api_bypass = False): # buildifier: disable=unused-variable
# enable_configs and disable_configs do not do anything in OSS, only in Google's CI.
# The order in which `enable_backends`, `enable_configs`, and `disable_configs` are applied is
# as follows:
# 1. `enable_backends` is applied first, enabling all test configs for the given backends.
# 2. `disable_configs` is applied second, disabling the named test configs.
# 3. `enable_configs` is applied last, enabling the named test configs.

if main == None:
if len(srcs) == 1:
main = srcs[0]
Expand All @@ -256,7 +263,7 @@ def jax_multiplatform_test(
"--jax_platform_name=" + backend,
]
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
if disable_backends and backend in disable_backends:
if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]):
test_tags += ["manual"]
if backend == "gpu":
test_tags += tf_cuda_tests_tags()
Expand Down
87 changes: 22 additions & 65 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ jax_py_test(
jax_multiplatform_test(
name = "array_interoperability_test",
srcs = ["array_interoperability_test.py"],
disable_backends = ["tpu"],
enable_backends = [
"cpu",
"gpu",
],
tags = ["multiaccelerator"],
deps = py_deps("tensorflow_core"),
)
Expand Down Expand Up @@ -160,10 +163,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "gpu_memory_flags_test_no_preallocation",
srcs = ["gpu_memory_flags_test.py"],
disable_backends = [
"cpu",
"tpu",
],
enable_backends = ["gpu"],
env = {
"XLA_PYTHON_CLIENT_PREALLOCATE": "0",
},
Expand All @@ -173,10 +173,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "gpu_memory_flags_test",
srcs = ["gpu_memory_flags_test.py"],
disable_backends = [
"cpu",
"tpu",
],
enable_backends = ["gpu"],
env = {
"XLA_PYTHON_CLIENT_PREALLOCATE": "1",
},
Expand Down Expand Up @@ -273,10 +270,7 @@ jax_multiplatform_test(
backend_tags = {
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
},
disable_backends = [
"cpu",
"tpu",
],
enable_backends = ["gpu"],
env = {"XLA_FLAGS": "--xla_dump_to=sponge --xla_gpu_enable_latency_hiding_scheduler=true"},
tags = [
"config-cuda-only",
Expand All @@ -290,10 +284,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "mock_gpu_test",
srcs = ["mock_gpu_test.py"],
disable_backends = [
"cpu",
"tpu",
],
enable_backends = ["gpu"],
tags = [
"config-cuda-only",
],
Expand Down Expand Up @@ -556,11 +547,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "lax_metal_test",
srcs = ["lax_metal_test.py"],
disable_backends = [
"cpu",
"gpu",
"tpu",
],
enable_backends = ["metal"],
tags = ["notap"],
deps = [
"//jax:internal_test_util",
Expand Down Expand Up @@ -649,10 +636,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "metadata_test",
srcs = ["metadata_test.py"],
disable_backends = [
"gpu",
"tpu",
],
enable_backends = ["cpu"],
)

jax_py_test(
Expand All @@ -672,10 +656,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "multi_device_test",
srcs = ["multi_device_test.py"],
disable_backends = [
"gpu",
"tpu",
],
enable_backends = ["cpu"],
)

jax_multiplatform_test(
Expand Down Expand Up @@ -734,10 +715,7 @@ jax_multiplatform_test(
name = "polynomial_test",
srcs = ["polynomial_test.py"],
# No implementation of nonsymmetric Eigendecomposition.
disable_backends = [
"gpu",
"tpu",
],
enable_backends = ["cpu"],
shard_count = {
"cpu": 10,
},
Expand All @@ -753,32 +731,29 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "heap_profiler_test",
srcs = ["heap_profiler_test.py"],
disable_backends = [
"gpu",
"tpu",
],
enable_backends = ["cpu"],
)

jax_multiplatform_test(
name = "profiler_test",
srcs = ["profiler_test.py"],
disable_backends = [
"gpu",
"tpu",
],
enable_backends = ["cpu"],
)

jax_multiplatform_test(
name = "pytorch_interoperability_test",
srcs = ["pytorch_interoperability_test.py"],
disable_backends = ["tpu"],
# The following cases are disabled because they time out in Google's CI, mostly because the
# CUDA kernels in Torch take a very long time to compile.
disable_configs = [
"gpu_p100", # Pytorch P100 build times out in Google's CI.
"gpu_a100", # Pytorch A100 build times out in Google's CI.
"gpu_h100", # Pytorch H100 build times out in Google's CI.
],
enable_backends = [
"cpu",
"gpu",
],
tags = [
"not_build:arm",
# TODO(b/355237462): Re-enable once MSAN issue is addressed.
Expand Down Expand Up @@ -1019,16 +994,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "sparse_nm_test",
srcs = ["sparse_nm_test.py"],
config_tags_overrides = {
"gpu_a100": {
"ondemand": False, # Include in presubmit.
},
},
disable_backends = [
"cpu",
"gpu",
"tpu",
],
enable_backends = [],
enable_configs = [
"gpu_a100",
"gpu_h100",
Expand Down Expand Up @@ -1386,13 +1352,10 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "experimental_rnn_test",
srcs = ["experimental_rnn_test.py"],
disable_backends = [
"tpu",
"cpu",
],
disable_configs = [
"gpu_a100", # Numerical precision problems.
],
enable_backends = ["gpu"],
shard_count = 15,
deps = [
"//jax:rnn",
Expand Down Expand Up @@ -1505,10 +1468,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "fused_attention_stablehlo_test",
srcs = ["fused_attention_stablehlo_test.py"],
disable_backends = [
"tpu",
"cpu",
],
enable_backends = ["gpu"],
shard_count = {
"gpu": 4,
},
Expand Down Expand Up @@ -1542,10 +1502,7 @@ jax_py_test(
jax_multiplatform_test(
name = "cudnn_fusion_test",
srcs = ["cudnn_fusion_test.py"],
disable_backends = [
"cpu",
"tpu",
],
enable_backends = ["gpu"],
enable_configs = [
"gpu_a100",
"gpu_h100",
Expand Down
31 changes: 8 additions & 23 deletions tests/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,16 @@ package(

jax_generate_backend_suites()

DISABLED_BACKENDS = [
"cpu",
"tpu",
]

DISABLED_CONFIGS = [
"gpu_a100",
"gpu_a100_x32",
"gpu_p100",
"gpu_p100_x32",
"gpu_pjrt_c_api",
"gpu_v100",
"gpu_x32",
]

jax_multiplatform_test(
name = "gpu_test",
srcs = ["gpu_test.py"],
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
enable_backends = [],
enable_configs = [
"gpu_h100",
"gpu_h100_2gpu",
],
shard_count = 4,
tags = ["multiaccelerator"],
deps = [
"//jax:mosaic_gpu",
] + py_deps("absl/testing") + py_deps("numpy"),
Expand All @@ -61,8 +46,8 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "matmul_test",
srcs = ["matmul_test.py"],
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
enable_backends = [],
enable_configs = ["gpu_h100"],
shard_count = 5,
deps = [
"//jax:mosaic_gpu",
Expand All @@ -73,8 +58,8 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "flash_attention",
srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"],
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
enable_backends = [],
enable_configs = ["gpu_h100"],
main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py",
tags = ["notap"],
deps = [
Expand All @@ -85,8 +70,8 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "flash_attention_test",
srcs = ["flash_attention_test.py"],
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
enable_backends = [],
enable_configs = ["gpu_h100"],
deps = [
"//jax:mosaic_gpu",
"//jax/experimental/mosaic/gpu/examples:flash_attention",
Expand Down
Loading

0 comments on commit 26632fd

Please sign in to comment.