diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 559d17cd9514..73e27245cfa9 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -35,7 +35,6 @@ import types from typing import (overload, Any, Literal, NamedTuple, Protocol, TypeVar, Union) -from textwrap import dedent as _dedent import warnings import numpy as np @@ -2316,15 +2315,58 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, return f -@util.implements(np.interp, - lax_description=_dedent(""" - In addition to constant interpolation supported by NumPy, jnp.interp also - supports left='extrapolate' and right='extrapolate' to indicate linear - extrapolation instead.""")) def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, right: ArrayLike | str | None = None, period: ArrayLike | None = None) -> Array: + """One-dimensional linear interpolation. + + JAX implementation of :func:`numpy.interp`. + + Args: + x: N-dimensional array of x coordinates at which to evaluate the interpolation. + xp: one-dimensional sorted array of points to be interpolated. + fp: array of shape ``xp.shape`` containing the function values associated with ``xp``. + left: specify how to handle points ``x < xp[0]``. Default is to return ``fp[0]``. + If ``left`` is a scalar value, it will return this value. if ``left`` is the string + ``"extrapolate"``, then the value will be determined by linear extrapolation. + ``left`` is ignored if ``period`` is specified. + right: specify how to handle points ``x > xp[-1]``. Default is to return ``fp[-1]``. + If ``right`` is a scalar value, it will return this value. if ``right`` is the string + ``"extrapolate"``, then the value will be determined by linear extrapolation. + ``right`` is ignored if ``period`` is specified. + period: optionally specify the period for the *x* coordinates, for e.g. interpolation + in angular space. + + Returns: + an array of shape ``x.shape`` containing the interpolated function at values ``x``. + + Examples: + >>> xp = jnp.arange(10) + >>> fp = 2 * xp + >>> x = jnp.array([0.5, 2.0, 3.5]) + >>> interp(x, xp, fp) + Array([1., 4., 7.], dtype=float32) + + Unless otherwise specified, extrapolation will be constant: + + >>> x = jnp.array([-10., 10.]) + >>> interp(x, xp, fp) + Array([ 0., 18.], dtype=float32) + + Use ``"extrapolate"`` mode for linear extrapolation: + + >>> interp(x, xp, fp, left='extrapolate', right='extrapolate') + Array([-20., 20.], dtype=float32) + + For periodic interpolation, specify the ``period``: + + >>> xp = jnp.array([0, jnp.pi / 2, jnp.pi, 3 * jnp.pi / 2]) + >>> fp = jnp.sin(xp) + >>> x = 2 * jnp.pi # note: not in input array + >>> jnp.interp(x, xp, fp, period=2 * jnp.pi) + Array(0., dtype=float32) + """ static_argnames = [] if isinstance(left, str) or left is None: static_argnames.append('left')