Skip to content

Commit

Permalink
Merge pull request #23916 from jakevdp:interp-doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678784484
  • Loading branch information
Google-ML-Automation committed Sep 25, 2024
2 parents 70346bd + ee6fd5a commit 5d4cae0
Showing 1 changed file with 48 additions and 6 deletions.
54 changes: 48 additions & 6 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import types
from typing import (overload, Any, Literal, NamedTuple,
Protocol, TypeVar, Union)
from textwrap import dedent as _dedent
import warnings

import numpy as np
Expand Down Expand Up @@ -2316,15 +2315,58 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
return f


@util.implements(np.interp,
lax_description=_dedent("""
In addition to constant interpolation supported by NumPy, jnp.interp also
supports left='extrapolate' and right='extrapolate' to indicate linear
extrapolation instead."""))
def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
left: ArrayLike | str | None = None,
right: ArrayLike | str | None = None,
period: ArrayLike | None = None) -> Array:
"""One-dimensional linear interpolation.
JAX implementation of :func:`numpy.interp`.
Args:
x: N-dimensional array of x coordinates at which to evaluate the interpolation.
xp: one-dimensional sorted array of points to be interpolated.
fp: array of shape ``xp.shape`` containing the function values associated with ``xp``.
left: specify how to handle points ``x < xp[0]``. Default is to return ``fp[0]``.
If ``left`` is a scalar value, it will return this value. if ``left`` is the string
``"extrapolate"``, then the value will be determined by linear extrapolation.
``left`` is ignored if ``period`` is specified.
right: specify how to handle points ``x > xp[-1]``. Default is to return ``fp[-1]``.
If ``right`` is a scalar value, it will return this value. if ``right`` is the string
``"extrapolate"``, then the value will be determined by linear extrapolation.
``right`` is ignored if ``period`` is specified.
period: optionally specify the period for the *x* coordinates, for e.g. interpolation
in angular space.
Returns:
an array of shape ``x.shape`` containing the interpolated function at values ``x``.
Examples:
>>> xp = jnp.arange(10)
>>> fp = 2 * xp
>>> x = jnp.array([0.5, 2.0, 3.5])
>>> interp(x, xp, fp)
Array([1., 4., 7.], dtype=float32)
Unless otherwise specified, extrapolation will be constant:
>>> x = jnp.array([-10., 10.])
>>> interp(x, xp, fp)
Array([ 0., 18.], dtype=float32)
Use ``"extrapolate"`` mode for linear extrapolation:
>>> interp(x, xp, fp, left='extrapolate', right='extrapolate')
Array([-20., 20.], dtype=float32)
For periodic interpolation, specify the ``period``:
>>> xp = jnp.array([0, jnp.pi / 2, jnp.pi, 3 * jnp.pi / 2])
>>> fp = jnp.sin(xp)
>>> x = 2 * jnp.pi # note: not in input array
>>> jnp.interp(x, xp, fp, period=2 * jnp.pi)
Array(0., dtype=float32)
"""
static_argnames = []
if isinstance(left, str) or left is None:
static_argnames.append('left')
Expand Down

0 comments on commit 5d4cae0

Please sign in to comment.