From ced012f5eda2da851ed1c905496b884ccca0c84e Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 28 Aug 2024 20:16:09 +0530 Subject: [PATCH] Update jnp.fabs to emulate the behavior of np.fabs for complex inputs --- CHANGELOG.md | 2 ++ jax/_src/numpy/ufuncs.py | 41 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d09d610bdb5..9d0b55b36476 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `jax.config.update('jax_cpu_enable_async_dispatch', False)`. * Added new {func}`jax.process_indices` function to replace the `jax.host_ids()` function that was deprecated in JAX v0.2.13. + * To align with the behavior of `numpy.fabs`, `jax.numpy.fabs` has been + modified to no longer support `complex dtypes`. * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 2893b14f7059..c4f9009eb877 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -52,9 +52,48 @@ def _replace_inf(x: ArrayLike) -> Array: def _to_bool(x: Array) -> Array: return x if x.dtype == bool else lax.ne(x, _lax_const(x, 0)) -@implements(np.fabs, module='numpy') + @partial(jit, inline=True) def fabs(x: ArrayLike, /) -> Array: + """Compute the element-wise absolute values of the real-valued input. + + JAX implementation of :func:`numpy.fabs`. + + Args: + x: input array or scalar. Must not have a complex dtype. + + Returns: + An array with same shape as ``x`` and dtype float, containing the element-wise + absolute values. + + See also: + - :func:`jax.numpy.absolute`: Computes the absolute values of the input including + complex dtypes. + - :func:`jax.numpy.abs`: Computes the absolute values of the input including + complex dtypes. + + Examples: + For integer inputs: + + >>> x = jnp.array([-5, -9, 1, 10, 15]) + >>> jnp.fabs(x) + Array([ 5., 9., 1., 10., 15.], dtype=float32) + + For float type inputs: + + >>> x1 = jnp.array([-1.342, 5.649, 3.927]) + >>> jnp.fabs(x1) + Array([1.342, 5.649, 3.927], dtype=float32) + + For boolean inputs: + + >>> x2 = jnp.array([True, False]) + >>> jnp.fabs(x2) + Array([1., 0.], dtype=float32) + """ + check_arraylike('fabs', x) + if dtypes.issubdtype(dtypes.dtype(x), np.complexfloating): + raise TypeError("ufunc 'fabs' does not support complex dtypes") return lax.abs(*promote_args_inexact('fabs', x)) @implements(getattr(np, 'bitwise_invert', np.invert), module='numpy')