Skip to content

Commit

Permalink
Merge pull request #23197 from jakevdp:quantile-docs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 667602295
  • Loading branch information
jax authors committed Aug 26, 2024
2 parents e3e0860 + 9090b8a commit 550607a
Showing 1 changed file with 161 additions and 7 deletions.
168 changes: 161 additions & 7 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1883,16 +1883,53 @@ def cumulative_sum(
# Quantiles

# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@implements(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
"""Compute the quantile of the data along the specified axis.
JAX implementation of :func:`numpy.quantile`.
Args:
a: N-dimensional array input.
q: scalar or 1-dimensional array specifying the desired quantiles. ``q``
should contain floating-point values between ``0.0`` and ``1.0``.
axis: optional axis or tuple of axes along which to compute the quantile
out: not implemented by JAX; will error if not None
overwrite_input: not implemented by JAX; will error if not False
method: specify the interpolation method to use. Options are one of
``["linear", "lower", "higher", "midpoint", "nearest"]``.
default is ``linear``.
keepdims: if True, then the returned array will have the same number of
dimensions as the input. Default is False.
interpolation: deprecated alias of the ``method`` argument. Will result
in a :class:`DeprecationWarning` if used.
Returns:
An array containing the specified quantiles along the specified axes.
See also:
- :func:`jax.numpy.nanquantile`: compute the quantile while ignoring NaNs
- :func:`jax.numpy.percentile`: compute the percentile (0-100)
Examples:
Computing the median and quartiles of an array, with linear interpolation:
>>> x = jnp.arange(10)
>>> q = jnp.array([0.25, 0.5, 0.75])
>>> jnp.quantile(x, q)
Array([2.25, 4.5 , 6.75], dtype=float32)
Computing the quartiles using nearest-value interpolation:
>>> jnp.quantile(x, q, method='nearest')
Array([2., 4., 7.], dtype=float32)
"""
check_arraylike("quantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
raise ValueError("jax.numpy.quantile does not support overwrite_input=True "
"or out != None")
if not isinstance(interpolation, DeprecatedArg):
deprecations.warn(
"jax-numpy-quantile-interpolation",
Expand All @@ -1902,11 +1939,50 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False)

# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@implements(np.nanquantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
"""Compute the quantile of the data along the specified axis, ignoring NaNs.
JAX implementation of :func:`numpy.nanquantile`.
Args:
a: N-dimensional array input.
q: scalar or 1-dimensional array specifying the desired quantiles. ``q``
should contain floating-point values between ``0.0`` and ``1.0``.
axis: optional axis or tuple of axes along which to compute the quantile
out: not implemented by JAX; will error if not None
overwrite_input: not implemented by JAX; will error if not False
method: specify the interpolation method to use. Options are one of
``["linear", "lower", "higher", "midpoint", "nearest"]``.
default is ``linear``.
keepdims: if True, then the returned array will have the same number of
dimensions as the input. Default is False.
interpolation: deprecated alias of the ``method`` argument. Will result
in a :class:`DeprecationWarning` if used.
Returns:
An array containing the specified quantiles along the specified axes.
See also:
- :func:`jax.numpy.quantile`: compute the quantile without ignoring nans
- :func:`jax.numpy.nanpercentile`: compute the percentile (0-100)
Examples:
Computing the median and quartiles of a 1D array:
>>> x = jnp.array([0, 1, 2, jnp.nan, 3, 4, 5, 6])
>>> q = jnp.array([0.25, 0.5, 0.75])
Because of the NaN value, :func:`jax.numpy.quantile` returns all NaNs,
while :func:`~jax.numpy.nanquantile` ignores them:
>>> jnp.quantile(x, q)
Array([nan, nan, nan], dtype=float32)
>>> jnp.nanquantile(x, q)
Array([1.5, 3. , 4.5], dtype=float32)
"""
check_arraylike("nanquantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
Expand Down Expand Up @@ -2043,12 +2119,50 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
return lax.convert_element_type(result, a.dtype)

# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@implements(np.percentile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
"""Compute the percentile of the data along the specified axis.
JAX implementation of :func:`numpy.percentile`.
Args:
a: N-dimensional array input.
q: scalar or 1-dimensional array specifying the desired quantiles. ``q``
should contain integer or floating point values between ``0`` and ``100``.
axis: optional axis or tuple of axes along which to compute the quantile
out: not implemented by JAX; will error if not None
overwrite_input: not implemented by JAX; will error if not False
method: specify the interpolation method to use. Options are one of
``["linear", "lower", "higher", "midpoint", "nearest"]``.
default is ``linear``.
keepdims: if True, then the returned array will have the same number of
dimensions as the input. Default is False.
interpolation: deprecated alias of the ``method`` argument. Will result
in a :class:`DeprecationWarning` if used.
Returns:
An array containing the specified percentiles along the specified axes.
See also:
- :func:`jax.numpy.quantile`: compute the quantile (0.0-1.0)
- :func:`jax.numpy.nanpercentile`: compute the percentile while ignoring NaNs
Examples:
Computing the median and quartiles of a 1D array:
>>> x = jnp.array([0, 1, 2, 3, 4, 5, 6])
>>> q = jnp.array([25, 50, 75])
>>> jnp.percentile(x, q)
Array([1.5, 3. , 4.5], dtype=float32)
Computing the same percentiles with nearest rather than linear interpolation:
>>> jnp.percentile(x, q, method='nearest')
Array([1., 3., 4.], dtype=float32)
"""
check_arraylike("percentile", a, q)
q, = promote_dtypes_inexact(q)
if not isinstance(interpolation, DeprecatedArg):
Expand All @@ -2061,12 +2175,52 @@ def percentile(a: ArrayLike, q: ArrayLike,
method=method, keepdims=keepdims)

# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@implements(np.nanpercentile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
"""Compute the percentile of the data along the specified axis, ignoring NaN values.
JAX implementation of :func:`numpy.nanpercentile`.
Args:
a: N-dimensional array input.
q: scalar or 1-dimensional array specifying the desired quantiles. ``q``
should contain integer or floating point values between ``0`` and ``100``.
axis: optional axis or tuple of axes along which to compute the quantile
out: not implemented by JAX; will error if not None
overwrite_input: not implemented by JAX; will error if not False
method: specify the interpolation method to use. Options are one of
``["linear", "lower", "higher", "midpoint", "nearest"]``.
default is ``linear``.
keepdims: if True, then the returned array will have the same number of
dimensions as the input. Default is False.
interpolation: deprecated alias of the ``method`` argument. Will result
in a :class:`DeprecationWarning` if used.
Returns:
An array containing the specified percentiles along the specified axes.
See also:
- :func:`jax.numpy.nanquantile`: compute the nan-aware quantile (0.0-1.0)
- :func:`jax.numpy.percentile`: compute the percentile without special
handling of NaNs.
Examples:
Computing the median and quartiles of a 1D array:
>>> x = jnp.array([0, 1, 2, jnp.nan, 3, 4, 5, 6])
>>> q = jnp.array([25, 50, 75])
Because of the NaN value, :func:`jax.numpy.percentile` returns all NaNs,
while :func:`~jax.numpy.nanpercentile` ignores them:
>>> jnp.percentile(x, q)
Array([nan, nan, nan], dtype=float32)
>>> jnp.nanpercentile(x, q)
Array([1.5, 3. , 4.5], dtype=float32)
"""
check_arraylike("nanpercentile", a, q)
q = ufuncs.true_divide(q, 100.0)
if not isinstance(interpolation, DeprecatedArg):
Expand Down

0 comments on commit 550607a

Please sign in to comment.