diff --git a/arkouda/array_api/searching_functions.py b/arkouda/array_api/searching_functions.py index 282d4e75b8..eceabe9d03 100644 --- a/arkouda/array_api/searching_functions.py +++ b/arkouda/array_api/searching_functions.py @@ -8,7 +8,7 @@ from .manipulation_functions import squeeze, reshape, broadcast_arrays from arkouda.client import generic_msg -from arkouda.pdarrayclass import parse_single_value, create_pdarray +from arkouda.pdarrayclass import parse_single_value, create_pdarray, create_pdarrays from arkouda.pdarraycreation import scalar_array from arkouda.numeric import cast as akcast import arkouda as ak @@ -110,12 +110,12 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: resp = cast( str, generic_msg( - cmd=f"nonzero{x.ndim}D", + cmd=f"nonzero<{x.dtype},{x.ndim}>", args={"x": x._array}, ), ) - return tuple([Array._new(create_pdarray(a)) for a in resp.split("+")]) + return tuple([Array._new(a) for a in create_pdarrays(resp)]) def where(condition: Array, x1: Array, x2: Array, /) -> Array: diff --git a/src/AryUtil.chpl b/src/AryUtil.chpl index 453ec8b148..da573e77c2 100644 --- a/src/AryUtil.chpl +++ b/src/AryUtil.chpl @@ -260,30 +260,44 @@ module AryUtil } } + // overload for tuple of axes + iter axisSlices(D: domain(?), axes: int ...?N): (domain(?), D.rank*int) throws + where N <= D.rank + { + for sliceIdx in domOffAxis(D, (...axes)) { + yield (domOnAxis(D, if D.rank == 1 then (sliceIdx,) else sliceIdx, (...axes)), sliceIdx); + } + } + + iter axisSlices(param tag: iterKind, D: domain(?), axes: int ...?N): (domain(?), D.rank*int) throws + where tag == iterKind.standalone && N <= D.rank + { + forall sliceIdx in domOffAxis(D, (...axes)) { + yield (domOnAxis(D, if D.rank == 1 then (sliceIdx,) else sliceIdx, (...axes)), sliceIdx); + } + } + /* - Naively create a domain over a chunk of the input domain + Create a domain over a chunk of the input domain - Chunks are created by splitting the largest dimension of the input domain + Chunks are created by splitting the 0th dimension of the input domain into 'nChunks' roughly equal-sized chunks, and then taking the 'chunkIdx'-th chunk - (if 'nChunks' is greater than the size of the largest dimension, the + (if 'nChunks' is greater than the size of the first dimension, the first 'nChunks-1' chunks will be empty, and the last chunk will contain - the entire set of indices along that dimension) + the entire set of indices) */ proc subDomChunk(dom: domain(?), chunkIdx: int, nChunks: int): domain(?) { - const dimSizes = [i in 0..', ark_shuffle_bigint_1, 'RandMsg', 696); +import ReductionMsg; + +proc ark_nonzero_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=int, array_nd=1); +registerFunction('nonzero', ark_nonzero_int_1, 'ReductionMsg', 325); + +proc ark_nonzero_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=uint, array_nd=1); +registerFunction('nonzero', ark_nonzero_uint_1, 'ReductionMsg', 325); + +proc ark_nonzero_uint8_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=uint(8), array_nd=1); +registerFunction('nonzero', ark_nonzero_uint8_1, 'ReductionMsg', 325); + +proc ark_nonzero_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=real, array_nd=1); +registerFunction('nonzero', ark_nonzero_real_1, 'ReductionMsg', 325); + +proc ark_nonzero_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=bool, array_nd=1); +registerFunction('nonzero', ark_nonzero_bool_1, 'ReductionMsg', 325); + +proc ark_nonzero_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=bigint, array_nd=1); +registerFunction('nonzero', ark_nonzero_bigint_1, 'ReductionMsg', 325); + import StatsMsg; proc ark_reg_mean_generic(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_0, param array_nd_0: int): MsgTuple throws { diff --git a/tests/array_api/searching_functions.py b/tests/array_api/searching_functions.py index 4209c60157..7a8ea99bd5 100644 --- a/tests/array_api/searching_functions.py +++ b/tests/array_api/searching_functions.py @@ -44,19 +44,36 @@ def test_argmin(self): @pytest.mark.skip_if_max_rank_less_than(3) def test_nonzero(self): - a = xp.zeros((4, 5, 6), dtype=ak.int64) + a = xp.zeros((40, 15, 16), dtype=ak.int64) a[0, 1, 0] = 1 a[1, 2, 3] = 1 a[2, 2, 2] = 1 a[3, 2, 1] = 1 + a[10, 10, 10] = 1 + a[30, 12, 11] = 1 + a[2, 13, 14] = 1 + a[3, 14, 15] = 1 nz = xp.nonzero(a) - print(nz) + a_np = a.to_ndarray() + nz_np = np.nonzero(a_np) + + assert nz[0].tolist() == nz_np[0].tolist() + assert nz[1].tolist() == nz_np[1].tolist() + assert nz[2].tolist() == nz_np[2].tolist() + + def test_nonzero_1d(self): + b = xp.zeros(500, dtype=ak.int64) + b[0] = 1 + b[12] = 1 + b[100] = 1 + b[205] = 1 + b[490] = 1 + + nz = xp.nonzero(b) - assert sorted(nz[0].tolist()) == sorted([0, 1, 2, 3]) - assert sorted(nz[1].tolist()) == sorted([1, 2, 2, 2]) - assert sorted(nz[2].tolist()) == sorted([0, 3, 2, 1]) + assert nz[0].tolist() == [0, 12, 100, 205, 490] @pytest.mark.skip_if_max_rank_less_than(3) def test_where(self):