Skip to content

Commit

Permalink
Merge pull request #23927 from jakevdp:pad-doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679142773
  • Loading branch information
Google-ML-Automation committed Sep 26, 2024
2 parents 5788773 + ad6c3a7 commit c07652f
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 12 deletions.
120 changes: 115 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3750,13 +3750,123 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str,
"not implemented modes")


@util.implements(np.pad, lax_description="""\
Unlike numpy, JAX "function" mode's argument (which is another function) should return
the modified array. This is because Jax arrays are immutable.
(In numpy, "function" mode's argument should modify a rank 1 array in-place.)
""")
def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray],
mode: str | Callable[..., Any] = "constant", **kwargs) -> Array:
"""Add padding to an array.
JAX implementation of :func:`numpy.pad`.
Args:
array: array to pad.
pad_width: specify the pad width for each dimension of an array. Padding widths
may be separately specified for *before* and *after* the array. Options are:
- ``int`` or ``(int,)``: pad each array dimension with the same number of values
both before and after.
- ``(before, after)``: pad each array with ``before`` elements before, and ``after``
elements after
- ``((before_1, after_1), (before_2, after_2), ... (before_N, after_N))``: specify
distinct ``before`` and ``after`` values for each array dimension.
mode: a string or callable. Supported pad modes are:
- ``'constant'`` (default): pad with a constant value, which defaults to zero.
- ``'empty'``: pad with empty values (i.e. zero)
- ``'edge'``: pad with the edge values of the array.
- ``'wrap'``: pad by wrapping the array.
- ``'linear_ramp'``: pad with a linear ramp to specified ``end_values``.
- ``'maximum'``: pad with the maximum value.
- ``'mean'``: pad with the mean value.
- ``'median'``: pad with the median value.
- ``'minimum'``: pad with the minimum value.
- ``'reflect'``: pad by reflection.
- ``'symmetric'``: pad by symmetric reflection.
- ``<callable>``: a callable function. See Notes below.
constant_values: referenced for ``mode = 'constant'``. Specify the constant value
to pad with.
stat_length: referenced for ``mode in ['maximum', 'mean', 'median', 'minimum']``.
An integer or tuple specifying the number of edge values to use when calculating
the statistic.
end_values: referenced for ``mode = 'linear_ramp'``. Specify the end values to
ramp the padding values to.
reflect_type: referenced for ``mode in ['reflect', 'symmetric']``. Specify whether
to use even or odd reflection.
Returns:
A padded copy of ``array``.
Notes:
When ``mode`` is callable, it should have the following signature::
def pad_func(row: Array, pad_width: tuple[int, int],
iaxis: int, kwargs: dict) -> Array:
...
Here ``row`` is a 1D slice of the padded array along axis ``iaxis``, with the pad
values filled with zeros. ``pad_width`` is a tuple specifying the ``(before, after)``
padding sizes, and ``kwargs`` are any additional keyword arguments passed to the
:func:`jax.numpy.pad` function.
Note that while in NumPy, the function should modify ``row`` in-place, in JAX the
function should return the modified ``row``. In JAX, the custom padding function
will be mapped across the padded axis using the :func:`jax.vmap` transformation.
See also:
- :func:`jax.numpy.resize`: resize an array
- :func:`jax.numpy.tile`: create a larger array by tiling a smaller array.
- :func:`jax.numpy.repeat`: create a larger array by repeating values of a smaller array.
Examples:
Pad a 1-dimensional array with zeros:
>>> x = jnp.array([10, 20, 30, 40])
>>> jnp.pad(x, 2)
Array([ 0, 0, 10, 20, 30, 40, 0, 0], dtype=int32)
>>> jnp.pad(x, (2, 4))
Array([ 0, 0, 10, 20, 30, 40, 0, 0, 0, 0], dtype=int32)
Pad a 1-dimensional array with specified values:
>>> jnp.pad(x, 2, constant_values=99)
Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32)
Pad a 1-dimensional array with the mean array value:
>>> jnp.pad(x, 2, mode='mean')
Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32)
Pad a 1-dimensional array with reflected values:
>>> jnp.pad(x, 2, mode='reflect')
Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32)
Pad a 2-dimensional array with different paddings in each dimension:
>>> x = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> jnp.pad(x, ((1, 2), (3, 0)))
Array([[0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 2, 3],
[0, 0, 0, 4, 5, 6],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]], dtype=int32)
Pad a 1-dimensional array with a custom padding function:
>>> def custom_pad(row, pad_width, iaxis, kwargs):
... # row represents a 1D slice of the zero-padded array.
... before, after = pad_width
... before_value = kwargs.get('before_value', 0)
... after_value = kwargs.get('after_value', 0)
... row = row.at[:before].set(before_value)
... return row.at[len(row) - after:].set(after_value)
>>> x = jnp.array([2, 3, 4])
>>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10)
Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32)
"""

util.check_arraylike("pad", array)
pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width")
if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1])
Expand Down
7 changes: 0 additions & 7 deletions jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def _parse_parameters(body: str) -> dict[str, str]:
def implements(
original_fun: Callable[..., Any] | None,
update_doc: bool = True,
lax_description: str = "",
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
skip_params: Sequence[str] = (),
module: str | None = None,
Expand All @@ -132,8 +131,6 @@ def implements(
update_doc: whether to transform the numpy docstring to remove references of
parameters that are supported by the numpy version but not the JAX version.
If False, include the numpy docstring verbatim.
lax_description: a string description that will be added to the beginning of
the docstring.
sections: a list of sections to include in the docstring. The default is
["Parameters", "Returns", "References"]
skip_params: a list of strings containing names of parameters accepted by the
Expand All @@ -146,8 +143,6 @@ def decorator(wrapped_fun):
wrapped_fun.__np_wrapped__ = original_fun
# Allows this pattern: @implements(getattr(np, 'new_function', None))
if original_fun is None:
if lax_description:
wrapped_fun.__doc__ = lax_description
return wrapped_fun
docstr = getattr(original_fun, "__doc__", None)
name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun)))
Expand Down Expand Up @@ -181,8 +176,6 @@ def decorator(wrapped_fun):

docstr = parsed.summary.strip() + "\n" if parsed.summary else ""
docstr += f"\nLAX-backend implementation of :func:`{name}`.\n"
if lax_description:
docstr += "\n" + lax_description.strip() + "\n"
docstr += "\n*Original docstring below.*\n"

# We remove signatures from the docstrings, because they redundant at best and
Expand Down

0 comments on commit c07652f

Please sign in to comment.