From b56ab830535c176ccf76be8b76df236c3c57ef7e Mon Sep 17 00:00:00 2001 From: mrava87 Date: Thu, 29 Feb 2024 11:15:03 +0300 Subject: [PATCH] feat: added DTCWT 1d operator --- docs/source/installation.rst | 13 +- environment-dev-arm.yml | 1 + environment-dev.yml | 1 + examples/plot_dtcwt.py | 82 +++++++++++++ pylops/signalprocessing/__init__.py | 6 +- pylops/signalprocessing/dct.py | 2 +- pylops/signalprocessing/dtcwt.py | 182 ++++++++++++++++++++++++++++ pylops/utils/deps.py | 19 +++ pyproject.toml | 1 + pytests/test_dtcwt.py | 82 +++++++++++++ requirements-dev.txt | 1 + 11 files changed, 387 insertions(+), 3 deletions(-) create mode 100644 examples/plot_dtcwt.py create mode 100644 pylops/signalprocessing/dtcwt.py create mode 100644 pytests/test_dtcwt.py diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 563b62aa..a9c2d52a 100755 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -319,9 +319,20 @@ of GPUs should install it prior to installing PyLops as described in :ref:`Optio In alphabetic order: +dtcwt +----- +`dtcwt `_ is a library used to implement the DT-CWT operators. + +Install it via ``pip`` with: + +.. code-block:: bash + + >> pip install dtcwt + + Devito ------ -`Devito `_ is library used to solve PDEs via +`Devito `_ is a library used to solve PDEs via the finite-difference method. It is used in PyLops to compute wavefields :py:class:`pylops.waveeqprocessing.AcousticWave2D` diff --git a/environment-dev-arm.yml b/environment-dev-arm.yml index e04c5af7..7cb73753 100755 --- a/environment-dev-arm.yml +++ b/environment-dev-arm.yml @@ -25,6 +25,7 @@ dependencies: - black - pip: - devito + - dtcwt - scikit-fmm - spgl1 - pytest-runner diff --git a/environment-dev.yml b/environment-dev.yml index 2e692c53..59b2c127 100755 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -26,6 +26,7 @@ dependencies: - black - pip: - devito + - dtcwt - scikit-fmm - spgl1 - pytest-runner diff --git a/examples/plot_dtcwt.py b/examples/plot_dtcwt.py new file mode 100644 index 00000000..b2a51f8b --- /dev/null +++ b/examples/plot_dtcwt.py @@ -0,0 +1,82 @@ +""" +Dual-Tree Complex Wavelet Transform +=================================== +This example shows how to use the :py:class:`pylops.signalprocessing.DTCWT` operator to perform the +1D Dual-Tree Complex Wavelet Transform on a (single or multi-dimensional) input array. Such a transform +provides advantages over the DWT which lacks shift invariance in 1-D and directional sensitivity in N-D. +""" + +import matplotlib.pyplot as plt +import numpy as np +import pywt + +import pylops + +plt.close("all") + +############################################################################### +# To begin with, let's define two 1D arrays with a spike at slightly different location + +n = 128 +x = np.zeros(n) +x1 = np.zeros(n) + +x[59] = 1 +x1[63] = 1 + +############################################################################### +# We now create the DTCWT operator with the shape of our input array. The DTCWT transform +# provides a Pyramid object that is internally flattened out into a vector. Here we re-obtain +# the Pyramid object such that we can visualize the different scales indipendently. + +level = 3 +DCOp = pylops.signalprocessing.DTCWT(dims=n, level=level) +Xc = DCOp.get_pyramid(DCOp @ x) +Xc1 = DCOp.get_pyramid(DCOp @ x1) + +############################################################################### +# To prove the superiority of the DTCWT transform over the DWT in shift-invariance, +# let's also compute the DWT transform of these two signals and compare the coefficents +# of both transform at level 3. As you will see, the coefficients change completely for +# the DWT despite the two input signals are very similar; this is not the case for the +# DCWT transform. + +DOp = pylops.signalprocessing.DWT(dims=n, level=level, wavelet="sym7") +X = pywt.array_to_coeffs(DOp @ x, DOp.sl, output_format="wavedecn") +X1 = pywt.array_to_coeffs(DOp @ x1, DOp.sl, output_format="wavedecn") + +fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 5)) +axs[0, 0].stem(np.abs(X[1]["d"]), linefmt="k", markerfmt=".k", basefmt="k") +axs[0, 0].set_title(f"DWT (Norm={np.linalg.norm(np.abs(X[1]['d']))**2:.3f})") +axs[0, 1].stem(np.abs(X1[1]["d"]), linefmt="k", markerfmt=".k", basefmt="k") +axs[0, 1].set_title(f"DWT (Norm={np.linalg.norm(np.abs(X1[1]['d']))**2:.3f})") +axs[1, 0].stem(np.abs(Xc.highpasses[2]), linefmt="k", markerfmt=".k", basefmt="k") +axs[1, 0].set_title(f"DCWT (Norm={np.linalg.norm(np.abs(Xc.highpasses[2]))**2:.3f})") +axs[1, 1].stem(np.abs(Xc1.highpasses[2]), linefmt="k", markerfmt=".k", basefmt="k") +axs[1, 1].set_title(f"DCWT (Norm={np.linalg.norm(np.abs(Xc1.highpasses[2]))**2:.3f})") +plt.tight_layout() + +################################################################################### +# The DTCWT can also be performed on multi-dimension arrays, where the parameter +# ``axis`` is used to define the axis over which the transform is performed. Let's +# just replicate our input signal over the second axis and see how the transform +# will produce the same series of coefficients for all replicas. + +nrepeat = 10 +x = np.repeat(np.random.rand(n, 1), 10, axis=1).T + +level = 3 +DCOp = pylops.signalprocessing.DTCWT(dims=(nrepeat, n), level=level, axis=1) +X = DCOp @ x + +fig, axs = plt.subplots(1, 2, sharey=True, figsize=(10, 3)) +axs[0].imshow(X[0]) +axs[0].axis("tight") +axs[0].set_xlabel("Coeffs") +axs[0].set_ylabel("Replicas") +axs[0].set_title("DTCWT Real") +axs[1].imshow(X[1]) +axs[1].axis("tight") +axs[1].set_xlabel("Coeffs") +axs[1].set_title("DTCWT Imag") +plt.tight_layout() diff --git a/pylops/signalprocessing/__init__.py b/pylops/signalprocessing/__init__.py index 8efa532e..7137b586 100755 --- a/pylops/signalprocessing/__init__.py +++ b/pylops/signalprocessing/__init__.py @@ -24,9 +24,10 @@ DWT One dimensional Wavelet operator. DWT2D Two dimensional Wavelet operator. DCT Discrete Cosine Transform. - Seislet Two dimensional Seislet operator. + DTCWT Dual-Tree Complex Wavelet Transform. Radon2D Two dimensional Radon transform. Radon3D Three dimensional Radon transform. + Seislet Two dimensional Seislet operator. Sliding1D 1D Sliding transform operator. Sliding2D 2D Sliding transform operator. Sliding3D 3D Sliding transform operator. @@ -62,6 +63,8 @@ from .dwt2d import * from .seislet import * from .dct import * +from .dtcwt import * + __all__ = [ "FFT", @@ -92,4 +95,5 @@ "DWT2D", "Seislet", "DCT", + "DTCWT", ] diff --git a/pylops/signalprocessing/dct.py b/pylops/signalprocessing/dct.py index eb46e872..1a336be6 100644 --- a/pylops/signalprocessing/dct.py +++ b/pylops/signalprocessing/dct.py @@ -29,7 +29,7 @@ class DCT(LinearOperator): axes : :obj:`int` or :obj:`list`, optional Axes over which the DCT is computed. If ``None``, the transform is applied over all axes. - workers :obj:`int`, optional + workers : :obj:`int`, optional Maximum number of workers to use for parallel computation. If negative, the value wraps around from os.cpu_count(). dtype : :obj:`DTypeLike`, optional diff --git a/pylops/signalprocessing/dtcwt.py b/pylops/signalprocessing/dtcwt.py new file mode 100644 index 00000000..2c78ff52 --- /dev/null +++ b/pylops/signalprocessing/dtcwt.py @@ -0,0 +1,182 @@ +__all__ = ["DTCWT"] + +from typing import Union + +import numpy as np + +from pylops import LinearOperator +from pylops.utils import deps +from pylops.utils._internal import _value_or_sized_to_tuple +from pylops.utils.decorators import reshaped +from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray + +dtcwt_message = deps.dtcwt_import("the dtcwt module") + +if dtcwt_message is None: + import dtcwt + + +class DTCWT(LinearOperator): + r"""Dual-Tree Complex Wavelet Transform + + Perform 1D Dual-Tree Complex Wavelet Transform along an ``axis`` of a + multi-dimensional array of size ``dims``. + + Note that the DTCWT operator is an overload of the ``dtcwt`` + implementation of the DT-CWT transform. Refer to + https://dtcwt.readthedocs.io for a detailed description of the + input parameters. + + Parameters + ---------- + dims : :obj:`int` or :obj:`tuple` + Number of samples for each dimension. + birot : :obj:`str`, optional + Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot`. Default is `"near_sym_a"`. + qshift : :obj:`str`, optional + Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift`. Default is `"qshift_a"` + level : :obj:`int`, optional + Number of levels of wavelet decomposition. Default is 3. + include_scale : :obj:`bool`, optional + Include scales in pyramid. See :py:class:`dtcwt.Pyramid`. Default is False. + axis : :obj:`int`, optional + Axis on which the transform is performed. + dtype : :obj:`DTypeLike`, optional + Type of elements in input array. + name : :obj:`str`, optional + Name of operator (to be used by :func:`pylops.utils.describe.describe`) + + Notes + ----- + The DTCWT operator applies the dual-tree complex wavelet transform + in forward mode and the dual-tree complex inverse wavelet transform in adjoint mode + from the ``dtcwt`` library. + + The ``dtcwt`` library uses a Pyramid object to represent the signal in the transformed domain, + which is composed of: + - `lowpass` (coarsest scale lowpass signal); + - `highpasses` (complex subband coefficients for corresponding scales); + - `scales` (lowpass signal for corresponding scales finest to coarsest). + + To make the dtcwt forward() and inverse() functions compatible with PyLops, in forward model + the Pyramid object is flattened out and all coefficients (high-pass and low pass coefficients) + are appended into one array using the `_coeff_to_array` method. + + In adjoint mode, the input array is transformed back into a Pyramid object using the `_array_to_coeff` + method and then the inverse transform is performed. + + """ + + def __init__( + self, + dims: Union[int, InputDimsLike], + biort: str = "near_sym_a", + qshift: str = "qshift_a", + level: int = 3, + include_scale: bool = False, + axis: int = -1, + dtype: DTypeLike = "float64", + name: str = "C", + ) -> None: + if dtcwt_message is not None: + raise NotImplementedError(dtcwt_message) + + dims = _value_or_sized_to_tuple(dims) + self.ndim = len(dims) + self.axis = axis + + self.otherdims = int(np.prod(dims) / dims[self.axis]) + self.dims_swapped = list(dims) + self.dims_swapped[0], self.dims_swapped[self.axis] = ( + self.dims_swapped[self.axis], + self.dims_swapped[0], + ) + self.dims_swapped = tuple(self.dims_swapped) + self.level = level + self.include_scale = include_scale + + # dry-run of transform to find dimensions of coefficients at different levels + self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift) + self._interpret_coeffs(dims, self.axis) + + dimsd = list(dims) + dimsd[self.axis] = self.coeff_array_size + self.dimsd_swapped = list(dimsd) + self.dimsd_swapped[0], self.dimsd_swapped[self.axis] = ( + self.dimsd_swapped[self.axis], + self.dimsd_swapped[0], + ) + self.dimsd_swapped = tuple(self.dimsd_swapped) + dimsd = tuple( + [ + 2, + ] + + dimsd + ) + + super().__init__( + dtype=np.dtype(dtype), + clinear=False, + dims=dims, + dimsd=dimsd, + name=name, + ) + + def _interpret_coeffs(self, dims, axis): + x = np.ones(dims[axis]) + pyr = self._transform.forward( + x, nlevels=self.level, include_scale=self.include_scale + ) + self.lowpass_size = pyr.lowpass.size + self.coeff_array_size = self.lowpass_size + self.highpass_sizes = [] + for _h in pyr.highpasses: + self.highpass_sizes.append(_h.size) + self.coeff_array_size += _h.size + + def _nd_to_2d(self, arr_nd): + arr_2d = arr_nd.reshape(self.dims[self.axis], -1).squeeze() + return arr_2d + + def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray: + highpass_coeffs = np.vstack([h for h in pyr.highpasses]) + coeffs = np.concatenate((highpass_coeffs, pyr.lowpass), axis=0) + return coeffs + + def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid: + lowpass = (X[-self.lowpass_size :].real).reshape((-1, self.otherdims)) + _ptr = 0 + highpasses = () + for _sl in self.highpass_sizes: + _h = X[_ptr : _ptr + _sl] + _ptr += _sl + _h = _h.reshape(-1, self.otherdims) + highpasses += (_h,) + return dtcwt.Pyramid(lowpass, highpasses) + + def get_pyramid(self, x: NDArray) -> dtcwt.Pyramid: + """Return Pyramid object from flat real-valued array""" + return self._array_to_coeff(x[0] + 1j * x[1]) + + @reshaped + def _matvec(self, x: NDArray) -> NDArray: + x = x.swapaxes(self.axis, 0) + y = self._nd_to_2d(x) + y = self._coeff_to_array( + self._transform.forward( + y, nlevels=self.level, include_scale=self.include_scale + ) + ) + y = y.reshape(self.dimsd_swapped) + y = y.swapaxes(self.axis, 0) + y = np.concatenate([y.real[np.newaxis], y.imag[np.newaxis]]) + return y + + @reshaped + def _rmatvec(self, x: NDArray) -> NDArray: + x = x[0] + 1j * x[1] + x = x.swapaxes(self.axis, 0) + y = self._transform.inverse(self._array_to_coeff(x)) + y = y.reshape(self.dims_swapped) + y = y.swapaxes(self.axis, 0) + return y diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index cbd5b3e3..4b2d21e7 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -1,6 +1,7 @@ __all__ = [ "cupy_enabled", "devito_enabled", + "dtcwt_enabled", "numba_enabled", "pyfftw_enabled", "pywt_enabled", @@ -67,6 +68,23 @@ def devito_import(message: Optional[str] = None) -> str: return devito_message +def dtcwt_import(message: Optional[str] = None) -> str: + if dtcwt_enabled: + try: + import dtcwt # noqa: F401 + + dtcwt_message = None + except Exception as e: + dtcwt_message = f"Failed to import dtcwt (error:{e})." + else: + dtcwt_message = ( + f"Dtcwt not available. " + f"In order to be able to use " + f'{message} run "pip install dtcwt".' + ) + return dtcwt_message + + def numba_import(message: Optional[str] = None) -> str: if numba_enabled: try: @@ -187,6 +205,7 @@ def sympy_import(message: Optional[str] = None) -> str: True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False ) devito_enabled = util.find_spec("devito") is not None +dtcwt_enabled = util.find_spec("dtcwt") is not None numba_enabled = util.find_spec("numba") is not None pyfftw_enabled = util.find_spec("pyfftw") is not None pywt_enabled = util.find_spec("pywt") is not None diff --git a/pyproject.toml b/pyproject.toml index d2f6c854..9c338a48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ advanced = [ "PyWavelets", "scikit-fmm", "spgl1", + "dtcwt", ] [tool.setuptools.packages.find] diff --git a/pytests/test_dtcwt.py b/pytests/test_dtcwt.py new file mode 100644 index 00000000..b0cf2b61 --- /dev/null +++ b/pytests/test_dtcwt.py @@ -0,0 +1,82 @@ +import numpy as np +import pytest + +from pylops.signalprocessing import DTCWT + +par1 = {"ny": 10, "nx": 10, "dtype": "float64"} +par2 = {"ny": 50, "nx": 50, "dtype": "float64"} + + +def sequential_array(shape): + num_elements = np.prod(shape) + seq_array = np.arange(1, num_elements + 1) + result = seq_array.reshape(shape) + return result + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_dtcwt1D_input1D(par): + """Test for DTCWT with 1D input""" + + t = sequential_array((par["ny"],)) + + for level in range(1, 10): + Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"]) + x = Dtcwt @ t + y = Dtcwt.H @ x + + np.testing.assert_allclose(t, y) + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_dtcwt1D_input2D(par): + """Test for DTCWT with 2D input (forward-inverse pair)""" + + t = sequential_array( + ( + par["ny"], + par["ny"], + ) + ) + + for level in range(1, 10): + Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"]) + x = Dtcwt @ t + y = Dtcwt.H @ x + + np.testing.assert_allclose(t, y) + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_dtcwt1D_input3D(par): + """Test for DTCWT with 3D input (forward-inverse pair)""" + + t = sequential_array((par["ny"], par["ny"], par["ny"])) + + for level in range(1, 10): + Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"]) + x = Dtcwt @ t + y = Dtcwt.H @ x + + np.testing.assert_allclose(t, y) + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_dtcwt1D_birot(par): + """Test for DTCWT birot (forward-inverse pair)""" + birots = ["antonini", "legall", "near_sym_a", "near_sym_b"] + + t = sequential_array( + ( + par["ny"], + par["ny"], + ) + ) + + for _b in birots: + print(f"birot {_b}") + Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"]) + x = Dtcwt @ t + y = Dtcwt.H @ x + + np.testing.assert_allclose(t, y) diff --git a/requirements-dev.txt b/requirements-dev.txt index 8eabb62c..d86f07f1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,6 +8,7 @@ spgl1 scikit-fmm sympy devito +dtcwt matplotlib ipython pytest