Skip to content

Commit

Permalink
docstring_sort_complex_added
Browse files Browse the repository at this point in the history
  • Loading branch information
selamw1 committed Sep 16, 2024
1 parent 7dde9b2 commit 0a14edc
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8561,9 +8561,29 @@ def sort(
return lax.rev(result, dimensions=[dimension]) if descending else result


@util.implements(np.sort_complex)
@jit
def sort_complex(a: ArrayLike) -> Array:
"""Return a sorted copy of complex array.
JAX implementation of :func:`numpy.sort_complex`.
Complex numbers are sorted lexicographically, meaning by their real part
first, and then by their imaginary part if real parts are equal.
Args:
a: input complex array.
Returns:
A sorted array of the same shape and complex dtype as the input.
See also:
- :func:`jax.numpy.sort`: Return a sorted copy of an array.
Examples:
>>> x = jnp.array([1+2j, 2+4j, 3-1j, 2+3j])
>>> jnp.sort_complex(x)
Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64)
"""
util.check_arraylike("sort_complex", a)
a = lax.sort(asarray(a), dimension=0)
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
Expand Down

0 comments on commit 0a14edc

Please sign in to comment.