From ad6c3a7f64c90472af50d59ea294d14a2aab5575 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 25 Sep 2024 14:41:13 -0700 Subject: [PATCH] Improve docs for jnp.pad --- jax/_src/numpy/lax_numpy.py | 120 ++++++++++++++++++++++++++++++++++-- jax/_src/numpy/util.py | 7 --- 2 files changed, 115 insertions(+), 12 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 73e27245cfa9..75bcbede6650 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3750,13 +3750,123 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str, "not implemented modes") -@util.implements(np.pad, lax_description="""\ -Unlike numpy, JAX "function" mode's argument (which is another function) should return -the modified array. This is because Jax arrays are immutable. -(In numpy, "function" mode's argument should modify a rank 1 array in-place.) -""") def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], mode: str | Callable[..., Any] = "constant", **kwargs) -> Array: + """Add padding to an array. + + JAX implementation of :func:`numpy.pad`. + + Args: + array: array to pad. + pad_width: specify the pad width for each dimension of an array. Padding widths + may be separately specified for *before* and *after* the array. Options are: + + - ``int`` or ``(int,)``: pad each array dimension with the same number of values + both before and after. + - ``(before, after)``: pad each array with ``before`` elements before, and ``after`` + elements after + - ``((before_1, after_1), (before_2, after_2), ... (before_N, after_N))``: specify + distinct ``before`` and ``after`` values for each array dimension. + + mode: a string or callable. Supported pad modes are: + + - ``'constant'`` (default): pad with a constant value, which defaults to zero. + - ``'empty'``: pad with empty values (i.e. zero) + - ``'edge'``: pad with the edge values of the array. + - ``'wrap'``: pad by wrapping the array. + - ``'linear_ramp'``: pad with a linear ramp to specified ``end_values``. + - ``'maximum'``: pad with the maximum value. + - ``'mean'``: pad with the mean value. + - ``'median'``: pad with the median value. + - ``'minimum'``: pad with the minimum value. + - ``'reflect'``: pad by reflection. + - ``'symmetric'``: pad by symmetric reflection. + - ````: a callable function. See Notes below. + + constant_values: referenced for ``mode = 'constant'``. Specify the constant value + to pad with. + stat_length: referenced for ``mode in ['maximum', 'mean', 'median', 'minimum']``. + An integer or tuple specifying the number of edge values to use when calculating + the statistic. + end_values: referenced for ``mode = 'linear_ramp'``. Specify the end values to + ramp the padding values to. + reflect_type: referenced for ``mode in ['reflect', 'symmetric']``. Specify whether + to use even or odd reflection. + + Returns: + A padded copy of ``array``. + + Notes: + When ``mode`` is callable, it should have the following signature:: + + def pad_func(row: Array, pad_width: tuple[int, int], + iaxis: int, kwargs: dict) -> Array: + ... + + Here ``row`` is a 1D slice of the padded array along axis ``iaxis``, with the pad + values filled with zeros. ``pad_width`` is a tuple specifying the ``(before, after)`` + padding sizes, and ``kwargs`` are any additional keyword arguments passed to the + :func:`jax.numpy.pad` function. + + Note that while in NumPy, the function should modify ``row`` in-place, in JAX the + function should return the modified ``row``. In JAX, the custom padding function + will be mapped across the padded axis using the :func:`jax.vmap` transformation. + + See also: + - :func:`jax.numpy.resize`: resize an array + - :func:`jax.numpy.tile`: create a larger array by tiling a smaller array. + - :func:`jax.numpy.repeat`: create a larger array by repeating values of a smaller array. + + Examples: + + Pad a 1-dimensional array with zeros: + + >>> x = jnp.array([10, 20, 30, 40]) + >>> jnp.pad(x, 2) + Array([ 0, 0, 10, 20, 30, 40, 0, 0], dtype=int32) + >>> jnp.pad(x, (2, 4)) + Array([ 0, 0, 10, 20, 30, 40, 0, 0, 0, 0], dtype=int32) + + Pad a 1-dimensional array with specified values: + + >>> jnp.pad(x, 2, constant_values=99) + Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32) + + Pad a 1-dimensional array with the mean array value: + + >>> jnp.pad(x, 2, mode='mean') + Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32) + + Pad a 1-dimensional array with reflected values: + + >>> jnp.pad(x, 2, mode='reflect') + Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32) + + Pad a 2-dimensional array with different paddings in each dimension: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.pad(x, ((1, 2), (3, 0))) + Array([[0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 3], + [0, 0, 0, 4, 5, 6], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], dtype=int32) + + Pad a 1-dimensional array with a custom padding function: + + >>> def custom_pad(row, pad_width, iaxis, kwargs): + ... # row represents a 1D slice of the zero-padded array. + ... before, after = pad_width + ... before_value = kwargs.get('before_value', 0) + ... after_value = kwargs.get('after_value', 0) + ... row = row.at[:before].set(before_value) + ... return row.at[len(row) - after:].set(after_value) + >>> x = jnp.array([2, 3, 4]) + >>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10) + Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32) + """ + util.check_arraylike("pad", array) pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width") if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1]) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e9d1db26731c..9c9bc5d389e1 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -114,7 +114,6 @@ def _parse_parameters(body: str) -> dict[str, str]: def implements( original_fun: Callable[..., Any] | None, update_doc: bool = True, - lax_description: str = "", sections: Sequence[str] = ('Parameters', 'Returns', 'References'), skip_params: Sequence[str] = (), module: str | None = None, @@ -132,8 +131,6 @@ def implements( update_doc: whether to transform the numpy docstring to remove references of parameters that are supported by the numpy version but not the JAX version. If False, include the numpy docstring verbatim. - lax_description: a string description that will be added to the beginning of - the docstring. sections: a list of sections to include in the docstring. The default is ["Parameters", "Returns", "References"] skip_params: a list of strings containing names of parameters accepted by the @@ -146,8 +143,6 @@ def decorator(wrapped_fun): wrapped_fun.__np_wrapped__ = original_fun # Allows this pattern: @implements(getattr(np, 'new_function', None)) if original_fun is None: - if lax_description: - wrapped_fun.__doc__ = lax_description return wrapped_fun docstr = getattr(original_fun, "__doc__", None) name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun))) @@ -181,8 +176,6 @@ def decorator(wrapped_fun): docstr = parsed.summary.strip() + "\n" if parsed.summary else "" docstr += f"\nLAX-backend implementation of :func:`{name}`.\n" - if lax_description: - docstr += "\n" + lax_description.strip() + "\n" docstr += "\n*Original docstring below.*\n" # We remove signatures from the docstrings, because they redundant at best and