-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add cudnn_fusion decorator lowering computations to XLA cuDNN fusions. #22699
Conversation
917e5da
to
65f6e35
Compare
The required change in XLA is done, this one is ready. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the slow review. Given this API isn't yet public, I'm comfortable merging it more or less as is.
Is there a minimum cudnn version for this test to pass?
9.0. |
The change looks fine to me, but it crashes in CI. |
Do I see right, that both failing checks are in non-GPU configurations? |
Yeah, you're right. Probably you need to error if lowering on a non-CUDA platform? I think just add You should also skip the test if not on cuda |
65f6e35
to
87f7704
Compare
Done. |
Sorry, it took me a long time to look at this. This test fails in our internal CI because it seems on V100 (which we run in CI) the rewrite to a cudnn fusion does not happen. Instead, the after optimization hlo ends up with a cublas gemm. Is that intended? Should the test be gated on particular GPU generations? |
It should run on H100. Is this https://github.com/google/jax/pull/22699/files#diff-77b54950a53c3196a56e8f570cb6dcd4eca602b5a8b4220f5cd2acb86f060e7fR1548 not sufficient to filter by GPU type? |
87f7704
to
85d792a
Compare
Anyway, I looked at other tests and added a check with skipTest(). It actually works on Ampere+. |
It appears not. However, in general BUILD rules aren't enough, because we support running the tests via other means such as pytest. So a BUILD rule filter is helpful (it stops us from running pointless tests), but the test should also skip itself if the hardware it needs isn't present. |
The rewrite also seems to fail on A100? |
I tested it on A100. |
I'm still finding this to fail in CI. It looks like the cudnn fusion is produced at the HLO fed to the compiler, but for some reason it gets rewritten away. Are we guaranteed that the fusion will be emitted, or can it sometimes be autotuned away or something? Are there any other circumstances under which the fusion will fall back? |
Indeed, I examined the tests we have (https://github.com/openxla/xla/blob/main/xla/service/gpu/transforms/cudnn_custom_call_converter_test.cc#L27, https://github.com/openxla/xla/blob/main/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc#L709) and realised, that the latter one relies on xla_gpu_cublas_fallback(false). Fix: #23505 |
This will require openxla/xla#15399 to work.
Code for jax/_src/cudnn/fusion.py provided by @hawkinsp.