Skip to content

Commit

Permalink
Merge pull request #23233 from rajasekharporeddy:testbranch1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668061985
  • Loading branch information
jax authors committed Aug 27, 2024
2 parents 140955d + 2f3d428 commit 6ee4136
Showing 1 changed file with 57 additions and 2 deletions.
59 changes: 57 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,36 @@ def result_type(*args: Any) -> DType:
return dtypes.result_type(*args)


@util.implements(np.trunc, module='numpy')
@jit
def trunc(x: ArrayLike) -> Array:
"""Round input to the nearest integer towards zero.
JAX implementation of :func:`numpy.trunc`.
Args:
x: input array or scalar.
Returns:
An array with same shape and dtype as ``x`` containing the rounded values.
See also:
- :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero.
- :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer.
- :func:`jax.numpy.floor`: Rounds the input down to the nearest integer.
Examples:
>>> key = jax.random.key(42)
>>> x = jax.random.uniform(key, (3, 3), minval=-10, maxval=10)
>>> with jnp.printoptions(precision=2, suppress=True):
... print(x)
[[ 2.88 -3.55 -6.13]
[ 7.73 4.49 -6.16]
[-3.1 -4.95 2.64]]
>>> jnp.trunc(x)
Array([[ 2., -3., -6.],
[ 7., 4., -6.],
[-3., -4., 2.]], dtype=float32)
"""
util.check_arraylike('trunc', x)
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
return lax_internal.asarray(x)
Expand Down Expand Up @@ -2653,9 +2680,37 @@ def _round_float(x: ArrayLike) -> Array:
round_ = round


@util.implements(np.fix, skip_params=['out'])
@jit
def fix(x: ArrayLike, out: None = None) -> Array:
"""Round input to the nearest integer towards zero.
JAX implementation of :func:`numpy.fix`.
Args:
x: input array.
out: unused by JAX.
Returns:
An array with same shape and dtype as ``x`` containing the rounded values.
See also:
- :func:`jax.numpy.trunc`: Rounds the input to nearest integer towards zero.
- :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer.
- :func:`jax.numpy.floor`: Rounds the input down to the nearest integer.
Examples:
>>> key = jax.random.key(0)
>>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5)
>>> with jnp.printoptions(precision=2, suppress=True):
... print(x)
[[-1.45 1.04 -0.72]
[-2.69 1.74 -0.6 ]
[-2.49 -2.23 2.68]]
>>> jnp.fix(x)
Array([[-1., 1., -0.],
[-2., 1., -0.],
[-2., -2., 2.]], dtype=float32)
"""
util.check_arraylike("fix", x)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.fix is not supported.")
Expand Down

0 comments on commit 6ee4136

Please sign in to comment.