Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pallas] Move ops_test.py from jax_triton to jax/pallas
The `jax_triton/ops_test.py` has over time accumulated many tests that are in fact platform-independent tests. Furthermore, those tests were only Google-internal, and they can be external as well. This moves test coverage for Pallas from the jax_triton package to the Pallas core package. A small number of the tests were deleted, because they were already present in Pallas, e.g., tests in `jax_triton/ops_test.py:ControlFlowTest`, and tests for unary and binary ops in `jax_triton/ops_test.py:OpsTest`. The other tests were distributed to different files in the Pallas repo, according to their purpose: * tests in `jax_triton/ops_test.py:PrettyPrintingTest` are moved to `tpu_pallas_test.py::PrettyPrintingTest` * tests in `jax_triton/ops_test.py::IndexingTest` are appended to `indexing_test.py::IndexingTest`; some other indexing tests from `jax_triton/ops_test.py::LoadStoreTest` are also moved there. * some tests in `jax_triton/ops_test.py:OpsTest` are moved to `ops_test.py::OpsTest`. * some tests for TPU specific ops in `jax_triton/ops_test.py:OpsTest` are moved to a new test file `tpu_ops_tests.py` Some of this required adding sharding and hypothesis support to `ops_test.py`, and adding TPU versions of `indexing_test.py`. PiperOrigin-RevId: 662045774
- Loading branch information