diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 0f27cb5ff409..1dfbc52422e7 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3736,8 +3736,50 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) -@util.implements(np.copy, lax_description=_ARRAY_DOC) def copy(a: ArrayLike, order: str | None = None) -> Array: + """Return a copy of the array. + + JAX implementation of :func:`numpy.copy`. + + Args: + a: arraylike object to copy + order: not implemented in JAX + + Returns: + a copy of the input array ``a``. + + See Also: + - :func:`jax.numpy.array`: create an array with or without a copy. + - :meth:`jax.Array.copy`: same function accessed as an array method. + + Examples: + Since JAX arrays are immutable, in most cases explicit array copies + are not necessary. One exception is when using a function with donated + arguments (see the ``donate_argnums`` argument to :func:`jax.jit`). + + >>> f = jax.jit(lambda x: 2 * x, donate_argnums=0) + >>> x = jnp.arange(4) + >>> y = f(x) + >>> print(y) + [0 2 4 6] + + Because we marked ``x`` as being donated, the original array is no longer + available: + + >>> print(x) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + RuntimeError: Array has been deleted with shape=int32[4]. + + In situations like this, an explicit copy will let you keep access to the + original buffer: + + >>> x = jnp.arange(4) + >>> y = f(x.copy()) + >>> print(y) + [0 2 4 6] + >>> print(x) + [0 1 2 3] + """ util.check_arraylike("copy", a) return array(a, copy=True, order=order)