Skip to content

Commit

Permalink
Merge pull request #575 from mrava87/dev
Browse files Browse the repository at this point in the history
doc: added safe typing to dtcwt
  • Loading branch information
mrava87 authored Mar 19, 2024
2 parents 99d91f1 + b8e35b2 commit ad3d9cc
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions pylops/signalprocessing/dtcwt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["DTCWT"]

from typing import Union
from typing import Any, NewType, Union

import numpy as np

Expand All @@ -15,6 +15,12 @@
if dtcwt_message is None:
import dtcwt

pyramid_type = dtcwt.numpy.common.Pyramid
else:
pyramid_type = Any

PyramidType = NewType("PyramidType", pyramid_type)


class DTCWT(LinearOperator):
r"""Dual-Tree Complex Wavelet Transform
Expand Down Expand Up @@ -122,7 +128,11 @@ def __init__(
name=name,
)

def _interpret_coeffs(self, dims, axis):
def _interpret_coeffs(
self,
dims: Union[int, InputDimsLike],
axis: int,
) -> None:
x = np.ones(dims[axis])
pyr = self._transform.forward(
x, nlevels=self.level, include_scale=self.include_scale
Expand All @@ -134,16 +144,16 @@ def _interpret_coeffs(self, dims, axis):
self.highpass_sizes.append(_h.size)
self.coeff_array_size += _h.size

def _nd_to_2d(self, arr_nd):
def _nd_to_2d(self, arr_nd: NDArray) -> NDArray:
arr_2d = arr_nd.reshape(self.dims[self.axis], -1).squeeze()
return arr_2d

def _coeff_to_array(self, pyr): # cannot use dtcwt types as it may not be installed
def _coeff_to_array(self, pyr: PyramidType) -> NDArray:
highpass_coeffs = np.vstack([h for h in pyr.highpasses])
coeffs = np.concatenate((highpass_coeffs, pyr.lowpass), axis=0)
return coeffs

def _array_to_coeff(self, X): # cannot use dtcwt types as it may not be installed
def _array_to_coeff(self, X: NDArray) -> PyramidType:
lowpass = (X[-self.lowpass_size :].real).reshape((-1, self.otherdims))
_ptr = 0
highpasses = ()
Expand All @@ -154,7 +164,7 @@ def _array_to_coeff(self, X): # cannot use dtcwt types as it may not be install
highpasses += (_h,)
return dtcwt.Pyramid(lowpass, highpasses)

def get_pyramid(self, x): # cannot use dtcwt types as it may not be installed
def get_pyramid(self, x: NDArray) -> PyramidType:
"""Return Pyramid object from flat real-valued array"""
return self._array_to_coeff(x[0] + 1j * x[1])

Expand Down

0 comments on commit ad3d9cc

Please sign in to comment.