Skip to content

Commit

Permalink
Merge pull request #195 from astro-informatics/feature/JAX_frontend_f…
Browse files Browse the repository at this point in the history
…or_C++_codes

add jax frontend support for c/c++ sht libraries
  • Loading branch information
CosmoMatt authored Apr 9, 2024
2 parents b3d033c + baa412b commit 76fa862
Show file tree
Hide file tree
Showing 32 changed files with 1,576 additions and 335 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
pip install jaxlib
pip install -r requirements/requirements-core.txt
pip install -r requirements/requirements-docs.txt
pip install .\[torch\]
pip install .
- name: Build Documentation
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements/requirements-tests.txt
pip install -r requirements/requirements-core.txt
pip install .\[torch\]
pip install .
- name: Run tests
run: |
Expand Down
8 changes: 7 additions & 1 deletion .pip_readme.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black
.. image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1YmJ2ljsF8HBvhPmD4hrYPlyAKc4WPUgq?usp=sharing
:target: https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooksspherical_harmonic_transform.ipynb

Differentiable and accelerated spherical transforms
=================================================================================================================
Expand All @@ -31,6 +31,12 @@ As of version 1.0.2 `S2FFT` also provides PyTorch implementations of underlying
precompute transforms. In future releases this support will be extended to our
on-the-fly algorithms.

As of version 1.1.0 `S2FFT` also provides JAX support for existing C/C++ packages,
specifically `HEALPix` and `SSHT`. This works by wrapping python bindings with custom
JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
limited to CPU, however for many applications this is desirable due to memory
constraints.

Documentation
=============
Read the full documentation `here <https://astro-informatics.github.io/s2fft/>`_.
Expand Down
65 changes: 49 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![image](https://badge.fury.io/py/s2fft.svg)](https://badge.fury.io/py/s2fft)
[![image](http://img.shields.io/badge/arXiv-2311.14670-orange.svg?style=flat)](https://arxiv.org/abs/2311.14670)<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-9-orange.svg?style=flat-square)](#contributors-)<!-- ALL-CONTRIBUTORS-BADGE:END -->
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1YmJ2ljsF8HBvhPmD4hrYPlyAKc4WPUgq?usp=sharing)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/spherical_harmonic_transform.ipynb)
<!-- [![image](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -->

<img align="left" height="85" width="98" src="./docs/assets/sax_logo.png">
Expand All @@ -22,10 +22,20 @@ for adjoint transformations where needed, and comes with different
optimisations (precompute or not) that one may select depending on
available resources and desired angular resolution $L$.

> [!IMPORTANT]
> HEALPix long JIT compile time fixed for CPU! Fix for GPU coming soon.
> [!TIP]
As of version 1.0.2 `S2FFT` also provides PyTorch implementations of underlying
precompute transforms. In future releases this support will be extended to our
on-the-fly algorithms.

> [!TIP]
As of version 1.1.0 `S2FFT` also provides JAX support for existing C/C++ packages,
specifically `HEALPix` and `SSHT`. This works by wrapping python bindings with custom
JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
limited to CPU.

## Algorithms :zap:

`S2FFT` leverages new algorithmic structures that can he highly
Expand Down Expand Up @@ -53,7 +63,7 @@ diagram below illustrates the separable spherical harmonic transform
## Sampling :earth_africa:

The structure of the algorithms implemented in `S2FFT` can support any
isolattitude sampling scheme. A number of sampling schemes are currently
isolatitude sampling scheme. A number of sampling schemes are currently
supported.

The equiangular sampling schemes of [McEwen & Wiaux
Expand All @@ -73,10 +83,10 @@ so the corresponding harmonic transforms do not achieve machine
precision but exhibit some error. However, the HEALPix sampling provides
pixels of equal areas, which has many practical advantages.

<p align="center"><img src="./docs/assets/figures/spherical_sampling.png" width="500"></p>
<p align="center"><img src="./docs/assets/figures/spherical_sampling.png" width="700"></p>

> [!NOTE]
> For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and will work to improve this in subsequent versions.
> For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and are working to fix it. A fix for CPU execution has now been implemented (see example [notebook](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_HEALPix_backend.html)). Fix for GPU execution is coming soon.
## Installation :computer:

Expand All @@ -87,12 +97,7 @@ into the active python environment by [pip](https://pypi.org) when running
``` bash
pip install s2fft
```
This will install all core functionality which includes JAX support. To install `S2FFT`
with PyTorch support run

``` bash
pip install s2fft[torch]
```
This will install all core functionality which includes JAX support (including PyTorch support).

Alternatively, the `S2FFT` package may be installed directly from GitHub by cloning this
repository and then running
Expand All @@ -101,16 +106,22 @@ repository and then running
pip install .
```

from the root directory of the repository. To enable PyTorch support you will need to run
from the root directory of the repository.

Unit tests can then be executed to ensure the installation was successful by first installing the test requirements and then running pytest

``` bash
pip install .[torch]
pip install -r requirements/requirements-tests.txt
pytest tests/
```

Unit tests can then be executed to ensure the installation was successful by running
Documentation for the released version is available [here](https://astro-informatics.github.io/s2fft/). To build the documentation locally run

``` bash
pytest tests/
pip install -r requirements/requirements-docs.txt
cd docs
make html
open _build/html/index.html
```

> [!NOTE]
Expand Down Expand Up @@ -143,7 +154,29 @@ For further details on usage see the [documentation](https://astro-informatics.g
> [!NOTE]
> We also provide PyTorch support for the precompute version of our transforms. These are called through forward/inverse_torch(). Full PyTorch support will be provided in future releases.
## Benchmarking :hourglass_flowing_sand:
## C/C++ JAX Frontends for SSHT/HEALPix :bulb:

`S2FFT` also provides JAX support for existing C/C++ packages, specifically [`HEALPix`](https://healpix.jpl.nasa.gov) and [`SSHT`](https://github.com/astro-informatics/ssht). This works
by wrapping python bindings with custom JAX frontends. Note that this C/C++ to JAX interoperability is currently limited to CPU.

For example, one may call these alternate backends for the spherical harmonic transform by:

``` python
# Forward SSHT spherical harmonic transform
flm = s2fft.forward(f, L, sampling=["mw"], method="jax_ssht")

# Forward HEALPix spherical harmonic transform
flm = s2fft.forward(f, L, nside=nside, sampling="healpix", method="jax_healpy")
```

All of these JAX frontends supports out of the box reverse mode automatic differentiation,
and under the hood is simply linking to the C/C++ packages you are familiar with. In this
way `S2fft` enhances existing packages with gradient functionality for modern scientific computing or machine learning
applications!

For further details on usage see the associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_SSHT_backend.html).

<!-- ## Benchmarking :hourglass_flowing_sand:
We benchmarked the spherical harmonic and Wigner transforms implemented
in `S2FFT` against the C implementations in the
Expand All @@ -167,7 +200,7 @@ that scale linearly with spin).
| 8192 | 82 s | 110.8 | 2.14E-13 | N/A | N/A | N/A | N/A |
where the left hand results are for the recursive based algorithm and the right hand side are
our precompute implementation.
our precompute implementation. -->

## Contributors ✨

Expand Down
7 changes: 7 additions & 0 deletions docs/api/transforms/c_backend_spherical.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:html_theme.sidebar_secondary.remove:

**************************
C/C++ custom JAX support
**************************
.. automodule:: s2fft.transforms.c_backend_spherical
:members:
20 changes: 20 additions & 0 deletions docs/api/transforms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@ Transforms
* - :func:`~s2fft.transforms.wigner.forward_jax`
- Forward Wigner transform (JAX)

.. list-table:: C/C++ backend gradient support
:widths: 25 25
:header-rows: 1

* - Function Name
- Description
* - :func:`~s2fft.transforms.c_backend_spherical.ssht_inverse`
- Custom JAX frontend for inverse SSHT C spherical harmonic library.
* - :func:`~s2fft.transforms.c_backend_spherical.ssht_forward`
- Custom JAX frontend for forward SSHT C spherical harmonic library.
* - :func:`~s2fft.transforms.c_backend_spherical.healpy_inverse`
- Custom JAX frontend for inverse HEALPix C++ spherical harmonic library.
* - :func:`~s2fft.transforms.c_backend_spherical.healpy_forward`
- Custom JAX frontend for forwardHEALPix C++ spherical harmonic library.
* - :func:`~s2fft.transforms.wigner.inverse_jax_ssht`
- Custom JAX frontend for hybrid inverse SSHT C Wigner transforms.
* - :func:`~s2fft.transforms.wigner.forward_jax_ssht`
- Custom JAX frontend for hybrid forward SSHT C Wigner transforms.

.. list-table:: On-the-fly Price-McEwen recursions.
:widths: 25 25
:header-rows: 1
Expand All @@ -64,4 +83,5 @@ Transforms
on_the_fly_recursions
spin_spherical_transform
wigner
.. c_backend_spherical

Binary file modified docs/assets/figures/spherical_sampling.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 8 additions & 8 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
author = "Matthew Price, Jason McEwen, Matthew Graham, Sofia Miñano, Devaraj Gopinathan"

# The short X.Y version
version = "1.0.2"
version = "1.1.0"
# The full version, including alpha/beta/rc tags
release = "1.0.2"
release = "1.1.0"


# -- General configuration ---------------------------------------------------
Expand Down Expand Up @@ -106,12 +106,12 @@
"icon": "_static/arxiv-logomark-small.png",
"type": "local",
},
# {
# "name": "YouTube",
# "url": "https://www.youtube.com/channel/UCrCOQsyQOJhOUaIYzmbkKQQ",
# "icon": "fa-brands fa-youtube fa-2x",
# "type": "fontawesome",
# },
{
"name": "Medium",
"url": "https://towardsdatascience.com/differentiable-and-accelerated-spherical-harmonic-transforms-c269393d08f1",
"icon": "fa-brands fa-medium",
"type": "fontawesome",
},
{
"name": "PyPi",
"url": "https://pypi.org/project/s2fft/",
Expand Down
20 changes: 15 additions & 5 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,19 @@ transforms (for both real and complex signals), with support for adjoint transfo
where needed, and comes with different optimisations (precompute or not) that one
may select depending on available resources and desired angular resolution :math:`L`.

As of version 1.0.2 ``S2FFT`` also provides PyTorch implementations of underlying
precompute transforms. In future releases this support will be extended to our
on-the-fly algorithms.
.. important::
HEALPix long JIT compile time fixed for CPU! Fix for GPU coming soon.

.. tip::
As of version 1.0.2 ``S2FFT`` also provides PyTorch implementations of underlying
precompute transforms. In future releases this support will be extended to our
on-the-fly algorithms.

.. tip::
As of version 1.1.0 ``S2FFT`` also provides JAX support for existing C/C++ packages,
specifically ``HEALPix`` and ``SSHT``. This works by wrapping python bindings with custom
JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
limited to CPU.

Algorithms |:zap:|
-------------------
Expand All @@ -40,7 +50,7 @@ diagram below illustrates the separable spherical harmonic transform.
.. image:: ./assets/figures/sax_schematic_github_docs.png

.. note::
For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and will work to improve this in subsequent versions.
For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and are working to fix it. A fix for CPU execution has now been implemented (see example `notebook <https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_HEALPix_backend.html>`_). Fix for GPU execution is coming soon.

Sampling |:earth_africa:|
-----------------------------------
Expand All @@ -53,7 +63,7 @@ The equiangular sampling schemes of `McEwen & Wiaux (2012) <https://arxiv.org/ab
The popular `HEALPix <https://healpix.jpl.nasa.gov>`_ sampling scheme (`Gorski et al. 2005 <https://arxiv.org/abs/astro-ph/0409513>`_) is also supported. The HEALPix sampling does not exhibit a sampling theorem and so the corresponding harmonic transforms do not achieve machine precision but exhibit some error. However, the HEALPix sampling provides pixels of equal areas, which has many practical advantages.

.. image:: ./assets/figures/spherical_sampling.png
:width: 700
:width: 900
:align: center

Contributors ✨
Expand Down
3 changes: 3 additions & 0 deletions docs/tutorials/JAX_HEALPix/JAX_HEALPix_frontend.nblink
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"path": "../../../notebooks/JAX_HEALPix_frontend.ipynb"
}
3 changes: 3 additions & 0 deletions docs/tutorials/JAX_SSHT/JAX_SSHT_frontend.nblink
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"path": "../../../notebooks/JAX_SSHT_frontend.ipynb"
}
52 changes: 21 additions & 31 deletions docs/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ in the time being feel free to contact contributors for advice! At a high-level
``S2FFT`` package is structured such that the 2 primary transforms, the Wigner and
spherical harmonic transforms, can easily be accessed.

Usage |:rocket:|
Core usage |:rocket:|
-----------------
To import and use ``S2FFT`` is as simple follows:

Expand All @@ -25,39 +25,27 @@ To import and use ``S2FFT`` is as simple follows:
| f = s2fft.inverse_jax(flm, L) | f = s2fft.wigner.inverse_jax(flmn, L, N) |
+-------------------------------------------------------+------------------------------------------------------------+

C/C++ backend usage |:bulb:|
-----------------
``S2FFT`` also provides JAX support for existing C/C++ packages, specifically `HEALPix <https://healpix.jpl.nasa.gov>`_
and `SSHT <https://github.com/astro-informatics/ssht>`_. This works
by wrapping python bindings with custom JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
limited to CPU, however for many applications this is desirable due to memory constraints.

For example, one may call these alternate backends for the spherical harmonic transform by:

.. code-block:: python
Benchmarking |:hourglass_flowing_sand:|
-------------------------------------
We benchmarked the spherical harmonic and Wigner transforms implemented in ``S2FFT``
against the C implementations in the `SSHT <https://github.com/astro-informatics/ssht>`_
pacakge.
# Forward SSHT spherical harmonic transform
flm = s2fft.forward(f, L, sampling=["mw"], method="jax_ssht")
A brief summary is shown in the table below for the recursion (left) and precompute
(right) algorithms, with ``S2FFT`` running on GPUs (for further details see Price &
McEwen, in prep.). Note that our compute time is agnostic to spin number (which is not
the case for many other methods that scale linearly with spin).
# Forward HEALPix spherical harmonic transform
flm = s2fft.forward(f, L, nside=nside, sampling="healpix", method="jax_healpy")
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| | Recursive Algorithm | Precompute Algorithm |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| L | Wall-Time | Speed-up | Error | Wall-Time | Speed-up | Error | Memory |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 64 | 3.6 ms | 0.88 | 1.81E-15 | 52.4 μs | 60.5 | 1.67E-15 | 4.2 MB |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 128 | 7.26 ms | 1.80 | 3.32E-15 | 162 μs | 80.5 | 3.64E-15 | 33 MB |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 256 | 17.3 ms | 6.32 | 6.66E-15 | 669 μs | 163 | 6.74E-15 | 268 MB |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 512 | 58.3 ms | 11.4 | 1.43E-14 | 3.6 ms | 184 | 1.37E-14 | 2.14 GB |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 1024 | 194 ms | 32.9 | 2.69E-14 | 32.6 ms | 195 | 2.47E-14 | 17.1 GB |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 2048 | 1.44 s | 49.7 | 5.17E-14 | N/A | N/A | N/A | N/A |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 4096 | 8.48 s | 133.9 | 1.06E-13 | N/A | N/A | N/A | N/A |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 8192 | 82 s | 110.8 | 2.14E-13 | N/A | N/A | N/A | N/A |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
All of these JAX frontends supports out of the box reverse mode automatic differentiation,
and under the hood is simply linking to the C/C++ packages you are familiar with. In this
way ``S2FFT`` enhances existing packages with gradient functionality for modern signal processing
applications!


.. toctree::
Expand All @@ -69,3 +57,5 @@ the case for many other methods that scale linearly with spin).
wigner/wigner_transform.nblink
rotation/rotation.nblink
torch_frontend/torch_frontend.nblink
JAX_SSHT/JAX_SSHT_frontend.nblink
JAX_HEALPix/JAX_HEALPix_frontend.nblink
Loading

0 comments on commit 76fa862

Please sign in to comment.