Skip to content

Commit

Permalink
Merge pull request #583 from solldavid/sollberger/3d-dwt
Browse files Browse the repository at this point in the history
Feature: N-dimensional discrete wavelet transforms
  • Loading branch information
mrava87 authored Jun 17, 2024
2 parents b19a63c + 9b0c9b9 commit 4860e7a
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,4 @@ A list of video tutorials to learn more about PyLops:
* Wei Zhang, ZhangWeiGeo
* Fedor Goncharov, fedor-goncharov
* Alex Rakowski, alex-rakowski
* David Sollberger, solldavid
1 change: 1 addition & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ Signal processing
Shift
DWT
DWT2D
DWTND
DCT
DTCWT
Seislet
Expand Down
1 change: 1 addition & 0 deletions docs/source/credits.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ Contributors
* `Wei Zhang <https://github.com/ZhangWeiGeo>`_, ZhangWeiGeo
* `Fedor Goncharov <https://github.com/fedor-goncharov>`_, fedor-goncharov
* `Alex Rakowski <https://github.com/alex-rakowski>`_, alex-rakowski
* `David Sollberger <https://github.com/solldavid>`_, solldavid
48 changes: 46 additions & 2 deletions examples/plot_wavelet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
Wavelet transform
=================
This example shows how to use the :py:class:`pylops.DWT` and
:py:class:`pylops.DWT2D` operators to perform 1- and 2-dimensional DWT.
This example shows how to use the :py:class:`pylops.DWT`,
:py:class:`pylops.DWT2D`, and :py:class:`pylops.DWTND` operators
to perform 1-, 2-, and N-dimensional DWT.
"""
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -67,3 +68,46 @@
axs[1, 1].set_title("DWT2 coefficients (zeroed)")
axs[1, 1].axis("tight")
plt.tight_layout()

###############################################################################
# Let us now try the same with a 3D volumetric model, where we use the
# N-dimensional DWT. This time, we only retain 10 percent of the coefficients
# of the DWT.

nx = 128
ny = 256
nz = 128

x = np.arange(nx)
y = np.arange(ny)
z = np.arange(nz)

xx, yy, zz = np.meshgrid(x, y, z, indexing="ij")
# Generate a 3D model with two block anomalies
m = np.ones_like(xx, dtype=float)
block1 = (xx > 10) & (xx < 60) & (yy > 100) & (yy < 150) & (zz > 20) & (zz < 70)
block2 = (xx > 70) & (xx < 80) & (yy > 100) & (yy < 200) & (zz > 10) & (zz < 50)
m[block1] = 1.2
m[block2] = 0.8
Wop = pylops.signalprocessing.DWTND((nx, ny, nz), wavelet="haar", level=3)
y = Wop * m

ratio = 0.1
yf = y.copy()
yf.flat[int(ratio * y.size) :] = 0
iminv = Wop.H * yf

fig, axs = plt.subplots(2, 2, figsize=(6, 6))
axs[0, 0].imshow(m[:, :, 30], cmap="gray")
axs[0, 0].set_title("Model (Slice at z=30)")
axs[0, 0].axis("tight")
axs[0, 1].imshow(y[:, :, 90], cmap="gray_r")
axs[0, 1].set_title("DWTNT coefficients")
axs[0, 1].axis("tight")
axs[1, 0].imshow(iminv[:, :, 30], cmap="gray")
axs[1, 0].set_title("Reconstructed model (Slice at z=30)")
axs[1, 0].axis("tight")
axs[1, 1].imshow(yf[:, :, 90], cmap="gray_r")
axs[1, 1].set_title("DWTNT coefficients (zeroed)")
axs[1, 1].axis("tight")
plt.tight_layout()
3 changes: 3 additions & 0 deletions pylops/signalprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Shift Fractional Shift operator.
DWT One dimensional Wavelet operator.
DWT2D Two dimensional Wavelet operator.
DWTND N-dimensional Wavelet operator.
DCT Discrete Cosine Transform.
DTCWT Dual-Tree Complex Wavelet Transform.
Radon2D Two dimensional Radon transform.
Expand Down Expand Up @@ -61,6 +62,7 @@
from .fredholm1 import *
from .dwt import *
from .dwt2d import *
from .dwtnd import *
from .seislet import *
from .dct import *
from .dtcwt import *
Expand Down Expand Up @@ -93,6 +95,7 @@
"Fredholm1",
"DWT",
"DWT2D",
"DWTND",
"Seislet",
"DCT",
"DTCWT",
Expand Down
138 changes: 138 additions & 0 deletions pylops/signalprocessing/dwtnd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
__all__ = ["DWTND"]

import logging
from math import ceil, log

import numpy as np

from pylops import LinearOperator
from pylops.basicoperators import Pad
from pylops.utils import deps
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

from .dwt import _adjointwavelet, _checkwavelet

pywt_message = deps.pywt_import("the dwtnd module")

if pywt_message is None:
import pywt

logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)


class DWTND(LinearOperator):
"""N-dimensional Wavelet operator.
Apply ND-Wavelet transform along N ``axes`` of a
multi-dimensional array of size ``dims``.
Note that the Wavelet operator is an overload of the ``pywt``
implementation of the wavelet transform. Refer to
https://pywavelets.readthedocs.io for a detailed description of the
input parameters.
Defaults to a 3D wavelet transform along the last three dimensions
of the input array.
Parameters
----------
dims : :obj:`tuple`
Number of samples for each dimension
axes : :obj:`int`, optional
Axis along which DWTND is applied
wavelet : :obj:`str`, optional
Name of wavelet type. Use :func:`pywt.wavelist(kind='discrete')` for
a list of available wavelets.
level : :obj:`int`, optional
Number of scaling levels (must be >=0).
dtype : :obj:`str`, optional
Type of elements in input array.
name : :obj:`str`, optional
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
Attributes
----------
shape : :obj:`tuple`
Operator shape
explicit : :obj:`bool`
Operator contains a matrix that can be solved explicitly
(``True``) or not (``False``)
Raises
------
ModuleNotFoundError
If ``pywt`` is not installed
ValueError
If ``wavelet`` does not belong to ``pywt.families``
Notes
-----
The Wavelet operator applies the N-dimensional multilevel Discrete
Wavelet Transform (DWTN) in forward mode and the N-dimensional multilevel
Inverse Discrete Wavelet Transform (IDWTN) in adjoint mode.
"""

def __init__(
self,
dims: InputDimsLike,
axes: InputDimsLike = (-3, -2, -1),
wavelet: str = "haar",
level: int = 1,
dtype: DTypeLike = "float64",
name: str = "D",
) -> None:
if pywt_message is not None:
raise ModuleNotFoundError(pywt_message)
_checkwavelet(wavelet)

# define padding for length to be power of 2
ndimpow2 = [max(2 ** ceil(log(dims[ax], 2)), 2**level) for ax in axes]
pad = [(0, 0)] * len(dims)
for i, ax in enumerate(axes):
pad[ax] = (0, ndimpow2[i] - dims[ax])
self.pad = Pad(dims, pad)
self.axes = axes
dimsd = list(dims)
for i, ax in enumerate(axes):
dimsd[ax] = ndimpow2[i]
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name)

# apply transform once again to find out slices
_, self.sl = pywt.coeffs_to_array(
pywt.wavedecn(
np.ones(self.dimsd),
wavelet=wavelet,
level=level,
mode="periodization",
axes=self.axes,
),
axes=self.axes,
)
self.wavelet = wavelet
self.waveletadj = _adjointwavelet(wavelet)
self.level = level

def _matvec(self, x: NDArray) -> NDArray:
x = self.pad.matvec(x)
x = np.reshape(x, self.dimsd)
y = pywt.coeffs_to_array(
pywt.wavedecn(
x,
wavelet=self.wavelet,
level=self.level,
mode="periodization",
axes=self.axes,
),
axes=(self.axes),
)[0]
return y.ravel()

def _rmatvec(self, x: NDArray) -> NDArray:
x = np.reshape(x, self.dimsd)
x = pywt.array_to_coeffs(x, self.sl, output_format="wavedecn")
y = pywt.waverecn(
x, wavelet=self.waveletadj, mode="periodization", axes=self.axes
)
y = self.pad.rmatvec(y.ravel())
return y
64 changes: 63 additions & 1 deletion pytests/test_dwts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,20 @@
from numpy.testing import assert_array_almost_equal
from scipy.sparse.linalg import lsqr

from pylops.signalprocessing import DWT, DWT2D
from pylops.signalprocessing import DWT, DWT2D, DWTND
from pylops.utils import dottest

par1 = {"ny": 7, "nx": 9, "nt": 10, "imag": 0, "dtype": "float32"} # real
par2 = {"ny": 7, "nx": 9, "nt": 10, "imag": 1j, "dtype": "complex64"} # complex
par3 = {"ny": 7, "nx": 9, "nz": 9, "nt": 10, "imag": 0, "dtype": "float32"} # real 4D
par4 = {
"ny": 7,
"nx": 9,
"nz": 9,
"nt": 10,
"imag": 1j,
"dtype": "complex64",
} # complex 4D

np.random.seed(10)

Expand Down Expand Up @@ -133,3 +142,56 @@ def test_DWT2D_3dsignal(par):

assert_array_almost_equal(x.ravel(), xadj, decimal=8)
assert_array_almost_equal(x.ravel(), xinv, decimal=8)


@pytest.mark.parametrize("par", [(par3), (par4)])
def test_DWTND_3dsignal(par):
"""Dot-test and inversion for DWTND operator for 3d signal"""
DWTop = DWTND(
dims=(par["nt"], par["nx"], par["ny"]), axes=(0, 1, 2), wavelet="haar", level=3
)
x = np.random.normal(0.0, 1.0, (par["nt"], par["nx"], par["ny"])) + par[
"imag"
] * np.random.normal(0.0, 1.0, (par["nt"], par["nx"], par["ny"]))

assert dottest(
DWTop, DWTop.shape[0], DWTop.shape[1], complexflag=0 if par["imag"] == 0 else 3
)

y = DWTop * x.ravel()
xadj = DWTop.H * y # adjoint is same as inverse for dwt
xinv = lsqr(DWTop, y, damp=1e-10, iter_lim=10, atol=1e-8, btol=1e-8, show=0)[0]

assert_array_almost_equal(x.ravel(), xadj, decimal=8)
assert_array_almost_equal(x.ravel(), xinv, decimal=8)


@pytest.mark.parametrize("par", [(par3), (par4)])
def test_DWTND_4dsignal(par):
"""Dot-test and inversion for DWTND operator for 4d signal"""
for axes in [(0, 1, 2), (0, 2, 3), (1, 2, 3), (0, 1, 3), (0, 1, 2, 3)]:
DWTop = DWTND(
dims=(par["nt"], par["nx"], par["ny"], par["nz"]),
axes=axes,
wavelet="haar",
level=3,
)
x = np.random.normal(
0.0, 1.0, (par["nt"], par["nx"], par["ny"], par["nz"])
) + par["imag"] * np.random.normal(
0.0, 1.0, (par["nt"], par["nx"], par["ny"], par["nz"])
)

assert dottest(
DWTop,
DWTop.shape[0],
DWTop.shape[1],
complexflag=0 if par["imag"] == 0 else 3,
)

y = DWTop * x.ravel()
xadj = DWTop.H * y # adjoint is same as inverse for dwt
xinv = lsqr(DWTop, y, damp=1e-10, iter_lim=10, atol=1e-8, btol=1e-8, show=0)[0]

assert_array_almost_equal(x.ravel(), xadj, decimal=8)
assert_array_almost_equal(x.ravel(), xinv, decimal=8)

0 comments on commit 4860e7a

Please sign in to comment.