Skip to content

Commit

Permalink
Merge pull request #566 from mrava87/feat-unittaper
Browse files Browse the repository at this point in the history
feat: added exponent to cosinetaper
  • Loading branch information
mrava87 authored Jan 25, 2024
2 parents a6d09c1 + 38fe5ee commit e588fe3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 9 deletions.
26 changes: 19 additions & 7 deletions pylops/utils/tapers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"tapernd",
]

from typing import Tuple, Union
from typing import Optional, Tuple, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -59,6 +59,7 @@ def cosinetaper(
nmask: int,
ntap: int,
square: bool = False,
exponent: Optional[float] = None,
) -> npt.ArrayLike:
r"""1D Cosine or Cosine square taper
Expand All @@ -71,8 +72,10 @@ def cosinetaper(
Number of samples of mask
ntap : :obj:`int`
Number of samples of hanning tapering at edges
square : :obj:`bool`
Cosine square taper (``True``)or Cosine taper (``False``)
square : :obj:`bool`, optional
Cosine square taper (``True``) or Cosine taper (``False``)
exponent : :obj:`float`, optional
Exponent to apply to Cosine taper. If provided, takes precedence over ``square``
Returns
-------
Expand All @@ -81,7 +84,8 @@ def cosinetaper(
"""
ntap = 0 if ntap == 1 else ntap
exponent = 1 if not square else 2
if exponent is None:
exponent = 1 if not square else 2
cos_win = (
0.5
* (
Expand Down Expand Up @@ -123,7 +127,8 @@ def taper(
ntap : :obj:`int`
Number of samples of hanning tapering at edges
tapertype : :obj:`str`, optional
Type of taper (``hanning``, ``cosine``, ``cosinesquare`` or ``None``)
Type of taper (``hanning``, ``cosine``,
``cosinesquare``, ``cosinesqrt`` or ``None``)
Returns
-------
Expand All @@ -137,6 +142,8 @@ def taper(
tpr_1d = cosinetaper(nmask, ntap, False)
elif tapertype == "cosinesquare":
tpr_1d = cosinetaper(nmask, ntap, True)
elif tapertype == "cosinesqrt":
tpr_1d = cosinetaper(nmask, ntap, False, 0.5)
else:
tpr_1d = np.ones(nmask)
return tpr_1d
Expand Down Expand Up @@ -214,7 +221,7 @@ def taper3d(
Number of samples of tapering at edges of first and second dimensions
tapertype : :obj:`int`
Type of taper (``hanning``, ``cosine``,
``cosinesquare`` or ``None``)
``cosinesquare``, ``cosinesqrt`` or ``None``)
Returns
-------
Expand All @@ -236,6 +243,9 @@ def taper3d(
elif tapertype == "cosinesquare":
tpr_y = cosinetaper(nmasky, ntapy, True)
tpr_x = cosinetaper(nmaskx, ntapx, True)
elif tapertype == "cosinesqrt":
tpr_y = cosinetaper(nmasky, ntapy, False, 0.5)
tpr_x = cosinetaper(nmaskx, ntapx, False, 0.5)
else:
tpr_y = np.ones(nmasky)
tpr_x = np.ones(nmaskx)
Expand Down Expand Up @@ -266,7 +276,7 @@ def tapernd(
Number of samples of tapering at edges of every dimension
tapertype : :obj:`int`
Type of taper (``hanning``, ``cosine``,
``cosinesquare`` or ``None``)
``cosinesquare``, ``cosinesqrt`` or ``None``)
Returns
-------
Expand All @@ -282,6 +292,8 @@ def tapernd(
tpr = [cosinetaper(nm, nt, False) for nm, nt in zip(nmask, ntap)]
elif tapertype == "cosinesquare":
tpr = [cosinetaper(nm, nt, True) for nm, nt in zip(nmask, ntap)]
elif tapertype == "cosinesqrt":
tpr = [cosinetaper(nm, nt, False, 0.5) for nm, nt in zip(nmask, ntap)]
else:
tpr = [np.ones(nm) for nm in nmask]

Expand Down
20 changes: 18 additions & 2 deletions pytests/test_tapers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,23 @@
"ntap": (4, 6),
"tapertype": "cosinesquare",
} # cosinesquare, even samples and taper
par7 = {
"nt": 21,
"nspat": (11, 13),
"ntap": (3, 5),
"tapertype": "cosinesqrt",
} # cosinesqrt, odd samples and taper
par8 = {
"nt": 20,
"nspat": (12, 16),
"ntap": (4, 6),
"tapertype": "cosinesqrt",
} # cosinesqrt, even samples and taper


@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
@pytest.mark.parametrize(
"par", [(par1), (par2), (par3), (par4), (par5), (par6), (par7), (par8)]
)
def test_taper2d(par):
"""Create taper wavelet and check size and values"""
tap = taper2d(par["nt"], par["nspat"][0], par["ntap"][0], par["tapertype"])
Expand All @@ -54,7 +68,9 @@ def test_taper2d(par):
assert_array_equal(tap[par["nspat"][0] // 2], np.ones(par["nt"]))


@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
@pytest.mark.parametrize(
"par", [(par1), (par2), (par3), (par4), (par5), (par6), (par7), (par8)]
)
def test_taper3d(par):
"""Create taper wavelet and check size and values"""
tap = taper3d(par["nt"], par["nspat"], par["ntap"], par["tapertype"])
Expand Down

0 comments on commit e588fe3

Please sign in to comment.