diff --git a/README.md b/README.md index c52bf3a5..a1fed65f 100644 --- a/README.md +++ b/README.md @@ -57,8 +57,9 @@ isolattitude sampling scheme. A number of sampling schemes are currently supported. The equiangular sampling schemes of [McEwen & Wiaux -(2012)](https://arxiv.org/abs/1110.6298) and [Driscoll & Healy -(1995)](https://www.sciencedirect.com/science/article/pii/S0196885884710086) +(2012)](https://arxiv.org/abs/1110.6298), [Driscoll & Healy +(1995)](https://www.sciencedirect.com/science/article/pii/S0196885884710086) +and [Gauss-Legendre (1986)](https://link.springer.com/article/10.1007/BF02519350) are supported, which exhibit associated sampling theorems and so harmonic transforms can be computed to machine precision. Note that the McEwen & Wiaux sampling theorem reduces the Nyquist rate on the sphere diff --git a/docs/index.rst b/docs/index.rst index 78c61339..c7b072fd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -21,7 +21,8 @@ Algorithms |:zap:| ``S2FFT`` leverages new algorithmic structures that can he highly parallelised and distributed, and so map very well onto the architecture of hardware accelerators (i.e. GPUs and TPUs). In particular, these algorithms are based on new Wigner-d recursions -that are stable to high angular resolution :math:`L`. The diagram below illustrates the recursions (for further details see Price & McEwen 2023). +that are stable to high angular resolution :math:`L`. The diagram below illustrates the +recursions (for further details see Price & McEwen 2023). .. image:: ./assets/figures/Wigner_recursion_github_docs.png @@ -46,7 +47,8 @@ Sampling |:earth_africa:| The structure of the algorithms implemented in ``S2FFT`` can support any isolattitude sampling scheme. A number of sampling schemes are currently supported. -The equiangular sampling schemes of `McEwen & Wiaux (2012) `_ and `Driscoll & Healy (1995) `_ are supported, which exhibit associated sampling theorems and so harmonic transforms can be computed to machine precision. Note that the McEwen & Wiaux sampling theorem reduces the Nyquist rate on the sphere by a factor of two compared to the Driscoll & Healy approach, halving the number of spherical samples required. +The equiangular sampling schemes of `McEwen & Wiaux (2012) `_, +`Driscoll & Healy (1995) `_, and `Gauss-Legendre (1986) `_ are supported, which exhibit associated sampling theorems and so harmonic transforms can be computed to machine precision. Note that the McEwen & Wiaux sampling theorem reduces the Nyquist rate on the sphere by a factor of two compared to the Driscoll & Healy approach, halving the number of spherical samples required. The popular `HEALPix `_ sampling scheme (`Gorski et al. 2005 `_) is also supported. The HEALPix sampling does not exhibit a sampling theorem and so the corresponding harmonic transforms do not achieve machine precision but exhibit some error. However, the HEALPix sampling provides pixels of equal areas, which has many practical advantages. diff --git a/requirements/requirements-core.txt b/requirements/requirements-core.txt index a2a369e2..7c86b0ea 100644 --- a/requirements/requirements-core.txt +++ b/requirements/requirements-core.txt @@ -3,4 +3,7 @@ numpy>=1.20 colorlog pyyaml jax>=0.3.13 -jaxlib \ No newline at end of file +jaxlib + +# Remove when subpackage functionality is fixed. +torch \ No newline at end of file diff --git a/s2fft/_version.py b/s2fft/_version.py new file mode 100644 index 00000000..684acf06 --- /dev/null +++ b/s2fft/_version.py @@ -0,0 +1,16 @@ +# file generated by setuptools_scm +# don't change, don't track in version control +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple, Union + VERSION_TUPLE = Tuple[Union[int, str], ...] +else: + VERSION_TUPLE = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE + +__version__ = version = '1.0.1.dev12' +__version_tuple__ = version_tuple = (1, 0, 1, 'dev12') diff --git a/s2fft/base_transforms/spherical.py b/s2fft/base_transforms/spherical.py index 4625e7bc..6bfaf5de 100644 --- a/s2fft/base_transforms/spherical.py +++ b/s2fft/base_transforms/spherical.py @@ -27,7 +27,7 @@ def inverse( spin (int, optional): Harmonic spin. Defaults to 0. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. diff --git a/s2fft/base_transforms/wigner.py b/s2fft/base_transforms/wigner.py index ba5295bf..b7ed60fe 100644 --- a/s2fft/base_transforms/wigner.py +++ b/s2fft/base_transforms/wigner.py @@ -33,7 +33,7 @@ def inverse( L_lower (int, optional): Harmonic lower bound. Defaults to 0. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to diff --git a/s2fft/precompute_transforms/construct.py b/s2fft/precompute_transforms/construct.py index 1ae4ea6d..1ca2f2db 100644 --- a/s2fft/precompute_transforms/construct.py +++ b/s2fft/precompute_transforms/construct.py @@ -34,7 +34,7 @@ def spin_spherical_kernel( Defaults to False. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int): HEALPix Nside resolution parameter. Only required if sampling="healpix". @@ -106,7 +106,7 @@ def spin_spherical_kernel_jax( Defaults to False. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int): HEALPix Nside resolution parameter. Only required if sampling="healpix". @@ -185,7 +185,7 @@ def wigner_kernel( Defaults to False. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int): HEALPix Nside resolution parameter. Only required if sampling="healpix". @@ -258,7 +258,7 @@ def wigner_kernel_jax( Defaults to False. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int): HEALPix Nside resolution parameter. Only required if sampling="healpix". diff --git a/s2fft/precompute_transforms/spherical.py b/s2fft/precompute_transforms/spherical.py index 5bf859c2..05bbb3d4 100644 --- a/s2fft/precompute_transforms/spherical.py +++ b/s2fft/precompute_transforms/spherical.py @@ -32,7 +32,7 @@ def inverse( kernel (np.ndarray, optional): Wigner-d kernel. Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -88,7 +88,7 @@ def inverse_transform( L (int): Harmonic band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -152,7 +152,7 @@ def inverse_transform_jax( L (int): Harmonic band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -277,7 +277,7 @@ def forward( kernel (np.ndarray, optional): Wigner-d kernel. Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -333,7 +333,7 @@ def forward_transform( L (int): Harmonic band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -401,7 +401,7 @@ def forward_transform_jax( L (int): Harmonic band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. diff --git a/s2fft/precompute_transforms/wigner.py b/s2fft/precompute_transforms/wigner.py index 3a33103b..8e266f51 100644 --- a/s2fft/precompute_transforms/wigner.py +++ b/s2fft/precompute_transforms/wigner.py @@ -32,7 +32,7 @@ def inverse( kernel (np.ndarray, optional): Wigner-d kernel. Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -89,7 +89,7 @@ def inverse_transform( N (int): Directional band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -150,7 +150,7 @@ def inverse_transform_jax( N (int): Directional band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -306,7 +306,7 @@ def forward( kernel (np.ndarray, optional): Wigner-d kernel. Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -361,7 +361,7 @@ def forward_transform( N (int): Directional band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -442,7 +442,7 @@ def forward_transform_jax( N (int): Directional band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. diff --git a/s2fft/sampling/s2_samples.py b/s2fft/sampling/s2_samples.py index dbc6bec3..398bec5b 100644 --- a/s2fft/sampling/s2_samples.py +++ b/s2fft/sampling/s2_samples.py @@ -10,7 +10,7 @@ def ntheta(L: int = None, sampling: str = "mw", nside: int = None) -> int: Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -31,7 +31,7 @@ def ntheta(L: int = None, sampling: str = "mw", nside: int = None) -> int: f"Sampling scheme sampling={sampling} with L={L} not supported" ) - if sampling.lower() == "mw": + if sampling.lower() in ["mw", "gl"]: return L elif sampling.lower() == "mwss": @@ -93,7 +93,7 @@ def nphi_equiang(L: int, sampling: str = "mw") -> int: L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl"}. Defaults to "mw". Raises: ValueError: HEALPix sampling scheme. @@ -104,7 +104,7 @@ def nphi_equiang(L: int, sampling: str = "mw") -> int: int: Number of :math:`\phi` samples. """ - if sampling.lower() == "mw": + if sampling.lower() in ["mw", "gl"]: return 2 * L - 1 elif sampling.lower() == "mwss": @@ -129,7 +129,7 @@ def ftm_shape(L: int, sampling: str = "mw", nside: int = None) -> Tuple[int, int L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. @@ -144,7 +144,7 @@ def ftm_shape(L: int, sampling: str = "mw", nside: int = None) -> Tuple[int, int if sampling.lower() in ["mwss", "healpix"]: return ntheta(L, sampling, nside), 2 * L - elif sampling.lower() in ["mw", "dh"]: + elif sampling.lower() in ["mw", "dh", "gl"]: return ntheta(L, sampling, nside), 2 * L - 1 else: @@ -203,7 +203,7 @@ def thetas(L: int = None, sampling: str = "mw", nside: int = None) -> np.ndarray Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -211,6 +211,9 @@ def thetas(L: int = None, sampling: str = "mw", nside: int = None) -> np.ndarray Returns: np.ndarray: Array of :math:`\theta` samples for given sampling scheme. """ + if sampling.lower() == "gl": + return np.flip(np.arccos(np.polynomial.legendre.leggauss(L)[0])) + t = np.arange(0, ntheta(L=L, sampling=sampling, nside=nside)).astype(np.float64) return t2theta(t, L, sampling, nside) @@ -228,7 +231,7 @@ def t2theta( Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -346,7 +349,7 @@ def phis_equiang(L: int, sampling: str = "mw") -> np.ndarray: L (int, optional): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported equiangular sampling - schemes include {"mw", "mwss", "dh"}. Defaults to "mw". + schemes include {"mw", "mwss", "dh", "gl"}. Defaults to "mw". Returns: np.ndarray: Array of :math:`\phi` samples for given sampling scheme. @@ -365,7 +368,7 @@ def p2phi_equiang(L: int, p: int, sampling: str = "mw") -> np.ndarray: p (int): :math:`\phi` index. sampling (str, optional): Sampling scheme. Supported equiangular sampling - schemes include {"mw", "mwss", "dh"}. Defaults to "mw". + schemes include {"mw", "mwss", "dh", "gl"}. Defaults to "mw". Raises: ValueError: HEALPix sampling not support (only equiangular schemes supported). @@ -376,7 +379,7 @@ def p2phi_equiang(L: int, p: int, sampling: str = "mw") -> np.ndarray: np.ndarray: :math:`\phi` sample(s) for given sampling scheme. """ - if sampling.lower() == "mw": + if sampling.lower() in ["mw", "gl"]: return 2 * p * np.pi / (2 * L - 1) elif sampling.lower() == "mwss": @@ -431,7 +434,7 @@ def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> Tuple[int L (int, optional): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. diff --git a/s2fft/sampling/so3_samples.py b/s2fft/sampling/so3_samples.py index 5c56b550..63a29784 100644 --- a/s2fft/sampling/so3_samples.py +++ b/s2fft/sampling/so3_samples.py @@ -23,7 +23,7 @@ def f_shape( N (int): Directional band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -35,15 +35,12 @@ def f_shape( Tuple[int,int,int]: Shape of pixel-space sampling of rotation group :math:`SO(3)`. """ - if sampling in ["mw", "mwss", "dh"]: + if sampling in ["mw", "mwss", "dh", "gl"]: return _ngamma(N), _nbeta(L, sampling), _nalpha(L, sampling) elif sampling.lower() == "healpix": return _ngamma(N), 12 * nside**2 - elif sampling.lower() == "healpix": - return 12 * nside**2, _ngamma(N) - else: raise ValueError(f"Sampling scheme sampling={sampling} not supported") @@ -76,7 +73,7 @@ def fnab_shape( N (int): Directional band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. @@ -91,7 +88,7 @@ def fnab_shape( if sampling.lower() in ["mwss", "healpix"]: return _ngamma(N), samples.ntheta(L, sampling, nside), 2 * L - elif sampling.lower() in ["mw", "dh"]: + elif sampling.lower() in ["mw", "dh", "gl"]: return _ngamma(N), samples.ntheta(L, sampling, nside), 2 * L - 1 else: @@ -121,7 +118,7 @@ def _nalpha(L: int, sampling: str = "mw") -> int: L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl"}. Defaults to "mw". Raises: ValueError: Unknown sampling scheme. @@ -129,7 +126,7 @@ def _nalpha(L: int, sampling: str = "mw") -> int: Returns: int: Number of :math:`\alpha` samples. """ - if sampling.lower() in ["mw", "dh"]: + if sampling.lower() in ["mw", "dh", "gl"]: return 2 * L - 1 elif sampling.lower() == "mwss": @@ -146,7 +143,7 @@ def _nbeta(L: int, sampling: str = "mw") -> int: L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl"}. Defaults to "mw". Raises: ValueError: Unknown sampling scheme. @@ -154,7 +151,7 @@ def _nbeta(L: int, sampling: str = "mw") -> int: Returns: int: Number of :math:`\beta` samples. """ - if sampling.lower() == "mw": + if sampling.lower() in ["mw", "gl"]: return L elif sampling.lower() == "mwss": diff --git a/s2fft/transforms/otf_recursions.py b/s2fft/transforms/otf_recursions.py index 74b80b9c..530c3501 100644 --- a/s2fft/transforms/otf_recursions.py +++ b/s2fft/transforms/otf_recursions.py @@ -46,7 +46,7 @@ def inverse_latitudinal_step( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to @@ -195,7 +195,7 @@ def inverse_latitudinal_step_jax( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to @@ -441,7 +441,7 @@ def forward_latitudinal_step( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to @@ -597,7 +597,7 @@ def forward_latitudinal_step_jax( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index eb53abc5..43e520c7 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -40,7 +40,7 @@ def inverse( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -112,7 +112,7 @@ def inverse_numpy( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -211,7 +211,7 @@ def inverse_jax( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -334,7 +334,7 @@ def forward( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -406,7 +406,7 @@ def forward_numpy( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -533,7 +533,7 @@ def forward_jax( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". diff --git a/s2fft/transforms/wigner.py b/s2fft/transforms/wigner.py index 8bb39e14..1db8382d 100644 --- a/s2fft/transforms/wigner.py +++ b/s2fft/transforms/wigner.py @@ -43,7 +43,7 @@ def inverse( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -124,7 +124,7 @@ def inverse_numpy( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -211,7 +211,7 @@ def inverse_jax( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -366,7 +366,7 @@ def forward( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -445,7 +445,7 @@ def forward_numpy( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -541,7 +541,7 @@ def forward_jax( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". diff --git a/s2fft/utils/quadrature.py b/s2fft/utils/quadrature.py index 34d1262e..571effcf 100644 --- a/s2fft/utils/quadrature.py +++ b/s2fft/utils/quadrature.py @@ -16,7 +16,7 @@ def quad_weights_transform( L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mwss", "dh", "healpix}. Defaults to "mwss". + {"mwss", "dh", "gl", "healpix}. Defaults to "mwss". spin (int, optional): Harmonic spin. Defaults to 0. @@ -38,6 +38,9 @@ def quad_weights_transform( elif sampling.lower() == "dh": return quad_weights_dh(L) + elif sampling.lower() == "gl": + return quad_weights_gl(L) + elif sampling.lower() == "healpix": return quad_weights_hp(nside) @@ -56,7 +59,7 @@ def quad_weights( Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". spin (int, optional): Harmonic spin. Defaults to 0. @@ -80,6 +83,9 @@ def quad_weights( elif sampling.lower() == "dh": return quad_weights_dh(L) + elif sampling.lower() == "gl": + return quad_weights_gl(L) + elif sampling.lower() == "healpix": return quad_weights_hp(nside) @@ -112,6 +118,43 @@ def quad_weights_hp(nside: int) -> np.ndarray: return hp_weights +def quad_weights_gl(L: int) -> np.ndarray: + r"""Compute GL quadrature weights for :math:`\theta` and :math:`\phi` integration. + + Args: + L (int): Harmonic band-limit. + + Returns: + np.ndarray: Weights computed for each :math:`\theta` (weights are identical + as :math:`\phi` varies for given :math:`\theta`). + """ + x1, x2 = -1.0, 1.0 + ntheta = samples.ntheta(L, "gl") + weights = np.zeros(ntheta, dtype=np.float64) + + m = int((L + 1) / 2) + x1 = 0.5 * (x2 - x1) + + i = np.arange(1, m + 1) + z = np.cos(np.pi * (i - 0.25) / (L + 0.5)) + z1 = 2.0 + while np.max(np.abs(z - z1)) > 1e-14: + p1 = 1.0 + p2 = 0.0 + for j in range(1, L + 1): + p3 = p2 + p2 = p1 + p1 = ((2.0 * j - 1.0) * z * p2 - (j - 1.0) * p3) / j + pp = L * (z * p1 - p2) / (z * z - 1.0) + z1 = z + z = z1 - p1 / pp + + weights[i - 1] = 2.0 * x1 / ((1.0 - z**2) * pp * pp) + weights[L + 1 - i - 1] = weights[i - 1] + + return weights * 2 * np.pi / (2 * L - 1) + + def quad_weights_dh(L: int) -> np.ndarray: r"""Compute DH quadrature weights for :math:`\theta` and :math:`\phi` integration. diff --git a/s2fft/utils/quadrature_jax.py b/s2fft/utils/quadrature_jax.py index a1b762a9..5ab7672a 100644 --- a/s2fft/utils/quadrature_jax.py +++ b/s2fft/utils/quadrature_jax.py @@ -1,3 +1,4 @@ +import jax from jax import jit from functools import partial @@ -20,7 +21,7 @@ def quad_weights_transform( L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mwss", "dh", "healpix}. Defaults to "mwss". + {"mwss", "dh", "gl", "healpix}. Defaults to "mwss". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -40,6 +41,9 @@ def quad_weights_transform( elif sampling.lower() == "dh": return quad_weights_dh(L) + elif sampling.lower() == "gl": + return quad_weights_gl(L) + elif sampling.lower() == "healpix": return quad_weights_hp(nside) @@ -58,7 +62,7 @@ def quad_weights(L: int = None, sampling: str = "mw", nside: int = None) -> jnp. Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". spin (int, optional): Harmonic spin. Defaults to 0. @@ -82,6 +86,9 @@ def quad_weights(L: int = None, sampling: str = "mw", nside: int = None) -> jnp. elif sampling.lower() == "dh": return quad_weights_dh(L) + elif sampling.lower() == "gl": + return quad_weights_gl(L) + elif sampling.lower() == "healpix": return quad_weights_hp(nside) @@ -111,6 +118,55 @@ def quad_weights_hp(nside: int) -> jnp.ndarray: return jnp.ones(rings, dtype=jnp.float64) * 4 * jnp.pi / npix +@partial(jit, static_argnums=(0)) +def quad_weights_gl(L: int) -> jnp.ndarray: + r"""Compute GL quadrature weights for :math:`\theta` and :math:`\phi` integration. + + Args: + L (int): Harmonic band-limit. + + Returns: + jnp.ndarray: Weights computed for each :math:`\theta` (weights are identical + as :math:`\phi` varies for given :math:`\theta`). + """ + x1, x2 = -1.0, 1.0 + ntheta = samples.ntheta(L, "gl") + weights = jnp.zeros(ntheta, dtype=jnp.float64) + + m = int((L + 1) / 2) + x1 = 0.5 * (x2 - x1) + i = jnp.arange(1, m + 1) + z = jnp.cos(jnp.pi * (i - 0.25) / (L + 0.5)) + z1 = 2.0 * jnp.ones_like(z) + pp = jnp.zeros_like(z) + + def optimizer(z, z1, pp): + def cond(arg): + z, z1, pp = arg + return jnp.max(jnp.abs(z - z1)) > 1e-14 + + def body(arg): + z, z1, pp = arg + p1 = 1.0 + p2 = 0.0 + for j in range(1, L + 1): + p3 = p2 + p2 = p1 + p1 = ((2.0 * j - 1.0) * z * p2 - (j - 1.0) * p3) / j + pp = L * (z * p1 - p2) / (z * z - 1.0) + z1 = z + z = z1 - p1 / pp + return z, z1, pp + + return jax.lax.while_loop(cond, body, (z, z1, pp)) + + z, z1, pp = optimizer(z, z1, pp) + weights = weights.at[i - 1].set(2.0 * x1 / ((1.0 - z**2) * pp * pp)) + weights = weights.at[L + 1 - i - 1].set(weights[i - 1]) + + return weights * 2 * jnp.pi / (2 * L - 1) + + @partial(jit, static_argnums=(0)) def quad_weights_dh(L: int) -> jnp.ndarray: r"""Compute DH quadrature weights for :math:`\theta` and :math:`\phi` integration. diff --git a/s2fft/utils/quadrature_torch.py b/s2fft/utils/quadrature_torch.py index 8743f6e6..8381642c 100644 --- a/s2fft/utils/quadrature_torch.py +++ b/s2fft/utils/quadrature_torch.py @@ -16,7 +16,7 @@ def quad_weights_transform( L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mwss", "dh", "healpix}. Defaults to "mwss". + {"mwss", "dh", "gl", "healpix}. Defaults to "mwss". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -36,6 +36,9 @@ def quad_weights_transform( elif sampling.lower() == "dh": return quad_weights_dh(L) + elif sampling.lower() == "gl": + return quad_weights_gl(L) + elif sampling.lower() == "healpix": return quad_weights_hp(nside) @@ -55,7 +58,7 @@ def quad_weights( Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". spin (int, optional): Harmonic spin. Defaults to 0. @@ -79,6 +82,9 @@ def quad_weights( elif sampling.lower() == "dh": return quad_weights_dh(L) + elif sampling.lower() == "gl": + return quad_weights_gl(L) + elif sampling.lower() == "healpix": return quad_weights_hp(nside) @@ -107,6 +113,43 @@ def quad_weights_hp(nside: int) -> torch.tensor: return torch.ones(rings, dtype=torch.float64) * 4 * torch.pi / npix +def quad_weights_gl(L: int) -> torch.tensor: + r"""Compute GL quadrature weights for :math:`\theta` and :math:`\phi` integration. + + Args: + L (int): Harmonic band-limit. + + Returns: + np.ndarray: Weights computed for each :math:`\theta` (weights are identical + as :math:`\phi` varies for given :math:`\theta`). + """ + x1, x2 = -1.0, 1.0 + ntheta = samples.ntheta(L, "gl") + weights = torch.zeros(ntheta, dtype=torch.float64) + + m = int((L + 1) / 2) + x1 = 0.5 * (x2 - x1) + + i = torch.arange(1, m + 1) + z = torch.cos(torch.pi * (i.type(torch.float64) - 0.25) / (L + 0.5)) + z1 = 2.0 + while torch.max(torch.abs(z - z1)) > 1e-14: + p1 = 1.0 + p2 = 0.0 + for j in range(1, L + 1): + p3 = p2 + p2 = p1 + p1 = ((2.0 * j - 1.0) * z * p2 - (j - 1.0) * p3) / j + pp = L * (z * p1 - p2) / (z * z - 1.0) + z1 = z + z = z1 - p1 / pp + + weights[i - 1] = 2.0 * x1 / ((1.0 - z**2) * pp * pp) + weights[L + 1 - i - 1] = weights[i - 1] + + return weights * 2 * torch.pi / (2 * L - 1) + + def quad_weights_dh(L: int) -> torch.tensor: r"""Compute DH quadrature weights for :math:`\theta` and :math:`\phi` integration. Torch implementation of :func:`s2fft.quadrature.quad_weights_dh`. diff --git a/tests/test_quadrature.py b/tests/test_quadrature.py index 126a11a7..93c826e6 100644 --- a/tests/test_quadrature.py +++ b/tests/test_quadrature.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("L", [5, 6]) -@pytest.mark.parametrize("sampling", ["mw", "mwss", "dh"]) +@pytest.mark.parametrize("sampling", ["mw", "mwss", "dh", "gl"]) @pytest.mark.parametrize("method", ["numpy", "jax", "torch"]) def test_quadrature_mw_weights(flm_generator, L: int, sampling: str, method: str): spin = 0 @@ -31,6 +31,9 @@ def test_quadrature_mw_weights(flm_generator, L: int, sampling: str, method: str nphi = samples.nphi_equiang(L, sampling) Q = q.dot(np.ones((1, nphi))) + print(q.shape) + print(Q.shape) + print(f.shape) integral_check = np.sum(Q * f) np.testing.assert_allclose(integral, integral_check, atol=1e-14) diff --git a/tests/test_samples.py b/tests/test_samples.py index 44f18687..95aa9a1a 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("L", [15, 16]) -@pytest.mark.parametrize("sampling", ["mw", "mwss", "dh"]) +@pytest.mark.parametrize("sampling", ["mw", "mwss", "dh", "gl"]) def test_samples_n_and_angles(L: int, sampling: str): # Test ntheta and nphi ntheta = samples.ntheta(L, sampling) @@ -18,8 +18,11 @@ def test_samples_n_and_angles(L: int, sampling: str): assert (ntheta, nphi) == pytest.approx((ntheta_ssht, nphi_ssht)) # Test thetas and phis - t = np.arange(0, ntheta) - thetas = samples.t2theta(t, L, sampling) + if sampling.lower() == "gl": + thetas = samples.thetas(L, sampling) + else: + t = np.arange(0, ntheta) + thetas = samples.t2theta(t, L, sampling) p = np.arange(0, nphi) phis = samples.p2phi_equiang(L, p, sampling) thetas_ssht, phis_ssht = ssht.sample_positions(L, sampling.upper()) diff --git a/tests/test_spherical_base.py b/tests/test_spherical_base.py index b888f0a2..9c79df75 100644 --- a/tests/test_spherical_base.py +++ b/tests/test_spherical_base.py @@ -11,7 +11,7 @@ spin_to_test = [-2, 0, 1] nside_to_test = [4, 5] L_to_nside_ratio = [2, 3] -sampling_to_test = ["mw", "mwss", "dh"] +sampling_to_test = ["mw", "mwss", "dh", "gl"] method_to_test = ["direct", "sov", "sov_fft", "sov_fft_vectorized"] reality_to_test = [False, True] diff --git a/tests/test_spherical_custom_grads.py b/tests/test_spherical_custom_grads.py index db52556d..2d34212d 100644 --- a/tests/test_spherical_custom_grads.py +++ b/tests/test_spherical_custom_grads.py @@ -12,7 +12,7 @@ L_lower_to_test = [2] spin_to_test = [-2, 0, 1] nside_to_test = [8] -sampling_to_test = ["mw", "mwss", "dh"] +sampling_to_test = ["mw", "mwss", "dh", "gl"] reality_to_test = [False, True] diff --git a/tests/test_spherical_precompute.py b/tests/test_spherical_precompute.py index 79e12390..de9d5662 100644 --- a/tests/test_spherical_precompute.py +++ b/tests/test_spherical_precompute.py @@ -9,7 +9,7 @@ spin_to_test = [-2, 0, 1] nside_to_test = [4, 5] L_to_nside_ratio = [2, 3] -sampling_to_test = ["mw", "mwss", "dh"] +sampling_to_test = ["mw", "mwss", "dh", "gl"] reality_to_test = [True, False] methods_to_test = ["numpy", "jax", "torch"] diff --git a/tests/test_spherical_transform.py b/tests/test_spherical_transform.py index cdb822cb..eaf03306 100644 --- a/tests/test_spherical_transform.py +++ b/tests/test_spherical_transform.py @@ -14,7 +14,7 @@ L_lower_to_test = [0, 2] spin_to_test = [-2, 0, 1] nside_to_test = [4, 5] -sampling_to_test = ["mw", "mwss", "dh"] +sampling_to_test = ["mw", "mwss", "dh", "gl"] method_to_test = ["numpy", "jax"] reality_to_test = [False, True] multiple_gpus = [False, True] diff --git a/tests/test_wigner_base.py b/tests/test_wigner_base.py index 64404ac3..e1f5b67a 100644 --- a/tests/test_wigner_base.py +++ b/tests/test_wigner_base.py @@ -9,7 +9,7 @@ N_to_test = [2, 3] L_lower_to_test = [0, 2] sampling_schemes_so3 = ["mw", "mwss"] -sampling_schemes = ["mw", "mwss", "dh"] +sampling_schemes = ["mw", "mwss", "dh", "gl"] reality_to_test = [False, True] diff --git a/tests/test_wigner_custom_grads.py b/tests/test_wigner_custom_grads.py index ffe491ab..780e81b3 100644 --- a/tests/test_wigner_custom_grads.py +++ b/tests/test_wigner_custom_grads.py @@ -11,7 +11,7 @@ L_to_test = [6] N_to_test = [3] L_lower_to_test = [1] -sampling_to_test = ["mw", "mwss", "dh"] +sampling_to_test = ["mw", "mwss", "dh", "gl"] reality_to_test = [False, True] diff --git a/tests/test_wigner_precompute.py b/tests/test_wigner_precompute.py index 8b8bc15c..43301f2b 100644 --- a/tests/test_wigner_precompute.py +++ b/tests/test_wigner_precompute.py @@ -11,7 +11,7 @@ nside_to_test = [4, 6] L_to_nside_ratio = [2] reality_to_test = [False, True] -sampling_schemes = ["mw", "mwss", "dh"] +sampling_schemes = ["mw", "mwss", "dh", "gl"] methods_to_test = ["numpy", "jax", "torch"] diff --git a/tests/test_wigner_recursions.py b/tests/test_wigner_recursions.py index 342c1f42..5cf059ce 100644 --- a/tests/test_wigner_recursions.py +++ b/tests/test_wigner_recursions.py @@ -10,7 +10,7 @@ L_to_test = [6, 7] spin_to_test = [-2, 0, 1] -sampling_schemes = ["mw", "mwss", "dh", "healpix"] +sampling_schemes = ["mw", "mwss", "dh", "gl", "healpix"] def test_trapani_with_ssht(): diff --git a/tests/test_wigner_transform.py b/tests/test_wigner_transform.py index 37e34b18..485b34a3 100644 --- a/tests/test_wigner_transform.py +++ b/tests/test_wigner_transform.py @@ -14,7 +14,7 @@ L_to_test = [6, 7] N_to_test = [2] L_lower_to_test = [0, 2] -sampling_to_test = ["mw", "mwss", "dh"] +sampling_to_test = ["mw", "mwss", "dh", "gl"] method_to_test = ["numpy", "jax"] reality_to_test = [False, True] multiple_gpus = [False, True]