Skip to content

Commit

Permalink
Merge pull request #23679 from selamw1:docstring_sort_complex
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675726527
  • Loading branch information
Google-ML-Automation committed Sep 17, 2024
2 parents 3f2c58b + 83a7555 commit 7864648
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8819,9 +8819,29 @@ def sort(
return lax.rev(result, dimensions=[dimension]) if descending else result


@util.implements(np.sort_complex)
@jit
def sort_complex(a: ArrayLike) -> Array:
"""Return a sorted copy of complex array.
JAX implementation of :func:`numpy.sort_complex`.
Complex numbers are sorted lexicographically, meaning by their real part
first, and then by their imaginary part if real parts are equal.
Args:
a: input array. If dtype is not complex, the array will be upcast to complex.
Returns:
A sorted array of the same shape and complex dtype as the input.
See also:
- :func:`jax.numpy.sort`: Return a sorted copy of an array.
Examples:
>>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j])
>>> jnp.sort_complex(a)
Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64)
"""
util.check_arraylike("sort_complex", a)
a = lax.sort(asarray(a), dimension=0)
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
Expand Down

0 comments on commit 7864648

Please sign in to comment.