Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new scripts and configs for running JAX tests #23677

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

copybara-service[bot]
Copy link

Add new scripts and configs for running JAX tests

This adds new CI scripts and envs for running JAX tests. As with the build artifact scripts, these scripts require that JAXCI_ENV_FILE be set to one of the envs inside ci/envs/run_tests before invoking run_bazel_test.sh or run_pytest.sh. Both CPU and GPU commands are in the same script and script behaviors are controlled by JAXCI environment variables set by JAXCI_ENV_FILE. In order to make the Bazel command concise, test configs are moved to the .bazelrc and are grouped under the multiaccelerator and non_multiaccelerator configs

As an example, for running Bazel CPU tests with RBE, we would run:

1. export JAXCI_ENV_FILE=ci/envs/run_tests/bazel_cpu
2. ./ci/run_bazel_test.sh

for running Pytest GPU tests, we would run:

1. export JAXCI_ENV_FILE=ci/envs/run_tests/pytest_gpu
2. ./ci/run_pytest.sh

As Pytests are run locally, note that these scripts require the JAX wheels to be present inside the dist/ folder in the JAX git repository root.

This adds new CI scripts and envs for running JAX tests. As with the build artifact scripts, these scripts require that `JAXCI_ENV_FILE` be set to one of the envs inside `ci/envs/run_tests` before invoking `run_bazel_test.sh` or `run_pytest.sh`. Both CPU and GPU commands are in the same script and script behaviors are controlled by `JAXCI` environment variables set by `JAXCI_ENV_FILE`. In order to make the Bazel command concise, test configs are moved to the .bazelrc and are grouped under the `multiaccelerator` and `non_multiaccelerator` configs

As an example, for running Bazel CPU tests with RBE, we would run:
```
1. export JAXCI_ENV_FILE=ci/envs/run_tests/bazel_cpu
2. ./ci/run_bazel_test.sh
```

for running Pytest GPU tests, we would run:
```
1. export JAXCI_ENV_FILE=ci/envs/run_tests/pytest_gpu
2. ./ci/run_pytest.sh
```

As Pytests are run locally, note that these scripts require the JAX wheels to be present inside the `dist/` folder in the JAX git repository root.

PiperOrigin-RevId: 675316647
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant