diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index e3af13f98..3fcb61650 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -135,7 +135,7 @@ jobs: - tox-env: py310-coverage-lmoments # No markers -- includes slow tests python-version: "3.10" os: ubuntu-latest - - tox-env: py311-coverage-sbck + - tox-env: py311-coverage-sbck-extras python-version: "3.11" markers: -m 'not slow' os: ubuntu-latest diff --git a/AUTHORS.rst b/AUTHORS.rst index f9a2ff924..c83977550 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -29,7 +29,7 @@ Contributors * David Caron `@davidcaron `_ * Carsten Ehbrecht `@cehbrecht `_ * Jeremy Fyke `@jeremyfyke `_ -* Sarah Gammon `@SarahG-579462 `_ +* Sarah Gammon `@SarahG-579462 `_ * Tom Keel `@Thomasjkeel `_ * Marie-Pier Labonté `@marielabonte `_ * Ludwig Lierhammer `@ludwiglierhammer `_ diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 026917a27..1f297d24d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,7 +4,11 @@ Changelog v0.52.0 (unreleased) -------------------- -Contributors to this version: David Huard (:user:`huard`), Trevor James Smith (:user:`Zeitsperre`), Hui-Min Wang (:user:`Hem-W`). +Contributors to this version: David Huard (:user:`huard`), Trevor James Smith (:user:`Zeitsperre`), Hui-Min Wang (:user:`Hem-W`), Éric Dupuis (:user:`coxipi`), Sarah Gammon (:user:`SarahG-579462`). + +New features and enhancements +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +* ``xclim.sdba.nbutils.quantile`` and its child functions are now faster. If the module `fastnanquantile` is installed, it is used as the backend for the computation of quantiles and yields even faster results. (:issue:`1255`, :pull:`1513`). Bug fixes ^^^^^^^^^ diff --git a/docs/conf.py b/docs/conf.py index b4e71d3a6..5c504e44b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -267,6 +267,7 @@ class XCStyle(AlphaStyle): "_build", "Thumbs.db", ".DS_Store", + "notebooks/benchmarks", "notebooks/xclim_training", "paper/paper.md", "**.ipynb_checkpoints", diff --git a/docs/notebooks/benchmarks/sdba_quantile.ipynb b/docs/notebooks/benchmarks/sdba_quantile.ipynb new file mode 100644 index 000000000..ce19b230c --- /dev/null +++ b/docs/notebooks/benchmarks/sdba_quantile.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "import time\n", + "\n", + "import numpy as np\n", + "\n", + "import xclim\n", + "from xclim import sdba\n", + "from xclim.testing import open_dataset\n", + "\n", + "ds = open_dataset(\"sdba/CanESM2_1950-2100.nc\")\n", + "tx = ds.sel(time=slice(\"1950\", \"1980\")).tasmax\n", + "kws = {\"dim\": \"time\", \"q\": np.linspace(0, 1, 50)}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tests with %%timeit (full 30 years)\n", + "\n", + "Here `fastnanquantile` is the best algorithm out of \n", + "* `xr.DataArray.quantile`\n", + "* `nbutils.quantile`, using: \n", + " * `xclim.core.utils.nan_quantile`\n", + " * `fastnanquantile`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "tx.quantile(**kws).compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "sdba.nbutils.USE_FASTNANQUANTILE = False\n", + "sdba.nbutils.quantile(tx, **kws).compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! pip install fastnanquantile" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "sdba.nbutils.USE_FASTNANQUANTILE = True\n", + "sdba.nbutils.quantile(tx, **kws).compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test computation time as a function of number of points\n", + "\n", + "For a smaller number of time steps <=2000, `_sortquantile` is the best algorithm in general" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import xarray as xr\n", + "\n", + "num_tests = 500\n", + "timed = {}\n", + "# fastnanquantile has nothing to do with sortquantile\n", + "# I just added a third step using this variable\n", + "\n", + "for use_fnq in [True, False]:\n", + " sdba.nbutils.USE_FASTNANQUANTILE = use_fnq\n", + " # heat-up the jit\n", + " sdba.nbutils.quantile(\n", + " xr.DataArray(np.array([0, 1.5])), dim=\"dim_0\", q=np.array([0.5])\n", + " )\n", + " for size in np.arange(250, 2000 + 250, 250):\n", + " da = tx.isel(time=slice(0, size))\n", + " t0 = time.time()\n", + " for ii in range(num_tests):\n", + " sdba.nbutils.quantile(da, **kws).compute()\n", + " timed[use_fnq].append([size, time.time() - t0])\n", + "\n", + "for k, lab in zip([True, False], [\"xclim.core.utils.nan_quantile\", \"fastnanquantile\"]):\n", + " arr = np.array(timed[k])\n", + " plt.plot(arr[:, 0], arr[:, 1] / num_tests, label=lab)\n", + "plt.legend()\n", + "plt.title(\"Quantile computation, average time vs array size, for 50 quantiles\")\n", + "plt.xlabel(\"Number of time steps in the distribution\")\n", + "plt.ylabel(\"Computation time (s)\")" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 68dcadab9..11ef7cfe2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,7 +110,8 @@ docs = [ "sphinxcontrib-bibtex", "sphinxcontrib-svg2pdfconverter[Cairosvg]" ] -all = ["xclim[dev]", "xclim[docs]"] +extras = ["fastnanquantile"] +all = ["xclim[dev]", "xclim[docs]", "xclim[extras]"] [project.scripts] xclim = "xclim.cli:cli" diff --git a/tests/test_sdba/test_nbutils.py b/tests/test_sdba/test_nbutils.py new file mode 100644 index 000000000..20ece1e21 --- /dev/null +++ b/tests/test_sdba/test_nbutils.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import numpy as np +import pytest +import xarray as xr + +from xclim.sdba import nbutils as nbu + + +class TestQuantiles: + @pytest.mark.parametrize("uses_dask", [True, False]) + def test_quantile(self, open_dataset, uses_dask): + da = ( + open_dataset("sdba/CanESM2_1950-2100.nc").sel(time=slice("1950", "1955")).pr + ).load() + if uses_dask: + da = da.chunk({"location": 1}) + else: + da = da.load() + q = np.linspace(0.1, 0.99, 50) + out_nbu = nbu.quantile(da, q, dim="time").transpose("location", ...) + out_xr = da.quantile(q=q, dim="time").transpose("location", ...) + np.testing.assert_array_almost_equal(out_nbu.values, out_xr.values) + + def test_edge_cases(self, open_dataset): + q = np.linspace(0.1, 0.99, 50) + + # only 1 non-null value + da = xr.DataArray([1] + [np.nan] * 100, dims="dim_0") + out_nbu = nbu.quantile(da, q, dim="dim_0") + np.testing.assert_array_equal(out_nbu.values, np.full_like(q, 1)) + + # only NANs + da = xr.DataArray([np.nan] * 100, dims="dim_0") + out_nbu = nbu.quantile(da, q, dim="dim_0") + np.testing.assert_array_equal(out_nbu.values, np.full_like(q, np.nan)) diff --git a/tox.ini b/tox.ini index dcb355920..ca18e6cb3 100644 --- a/tox.ini +++ b/tox.ini @@ -113,7 +113,9 @@ passenv = LD_LIBRARY_PATH SKIP_NOTEBOOKS XCLIM_* -extras = dev +extras = + dev + extras: extras deps = upstream: -r CI/requirements_upstream.txt sbck: pybind11 diff --git a/xclim/sdba/nbutils.py b/xclim/sdba/nbutils.py index aa87fd0a3..29a1ae9ad 100644 --- a/xclim/sdba/nbutils.py +++ b/xclim/sdba/nbutils.py @@ -9,15 +9,154 @@ import numpy as np from numba import boolean, float32, float64, guvectorize, njit -from xarray import DataArray +from xarray import DataArray, apply_ufunc from xarray.core import utils +try: + from fastnanquantile.xrcompat import xr_apply_nanquantile + + USE_FASTNANQUANTILE = True +except ImportError: + USE_FASTNANQUANTILE = False + + +@njit( + fastmath={"arcp", "contract", "reassoc", "nsz", "afn"}, + nogil=True, + cache=False, +) +def _get_indexes( + arr: np.array, virtual_indexes: np.array, valid_values_count: np.array +) -> tuple[np.array, np.array]: + """Get the valid indexes of arr neighbouring virtual_indexes. + + Parameters + ---------- + arr : array-like + virtual_indexes : array-like + valid_values_count : array-like + + Returns + ------- + array-like, array-like + A tuple of virtual_indexes neighbouring indexes (previous and next) + + Notes + ----- + This is a companion function to linear interpolation of quantiles. + """ + previous_indexes = np.asarray(np.floor(virtual_indexes)) + next_indexes = np.asarray(previous_indexes + 1) + indexes_above_bounds = virtual_indexes >= valid_values_count - 1 + # When indexes is above max index, take the max value of the array + if indexes_above_bounds.any(): + previous_indexes[indexes_above_bounds] = -1 + next_indexes[indexes_above_bounds] = -1 + # When indexes is below min index, take the min value of the array + indexes_below_bounds = virtual_indexes < 0 + if indexes_below_bounds.any(): + previous_indexes[indexes_below_bounds] = 0 + next_indexes[indexes_below_bounds] = 0 + if (arr.dtype is np.dtype(np.float64)) or (arr.dtype is np.dtype(np.float32)): + # After the sort, slices having NaNs will have for last element a NaN + virtual_indexes_nans = np.isnan(virtual_indexes) + if virtual_indexes_nans.any(): + previous_indexes[virtual_indexes_nans] = -1 + next_indexes[virtual_indexes_nans] = -1 + previous_indexes = previous_indexes.astype(np.intp) + next_indexes = next_indexes.astype(np.intp) + return previous_indexes, next_indexes + + +@njit( + fastmath={"arcp", "contract", "reassoc", "nsz", "afn"}, + nogil=True, + cache=False, +) +def _linear_interpolation( + left: np.array, + right: np.array, + gamma: np.array, +) -> np.array: + """Compute the linear interpolation weighted by gamma on each point of two same shape arrays. + + Parameters + ---------- + left : array_like + Left bound. + right : array_like + Right bound. + gamma : array_like + The interpolation weight. + + Returns + ------- + array_like + + Notes + ----- + This is a companion function for `_nan_quantile_1d` + """ + diff_b_a = np.subtract(right, left) + lerp_interpolation = np.asarray(np.add(left, diff_b_a * gamma)) + ind = gamma >= 0.5 + lerp_interpolation[ind] = right[ind] - diff_b_a[ind] * (1 - gamma[ind]) + return lerp_interpolation + + +@njit( + fastmath={"arcp", "contract", "reassoc", "nsz", "afn"}, + nogil=True, + cache=False, +) +def _nan_quantile_1d( + arr: np.array, + quantiles: np.array, + alpha: float = 1.0, + beta: float = 1.0, +) -> float | np.array: + """Get the quantiles of the 1-dimensional array. + + A linear interpolation is performed using alpha and beta. + + Notes + ----- + By default, `alpha == beta == 1` which performs the 7th method of :cite:t:`hyndman_sample_1996`. + with `alpha == beta == 1/3` we get the 8th method. alpha == beta == 1 reproduces the behaviour of `np.nanquantile`. + """ + # We need at least two values to do an interpolation + valid_values_count = (~np.isnan(arr)).sum() + + # Computation of indexes + virtual_indexes = ( + valid_values_count * quantiles + (alpha + quantiles * (1 - alpha - beta)) - 1 + ) + virtual_indexes = np.asarray(virtual_indexes) + previous_indexes, next_indexes = _get_indexes( + arr, virtual_indexes, valid_values_count + ) + # Sorting + arr.sort() + + previous = arr[previous_indexes] + next_elements = arr[next_indexes] + + # Linear interpolation + gamma = np.asarray(virtual_indexes - previous_indexes, dtype=arr.dtype) + interpolation = _linear_interpolation(previous, next_elements, gamma) + # When an interpolation is in Nan range, (near the end of the sorted array) it means + # we can clip to the array max value. + result = np.where( + np.isnan(interpolation), arr[np.intp(valid_values_count) - 1], interpolation + ) + return result + @guvectorize( [(float32[:], float32, float32[:]), (float64[:], float64, float64[:])], "(n),()->()", nopython=True, - cache=True, + cache=False, ) def _vecquantiles(arr, rnk, res): if np.isnan(rnk): @@ -62,14 +201,27 @@ def vecquantiles( @njit -def _quantile(arr, q): - if arr.ndim == 1: - out = np.empty((q.size,), dtype=arr.dtype) - out[:] = np.nanquantile(arr, q) +def _wrapper_quantile1d(arr, q): + out = np.empty((arr.shape[0], q.size), dtype=arr.dtype) + for index in range(out.shape[0]): + out[index] = _nan_quantile_1d(arr[index], q) + return out + + +def _quantile(arr, q, nreduce): + if arr.ndim == nreduce: + out = _nan_quantile_1d(arr.flatten(), q) else: - out = np.empty((arr.shape[0], q.size), dtype=arr.dtype) - for index in range(out.shape[0]): - out[index] = np.nanquantile(arr[index], q) + # dimensions that are reduced by quantile + red_axis = np.arange(len(arr.shape) - nreduce, len(arr.shape)) + reduction_dim_size = np.prod([arr.shape[idx] for idx in red_axis]) + # kept dimensions + keep_axis = np.arange(len(arr.shape) - nreduce) + final_shape = [arr.shape[idx] for idx in keep_axis] + [len(q)] + # reshape as (keep_dims, red_dims), compute, reshape back + arr = arr.reshape(-1, reduction_dim_size) + out = _wrapper_quantile1d(arr, q) + out = out.reshape(final_shape) return out @@ -90,49 +242,39 @@ def quantile(da: DataArray, q: np.ndarray, dim: str | Sequence[Hashable]) -> Dat xarray.DataArray The quantiles computed along the `dim` dimension. """ - # We have two cases : - # - When all dims are processed : we stack them and use _quantile1d - # - When the quantiles are vectorized over some dims, these are also stacked and then _quantile2D is used. - # All this stacking is so that we can cover all ND+1D cases with one numba function. - - # Stack the dims and send to the last position - # This is in case there are more than one - dims = [dim] if isinstance(dim, str) else dim - tem = utils.get_temp_dimname(da.dims, "temporal") - da = da.stack({tem: dims}) - - # So we cut in half the definitions to declare in numba - # We still use q as the coords, so it corresponds to what was done upstream - if not hasattr(q, "dtype") or q.dtype != da.dtype: - qc = np.array(q, dtype=da.dtype) + if USE_FASTNANQUANTILE is True: + return xr_apply_nanquantile(da, dim=dim, q=q).rename({"quantile": "quantiles"}) else: - qc = q - - if len(da.dims) > 1: - # There are some extra dims - extra = utils.get_temp_dimname(da.dims, "extra") - da = da.stack({extra: list(set(da.dims) - {tem})}) - da = da.transpose(..., tem) - res = DataArray( - _quantile(da.values, qc), - dims=(extra, "quantiles"), - coords={extra: da[extra], "quantiles": q}, - attrs=da.attrs, - ).unstack([extra]) - - else: - # All dims are processed - res = DataArray( - _quantile(da.values, qc), - dims="quantiles", - coords={"quantiles": q}, - attrs=da.attrs, + qc = np.array(q, dtype=da.dtype) + dims = [dim] if isinstance(dim, str) else dim + kwargs = dict(nreduce=len(dims), q=qc) + res = ( + apply_ufunc( + _quantile, + da, + input_core_dims=[dims], + exclude_dims=set(dims), + output_core_dims=[["quantiles"]], + output_dtypes=[da.dtype], + dask_gufunc_kwargs=dict(output_sizes={"quantiles": len(q)}), + dask="parallelized", + kwargs=kwargs, + ) + .assign_coords(quantiles=q) + .assign_attrs(da.attrs) ) - - return res + return res -@njit +@njit( + [ + float32[:, :](float32[:, :]), + float64[:, :](float64[:, :]), + ], + fastmath=False, + nogil=True, + cache=False, +) def remove_NaNs(x): # noqa """Remove NaN values from series.""" remove = np.zeros_like(x[0, :], dtype=boolean) @@ -141,7 +283,15 @@ def remove_NaNs(x): # noqa return x[:, ~remove] -@njit(fastmath=True) +@njit( + [ + float32(float32[:, :], float32[:, :]), + float64(float64[:, :], float64[:, :]), + ], + fastmath=True, + nogil=True, + cache=False, +) def _correlation(X, Y): """Compute a correlation as the mean of pairwise distances between points in X and Y. @@ -158,7 +308,15 @@ def _correlation(X, Y): return d / (X.shape[1] * Y.shape[1]) -@njit(fastmath=True) +@njit( + [ + float32(float32[:, :]), + float64(float64[:, :]), + ], + fastmath=True, + nogil=True, + cache=False, +) def _autocorrelation(X): """Mean of the NxN pairwise distances of points in X of shape KxN. @@ -181,7 +339,7 @@ def _autocorrelation(X): ], "(k, n),(k, m)->()", nopython=True, - cache=True, + cache=False, ) def _escore(tgt, sim, out): """E-score based on the Székely-Rizzo e-distances between clusters. @@ -204,7 +362,11 @@ def _escore(tgt, sim, out): out[0] = w * (sXY + sXY - sXX - sYY) / 2 -@njit +@njit( + fastmath=False, + nogil=True, + cache=False, +) def _first_and_last_nonnull(arr): """For each row of arr, get the first and last non NaN elements.""" out = np.empty((arr.shape[0], 2)) @@ -217,8 +379,14 @@ def _first_and_last_nonnull(arr): return out -@njit -def _extrapolate_on_quantiles(interp, oldx, oldg, oldy, newx, newg, method="constant"): +@njit( + fastmath=False, + nogil=True, + cache=False, +) +def _extrapolate_on_quantiles( + interp, oldx, oldg, oldy, newx, newg, method="constant" +): # noqa """Apply extrapolation to the output of interpolation on quantiles with a given grouping. Arguments are the same as _interp_on_quantiles_2D. @@ -239,7 +407,11 @@ def _extrapolate_on_quantiles(interp, oldx, oldg, oldy, newx, newg, method="cons return interp -@njit +@njit( + fastmath=False, + nogil=True, + cache=False, +) def _pairwise_haversine_and_bins(lond, latd, transpose=False): """Inter-site distances with the haversine approximation.""" N = lond.shape[0]