Skip to content

Commit

Permalink
Update jnp.fabs to emulate the behavior of np.fabs for complex inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
rajasekharporeddy committed Aug 28, 2024
1 parent fc1af8d commit ced012f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* Added new {func}`jax.process_indices` function to replace the
`jax.host_ids()` function that was deprecated in JAX v0.2.13.
* To align with the behavior of `numpy.fabs`, `jax.numpy.fabs` has been
modified to no longer support `complex dtypes`.

* Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
Expand Down
41 changes: 40 additions & 1 deletion jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,48 @@ def _replace_inf(x: ArrayLike) -> Array:
def _to_bool(x: Array) -> Array:
return x if x.dtype == bool else lax.ne(x, _lax_const(x, 0))

@implements(np.fabs, module='numpy')

@partial(jit, inline=True)
def fabs(x: ArrayLike, /) -> Array:
"""Compute the element-wise absolute values of the real-valued input.
JAX implementation of :func:`numpy.fabs`.
Args:
x: input array or scalar. Must not have a complex dtype.
Returns:
An array with same shape as ``x`` and dtype float, containing the element-wise
absolute values.
See also:
- :func:`jax.numpy.absolute`: Computes the absolute values of the input including
complex dtypes.
- :func:`jax.numpy.abs`: Computes the absolute values of the input including
complex dtypes.
Examples:
For integer inputs:
>>> x = jnp.array([-5, -9, 1, 10, 15])
>>> jnp.fabs(x)
Array([ 5., 9., 1., 10., 15.], dtype=float32)
For float type inputs:
>>> x1 = jnp.array([-1.342, 5.649, 3.927])
>>> jnp.fabs(x1)
Array([1.342, 5.649, 3.927], dtype=float32)
For boolean inputs:
>>> x2 = jnp.array([True, False])
>>> jnp.fabs(x2)
Array([1., 0.], dtype=float32)
"""
check_arraylike('fabs', x)
if dtypes.issubdtype(dtypes.dtype(x), np.complexfloating):
raise TypeError("ufunc 'fabs' does not support complex dtypes")
return lax.abs(*promote_args_inexact('fabs', x))

@implements(getattr(np, 'bitwise_invert', np.invert), module='numpy')
Expand Down

0 comments on commit ced012f

Please sign in to comment.