Skip to content

Commit

Permalink
Merge pull request astropy#17204 from neutrinoceros/stats/bug/avoid_b…
Browse files Browse the repository at this point in the history
…ottleneck_with_float32

BUG: (slow, steady and) correct wins the race: prefer numpy over bottleneck for nanfunctions on float32 arrays
  • Loading branch information
larrybradley authored Oct 18, 2024
2 parents e9d9583 + c06e4f7 commit f913a55
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 5 deletions.
38 changes: 34 additions & 4 deletions astropy/stats/nanfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,40 @@ def _apply_bottleneck(
else:
return result

nansum = functools.partial(_apply_bottleneck, bottleneck.nansum)
nanmean = functools.partial(_apply_bottleneck, bottleneck.nanmean)
nanmedian = functools.partial(_apply_bottleneck, bottleneck.nanmedian)
nanstd = functools.partial(_apply_bottleneck, bottleneck.nanstd)
bn_funcs = dict(
nansum=functools.partial(_apply_bottleneck, bottleneck.nansum),
nanmean=functools.partial(_apply_bottleneck, bottleneck.nanmean),
nanmedian=functools.partial(_apply_bottleneck, bottleneck.nanmedian),
nanstd=functools.partial(_apply_bottleneck, bottleneck.nanstd),
)

np_funcs = dict(
nansum=np.nansum,
nanmean=np.nanmean,
nanmedian=np.nanmedian,
nanstd=np.nanstd,
)

def _dtype_dispatch(func_name):
# dispatch to bottleneck or numpy depending on the input array dtype
# this is done to workaround known accuracy bugs in bottleneck
# affecting float32 calculations
# see https://github.com/pydata/bottleneck/issues/379
# see https://github.com/pydata/bottleneck/issues/462
# see https://github.com/astropy/astropy/issues/17185
# see https://github.com/astropy/astropy/issues/11492
def wrapped(*args, **kwargs):
if args[0].dtype.str[1:] == "f8":
return bn_funcs[func_name](*args, **kwargs)
else:
return np_funcs[func_name](*args, **kwargs)

return wrapped

nansum = _dtype_dispatch("nansum")
nanmean = _dtype_dispatch("nanmean")
nanmedian = _dtype_dispatch("nanmedian")
nanstd = _dtype_dispatch("nanstd")

else:
nansum = np.nansum
Expand Down
3 changes: 3 additions & 0 deletions astropy/stats/sigma_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class SigmaClip:
specified as a string. If one of the options is set to a string
while the other has a custom callable, you may in some cases see
better performance if you have the `bottleneck`_ package installed.
To preserve accuracy, bottleneck is only used for float64 computations.
.. _bottleneck: https://github.com/pydata/bottleneck
Expand Down Expand Up @@ -825,6 +826,7 @@ def sigma_clip(
specified as a string. If one of the options is set to a string
while the other has a custom callable, you may in some cases see
better performance if you have the `bottleneck`_ package installed.
To preserve accuracy, bottleneck is only used for float64 computations.
.. _bottleneck: https://github.com/pydata/bottleneck
Expand Down Expand Up @@ -973,6 +975,7 @@ def sigma_clipped_stats(
specified as a string. If one of the options is set to a string
while the other has a custom callable, you may in some cases see
better performance if you have the `bottleneck`_ package installed.
To preserve accuracy, bottleneck is only used for float64 computations.
.. _bottleneck: https://github.com/pydata/bottleneck
Expand Down
22 changes: 21 additions & 1 deletion astropy/stats/tests/test_sigma_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from astropy.stats import mad_std
from astropy.stats.sigma_clipping import SigmaClip, sigma_clip, sigma_clipped_stats
from astropy.table import MaskedColumn
from astropy.utils.compat.optional_deps import HAS_SCIPY
from astropy.utils.compat import COPY_IF_NEEDED
from astropy.utils.compat.optional_deps import HAS_BOTTLENECK, HAS_SCIPY
from astropy.utils.exceptions import AstropyUserWarning
from astropy.utils.misc import NumpyRNGContext

Expand Down Expand Up @@ -173,6 +174,25 @@ def test_sigma_clipped_stats_masked_col():
sigma_clipped_stats(col)


@pytest.mark.slow
@pytest.mark.skipif(
not HAS_BOTTLENECK,
reason="test a workaround for upstream bug in bottleneck",
)
@pytest.mark.parametrize("shape", [(1024, 1024), (6388, 9576)])
def test_sigma_clip_large_float32_arrays(shape):
# see https://github.com/astropy/astropy/issues/17185
rng = np.random.default_rng(0)

expected = (0.5, 0.5, 0.288) # mean, median, stddev

arr = rng.random(size=shape, dtype="f4")
for byteorder in (">", "<"):
data = arr.astype(dtype=f"{byteorder}f4", copy=COPY_IF_NEEDED)
res = sigma_clipped_stats(data, sigma=3, maxiters=5)
assert_allclose(res, expected, rtol=3e-3)


def test_invalid_sigma_clip():
"""Test sigma_clip of data containing invalid values."""

Expand Down
4 changes: 4 additions & 0 deletions docs/changes/stats/17204.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fixed accuracy of sigma clipping for large ``float32`` arrays when
``bottleneck`` is installed. Performance may be impacted for computations
involving arrays with dtype other than ``float64``. This change has no impact
for environments that do not have ``bottleneck`` installed.

0 comments on commit f913a55

Please sign in to comment.