From 63004b9be96a5c6dfc669660b8782bc45243dbe7 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 28 Aug 2024 23:46:06 +0530 Subject: [PATCH] Better docs for jnp.floor and jnp.ceil --- jax/_src/numpy/ufuncs.py | 64 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index c4f9009eb877..36ce9f5135a3 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -126,17 +126,77 @@ def positive(x: ArrayLike, /) -> Array: def sign(x: ArrayLike, /) -> Array: return lax.sign(*promote_args('sign', x)) -@implements(np.floor, module='numpy') + @partial(jit, inline=True) def floor(x: ArrayLike, /) -> Array: + """Round input to the nearest integer downwards. + + JAX implementation of :func:`numpy.floor`. + + Args: + x: input array or scalar. Must not have complex dtype. + + Returns: + An array with same shape and dtype as ``x`` containing the values rounded to + the nearest integer that is less than or equal to the value itself. + + See also: + - :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero. + - :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards + zero. + - :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer. + + Examples: + >>> key = jax.random.key(42) + >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(x) + [[ 1.44 -1.77 -3.07] + [ 3.86 2.25 -3.08] + [-1.55 -2.48 1.32]] + >>> jnp.floor(x) + Array([[ 1., -2., -4.], + [ 3., 2., -4.], + [-2., -3., 1.]], dtype=float32) + """ check_arraylike('floor', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): return lax.asarray(x) return lax.floor(*promote_args_inexact('floor', x)) -@implements(np.ceil, module='numpy') + @partial(jit, inline=True) def ceil(x: ArrayLike, /) -> Array: + """Round input to the nearest integer upwards. + + JAX implementation of :func:`numpy.ceil`. + + Args: + x: input array or scalar. Must not have complex dtype. + + Returns: + An array with same shape and dtype as ``x`` containing the values rounded to + the nearest integer that is greater than or equal to the value itself. + + See also: + - :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero. + - :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards + zero. + - :func:`jax.numpy.floor`: Rounds the input down to the nearest integer. + + Examples: + >>> key = jax.random.key(1) + >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(x) + [[ 2.55 -1.87 -3.76] + [ 0.48 3.85 -1.94] + [ 3.2 4.56 -1.43]] + >>> jnp.ceil(x) + Array([[ 3., -1., -3.], + [ 1., 4., -1.], + [ 4., 5., -1.]], dtype=float32) + """ check_arraylike('ceil', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): return lax.asarray(x)