Skip to content

Commit

Permalink
Better documentation for jnp.choice
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 26, 2024
1 parent 0a66e2d commit 36f74bb
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 3 deletions.
80 changes: 78 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4388,9 +4388,85 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array:
return concatenate(arrs, axis=1)


@util.implements(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
def choose(a: ArrayLike, choices: Array | np.ndarray | Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
"""Construct an array by stacking slices of choice arrays.
JAX implementation of :func:`numpy.choose`.
The semantics of this function can be confusing, but in the simplest case where
``a`` is a one-dimensional array, ``choices`` is a two-dimensional array, and
all entries of ``a`` are in-bounds (i.e. ``0 <= a_i < len(choices)``), then the
function is equivalent to the following::
def choose(a, choices):
return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)])
In the more general case, ``a`` may have any number of dimensions and ``choices``
may be an arbitrary sequence of broadcast-compatible arrays. In this case, again
for in-bound indices, the logic is equivalent to::
def choose(a, choices):
a, *choices = jnp.broadcast_arrays(a, *choices)
choices = jnp.array(choices)
return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)])
The only additional complexity comes from the ``mode`` argument, which controls
the behavior for out-of-bound indices in ``a`` as described below.
Args:
a: an N-dimensional array of integer indices.
choices: an array or sequence of arrays. All arrays in the sequence must be
mutually broadcast compatible with ``a``.
out: unused by JAX
mode: specify the out-of-bounds indexing mode; one of ``'raise'`` (default),
``'wrap'``, or ``'clip'``. Note that the default mode of ``'raise'`` is
not compatible with JAX transformations.
Returns:
an array containing stacked slices from ``choices`` at the indices
specified by ``a``. The shape of the result is
``broadcast_shapes(a.shape, *(c.shape for c in choices))``.
See also:
- :func:`jax.lax.switch`: choose between N functions based on an index.
Examples:
Here is the simplest case of a 1D index array with a 2D choice array,
in which case this chooses the indexed value from each column:
>>> choices = jnp.array([[ 1, 2, 3, 4],
... [ 5, 6, 7, 8],
... [ 9, 10, 11, 12]])
>>> a = jnp.array([2, 0, 1, 0])
>>> jnp.choose(a, choices)
Array([9, 2, 7, 4], dtype=int32)
The ``mode`` argument specifies what to do with out-of-bound indices;
options are to either ``wrap`` or ``clip``:
>>> a2 = jnp.array([2, 0, 1, 4]) # last index out-of-bound
>>> jnp.choose(a2, choices, mode='clip')
Array([ 9, 2, 7, 12], dtype=int32)
>>> jnp.choose(a2, choices, mode='wrap')
Array([9, 2, 7, 8], dtype=int32)
In the more general case, ``choices`` may be a sequence of array-like
objects with any broadcast-compatible shapes.
>>> choice_1 = jnp.array([1, 2, 3, 4])
>>> choice_2 = 99
>>> choice_3 = jnp.array([[10],
... [20],
... [30]])
>>> a = jnp.array([[0, 1, 2, 0],
... [1, 2, 0, 1],
... [2, 0, 1, 2]])
>>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap')
Array([[ 1, 99, 10, 4],
[99, 20, 3, 99],
[30, 2, 99, 30]], dtype=int32)
"""
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
Expand Down
2 changes: 1 addition & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def cbrt(x: ArrayLike, /) -> Array: ...
cdouble: Any
def ceil(x: ArrayLike, /) -> Array: ...
character = _np.character
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
def choose(a: ArrayLike, choices: Array | np.ndarray | Sequence[ArrayLike],
out: None = ..., mode: str = ...) -> Array: ...
def clip(
x: ArrayLike | None = ...,
Expand Down

0 comments on commit 36f74bb

Please sign in to comment.