diff --git a/docs/source/installation.rst b/docs/source/installation.rst
index 563b62aa..a9c2d52a 100755
--- a/docs/source/installation.rst
+++ b/docs/source/installation.rst
@@ -319,9 +319,20 @@ of GPUs should install it prior to installing PyLops as described in :ref:`Optio
In alphabetic order:
+dtcwt
+-----
+`dtcwt `_ is a library used to implement the DT-CWT operators.
+
+Install it via ``pip`` with:
+
+.. code-block:: bash
+
+ >> pip install dtcwt
+
+
Devito
------
-`Devito `_ is library used to solve PDEs via
+`Devito `_ is a library used to solve PDEs via
the finite-difference method. It is used in PyLops to compute wavefields
:py:class:`pylops.waveeqprocessing.AcousticWave2D`
diff --git a/environment-dev-arm.yml b/environment-dev-arm.yml
index e04c5af7..7cb73753 100755
--- a/environment-dev-arm.yml
+++ b/environment-dev-arm.yml
@@ -25,6 +25,7 @@ dependencies:
- black
- pip:
- devito
+ - dtcwt
- scikit-fmm
- spgl1
- pytest-runner
diff --git a/environment-dev.yml b/environment-dev.yml
index 2e692c53..59b2c127 100755
--- a/environment-dev.yml
+++ b/environment-dev.yml
@@ -26,6 +26,7 @@ dependencies:
- black
- pip:
- devito
+ - dtcwt
- scikit-fmm
- spgl1
- pytest-runner
diff --git a/examples/plot_dtcwt.py b/examples/plot_dtcwt.py
new file mode 100644
index 00000000..b2a51f8b
--- /dev/null
+++ b/examples/plot_dtcwt.py
@@ -0,0 +1,82 @@
+"""
+Dual-Tree Complex Wavelet Transform
+===================================
+This example shows how to use the :py:class:`pylops.signalprocessing.DTCWT` operator to perform the
+1D Dual-Tree Complex Wavelet Transform on a (single or multi-dimensional) input array. Such a transform
+provides advantages over the DWT which lacks shift invariance in 1-D and directional sensitivity in N-D.
+"""
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pywt
+
+import pylops
+
+plt.close("all")
+
+###############################################################################
+# To begin with, let's define two 1D arrays with a spike at slightly different location
+
+n = 128
+x = np.zeros(n)
+x1 = np.zeros(n)
+
+x[59] = 1
+x1[63] = 1
+
+###############################################################################
+# We now create the DTCWT operator with the shape of our input array. The DTCWT transform
+# provides a Pyramid object that is internally flattened out into a vector. Here we re-obtain
+# the Pyramid object such that we can visualize the different scales indipendently.
+
+level = 3
+DCOp = pylops.signalprocessing.DTCWT(dims=n, level=level)
+Xc = DCOp.get_pyramid(DCOp @ x)
+Xc1 = DCOp.get_pyramid(DCOp @ x1)
+
+###############################################################################
+# To prove the superiority of the DTCWT transform over the DWT in shift-invariance,
+# let's also compute the DWT transform of these two signals and compare the coefficents
+# of both transform at level 3. As you will see, the coefficients change completely for
+# the DWT despite the two input signals are very similar; this is not the case for the
+# DCWT transform.
+
+DOp = pylops.signalprocessing.DWT(dims=n, level=level, wavelet="sym7")
+X = pywt.array_to_coeffs(DOp @ x, DOp.sl, output_format="wavedecn")
+X1 = pywt.array_to_coeffs(DOp @ x1, DOp.sl, output_format="wavedecn")
+
+fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 5))
+axs[0, 0].stem(np.abs(X[1]["d"]), linefmt="k", markerfmt=".k", basefmt="k")
+axs[0, 0].set_title(f"DWT (Norm={np.linalg.norm(np.abs(X[1]['d']))**2:.3f})")
+axs[0, 1].stem(np.abs(X1[1]["d"]), linefmt="k", markerfmt=".k", basefmt="k")
+axs[0, 1].set_title(f"DWT (Norm={np.linalg.norm(np.abs(X1[1]['d']))**2:.3f})")
+axs[1, 0].stem(np.abs(Xc.highpasses[2]), linefmt="k", markerfmt=".k", basefmt="k")
+axs[1, 0].set_title(f"DCWT (Norm={np.linalg.norm(np.abs(Xc.highpasses[2]))**2:.3f})")
+axs[1, 1].stem(np.abs(Xc1.highpasses[2]), linefmt="k", markerfmt=".k", basefmt="k")
+axs[1, 1].set_title(f"DCWT (Norm={np.linalg.norm(np.abs(Xc1.highpasses[2]))**2:.3f})")
+plt.tight_layout()
+
+###################################################################################
+# The DTCWT can also be performed on multi-dimension arrays, where the parameter
+# ``axis`` is used to define the axis over which the transform is performed. Let's
+# just replicate our input signal over the second axis and see how the transform
+# will produce the same series of coefficients for all replicas.
+
+nrepeat = 10
+x = np.repeat(np.random.rand(n, 1), 10, axis=1).T
+
+level = 3
+DCOp = pylops.signalprocessing.DTCWT(dims=(nrepeat, n), level=level, axis=1)
+X = DCOp @ x
+
+fig, axs = plt.subplots(1, 2, sharey=True, figsize=(10, 3))
+axs[0].imshow(X[0])
+axs[0].axis("tight")
+axs[0].set_xlabel("Coeffs")
+axs[0].set_ylabel("Replicas")
+axs[0].set_title("DTCWT Real")
+axs[1].imshow(X[1])
+axs[1].axis("tight")
+axs[1].set_xlabel("Coeffs")
+axs[1].set_title("DTCWT Imag")
+plt.tight_layout()
diff --git a/pylops/signalprocessing/__init__.py b/pylops/signalprocessing/__init__.py
index 8efa532e..7137b586 100755
--- a/pylops/signalprocessing/__init__.py
+++ b/pylops/signalprocessing/__init__.py
@@ -24,9 +24,10 @@
DWT One dimensional Wavelet operator.
DWT2D Two dimensional Wavelet operator.
DCT Discrete Cosine Transform.
- Seislet Two dimensional Seislet operator.
+ DTCWT Dual-Tree Complex Wavelet Transform.
Radon2D Two dimensional Radon transform.
Radon3D Three dimensional Radon transform.
+ Seislet Two dimensional Seislet operator.
Sliding1D 1D Sliding transform operator.
Sliding2D 2D Sliding transform operator.
Sliding3D 3D Sliding transform operator.
@@ -62,6 +63,8 @@
from .dwt2d import *
from .seislet import *
from .dct import *
+from .dtcwt import *
+
__all__ = [
"FFT",
@@ -92,4 +95,5 @@
"DWT2D",
"Seislet",
"DCT",
+ "DTCWT",
]
diff --git a/pylops/signalprocessing/dct.py b/pylops/signalprocessing/dct.py
index eb46e872..1a336be6 100644
--- a/pylops/signalprocessing/dct.py
+++ b/pylops/signalprocessing/dct.py
@@ -29,7 +29,7 @@ class DCT(LinearOperator):
axes : :obj:`int` or :obj:`list`, optional
Axes over which the DCT is computed. If ``None``, the transform is applied
over all axes.
- workers :obj:`int`, optional
+ workers : :obj:`int`, optional
Maximum number of workers to use for parallel computation. If negative,
the value wraps around from os.cpu_count().
dtype : :obj:`DTypeLike`, optional
diff --git a/pylops/signalprocessing/dtcwt.py b/pylops/signalprocessing/dtcwt.py
new file mode 100644
index 00000000..2c78ff52
--- /dev/null
+++ b/pylops/signalprocessing/dtcwt.py
@@ -0,0 +1,182 @@
+__all__ = ["DTCWT"]
+
+from typing import Union
+
+import numpy as np
+
+from pylops import LinearOperator
+from pylops.utils import deps
+from pylops.utils._internal import _value_or_sized_to_tuple
+from pylops.utils.decorators import reshaped
+from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
+
+dtcwt_message = deps.dtcwt_import("the dtcwt module")
+
+if dtcwt_message is None:
+ import dtcwt
+
+
+class DTCWT(LinearOperator):
+ r"""Dual-Tree Complex Wavelet Transform
+
+ Perform 1D Dual-Tree Complex Wavelet Transform along an ``axis`` of a
+ multi-dimensional array of size ``dims``.
+
+ Note that the DTCWT operator is an overload of the ``dtcwt``
+ implementation of the DT-CWT transform. Refer to
+ https://dtcwt.readthedocs.io for a detailed description of the
+ input parameters.
+
+ Parameters
+ ----------
+ dims : :obj:`int` or :obj:`tuple`
+ Number of samples for each dimension.
+ birot : :obj:`str`, optional
+ Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot`. Default is `"near_sym_a"`.
+ qshift : :obj:`str`, optional
+ Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift`. Default is `"qshift_a"`
+ level : :obj:`int`, optional
+ Number of levels of wavelet decomposition. Default is 3.
+ include_scale : :obj:`bool`, optional
+ Include scales in pyramid. See :py:class:`dtcwt.Pyramid`. Default is False.
+ axis : :obj:`int`, optional
+ Axis on which the transform is performed.
+ dtype : :obj:`DTypeLike`, optional
+ Type of elements in input array.
+ name : :obj:`str`, optional
+ Name of operator (to be used by :func:`pylops.utils.describe.describe`)
+
+ Notes
+ -----
+ The DTCWT operator applies the dual-tree complex wavelet transform
+ in forward mode and the dual-tree complex inverse wavelet transform in adjoint mode
+ from the ``dtcwt`` library.
+
+ The ``dtcwt`` library uses a Pyramid object to represent the signal in the transformed domain,
+ which is composed of:
+ - `lowpass` (coarsest scale lowpass signal);
+ - `highpasses` (complex subband coefficients for corresponding scales);
+ - `scales` (lowpass signal for corresponding scales finest to coarsest).
+
+ To make the dtcwt forward() and inverse() functions compatible with PyLops, in forward model
+ the Pyramid object is flattened out and all coefficients (high-pass and low pass coefficients)
+ are appended into one array using the `_coeff_to_array` method.
+
+ In adjoint mode, the input array is transformed back into a Pyramid object using the `_array_to_coeff`
+ method and then the inverse transform is performed.
+
+ """
+
+ def __init__(
+ self,
+ dims: Union[int, InputDimsLike],
+ biort: str = "near_sym_a",
+ qshift: str = "qshift_a",
+ level: int = 3,
+ include_scale: bool = False,
+ axis: int = -1,
+ dtype: DTypeLike = "float64",
+ name: str = "C",
+ ) -> None:
+ if dtcwt_message is not None:
+ raise NotImplementedError(dtcwt_message)
+
+ dims = _value_or_sized_to_tuple(dims)
+ self.ndim = len(dims)
+ self.axis = axis
+
+ self.otherdims = int(np.prod(dims) / dims[self.axis])
+ self.dims_swapped = list(dims)
+ self.dims_swapped[0], self.dims_swapped[self.axis] = (
+ self.dims_swapped[self.axis],
+ self.dims_swapped[0],
+ )
+ self.dims_swapped = tuple(self.dims_swapped)
+ self.level = level
+ self.include_scale = include_scale
+
+ # dry-run of transform to find dimensions of coefficients at different levels
+ self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift)
+ self._interpret_coeffs(dims, self.axis)
+
+ dimsd = list(dims)
+ dimsd[self.axis] = self.coeff_array_size
+ self.dimsd_swapped = list(dimsd)
+ self.dimsd_swapped[0], self.dimsd_swapped[self.axis] = (
+ self.dimsd_swapped[self.axis],
+ self.dimsd_swapped[0],
+ )
+ self.dimsd_swapped = tuple(self.dimsd_swapped)
+ dimsd = tuple(
+ [
+ 2,
+ ]
+ + dimsd
+ )
+
+ super().__init__(
+ dtype=np.dtype(dtype),
+ clinear=False,
+ dims=dims,
+ dimsd=dimsd,
+ name=name,
+ )
+
+ def _interpret_coeffs(self, dims, axis):
+ x = np.ones(dims[axis])
+ pyr = self._transform.forward(
+ x, nlevels=self.level, include_scale=self.include_scale
+ )
+ self.lowpass_size = pyr.lowpass.size
+ self.coeff_array_size = self.lowpass_size
+ self.highpass_sizes = []
+ for _h in pyr.highpasses:
+ self.highpass_sizes.append(_h.size)
+ self.coeff_array_size += _h.size
+
+ def _nd_to_2d(self, arr_nd):
+ arr_2d = arr_nd.reshape(self.dims[self.axis], -1).squeeze()
+ return arr_2d
+
+ def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray:
+ highpass_coeffs = np.vstack([h for h in pyr.highpasses])
+ coeffs = np.concatenate((highpass_coeffs, pyr.lowpass), axis=0)
+ return coeffs
+
+ def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid:
+ lowpass = (X[-self.lowpass_size :].real).reshape((-1, self.otherdims))
+ _ptr = 0
+ highpasses = ()
+ for _sl in self.highpass_sizes:
+ _h = X[_ptr : _ptr + _sl]
+ _ptr += _sl
+ _h = _h.reshape(-1, self.otherdims)
+ highpasses += (_h,)
+ return dtcwt.Pyramid(lowpass, highpasses)
+
+ def get_pyramid(self, x: NDArray) -> dtcwt.Pyramid:
+ """Return Pyramid object from flat real-valued array"""
+ return self._array_to_coeff(x[0] + 1j * x[1])
+
+ @reshaped
+ def _matvec(self, x: NDArray) -> NDArray:
+ x = x.swapaxes(self.axis, 0)
+ y = self._nd_to_2d(x)
+ y = self._coeff_to_array(
+ self._transform.forward(
+ y, nlevels=self.level, include_scale=self.include_scale
+ )
+ )
+ y = y.reshape(self.dimsd_swapped)
+ y = y.swapaxes(self.axis, 0)
+ y = np.concatenate([y.real[np.newaxis], y.imag[np.newaxis]])
+ return y
+
+ @reshaped
+ def _rmatvec(self, x: NDArray) -> NDArray:
+ x = x[0] + 1j * x[1]
+ x = x.swapaxes(self.axis, 0)
+ y = self._transform.inverse(self._array_to_coeff(x))
+ y = y.reshape(self.dims_swapped)
+ y = y.swapaxes(self.axis, 0)
+ return y
diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py
index cbd5b3e3..4b2d21e7 100644
--- a/pylops/utils/deps.py
+++ b/pylops/utils/deps.py
@@ -1,6 +1,7 @@
__all__ = [
"cupy_enabled",
"devito_enabled",
+ "dtcwt_enabled",
"numba_enabled",
"pyfftw_enabled",
"pywt_enabled",
@@ -67,6 +68,23 @@ def devito_import(message: Optional[str] = None) -> str:
return devito_message
+def dtcwt_import(message: Optional[str] = None) -> str:
+ if dtcwt_enabled:
+ try:
+ import dtcwt # noqa: F401
+
+ dtcwt_message = None
+ except Exception as e:
+ dtcwt_message = f"Failed to import dtcwt (error:{e})."
+ else:
+ dtcwt_message = (
+ f"Dtcwt not available. "
+ f"In order to be able to use "
+ f'{message} run "pip install dtcwt".'
+ )
+ return dtcwt_message
+
+
def numba_import(message: Optional[str] = None) -> str:
if numba_enabled:
try:
@@ -187,6 +205,7 @@ def sympy_import(message: Optional[str] = None) -> str:
True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False
)
devito_enabled = util.find_spec("devito") is not None
+dtcwt_enabled = util.find_spec("dtcwt") is not None
numba_enabled = util.find_spec("numba") is not None
pyfftw_enabled = util.find_spec("pyfftw") is not None
pywt_enabled = util.find_spec("pywt") is not None
diff --git a/pyproject.toml b/pyproject.toml
index d2f6c854..9c338a48 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -43,6 +43,7 @@ advanced = [
"PyWavelets",
"scikit-fmm",
"spgl1",
+ "dtcwt",
]
[tool.setuptools.packages.find]
diff --git a/pytests/test_dtcwt.py b/pytests/test_dtcwt.py
new file mode 100644
index 00000000..b0cf2b61
--- /dev/null
+++ b/pytests/test_dtcwt.py
@@ -0,0 +1,82 @@
+import numpy as np
+import pytest
+
+from pylops.signalprocessing import DTCWT
+
+par1 = {"ny": 10, "nx": 10, "dtype": "float64"}
+par2 = {"ny": 50, "nx": 50, "dtype": "float64"}
+
+
+def sequential_array(shape):
+ num_elements = np.prod(shape)
+ seq_array = np.arange(1, num_elements + 1)
+ result = seq_array.reshape(shape)
+ return result
+
+
+@pytest.mark.parametrize("par", [(par1), (par2)])
+def test_dtcwt1D_input1D(par):
+ """Test for DTCWT with 1D input"""
+
+ t = sequential_array((par["ny"],))
+
+ for level in range(1, 10):
+ Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
+ x = Dtcwt @ t
+ y = Dtcwt.H @ x
+
+ np.testing.assert_allclose(t, y)
+
+
+@pytest.mark.parametrize("par", [(par1), (par2)])
+def test_dtcwt1D_input2D(par):
+ """Test for DTCWT with 2D input (forward-inverse pair)"""
+
+ t = sequential_array(
+ (
+ par["ny"],
+ par["ny"],
+ )
+ )
+
+ for level in range(1, 10):
+ Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
+ x = Dtcwt @ t
+ y = Dtcwt.H @ x
+
+ np.testing.assert_allclose(t, y)
+
+
+@pytest.mark.parametrize("par", [(par1), (par2)])
+def test_dtcwt1D_input3D(par):
+ """Test for DTCWT with 3D input (forward-inverse pair)"""
+
+ t = sequential_array((par["ny"], par["ny"], par["ny"]))
+
+ for level in range(1, 10):
+ Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
+ x = Dtcwt @ t
+ y = Dtcwt.H @ x
+
+ np.testing.assert_allclose(t, y)
+
+
+@pytest.mark.parametrize("par", [(par1), (par2)])
+def test_dtcwt1D_birot(par):
+ """Test for DTCWT birot (forward-inverse pair)"""
+ birots = ["antonini", "legall", "near_sym_a", "near_sym_b"]
+
+ t = sequential_array(
+ (
+ par["ny"],
+ par["ny"],
+ )
+ )
+
+ for _b in birots:
+ print(f"birot {_b}")
+ Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"])
+ x = Dtcwt @ t
+ y = Dtcwt.H @ x
+
+ np.testing.assert_allclose(t, y)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 8eabb62c..d86f07f1 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -8,6 +8,7 @@ spgl1
scikit-fmm
sympy
devito
+dtcwt
matplotlib
ipython
pytest