From 83a7555ffd355545c1f3d4642eaf4ab5d18ebcc8 Mon Sep 17 00:00:00 2001 From: selamw1 Date: Mon, 16 Sep 2024 16:47:52 -0700 Subject: [PATCH] docstring_sort_complex_added input_array_modified --- jax/_src/numpy/lax_numpy.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ef6a30400d30..69b3e1023fa3 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -8561,9 +8561,29 @@ def sort( return lax.rev(result, dimensions=[dimension]) if descending else result -@util.implements(np.sort_complex) @jit def sort_complex(a: ArrayLike) -> Array: + """Return a sorted copy of complex array. + + JAX implementation of :func:`numpy.sort_complex`. + + Complex numbers are sorted lexicographically, meaning by their real part + first, and then by their imaginary part if real parts are equal. + + Args: + a: input array. If dtype is not complex, the array will be upcast to complex. + + Returns: + A sorted array of the same shape and complex dtype as the input. + + See also: + - :func:`jax.numpy.sort`: Return a sorted copy of an array. + + Examples: + >>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j]) + >>> jnp.sort_complex(a) + Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64) + """ util.check_arraylike("sort_complex", a) a = lax.sort(asarray(a), dimension=0) return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))