Skip to content

Commit

Permalink
Merge pull request #23291 from rajasekharporeddy:testbranch2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668564782
  • Loading branch information
jax authors committed Aug 28, 2024
2 parents 3e63a65 + 63004b9 commit f0a7266
Showing 1 changed file with 62 additions and 2 deletions.
64 changes: 62 additions & 2 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,77 @@ def positive(x: ArrayLike, /) -> Array:
def sign(x: ArrayLike, /) -> Array:
return lax.sign(*promote_args('sign', x))

@implements(np.floor, module='numpy')

@partial(jit, inline=True)
def floor(x: ArrayLike, /) -> Array:
"""Round input to the nearest integer downwards.
JAX implementation of :func:`numpy.floor`.
Args:
x: input array or scalar. Must not have complex dtype.
Returns:
An array with same shape and dtype as ``x`` containing the values rounded to
the nearest integer that is less than or equal to the value itself.
See also:
- :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero.
- :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards
zero.
- :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer.
Examples:
>>> key = jax.random.key(42)
>>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5)
>>> with jnp.printoptions(precision=2, suppress=True):
... print(x)
[[ 1.44 -1.77 -3.07]
[ 3.86 2.25 -3.08]
[-1.55 -2.48 1.32]]
>>> jnp.floor(x)
Array([[ 1., -2., -4.],
[ 3., 2., -4.],
[-2., -3., 1.]], dtype=float32)
"""
check_arraylike('floor', x)
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
return lax.asarray(x)
return lax.floor(*promote_args_inexact('floor', x))

@implements(np.ceil, module='numpy')

@partial(jit, inline=True)
def ceil(x: ArrayLike, /) -> Array:
"""Round input to the nearest integer upwards.
JAX implementation of :func:`numpy.ceil`.
Args:
x: input array or scalar. Must not have complex dtype.
Returns:
An array with same shape and dtype as ``x`` containing the values rounded to
the nearest integer that is greater than or equal to the value itself.
See also:
- :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero.
- :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards
zero.
- :func:`jax.numpy.floor`: Rounds the input down to the nearest integer.
Examples:
>>> key = jax.random.key(1)
>>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5)
>>> with jnp.printoptions(precision=2, suppress=True):
... print(x)
[[ 2.55 -1.87 -3.76]
[ 0.48 3.85 -1.94]
[ 3.2 4.56 -1.43]]
>>> jnp.ceil(x)
Array([[ 3., -1., -3.],
[ 1., 4., -1.],
[ 4., 5., -1.]], dtype=float32)
"""
check_arraylike('ceil', x)
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
return lax.asarray(x)
Expand Down

0 comments on commit f0a7266

Please sign in to comment.