diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index c1b74205e140..47a36a2b83ff 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -10808,11 +10808,46 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', }[method] return impl(asarray(a), asarray(v), side, dtype) # type: ignore -@util.implements(np.digitize, lax_description=_dedent(""" - Optionally, the ``method`` argument can be used to configure the - underlying :func:`jax.numpy.searchsorted` algorithm.""")) + @partial(jit, static_argnames=('right', 'method')) -def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str = 'scan') -> Array: +def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, + *, method: str | None = None) -> Array: + """Convert an array to bin indices. + + JAX implementation of :func:`numpy.digitize`. + + Args: + x: array of values to digitize. + bins: 1D array of bin edges. Must be monotonically increasing or decreasing. + right: if true, the intervals include the right bin edges. If false (default) + the intervals include the left bin edges. + method: optional method argument to be passed to :func:`~jax.numpy.searchsorted`. + See that function for available options. + + Returns: + An integer array of the same shape as ``x`` indicating the bin number that + the values are in. + + See also: + - :func:`jax.numpy.searchsorted`: find insertion indices for values in a + sorted array. + - :func:`jax.numpy.histogram`: compute frequency of array values within + specified bins. + + Examples: + >>> x = jnp.array([1.0, 2.0, 2.5, 1.5, 3.0, 3.5]) + >>> bins = jnp.array([1, 2, 3]) + >>> jnp.digitize(x, bins) + Array([1, 2, 2, 1, 3, 3], dtype=int32) + >>> jnp.digitize(x, bins, right=True) + Array([0, 1, 2, 1, 2, 3], dtype=int32) + + ``digitize`` supports reverse-ordered bins as well: + + >>> bins = jnp.array([3, 2, 1]) + >>> jnp.digitize(x, bins) + Array([2, 1, 1, 2, 0, 0], dtype=int32) + """ util.check_arraylike("digitize", x, bins) right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()") bins_arr = asarray(bins) @@ -10821,10 +10856,11 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str if bins_arr.shape[0] == 0: return zeros_like(x, dtype=int32) side = 'right' if not right else 'left' + kwds: dict[str, str] = {} if method is None else {'method': method} return where( bins_arr[-1] >= bins_arr[0], - searchsorted(bins_arr, x, side=side, method=method), - bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, method=method) + searchsorted(bins_arr, x, side=side, **kwds), + bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds) ) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index d5b66c1b3b32..c23f659bd3f9 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -300,7 +300,8 @@ def diagonal( def diff(a: ArrayLike, n: int = ..., axis: int = ..., prepend: ArrayLike | None = ..., append: ArrayLike | None = ...) -> Array: ... -def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ...) -> Array: ... +def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ..., *, + method: str | None = ...) -> Array: ... divide = true_divide def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ... def dot(