Skip to content

Commit

Permalink
feat: added savetaper to Patch2D and Patch3D
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed May 18, 2024
1 parent c3da131 commit 3066a4f
Show file tree
Hide file tree
Showing 3 changed files with 741 additions and 120 deletions.
166 changes: 148 additions & 18 deletions pylops/signalprocessing/patch2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ class Patch2D(LinearOperator):
Size of model in the transformed domain
tapertype : :obj:`str`, optional
Type of taper (``hanning``, ``cosine``, ``cosinesquare`` or ``None``)
savetaper: :obj:`bool`, optional
Save all tapers and apply them in one go (``True``) or save unique tapers and apply them one by one (``False``).
The first option is more computationally efficient, whilst the second is more memory efficient.
scalings : :obj:`tuple` or :obj:`list`, optional
Set of scalings to apply to each patch. If ``None``, no scale will be
applied
Expand Down Expand Up @@ -178,6 +181,7 @@ def __init__(
nover: Tuple[int, int],
nop: Tuple[int, int],
tapertype: str = "hanning",
savetaper: bool = True,
scalings: Optional[Sequence[float]] = None,
name: str = "P",
) -> None:
Expand Down Expand Up @@ -206,52 +210,68 @@ def __init__(

# create tapers
self.tapertype = tapertype
self.savetaper = savetaper
if self.tapertype is not None:
tap = taper2d(nwin[1], nwin[0], nover, tapertype=tapertype).astype(Op.dtype)
taps = [
tap,
] * nwins
# topmost tapers
taptop = tap.copy()
taptop[: nover[0]] = tap[nwin[0] // 2]
for itap in range(0, nwins1):
taps[itap] = taptop
# bottommost tapers
tapbottom = tap.copy()
tapbottom[-nover[0] :] = tap[nwin[0] // 2]
for itap in range(nwins - nwins1, nwins):
taps[itap] = tapbottom
# leftmost tapers
tapleft = tap.copy()
tapleft[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
for itap in range(0, nwins, nwins1):
taps[itap] = tapleft
# rightmost tapers
tapright = tap.copy()
tapright[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
for itap in range(nwins1 - 1, nwins, nwins1):
taps[itap] = tapright
# lefttopcorner taper
taplefttop = tap.copy()
taplefttop[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
taplefttop[: nover[0]] = taplefttop[nwin[0] // 2]
taps[0] = taplefttop
# righttopcorner taper
taprighttop = tap.copy()
taprighttop[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
taprighttop[: nover[0]] = taprighttop[nwin[0] // 2]
taps[nwins1 - 1] = taprighttop
# leftbottomcorner taper
tapleftbottom = tap.copy()
tapleftbottom[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
tapleftbottom[-nover[0] :] = tapleftbottom[nwin[0] // 2]
taps[nwins - nwins1] = tapleftbottom
# rightbottomcorner taper
taprightbottom = tap.copy()
taprightbottom[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
taprightbottom[-nover[0] :] = taprightbottom[nwin[0] // 2]
taps[nwins - 1] = taprightbottom
self.taps = np.vstack(taps).reshape(nwins0, nwins1, nwin[0], nwin[1])

if self.savetaper:
taps = [
tap,
] * nwins
for itap in range(0, nwins1):
taps[itap] = taptop
for itap in range(nwins - nwins1, nwins):
taps[itap] = tapbottom
for itap in range(0, nwins, nwins1):
taps[itap] = tapleft
for itap in range(nwins1 - 1, nwins, nwins1):
taps[itap] = tapright
taps[0] = taplefttop
taps[nwins1 - 1] = taprighttop
taps[nwins - nwins1] = tapleftbottom
taps[nwins - 1] = taprightbottom
self.taps = np.vstack(taps).reshape(nwins0, nwins1, nwin[0], nwin[1])
else:
taps = [
taplefttop,
taptop,
taprighttop,
tapleft,
tap,
tapright,
tapleftbottom,
tapbottom,
taprightbottom,
]
self.taps = np.vstack(taps).reshape(3, 3, nwin[0], nwin[1])

# define scalings
if scalings is None:
Expand All @@ -273,8 +293,10 @@ def __init__(
name=name,
)

self._register_multiplications(self.savetaper)

@reshaped()
def _matvec(self, x: NDArray) -> NDArray:
def _matvec_savetaper(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
if self.tapertype is not None:
self.taps = to_cupy_conditional(x, self.taps)
Expand All @@ -299,7 +321,7 @@ def _matvec(self, x: NDArray) -> NDArray:
return y

@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
def _rmatvec_savetaper(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
ncp_sliding_window_view = get_sliding_window_view(x)
if self.tapertype is not None:
Expand All @@ -319,3 +341,111 @@ def _rmatvec(self, x: NDArray) -> NDArray:
ywins[iwin0, iwin1].ravel()
).reshape(self.dims[2], self.dims[3])
return y

@reshaped()
def _matvec_nosavetaper(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
if self.tapertype is not None:
self.taps = to_cupy_conditional(x, self.taps)
y = ncp.zeros(self.dimsd, dtype=self.dtype)
if self.simOp:
x = self.Op @ x
for iwin0 in range(self.dims[0]):
for iwin1 in range(self.dims[1]):
if self.simOp:
xxwin = x[iwin0, iwin1].reshape(self.nwin)
else:
xxwin = self.Op.matvec(x[iwin0, iwin1].ravel()).reshape(self.nwin)
if self.tapertype is not None:
if iwin0 == 0 and iwin1 == 0:
xxwin = self.taps[0, 0] * xxwin
elif iwin0 == 0 and iwin1 == self.dims[1] - 1:
xxwin = self.taps[0, -1] * xxwin
elif iwin0 == 0:
xxwin = self.taps[0, 1] * xxwin
elif iwin0 == self.dims[0] - 1 and iwin1 == 0:
xxwin = self.taps[-1, 0] * xxwin
elif iwin0 == self.dims[0] - 1 and iwin1 == self.dims[1] - 1:
xxwin = self.taps[-1, -1] * xxwin
elif iwin0 == self.dims[0] - 1:
xxwin = self.taps[-1, 1] * xxwin
elif iwin1 == 0:
xxwin = self.taps[1, 0] * xxwin
elif iwin1 == self.dims[1] - 1:
xxwin = self.taps[1, -1] * xxwin
else:
xxwin = self.taps[1, 1] * xxwin

y[
self.dwins_inends[0][0][iwin0] : self.dwins_inends[0][1][iwin0],
self.dwins_inends[1][0][iwin1] : self.dwins_inends[1][1][iwin1],
] += xxwin
return y

@reshaped
def _rmatvec_nosavetaper(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
ncp_sliding_window_view = get_sliding_window_view(x)
if self.tapertype is not None:
self.taps = to_cupy_conditional(x, self.taps)
ywins = ncp_sliding_window_view(x, self.nwin)[
:: self.nwin[0] - self.nover[0], :: self.nwin[1] - self.nover[1]
].copy()
if self.simOp:
if self.tapertype is not None:
for iwin0 in range(self.dims[0]):
for iwin1 in range(self.dims[1]):
if iwin0 == 0 and iwin1 == 0:
ywins[0, 0] = self.taps[0, 0] * ywins[0, 0]
elif iwin0 == 0 and iwin1 == self.dims[1] - 1:
ywins[0, -1] = self.taps[0, -1] * ywins[0, -1]
elif iwin0 == 0:
ywins[0, iwin1] = self.taps[0, 1] * ywins[0, iwin1]
elif iwin0 == self.dims[0] - 1 and iwin1 == 0:
ywins[-1, 0] = self.taps[-1, 0] * ywins[-1, 0]
elif iwin0 == self.dims[0] - 1 and iwin1 == self.dims[1] - 1:
ywins[-1, -1] = self.taps[-1, -1] * ywins[-1, -1]
elif iwin0 == self.dims[0] - 1:
ywins[-1, iwin1] = self.taps[-1, 1] * ywins[-1, iwin1]
elif iwin1 == 0:
ywins[iwin0, 0] = self.taps[1, 0] * ywins[iwin0, 0]
elif iwin1 == self.dims[1] - 1:
ywins[iwin0, -1] = self.taps[1, -1] * ywins[iwin0, -1]
else:
ywins[iwin0, iwin1] = self.taps[1, 1] * ywins[iwin0, iwin1]
y = self.Op.H @ ywins
else:
y = ncp.zeros(self.dims, dtype=self.dtype)
for iwin0 in range(self.dims[0]):
for iwin1 in range(self.dims[1]):
if self.tapertype is not None:
if iwin0 == 0 and iwin1 == 0:
ywins[0, 0] = self.taps[0, 0] * ywins[0, 0]
elif iwin0 == 0 and iwin1 == self.dims[1] - 1:
ywins[0, -1] = self.taps[0, -1] * ywins[0, -1]
elif iwin0 == 0:
ywins[0, iwin1] = self.taps[0, 1] * ywins[0, iwin1]
elif iwin0 == self.dims[0] - 1 and iwin1 == 0:
ywins[-1, 0] = self.taps[-1, 0] * ywins[-1, 0]
elif iwin0 == self.dims[0] - 1 and iwin1 == self.dims[1] - 1:
ywins[-1, -1] = self.taps[-1, -1] * ywins[-1, -1]
elif iwin0 == self.dims[0] - 1:
ywins[-1, iwin1] = self.taps[-1, 1] * ywins[-1, iwin1]
elif iwin1 == 0:
ywins[iwin0, 0] = self.taps[1, 0] * ywins[iwin0, 0]
elif iwin1 == self.dims[1] - 1:
ywins[iwin0, -1] = self.taps[1, -1] * ywins[iwin0, -1]
else:
ywins[iwin0, iwin1] = self.taps[1, 1] * ywins[iwin0, iwin1]
y[iwin0, iwin1] = self.Op.rmatvec(
ywins[iwin0, iwin1].ravel()
).reshape(self.dims[2], self.dims[3])
return y

def _register_multiplications(self, savetaper: bool) -> None:
if savetaper:
self._matvec = self._matvec_savetaper
self._rmatvec = self._rmatvec_savetaper
else:
self._matvec = self._matvec_nosavetaper
self._rmatvec = self._rmatvec_nosavetaper
Loading

0 comments on commit 3066a4f

Please sign in to comment.