Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reducing compile time of JAX HEALPix (I)FFT implementations #171

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 63 additions & 30 deletions s2fft/utils/healpix_ffts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jax import jit
from jax import jit, vmap

import numpy as np
import jax.numpy as jnp
Expand Down Expand Up @@ -220,23 +220,47 @@ def healpix_fft_jax(f: jnp.ndarray, L: int, nside: int, reality: bool) -> jnp.nd
Returns:
jnp.ndarray: Array of Fourier coefficients for all latitudes.
"""
ntheta = samples.ntheta(L, "healpix", nside)
index = 0
ftm_rows = []
for t in range(ntheta):
nphi = samples.nphi_ring(t, nside)

def f_chunks_to_ftm_rows(f_chunks, nphi):
if reality and nphi == 2 * L:
fm_chunk = jnp.zeros(nphi, dtype=jnp.complex128)
fm_chunk = fm_chunk.at[nphi // 2 :].set(
jnp.fft.rfft(jnp.real(f[index : index + nphi]), norm="backward")[:-1]
fm_chunks = jnp.concatenate(
(
jnp.zeros((f_chunks.shape[0], nphi // 2)),
jnp.fft.rfft(jnp.real(f_chunks), norm="backward")[:, :-1],
),
axis=1,
)
else:
fm_chunk = jnp.fft.fftshift(
jnp.fft.fft(f[index : index + nphi], norm="backward")
fm_chunks = jnp.fft.fftshift(
jnp.fft.fft(f_chunks, norm="backward"), axes=-1
)
ftm_rows.append(spectral_periodic_extension_jax(fm_chunk, L))
index += nphi
return jnp.stack(ftm_rows)
return vmap(spectral_periodic_extension_jax, (0, None))(fm_chunks, L)

# Process f chunks corresponding to pairs of polar theta rings with the same number
# of phi samples together to reduce size of unrolled traced computational graph
ftm_rows_polar = []
start_index, end_index = 0, 12 * nside**2
for t in range(0, nside - 1):
nphi = 4 * (t + 1)
f_chunks = jnp.stack(
(f[start_index : start_index + nphi], f[end_index - nphi : end_index])
)
ftm_rows_polar.append(f_chunks_to_ftm_rows(f_chunks, nphi))
start_index, end_index = start_index + nphi, end_index - nphi
ftm_rows_polar = jnp.stack(ftm_rows_polar)
# Process all f chunks for the equal sized equatorial theta rings together
nphi = 4 * nside
f_chunks_equatorial = f[start_index:end_index].reshape((-1, nphi))
ftm_rows_equatorial = f_chunks_to_ftm_rows(f_chunks_equatorial, nphi)
# Concatenate Fourier coefficients for all latitudes, reversing second polar set to
# account for processing order
return jnp.concatenate(
(
ftm_rows_polar[:, 0],
ftm_rows_equatorial,
ftm_rows_polar[::-1, 1],
)
)


def healpix_ifft(
Expand Down Expand Up @@ -336,28 +360,37 @@ def healpix_ifft_jax(
Returns:
jnp.ndarray: HEALPix pixel-space array.
"""
f = jnp.zeros(
samples.f_shape(sampling="healpix", nside=nside), dtype=jnp.complex128
)
ntheta = ftm.shape[0]
index = 0

for t in range(ntheta):
nphi = samples.nphi_ring(t, nside)
fm_chunk = ftm[t] if nphi == 2 * L else spectral_folding_jax(ftm[t], nphi, L)
def ftm_rows_to_f_chunks(ftm_rows, nphi):
fm_chunks = (
ftm_rows
if nphi == 2 * L
else vmap(spectral_folding_jax, (0, None, None))(ftm_rows, nphi, L)
)
if reality and nphi == 2 * L:
f = f.at[index : index + nphi].set(
jnp.fft.irfft(fm_chunk[nphi // 2 :], nphi, norm="forward")
)
return jnp.fft.irfft(fm_chunks[:, nphi // 2 :], nphi, norm="forward")
else:
f = f.at[index : index + nphi].set(
jnp.conj(
jnp.fft.fft(jnp.fft.ifftshift(jnp.conj(fm_chunk)), norm="backward")
return jnp.conj(
jnp.fft.fft(
jnp.fft.ifftshift(jnp.conj(fm_chunks), axes=-1), norm="backward"
)
)

index += nphi
return f
# Process ftm rows corresponding to pairs of polar theta rings with the same number
# of phi samples together to reduce size of unrolled traced computational graph
f_chunks_polar = [
ftm_rows_to_f_chunks(jnp.stack((ftm[t], ftm[-(t + 1)])), 4 * (t + 1))
for t in range(nside - 1)
]
# Process all ftm rows for the equal sized equatorial theta rings together
f_chunks_equatorial = ftm_rows_to_f_chunks(ftm[nside - 1 : 3 * nside], 4 * nside)
# Concatenate f chunks for all theta rings together, reversing second polar set
# to account for processing order
return jnp.concatenate(
[f_chunks_polar[t][0] for t in range(nside - 1)]
+ [f_chunks_equatorial.flatten()]
+ [f_chunks_polar[t][1] for t in reversed(range(nside - 1))]
)


def p2phi_rings(t: np.ndarray, nside: int) -> np.ndarray:
Expand Down
Loading