Skip to content

Commit

Permalink
#sdy Add CPU targets in JAX.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678626718
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Sep 26, 2024
1 parent e62a50c commit 9d9e468
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ jax_multiplatform_test(
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
},
enable_configs = [
"cpu_shardy",
"gpu_2gpu_shardy",
"tpu_v3_2x2_shardy",
"tpu_v4_2x2_shardy",
Expand Down Expand Up @@ -1347,6 +1348,7 @@ jax_multiplatform_test(
name = "shard_map_test",
srcs = ["shard_map_test.py"],
enable_configs = [
"cpu_shardy",
"gpu_2gpu_shardy",
"tpu_v3_2x2_shardy",
"tpu_v4_2x2_shardy",
Expand Down

0 comments on commit 9d9e468

Please sign in to comment.