Skip to content

Commit

Permalink
Merge pull request #23336 from rajasekharporeddy:testbranch3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 669363056
  • Loading branch information
jax authors committed Aug 30, 2024
2 parents ca063c7 + 9d757cd commit e494de8
Showing 1 changed file with 25 additions and 23 deletions.
48 changes: 25 additions & 23 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _to_bool(x: Array) -> Array:
def fabs(x: ArrayLike, /) -> Array:
"""Compute the element-wise absolute values of the real-valued input.
JAX implementation of :func:`numpy.fabs`.
JAX implementation of :obj:`numpy.fabs`.
Args:
x: input array or scalar. Must not have a complex dtype.
Expand Down Expand Up @@ -132,7 +132,7 @@ def sign(x: ArrayLike, /) -> Array:
def floor(x: ArrayLike, /) -> Array:
"""Round input to the nearest integer downwards.
JAX implementation of :func:`numpy.floor`.
JAX implementation of :obj:`numpy.floor`.
Args:
x: input array or scalar. Must not have complex dtype.
Expand Down Expand Up @@ -170,7 +170,7 @@ def floor(x: ArrayLike, /) -> Array:
def ceil(x: ArrayLike, /) -> Array:
"""Round input to the nearest integer upwards.
JAX implementation of :func:`numpy.ceil`.
JAX implementation of :obj:`numpy.ceil`.
Args:
x: input array or scalar. Must not have complex dtype.
Expand Down Expand Up @@ -466,7 +466,7 @@ def bitwise_count(x: ArrayLike, /) -> Array:
r"""Counts the number of 1 bits in the binary representation of the absolute value
of each element of ``x``.
LAX-backend implementation of :func:`numpy.bitwise_count`.
JAX implementation of :obj:`numpy.bitwise_count`.
Args:
x: Input array, only accepts integer subtypes
Expand Down Expand Up @@ -500,7 +500,7 @@ def bitwise_count(x: ArrayLike, /) -> Array:
def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
r"""Right shift the bits of ``x1`` to the amount specified in ``x2``.
LAX-backend implementation of :func:`numpy.right_shift`.
JAX implementation of :obj:`numpy.right_shift`.
Args:
x1: Input array, only accepts unsigned integer subtypes
Expand Down Expand Up @@ -559,7 +559,7 @@ def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
def absolute(x: ArrayLike, /) -> Array:
r"""Calculate the absolute value element-wise.
LAX-backend implementation of :func:`numpy.absolute`.
JAX implementation of :obj:`numpy.absolute`.
This is the same function as :func:`jax.numpy.abs`.
Expand Down Expand Up @@ -600,7 +600,7 @@ def abs(x: ArrayLike, /) -> Array:
def rint(x: ArrayLike, /) -> Array:
"""Rounds the elements of x to the nearest integer
LAX-backend implementation of :func:`numpy.rint`.
JAX implementation of :obj:`numpy.rint`.
Args:
x: Input array
Expand Down Expand Up @@ -639,7 +639,7 @@ def rint(x: ArrayLike, /) -> Array:
def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Copies the sign of each element in ``x2`` to the corresponding element in ``x1``.
LAX-backend implementation of :func:`numpy.copysign`.
JAX implementation of :obj:`numpy.copysign`.
Args:
x1: Input array
Expand Down Expand Up @@ -687,7 +687,7 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Calculates the floor division of x1 by x2 element-wise
LAX-backend implementation of :func:`numpy.floor_divide`.
JAX implementation of :obj:`numpy.floor_divide`.
Args:
x1: Input array, the dividend
Expand All @@ -698,6 +698,14 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
to the nearest integer towards negative infinity. This is equivalent
to ``x1 // x2`` in Python.
Note:
``x1 // x2`` is equivalent to ``jnp.floor_divide(x1, x2)`` for arrays ``x1``
and ``x2``
See Also:
:func:`jax.numpy.divide` and :func:`jax.numpy.true_divide` for floating point
division.
Examples:
>>> x1 = jnp.array([10, 20, 30])
>>> x2 = jnp.array([3, 4, 7])
Expand All @@ -713,12 +721,6 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
>>> x2 = jnp.array([2.0, 2.5, 3.0], dtype=jnp.float32)
>>> jnp.floor_divide(x1, x2)
Array([3., 2., 2.], dtype=float32)
Note:
``x1 // x2`` is equivalent to ``jnp.floor_divide(x1, x2)`` for arrays ``x1`` and ``x2``
See Also:
:func:`jnp.divide` and :func:`jnp.true_divide` for floating point division
"""
x1, x2 = promote_args_numeric("floor_divide", x1, x2)
dtype = dtypes.dtype(x1)
Expand All @@ -739,7 +741,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]:
"""Calculates the integer quotient and remainder of x1 by x2 element-wise
LAX-backend implementation of :func:`numpy.divmod`.
JAX implementation of :obj:`numpy.divmod`.
Args:
x1: Input array, the dividend
Expand All @@ -748,6 +750,10 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]:
Returns:
A tuple of arrays ``(x1 // x2, x1 % x2)``.
See Also:
- :func:`jax.numpy.floor_divide`: floor division function
- :func:`jax.numpy.remainder`: remainder function
Examples:
>>> x1 = jnp.array([10, 20, 30])
>>> x2 = jnp.array([3, 4, 7])
Expand All @@ -765,10 +771,6 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]:
>>> jnp.divmod(x1, x2)
(Array([3., 2., 1.], dtype=float32),
Array([0.30000007, 1. , 2.9 ], dtype=float32))
See Also:
- :func:`jax.numpy.floor_divide`: floor division function
- :func:`jax.numpy.remainder`: remainder function
"""
x1, x2 = promote_args_numeric("divmod", x1, x2)
if dtypes.issubdtype(dtypes.dtype(x1), np.integer):
Expand Down Expand Up @@ -862,7 +864,7 @@ def _pow_int_int(x1, x2):
def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Compute ``log(exp(x1) + exp(x2))`` avoiding overflow.
JAX implementation of :func:`numpy.logaddexp`
JAX implementation of :obj:`numpy.logaddexp`
Args:
x1: input array
Expand Down Expand Up @@ -927,7 +929,7 @@ def _logaddexp2_jvp(primals, tangents):
def log2(x: ArrayLike, /) -> Array:
"""Calculates the base-2 logarithm of x element-wise
LAX-backend implementation of :func:`numpy.log2`.
JAX implementation of :obj:`numpy.log2`.
Args:
x: Input array
Expand All @@ -949,7 +951,7 @@ def log2(x: ArrayLike, /) -> Array:
def log10(x: ArrayLike, /) -> Array:
"""Calculates the base-10 logarithm of x element-wise
LAX-backend implementation of :func:`numpy.log10`.
JAX implementation of :obj:`numpy.log10`.
Args:
x: Input array
Expand Down

0 comments on commit e494de8

Please sign in to comment.