diff --git a/README.md b/README.md index 17512736..8ee6ff83 100644 --- a/README.md +++ b/README.md @@ -150,3 +150,4 @@ A list of video tutorials to learn more about PyLops: * Wei Zhang, ZhangWeiGeo * Fedor Goncharov, fedor-goncharov * Alex Rakowski, alex-rakowski +* David Sollberger, solldavid diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 6879b247..e86981fe 100755 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -102,6 +102,7 @@ Signal processing Shift DWT DWT2D + DWTND DCT DTCWT Seislet diff --git a/docs/source/credits.rst b/docs/source/credits.rst index 6310549d..cea46fe0 100755 --- a/docs/source/credits.rst +++ b/docs/source/credits.rst @@ -22,3 +22,4 @@ Contributors * `Wei Zhang `_, ZhangWeiGeo * `Fedor Goncharov `_, fedor-goncharov * `Alex Rakowski `_, alex-rakowski +* `David Sollberger `_, solldavid diff --git a/examples/plot_wavelet.py b/examples/plot_wavelet.py index d080b025..4c410112 100644 --- a/examples/plot_wavelet.py +++ b/examples/plot_wavelet.py @@ -1,8 +1,9 @@ """ Wavelet transform ================= -This example shows how to use the :py:class:`pylops.DWT` and -:py:class:`pylops.DWT2D` operators to perform 1- and 2-dimensional DWT. +This example shows how to use the :py:class:`pylops.DWT`, +:py:class:`pylops.DWT2D`, and :py:class:`pylops.DWTND` operators +to perform 1-, 2-, and N-dimensional DWT. """ import matplotlib.pyplot as plt import numpy as np @@ -67,3 +68,46 @@ axs[1, 1].set_title("DWT2 coefficients (zeroed)") axs[1, 1].axis("tight") plt.tight_layout() + +############################################################################### +# Let us now try the same with a 3D volumetric model, where we use the +# N-dimensional DWT. This time, we only retain 10 percent of the coefficients +# of the DWT. + +nx = 128 +ny = 256 +nz = 128 + +x = np.arange(nx) +y = np.arange(ny) +z = np.arange(nz) + +xx, yy, zz = np.meshgrid(x, y, z, indexing="ij") +# Generate a 3D model with two block anomalies +m = np.ones_like(xx, dtype=float) +block1 = (xx > 10) & (xx < 60) & (yy > 100) & (yy < 150) & (zz > 20) & (zz < 70) +block2 = (xx > 70) & (xx < 80) & (yy > 100) & (yy < 200) & (zz > 10) & (zz < 50) +m[block1] = 1.2 +m[block2] = 0.8 +Wop = pylops.signalprocessing.DWTND((nx, ny, nz), wavelet="haar", level=3) +y = Wop * m + +ratio = 0.1 +yf = y.copy() +yf.flat[int(ratio * y.size) :] = 0 +iminv = Wop.H * yf + +fig, axs = plt.subplots(2, 2, figsize=(6, 6)) +axs[0, 0].imshow(m[:, :, 30], cmap="gray") +axs[0, 0].set_title("Model (Slice at z=30)") +axs[0, 0].axis("tight") +axs[0, 1].imshow(y[:, :, 90], cmap="gray_r") +axs[0, 1].set_title("DWTNT coefficients") +axs[0, 1].axis("tight") +axs[1, 0].imshow(iminv[:, :, 30], cmap="gray") +axs[1, 0].set_title("Reconstructed model (Slice at z=30)") +axs[1, 0].axis("tight") +axs[1, 1].imshow(yf[:, :, 90], cmap="gray_r") +axs[1, 1].set_title("DWTNT coefficients (zeroed)") +axs[1, 1].axis("tight") +plt.tight_layout() diff --git a/pylops/signalprocessing/__init__.py b/pylops/signalprocessing/__init__.py index 7137b586..a8e5ed65 100755 --- a/pylops/signalprocessing/__init__.py +++ b/pylops/signalprocessing/__init__.py @@ -23,6 +23,7 @@ Shift Fractional Shift operator. DWT One dimensional Wavelet operator. DWT2D Two dimensional Wavelet operator. + DWTND N-dimensional Wavelet operator. DCT Discrete Cosine Transform. DTCWT Dual-Tree Complex Wavelet Transform. Radon2D Two dimensional Radon transform. @@ -61,6 +62,7 @@ from .fredholm1 import * from .dwt import * from .dwt2d import * +from .dwtnd import * from .seislet import * from .dct import * from .dtcwt import * @@ -93,6 +95,7 @@ "Fredholm1", "DWT", "DWT2D", + "DWTND", "Seislet", "DCT", "DTCWT", diff --git a/pylops/signalprocessing/dwtnd.py b/pylops/signalprocessing/dwtnd.py new file mode 100644 index 00000000..af43bb0d --- /dev/null +++ b/pylops/signalprocessing/dwtnd.py @@ -0,0 +1,138 @@ +__all__ = ["DWTND"] + +import logging +from math import ceil, log + +import numpy as np + +from pylops import LinearOperator +from pylops.basicoperators import Pad +from pylops.utils import deps +from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray + +from .dwt import _adjointwavelet, _checkwavelet + +pywt_message = deps.pywt_import("the dwtnd module") + +if pywt_message is None: + import pywt + +logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) + + +class DWTND(LinearOperator): + """N-dimensional Wavelet operator. + + Apply ND-Wavelet transform along N ``axes`` of a + multi-dimensional array of size ``dims``. + + Note that the Wavelet operator is an overload of the ``pywt`` + implementation of the wavelet transform. Refer to + https://pywavelets.readthedocs.io for a detailed description of the + input parameters. + + Defaults to a 3D wavelet transform along the last three dimensions + of the input array. + + Parameters + ---------- + dims : :obj:`tuple` + Number of samples for each dimension + axes : :obj:`int`, optional + Axis along which DWTND is applied + wavelet : :obj:`str`, optional + Name of wavelet type. Use :func:`pywt.wavelist(kind='discrete')` for + a list of available wavelets. + level : :obj:`int`, optional + Number of scaling levels (must be >=0). + dtype : :obj:`str`, optional + Type of elements in input array. + name : :obj:`str`, optional + Name of operator (to be used by :func:`pylops.utils.describe.describe`) + + Attributes + ---------- + shape : :obj:`tuple` + Operator shape + explicit : :obj:`bool` + Operator contains a matrix that can be solved explicitly + (``True``) or not (``False``) + + Raises + ------ + ModuleNotFoundError + If ``pywt`` is not installed + ValueError + If ``wavelet`` does not belong to ``pywt.families`` + + Notes + ----- + The Wavelet operator applies the N-dimensional multilevel Discrete + Wavelet Transform (DWTN) in forward mode and the N-dimensional multilevel + Inverse Discrete Wavelet Transform (IDWTN) in adjoint mode. + + """ + + def __init__( + self, + dims: InputDimsLike, + axes: InputDimsLike = (-3, -2, -1), + wavelet: str = "haar", + level: int = 1, + dtype: DTypeLike = "float64", + name: str = "D", + ) -> None: + if pywt_message is not None: + raise ModuleNotFoundError(pywt_message) + _checkwavelet(wavelet) + + # define padding for length to be power of 2 + ndimpow2 = [max(2 ** ceil(log(dims[ax], 2)), 2**level) for ax in axes] + pad = [(0, 0)] * len(dims) + for i, ax in enumerate(axes): + pad[ax] = (0, ndimpow2[i] - dims[ax]) + self.pad = Pad(dims, pad) + self.axes = axes + dimsd = list(dims) + for i, ax in enumerate(axes): + dimsd[ax] = ndimpow2[i] + super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name) + + # apply transform once again to find out slices + _, self.sl = pywt.coeffs_to_array( + pywt.wavedecn( + np.ones(self.dimsd), + wavelet=wavelet, + level=level, + mode="periodization", + axes=self.axes, + ), + axes=self.axes, + ) + self.wavelet = wavelet + self.waveletadj = _adjointwavelet(wavelet) + self.level = level + + def _matvec(self, x: NDArray) -> NDArray: + x = self.pad.matvec(x) + x = np.reshape(x, self.dimsd) + y = pywt.coeffs_to_array( + pywt.wavedecn( + x, + wavelet=self.wavelet, + level=self.level, + mode="periodization", + axes=self.axes, + ), + axes=(self.axes), + )[0] + return y.ravel() + + def _rmatvec(self, x: NDArray) -> NDArray: + x = np.reshape(x, self.dimsd) + x = pywt.array_to_coeffs(x, self.sl, output_format="wavedecn") + y = pywt.waverecn( + x, wavelet=self.waveletadj, mode="periodization", axes=self.axes + ) + y = self.pad.rmatvec(y.ravel()) + return y diff --git a/pytests/test_dwts.py b/pytests/test_dwts.py index 0fca4526..09f567dc 100755 --- a/pytests/test_dwts.py +++ b/pytests/test_dwts.py @@ -3,11 +3,20 @@ from numpy.testing import assert_array_almost_equal from scipy.sparse.linalg import lsqr -from pylops.signalprocessing import DWT, DWT2D +from pylops.signalprocessing import DWT, DWT2D, DWTND from pylops.utils import dottest par1 = {"ny": 7, "nx": 9, "nt": 10, "imag": 0, "dtype": "float32"} # real par2 = {"ny": 7, "nx": 9, "nt": 10, "imag": 1j, "dtype": "complex64"} # complex +par3 = {"ny": 7, "nx": 9, "nz": 9, "nt": 10, "imag": 0, "dtype": "float32"} # real 4D +par4 = { + "ny": 7, + "nx": 9, + "nz": 9, + "nt": 10, + "imag": 1j, + "dtype": "complex64", +} # complex 4D np.random.seed(10) @@ -133,3 +142,56 @@ def test_DWT2D_3dsignal(par): assert_array_almost_equal(x.ravel(), xadj, decimal=8) assert_array_almost_equal(x.ravel(), xinv, decimal=8) + + +@pytest.mark.parametrize("par", [(par3), (par4)]) +def test_DWTND_3dsignal(par): + """Dot-test and inversion for DWTND operator for 3d signal""" + DWTop = DWTND( + dims=(par["nt"], par["nx"], par["ny"]), axes=(0, 1, 2), wavelet="haar", level=3 + ) + x = np.random.normal(0.0, 1.0, (par["nt"], par["nx"], par["ny"])) + par[ + "imag" + ] * np.random.normal(0.0, 1.0, (par["nt"], par["nx"], par["ny"])) + + assert dottest( + DWTop, DWTop.shape[0], DWTop.shape[1], complexflag=0 if par["imag"] == 0 else 3 + ) + + y = DWTop * x.ravel() + xadj = DWTop.H * y # adjoint is same as inverse for dwt + xinv = lsqr(DWTop, y, damp=1e-10, iter_lim=10, atol=1e-8, btol=1e-8, show=0)[0] + + assert_array_almost_equal(x.ravel(), xadj, decimal=8) + assert_array_almost_equal(x.ravel(), xinv, decimal=8) + + +@pytest.mark.parametrize("par", [(par3), (par4)]) +def test_DWTND_4dsignal(par): + """Dot-test and inversion for DWTND operator for 4d signal""" + for axes in [(0, 1, 2), (0, 2, 3), (1, 2, 3), (0, 1, 3), (0, 1, 2, 3)]: + DWTop = DWTND( + dims=(par["nt"], par["nx"], par["ny"], par["nz"]), + axes=axes, + wavelet="haar", + level=3, + ) + x = np.random.normal( + 0.0, 1.0, (par["nt"], par["nx"], par["ny"], par["nz"]) + ) + par["imag"] * np.random.normal( + 0.0, 1.0, (par["nt"], par["nx"], par["ny"], par["nz"]) + ) + + assert dottest( + DWTop, + DWTop.shape[0], + DWTop.shape[1], + complexflag=0 if par["imag"] == 0 else 3, + ) + + y = DWTop * x.ravel() + xadj = DWTop.H * y # adjoint is same as inverse for dwt + xinv = lsqr(DWTop, y, damp=1e-10, iter_lim=10, atol=1e-8, btol=1e-8, show=0)[0] + + assert_array_almost_equal(x.ravel(), xadj, decimal=8) + assert_array_almost_equal(x.ravel(), xinv, decimal=8)