Skip to content

Commit

Permalink
[pallas] Move ops_test.py from jax_triton to jax/pallas
Browse files Browse the repository at this point in the history
The `jax_triton/ops_test.py` has over time accumulated many tests that are in fact platform-independent tests.
Furthermore, those tests were only Google-internal, and they can be external as well.

This moves test coverage for Pallas from the jax_triton package to the Pallas core package.

A small number of the tests were deleted, because they were already present in Pallas, e.g., tests in `jax_triton/ops_test.py:ControlFlowTest`, and tests for unary and binary ops in `jax_triton/ops_test.py:OpsTest`.

The other tests were distributed to different files in the Pallas repo, according to their purpose:

  * tests in `jax_triton/ops_test.py:PrettyPrintingTest` are moved to `tpu_pallas_test.py::PrettyPrintingTest`
  * tests in `jax_triton/ops_test.py::IndexingTest` are appended to `indexing_test.py::IndexingTest`; some other indexing tests from `jax_triton/ops_test.py::LoadStoreTest` are also moved there.
   * some tests in `jax_triton/ops_test.py:OpsTest` are moved to `ops_test.py::OpsTest`.
   * some tests for TPU specific ops in `jax_triton/ops_test.py:OpsTest` are moved to a new test file `tpu_ops_tests.py`

Some of this required adding sharding and hypothesis support to `ops_test.py`,
and adding TPU versions of `indexing_test.py`.

PiperOrigin-RevId: 662045774
  • Loading branch information
gnecula authored and jax authors committed Aug 12, 2024
1 parent 3c014a4 commit 7f680aa
Show file tree
Hide file tree
Showing 5 changed files with 929 additions and 6 deletions.
25 changes: 23 additions & 2 deletions tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,17 @@ jax_test(
"gpu_a100_x32",
"gpu_h100_x32",
],
shard_count = {
"cpu": 4,
"gpu": 4,
"tpu": 4,
},
deps = [
"//jax:pallas",
"//jax:pallas_gpu", # build_cleaner: keep
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
] + py_deps("absl/testing") + py_deps("numpy"),
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
)

jax_test(
Expand All @@ -99,10 +104,10 @@ jax_test(
],
disable_backends = [
"gpu",
"tpu",
],
deps = [
"//jax:pallas",
"//jax:pallas_tpu",
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
)

Expand Down Expand Up @@ -317,6 +322,22 @@ jax_test(
],
)

jax_test(
name = "tpu_ops_test",
srcs = [
"tpu_ops_test.py",
],
disable_backends = [
"gpu",
],
deps = [
"//jax:pallas",
"//jax:pallas_gpu", # build_cleaner: keep
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
)

jax_test(
name = "tpu_pallas_distributed_test",
srcs = ["tpu_pallas_distributed_test.py"],
Expand Down
Loading

0 comments on commit 7f680aa

Please sign in to comment.