Reducing compile time of JAX HEALPix (I)FFT implementations #171
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Related to #140 though this doesn't completely remove the loops in the JAX HEALPix FFT and IFFT implementations, but it does reduce the number of unrolled operations and so compile time. Unfortunately the optimizations do make the code a bit less readable and less directly tied to the NumPy implementations.
I've tested locally against the tests added in #170 which pass, but we would probably want to merge that in first, so I'm marking this as draft until that is merged in and we can then rebase on top of that.
Compared to previous implementations, this tries to vectorize operations as much as possible by processing data connected to$\theta$ rings of the same size (all equatorial rings and the pairs of polar rings of equal sizes) together.
The big gain is in vectorizing the operations on all the equally sized equatorial bands together, as this removes around
2 * nside
unrolled loop iterations in favour of one set of vectorized operations. Processing the pairs of polar rings together gives a smaller but still helpful reduction in compile time.