Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cudnn_fusion decorator lowering computations to XLA cuDNN fusions. #22699

Merged
merged 1 commit into from
Sep 5, 2024

Conversation

sergachev
Copy link
Contributor

This will require openxla/xla#15399 to work.

Code for jax/_src/cudnn/fusion.py provided by @hawkinsp.

@sergachev sergachev force-pushed the cudnn_fusion branch 2 times, most recently from 917e5da to 65f6e35 Compare August 8, 2024 21:46
@sergachev
Copy link
Contributor Author

The required change in XLA is done, this one is ready.

Copy link
Collaborator

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the slow review. Given this API isn't yet public, I'm comfortable merging it more or less as is.

Is there a minimum cudnn version for this test to pass?

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Aug 27, 2024
@sergachev
Copy link
Contributor Author

Is there a minimum cudnn version for this test to pass?

9.0.

@hawkinsp
Copy link
Collaborator

The change looks fine to me, but it crashes in CI.

@sergachev
Copy link
Contributor Author

Do I see right, that both failing checks are in non-GPU configurations?

@hawkinsp
Copy link
Collaborator

Yeah, you're right. Probably you need to error if lowering on a non-CUDA platform?

I think just add platform="cuda" to the register_lowering. Currently you're asserting that lowering works everywhere.

You should also skip the test if not on cuda @jtu.run_on_devices("cuda") iirc.

@sergachev
Copy link
Contributor Author

Done.

@hawkinsp hawkinsp added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels Sep 4, 2024
@hawkinsp
Copy link
Collaborator

hawkinsp commented Sep 4, 2024

Sorry, it took me a long time to look at this. This test fails in our internal CI because it seems on V100 (which we run in CI) the rewrite to a cudnn fusion does not happen. Instead, the after optimization hlo ends up with a cublas gemm. Is that intended? Should the test be gated on particular GPU generations?

@sergachev
Copy link
Contributor Author

It should run on H100. Is this https://github.com/google/jax/pull/22699/files#diff-77b54950a53c3196a56e8f570cb6dcd4eca602b5a8b4220f5cd2acb86f060e7fR1548 not sufficient to filter by GPU type?

@sergachev
Copy link
Contributor Author

Anyway, I looked at other tests and added a check with skipTest(). It actually works on Ampere+.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Sep 5, 2024

It should run on H100. Is this https://github.com/google/jax/pull/22699/files#diff-77b54950a53c3196a56e8f570cb6dcd4eca602b5a8b4220f5cd2acb86f060e7fR1548 not sufficient to filter by GPU type?

It appears not. However, in general BUILD rules aren't enough, because we support running the tests via other means such as pytest. So a BUILD rule filter is helpful (it stops us from running pointless tests), but the test should also skip itself if the hardware it needs isn't present.

@copybara-service copybara-service bot merged commit 8fe99ff into jax-ml:main Sep 5, 2024
14 checks passed
@sergachev sergachev deleted the cudnn_fusion branch September 5, 2024 17:27
@hawkinsp
Copy link
Collaborator

hawkinsp commented Sep 6, 2024

The rewrite also seems to fail on A100?

@sergachev
Copy link
Contributor Author

I tested it on A100.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Sep 8, 2024

I'm still finding this to fail in CI. It looks like the cudnn fusion is produced at the HLO fed to the compiler, but for some reason it gets rewritten away.

Are we guaranteed that the fusion will be emitted, or can it sometimes be autotuned away or something? Are there any other circumstances under which the fusion will fall back?

@sergachev
Copy link
Contributor Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants