Skip to content

astro-informatics/s2fft

Repository files navigation

Tests status Linting status Documentation status Codecov MIT License PyPI package arXiv All Contributors Open In Colab Linter

Differentiable and accelerated spherical transforms

S2FFT is a Python package for computing Fourier transforms on the sphere and rotation group (Price & McEwen 2024) using JAX or PyTorch. It leverages autodiff to provide differentiable transforms, which are also deployable on hardware accelerators (e.g. GPUs and TPUs).

More specifically, S2FFT provides support for spin spherical harmonic and Wigner transforms (for both real and complex signals), with support for adjoint transformations where needed, and comes with different optimisations (precompute or not) that one may select depending on available resources and desired angular resolution $L$.

Important

HEALPix long JIT compile time fixed for CPU! Fix for GPU coming soon.

Tip

As of version 1.0.2 S2FFT also provides PyTorch implementations of underlying precompute transforms. In future releases this support will be extended to our on-the-fly algorithms.

Tip

As of version 1.1.0 S2FFT also provides JAX support for existing C/C++ packages, specifically HEALPix and SSHT. This works by wrapping python bindings with custom JAX frontends. Note that currently this C/C++ to JAX interoperability is currently limited to CPU.

Algorithms ⚑

S2FFT leverages new algorithmic structures that can he highly parallelised and distributed, and so map very well onto the architecture of hardware accelerators (i.e. GPUs and TPUs). In particular, these algorithms are based on new Wigner-d recursions that are stable to high angular resolution $L$. The diagram below illustrates the recursions (for further details see Price & McEwen 2024).

image With this recursion to hand, the spherical harmonic coefficients of an isolatitudinally sampled map may be computed as a two step process. First, a 1D Fourier transform over longitude, for each latitudinal ring. Second, a projection onto the real polar-d functions. One may precompute and store all real polar-d functions for extreme acceleration, however this comes with an equally extreme memory overhead, which is infeasible at L ~ 1024. Alternatively, the real polar-d functions may calculated recursively, computing only a portion of the projection at a time, hence incurring negligible memory overhead at the cost of slightly slower execution. The diagram below illustrates the separable spherical harmonic transform (for further details see Price & McEwen 2024).

image

Sampling 🌍

The structure of the algorithms implemented in S2FFT can support any isolatitude sampling scheme. A number of sampling schemes are currently supported.

The equiangular sampling schemes of McEwen & Wiaux (2012), Driscoll & Healy (1995) and Gauss-Legendre (1986) are supported, which exhibit associated sampling theorems and so harmonic transforms can be computed to machine precision. Note that the McEwen & Wiaux sampling theorem reduces the Nyquist rate on the sphere by a factor of two compared to the Driscoll & Healy approach, halving the number of spherical samples required.

The popular HEALPix sampling scheme (Gorski et al. 2005) is also supported. The HEALPix sampling does not exhibit a sampling theorem and so the corresponding harmonic transforms do not achieve machine precision but exhibit some error. However, the HEALPix sampling provides pixels of equal areas, which has many practical advantages.

Note

For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and are working to fix it. A fix for CPU execution has now been implemented (see example notebook). Fix for GPU execution is coming soon.

Installation πŸ’»

The Python dependencies for the S2FFT package are listed in the file requirements/requirements-core.txt and will be automatically installed into the active python environment by pip when running

pip install s2fft

This will install all core functionality which includes JAX support (including PyTorch support).

Alternatively, the S2FFT package may be installed directly from GitHub by cloning this repository and then running

pip install .        

from the root directory of the repository.

Unit tests can then be executed to ensure the installation was successful by first installing the test requirements and then running pytest

pip install -r requirements/requirements-tests.txt
pytest tests/  

Documentation for the released version is available here. To build the documentation locally run

pip install -r requirements/requirements-docs.txt
cd docs 
make html
open _build/html/index.html

Note

For plotting functionality which can be found throughout our various notebooks, one must install the requirements which can be found in requirements/requirements-plotting.txt.

Usage πŸš€

To import and use S2FFT is as simple follows:

For a signal on the sphere

# Compute harmonic coefficients
flm = s2fft.forward_jax(f, L)  
# Map back to pixel-space signal
f = s2fft.inverse_jax(flm, L)

For a signal on the rotation group

# Compute Wigner coefficients
flmn = s2fft.wigner.forward_jax(f, L, N)
# Map back to pixel-space signal
f = fft.wigner.inverse_jax(flmn, L, N)

For further details on usage see the documentation and associated notebooks.

Note

We also provide PyTorch support for the precompute version of our transforms. These are called through forward/inverse_torch(). Full PyTorch support will be provided in future releases.

C/C++ JAX Frontends for SSHT/HEALPix πŸ’‘

S2FFT also provides JAX support for existing C/C++ packages, specifically HEALPix and SSHT. This works by wrapping python bindings with custom JAX frontends. Note that this C/C++ to JAX interoperability is currently limited to CPU.

For example, one may call these alternate backends for the spherical harmonic transform by:

# Forward SSHT spherical harmonic transform
flm = s2fft.forward(f, L, sampling=["mw"], method="jax_ssht")  

# Forward HEALPix spherical harmonic transform
flm = s2fft.forward(f, L, nside=nside, sampling="healpix", method="jax_healpy")  

All of these JAX frontends supports out of the box reverse mode automatic differentiation, and under the hood is simply linking to the C/C++ packages you are familiar with. In this way S2fft enhances existing packages with gradient functionality for modern scientific computing or machine learning applications!

For further details on usage see the associated notebooks.

Contributors ✨

Thanks goes to these wonderful people (emoji key):

Matt Price
Matt Price

πŸ’» πŸ‘€ πŸ€”
Jason McEwen
Jason McEwen

πŸ’» πŸ‘€ πŸ€”
Matt Graham
Matt Graham

πŸ’» πŸ‘€
sfmig
sfmig

πŸ’» πŸ‘€
Devaraj Gopinathan
Devaraj Gopinathan

πŸ’»
Francois Lanusse
Francois Lanusse

πŸ’» πŸ›
Ikko Eltociear Ashimine
Ikko Eltociear Ashimine

πŸ“–
Kevin Mulder
Kevin Mulder

πŸ›
Philipp Misof
Philipp Misof

πŸ›
Elis Roberts
Elis Roberts

πŸ› πŸ“–
Wassim KABALAN
Wassim KABALAN

πŸ’» πŸ‘€ ⚠️

We encourage contributions from any interested developers. A simple first addition could be adding support for more spherical sampling patterns!

Attribution πŸ“š

Should this code be used in any way, we kindly request that the following article is referenced. A BibTeX entry for this reference may look like:

@article{price:s2fft, 
   author      = "Matthew A. Price and Jason D. McEwen",
   title       = "Differentiable and accelerated spherical harmonic and Wigner transforms",
   journal     = "Journal of Computational Physics",
   year        = "2024",
   volume      = "510",
   pages       = "113109",
   eprint      = "arXiv:2311.14670",
   doi         = "10.1016/j.jcp.2024.113109"
}

You might also like to consider citing our related papers on which this code builds:

@article{mcewen:fssht,
    author      = "Jason D. McEwen and Yves Wiaux",
    title       = "A novel sampling theorem on the sphere",
    journal     = "IEEE Trans. Sig. Proc.",
    year        = "2011",
    volume      = "59",
    number      = "12",
    pages       = "5876--5887",        
    eprint      = "arXiv:1110.6298",
    doi         = "10.1109/TSP.2011.2166394"
}
@article{mcewen:so3,
    author      = "Jason D. McEwen and Martin B{\"u}ttner and Boris ~Leistedt and Hiranya V. Peiris and Yves Wiaux",
    title       = "A novel sampling theorem on the rotation group",
    journal     = "IEEE Sig. Proc. Let.",
    year        = "2015",
    volume      = "22",
    number      = "12",
    pages       = "2425--2429",
    eprint      = "arXiv:1508.03101",
    doi         = "10.1109/LSP.2015.2490676"    
}

License πŸ“

We provide this code under an MIT open-source licence with the hope that it will be of use to a wider community.

Copyright 2023 Matthew Price, Jason McEwen and contributors.

S2FFT is free software made available under the MIT License. For details see the LICENCE.txt file.

The file lib/include/kernel_helpers.h is adapted from code in a tutorial on extending JAX by Dan Foreman-Mackey and licensed under a MIT license.

The file lib/include/kernel_nanobind_helpers.h is adapted from code by the JAX authors and licensed under a Apache-2.0 license.