From 9d757cdb85d5f4b854abd23fb9178280c9c1f481 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Fri, 30 Aug 2024 21:35:44 +0530 Subject: [PATCH] Use :obj: instead of :func: in ufuncs.py --- jax/_src/numpy/ufuncs.py | 48 +++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index dfeff38df0fe..1f66ee22265b 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -58,7 +58,7 @@ def _to_bool(x: Array) -> Array: def fabs(x: ArrayLike, /) -> Array: """Compute the element-wise absolute values of the real-valued input. - JAX implementation of :func:`numpy.fabs`. + JAX implementation of :obj:`numpy.fabs`. Args: x: input array or scalar. Must not have a complex dtype. @@ -132,7 +132,7 @@ def sign(x: ArrayLike, /) -> Array: def floor(x: ArrayLike, /) -> Array: """Round input to the nearest integer downwards. - JAX implementation of :func:`numpy.floor`. + JAX implementation of :obj:`numpy.floor`. Args: x: input array or scalar. Must not have complex dtype. @@ -170,7 +170,7 @@ def floor(x: ArrayLike, /) -> Array: def ceil(x: ArrayLike, /) -> Array: """Round input to the nearest integer upwards. - JAX implementation of :func:`numpy.ceil`. + JAX implementation of :obj:`numpy.ceil`. Args: x: input array or scalar. Must not have complex dtype. @@ -466,7 +466,7 @@ def bitwise_count(x: ArrayLike, /) -> Array: r"""Counts the number of 1 bits in the binary representation of the absolute value of each element of ``x``. - LAX-backend implementation of :func:`numpy.bitwise_count`. + JAX implementation of :obj:`numpy.bitwise_count`. Args: x: Input array, only accepts integer subtypes @@ -500,7 +500,7 @@ def bitwise_count(x: ArrayLike, /) -> Array: def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Right shift the bits of ``x1`` to the amount specified in ``x2``. - LAX-backend implementation of :func:`numpy.right_shift`. + JAX implementation of :obj:`numpy.right_shift`. Args: x1: Input array, only accepts unsigned integer subtypes @@ -559,7 +559,7 @@ def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: def absolute(x: ArrayLike, /) -> Array: r"""Calculate the absolute value element-wise. - LAX-backend implementation of :func:`numpy.absolute`. + JAX implementation of :obj:`numpy.absolute`. This is the same function as :func:`jax.numpy.abs`. @@ -600,7 +600,7 @@ def abs(x: ArrayLike, /) -> Array: def rint(x: ArrayLike, /) -> Array: """Rounds the elements of x to the nearest integer - LAX-backend implementation of :func:`numpy.rint`. + JAX implementation of :obj:`numpy.rint`. Args: x: Input array @@ -639,7 +639,7 @@ def rint(x: ArrayLike, /) -> Array: def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Copies the sign of each element in ``x2`` to the corresponding element in ``x1``. - LAX-backend implementation of :func:`numpy.copysign`. + JAX implementation of :obj:`numpy.copysign`. Args: x1: Input array @@ -687,7 +687,7 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculates the floor division of x1 by x2 element-wise - LAX-backend implementation of :func:`numpy.floor_divide`. + JAX implementation of :obj:`numpy.floor_divide`. Args: x1: Input array, the dividend @@ -698,6 +698,14 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: to the nearest integer towards negative infinity. This is equivalent to ``x1 // x2`` in Python. + Note: + ``x1 // x2`` is equivalent to ``jnp.floor_divide(x1, x2)`` for arrays ``x1`` + and ``x2`` + + See Also: + :func:`jax.numpy.divide` and :func:`jax.numpy.true_divide` for floating point + division. + Examples: >>> x1 = jnp.array([10, 20, 30]) >>> x2 = jnp.array([3, 4, 7]) @@ -713,12 +721,6 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: >>> x2 = jnp.array([2.0, 2.5, 3.0], dtype=jnp.float32) >>> jnp.floor_divide(x1, x2) Array([3., 2., 2.], dtype=float32) - - Note: - ``x1 // x2`` is equivalent to ``jnp.floor_divide(x1, x2)`` for arrays ``x1`` and ``x2`` - - See Also: - :func:`jnp.divide` and :func:`jnp.true_divide` for floating point division """ x1, x2 = promote_args_numeric("floor_divide", x1, x2) dtype = dtypes.dtype(x1) @@ -739,7 +741,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: """Calculates the integer quotient and remainder of x1 by x2 element-wise - LAX-backend implementation of :func:`numpy.divmod`. + JAX implementation of :obj:`numpy.divmod`. Args: x1: Input array, the dividend @@ -748,6 +750,10 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: Returns: A tuple of arrays ``(x1 // x2, x1 % x2)``. + See Also: + - :func:`jax.numpy.floor_divide`: floor division function + - :func:`jax.numpy.remainder`: remainder function + Examples: >>> x1 = jnp.array([10, 20, 30]) >>> x2 = jnp.array([3, 4, 7]) @@ -765,10 +771,6 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: >>> jnp.divmod(x1, x2) (Array([3., 2., 1.], dtype=float32), Array([0.30000007, 1. , 2.9 ], dtype=float32)) - - See Also: - - :func:`jax.numpy.floor_divide`: floor division function - - :func:`jax.numpy.remainder`: remainder function """ x1, x2 = promote_args_numeric("divmod", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.integer): @@ -862,7 +864,7 @@ def _pow_int_int(x1, x2): def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute ``log(exp(x1) + exp(x2))`` avoiding overflow. - JAX implementation of :func:`numpy.logaddexp` + JAX implementation of :obj:`numpy.logaddexp` Args: x1: input array @@ -927,7 +929,7 @@ def _logaddexp2_jvp(primals, tangents): def log2(x: ArrayLike, /) -> Array: """Calculates the base-2 logarithm of x element-wise - LAX-backend implementation of :func:`numpy.log2`. + JAX implementation of :obj:`numpy.log2`. Args: x: Input array @@ -949,7 +951,7 @@ def log2(x: ArrayLike, /) -> Array: def log10(x: ArrayLike, /) -> Array: """Calculates the base-10 logarithm of x element-wise - LAX-backend implementation of :func:`numpy.log10`. + JAX implementation of :obj:`numpy.log10`. Args: x: Input array