Skip to content

Commit

Permalink
Merge pull request #83 from astro-informatics/feature/acceleration
Browse files Browse the repository at this point in the history
Feature/acceleration
  • Loading branch information
CosmoMatt authored Apr 15, 2024
2 parents 8199cce + dd4337a commit 9a52c1f
Show file tree
Hide file tree
Showing 10 changed files with 372 additions and 256 deletions.
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
author = "Matthew Price, Jason McEwen, Jessica Whitney, Alicja Polanska"

# The short X.Y version
version = "1.0.3"
version = "1.0.4"
# The full version, including alpha/beta/rc tags
release = "1.0.3"
release = "1.0.4"


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ pyyaml==6.0
scipy

# For spherical transforms
s2fft >= 1.1.0
s2fft >= 1.1.1
5 changes: 5 additions & 0 deletions s2wav/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
# JAX recursive transforms
from .transforms.wavelet import analysis, synthesis, flm_to_analysis

# C Backend transforms
from .transforms.wavelet_c import analysis as analysis_c
from .transforms.wavelet_c import synthesis as synthesis_c
from .transforms.wavelet_c import flm_to_analysis as flm_to_analysis_c

# Base transforms
from .transforms.base import analysis as analysis_base
from .transforms.base import synthesis as synthesis_base
Expand Down
1 change: 1 addition & 0 deletions s2wav/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import base
from . import construct
from . import wavelet
from . import wavelet_c
from . import wavelet_precompute
from . import wavelet_precompute_torch
171 changes: 20 additions & 151 deletions s2wav/transforms/wavelet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from jax import jit
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import Tuple, List
import s2fft
from s2wav import samples
from s2wav.transforms import construct


@partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9))
def synthesis(
f_wav: jnp.ndarray,
f_scal: jnp.ndarray,
Expand All @@ -21,8 +21,6 @@ def synthesis(
reality: bool = False,
filters: Tuple[jnp.ndarray] = None,
precomps: List[List[jnp.ndarray]] = None,
use_c_backend: bool = False,
_ssht_backend: int = 1,
) -> jnp.ndarray:
r"""Computes the synthesis directional wavelet transform [1,2].
Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2`
Expand Down Expand Up @@ -61,76 +59,30 @@ def synthesis(
precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most
of length :math:`L^2`, which is a minimal memory overhead.
use_c_backend (bool, optional): Execution mode in {"jax" = False, "jax_ssht" = True}.
Defaults to False.
_ssht_backend (int, optional, experimental): Whether to default to SSHT core
(set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental
backend. Use with caution.
Raises:
AssertionError: Shape of wavelet/scaling coefficients incorrect.
ValueError: If healpix sampling is provided to SSHT C backend.
Returns:
jnp.ndarray: Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`.
Notes:
[1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on the sphere", A&A, vol. 558, p. A128, 2013.
[2] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint arXiv:1509.06749 (2015).
"""
if precomps == None and not use_c_backend:
if precomps is None:
precomps = construct.generate_wigner_precomputes(
L, N, J_min, lam, sampling, nside, True, reality
)
if use_c_backend and sampling.lower() == "healpix":
raise ValueError("SSHT C backend does not support healpix sampling.")

J = samples.j_max(L, lam)
Ls = samples.scal_bandlimit(L, J_min, lam, True)
flm = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128)

f_scal_lm = (
s2fft.forward(
f_scal.real if reality else f_scal,
Ls,
spin,
nside,
sampling,
"jax_ssht",
reality,
_ssht_backend=_ssht_backend,
)
if use_c_backend
else s2fft.forward_jax(f_scal, Ls, spin, nside, sampling, reality)
)
f_scal_lm = s2fft.forward_jax(f_scal, Ls, spin, nside, sampling, reality)

# Sum the all wavelet wigner coefficients for each lmn
# Note that almost the entire compute is concentrated at the highest J
for j in range(J_min, J + 1):
Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True)
temp = (
s2fft.wigner.forward(
f_wav[j - J_min],
Lj,
Nj,
nside,
sampling,
"jax_ssht",
reality,
_ssht_backend=_ssht_backend,
)
if use_c_backend
else s2fft.wigner.forward_jax(
f_wav[j - J_min],
Lj,
Nj,
nside,
sampling,
reality,
precomps[j - J_min],
L_lower=L0j,
)
temp = s2fft.wigner.forward_jax(
f_wav[j - J_min], Lj, Nj, nside, sampling, reality, precomps[j - J_min], L0j
)
flm = flm.at[L0j:Lj, L - Lj : L - 1 + Lj].add(
jnp.einsum(
Expand All @@ -146,22 +98,10 @@ def synthesis(
flm = flm.at[:Ls, L - Ls : L - 1 + Ls].add(
jnp.einsum("lm,l->lm", f_scal_lm, phi, optimize=True)
)
return (
s2fft.inverse(
flm,
L,
spin,
nside,
sampling,
"jax_ssht",
reality,
_ssht_backend=_ssht_backend,
)
if use_c_backend
else s2fft.inverse_jax(flm, L, spin, nside, sampling, reality)
)
return s2fft.inverse_jax(flm, L, spin, nside, sampling, reality)


@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8))
def analysis(
f: jnp.ndarray,
L: int,
Expand All @@ -174,8 +114,6 @@ def analysis(
reality: bool = False,
filters: Tuple[jnp.ndarray] = None,
precomps: List[List[jnp.ndarray]] = None,
use_c_backend: bool = False,
_ssht_backend: int = 1,
) -> Tuple[jnp.ndarray]:
r"""Wavelet analysis from pixel space to wavelet space for complex signals.
Expand Down Expand Up @@ -206,54 +144,32 @@ def analysis(
precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most
of length :math:`L^2`, which is a minimal memory overhead.
use_c_backend (bool, optional): Execution mode in {"jax" = False, "jax_ssht" = True}.
Defaults to False.
_ssht_backend (int, optional, experimental): Whether to default to SSHT core
(set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental
backend. Use with caution.
Returns:
f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients
with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`.
f_scal (jnp.ndarray): Array of scaling pixel-space coefficients
with shape :math:`[n_{\theta}, n_{\phi}]`.
"""
if precomps == None and not use_c_backend:
if precomps is None:
precomps = construct.generate_wigner_precomputes(
L, N, J_min, lam, sampling, nside, False, reality
)
J = samples.j_max(L, lam)
Ls = samples.scal_bandlimit(L, J_min, lam, True)

f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True)
f_wav = samples.construct_f_jax(L, J_min, J, lam)

wav_lm = jnp.einsum(
"jln, l->jln",
jnp.conj(filters[0]),
8 * jnp.pi**2 / (2 * jnp.arange(L) + 1),
optimize=True,
)

flm = (
s2fft.forward(
f,
L,
spin,
nside,
sampling,
"jax_ssht",
reality,
_ssht_backend=_ssht_backend,
)
if use_c_backend
else s2fft.forward_jax(f, L, spin, nside, sampling, reality)
)
flm = s2fft.forward_jax(f, L, spin, nside, sampling, reality)

# Project all wigner coefficients for each lmn onto wavelet coefficients
# Note that almost the entire compute is concentrated at the highest J
f_wav = []
f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True)
for j in range(J_min, J + 1):
Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True)
f_wav_lmn[j - J_min] = (
Expand All @@ -269,27 +185,15 @@ def analysis(
)
)

f_wav[j - J_min] = (
s2fft.wigner.inverse(
f_wav_lmn[j - J_min],
Lj,
Nj,
nside,
sampling,
"jax_ssht",
reality,
_ssht_backend=_ssht_backend,
)
if use_c_backend
else s2fft.wigner.inverse_jax(
f_wav.append(
s2fft.wigner.inverse_jax(
f_wav_lmn[j - J_min],
Lj,
Nj,
nside,
sampling,
reality,
precomps[j - J_min],
False,
L0j,
)
)
Expand All @@ -302,23 +206,11 @@ def analysis(
if Ls == 1:
f_scal = temp * jnp.sqrt(1 / (4 * jnp.pi))
else:
f_scal = (
s2fft.inverse(
temp,
Ls,
spin,
nside,
sampling,
"jax_ssht",
reality,
_ssht_backend=_ssht_backend,
)
if use_c_backend
else s2fft.inverse_jax(temp, Ls, spin, nside, sampling, reality)
)
f_scal = s2fft.inverse_jax(temp, Ls, spin, nside, sampling, reality)
return f_wav, f_scal


@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8))
def flm_to_analysis(
flm: jnp.ndarray,
L: int,
Expand All @@ -331,8 +223,6 @@ def flm_to_analysis(
reality: bool = False,
filters: Tuple[jnp.ndarray] = None,
precomps: List[List[jnp.ndarray]] = None,
use_c_backend: bool = False,
_ssht_backend: int = 1,
) -> Tuple[jnp.ndarray]:
r"""Wavelet analysis from pixel space to wavelet space for complex signals.
Expand Down Expand Up @@ -363,27 +253,16 @@ def flm_to_analysis(
precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most
of length :math:`L^2`, which is a minimal memory overhead.
use_c_backend (bool, optional): Execution mode in {"jax" = False, "jax_ssht" = True}.
Defaults to False.
_ssht_backend (int, optional, experimental): Whether to default to SSHT core
(set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental
backend. Use with caution.
Returns:
f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients
with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`.
"""
if precomps == None and not use_c_backend:
if precomps is None:
precomps = construct.generate_wigner_precomputes(
L, N, J_min, lam, sampling, nside, False, reality
)

J = J_max if J_max is not None else samples.j_max(L, lam)

f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True)
f_wav = samples.construct_f_jax(L, J_min, J, lam)

wav_lm = jnp.einsum(
"jln, l->jln",
jnp.conj(filters),
Expand All @@ -393,6 +272,8 @@ def flm_to_analysis(

# Project all wigner coefficients for each lmn onto wavelet coefficients
# Note that almost the entire compute is concentrated at the highest J
f_wav = []
f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True)
for j in range(J_min, J + 1):
Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True)
f_wav_lmn[j - J_min] = (
Expand All @@ -408,27 +289,15 @@ def flm_to_analysis(
)
)

f_wav[j - J_min] = jnp.array(
s2fft.wigner.inverse(
jnp.array(f_wav_lmn[j - J_min]),
Lj,
Nj,
nside,
sampling,
"jax_ssht",
reality,
_ssht_backend=_ssht_backend,
)
if use_c_backend
else s2fft.wigner.inverse_jax(
f_wav.append(
s2fft.wigner.inverse_jax(
f_wav_lmn[j - J_min],
Lj,
Nj,
nside,
sampling,
reality,
precomps[j - J_min],
False,
L0j,
)
)
Expand Down
Loading

0 comments on commit 9a52c1f

Please sign in to comment.