Skip to content

Commit

Permalink
Fixes Bears-R-Us#3676 - index order bug in nonzero (Bears-R-Us#3690)
Browse files Browse the repository at this point in the history
* convert nonzero command to use instantiateAndRegister. Fix intermittent out of bounds indexing error

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* re-include deleted comments in nonzero

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* add bigint overload of nonzero proc

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* assert order of indices in test_nonzero

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* modify subDomChunk to only chunk along 0th dimension to ensure row-major order is maintained in nonzero command

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* first pass at fixing nonzero correctness issue for multi-locale case

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix return type on nonzero1D

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* clean up code and add test for 1D non-zero cmd

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* remove unused proc from AryUtil

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

---------

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>
  • Loading branch information
jeremiah-corrado committed Aug 23, 2024
1 parent 3d65f85 commit eb66ea7
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 85 deletions.
6 changes: 3 additions & 3 deletions arkouda/array_api/searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .manipulation_functions import squeeze, reshape, broadcast_arrays

from arkouda.client import generic_msg
from arkouda.pdarrayclass import parse_single_value, create_pdarray
from arkouda.pdarrayclass import parse_single_value, create_pdarray, create_pdarrays
from arkouda.pdarraycreation import scalar_array
from arkouda.numeric import cast as akcast
import arkouda as ak
Expand Down Expand Up @@ -110,12 +110,12 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]:
resp = cast(
str,
generic_msg(
cmd=f"nonzero{x.ndim}D",
cmd=f"nonzero<{x.dtype},{x.ndim}>",
args={"x": x._array},
),
)

return tuple([Array._new(create_pdarray(a)) for a in resp.split("+")])
return tuple([Array._new(a) for a in create_pdarrays(resp)])


def where(condition: Array, x1: Array, x2: Array, /) -> Array:
Expand Down
40 changes: 27 additions & 13 deletions src/AryUtil.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -260,30 +260,44 @@ module AryUtil
}
}

// overload for tuple of axes
iter axisSlices(D: domain(?), axes: int ...?N): (domain(?), D.rank*int) throws
where N <= D.rank
{
for sliceIdx in domOffAxis(D, (...axes)) {
yield (domOnAxis(D, if D.rank == 1 then (sliceIdx,) else sliceIdx, (...axes)), sliceIdx);
}
}

iter axisSlices(param tag: iterKind, D: domain(?), axes: int ...?N): (domain(?), D.rank*int) throws
where tag == iterKind.standalone && N <= D.rank
{
forall sliceIdx in domOffAxis(D, (...axes)) {
yield (domOnAxis(D, if D.rank == 1 then (sliceIdx,) else sliceIdx, (...axes)), sliceIdx);
}
}

/*
Naively create a domain over a chunk of the input domain
Create a domain over a chunk of the input domain
Chunks are created by splitting the largest dimension of the input domain
Chunks are created by splitting the 0th dimension of the input domain
into 'nChunks' roughly equal-sized chunks, and then taking the
'chunkIdx'-th chunk
(if 'nChunks' is greater than the size of the largest dimension, the
(if 'nChunks' is greater than the size of the first dimension, the
first 'nChunks-1' chunks will be empty, and the last chunk will contain
the entire set of indices along that dimension)
the entire set of indices)
*/
proc subDomChunk(dom: domain(?), chunkIdx: int, nChunks: int): domain(?) {
const dimSizes = [i in 0..<dom.rank] dom.dim(i).size,
(maxDim, maxDimIdx) = maxloc reduce zip(dimSizes, dimSizes.domain);

const chunkSize = maxDim / nChunks,
start = chunkIdx * chunkSize + dom.dim(maxDimIdx).low,
const chunkSize = dom.dim(0).size / nChunks,
start = chunkIdx * chunkSize + dom.dim(0).low,
end = if chunkIdx == nChunks-1
then dom.dim(maxDimIdx).high
else (chunkIdx+1) * chunkSize + dom.dim(maxDimIdx).low - 1;
then dom.dim(0).high
else (chunkIdx+1) * chunkSize + dom.dim(0).low - 1;

var rngs: dom.rank*range;
for i in 0..<dom.rank do rngs[i] = dom.dim(i);
rngs[maxDimIdx] = start..end;
for i in 1..<dom.rank do rngs[i] = dom.dim(i);
rngs[0] = start..end;
return {(...rngs)};
}

Expand Down
159 changes: 95 additions & 64 deletions src/ReductionMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -321,80 +321,111 @@ module ReductionMsg
}
}

/*
Find the indices of all the non-zero elements along each dimension of the input array
@arkouda.instantiateAndRegister
proc nonzero(
cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype,
param array_nd: int
): MsgTuple throws
where array_dtype != bigint
{
var x = st[msgArgs['x']]: SymEntry(array_dtype, array_nd);

// call fast / simple path for 1D arrays
if array_nd == 1 then return st.insert(new shared SymEntry(nonzero1D(x.a)));

var nnzPerSlab: [0..<x.a.domain.dim(0).size] int;
var axes: (x.a.rank - 1)*int;
for i in 1..<x.a.rank do axes[i-1] = i;

// count the number of non-zero elements in each slab
forall (slabDom, slabIdx) in axisSlices(x.a.domain, (...axes)) {
var nnzSlabCount = 0;

// TODO: see comment below about making this a coforall loop
for idx in slabDom do
if x.a[idx] != 0 then nnzSlabCount += 1;
nnzPerSlab[slabIdx[0]] = nnzSlabCount;
}

const nnzTotalCount = + reduce nnzPerSlab,
dimIndexStarts = (+ scan nnzPerSlab) - nnzPerSlab;

var dimIndices = for 0..<array_nd do createSymEntry(nnzTotalCount, int);

// populate the arrays with the indices of the non-zero elements
forall (slabDom, slabIdx) in axisSlices(x.a.domain, (...axes)) {
var i = dimIndexStarts[slabIdx[0]];

/*
TODO: make this a coforall loop over a locale-wise decomposition of 'slabDom'
since it is a (potentially large) distributed domain. This requires computing
each task's starting index in the output array ahead of time (and ensuring
their proper relative ordering in the output arrays (not trivial)). Potentially
not the most performant strategy since multiple tasks and `on` blocks have to be
kicked off by each iteration of the outer `forall` loop?
*/
for idx in slabDom {
if x.a[idx] != 0 {
for d in 0..<array_nd do
dimIndices[d].a[i] = idx[d];
i += 1;
}
}
}

Returns one array of indices for each dimension of the input array
*/
@arkouda.registerND
proc nonzeroMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
const x = msgArgs.getValueOf("x"),
rnames = [i in 0..<nd] st.nextName();
const responses = for di in dimIndices do st.insert(di);
return MsgTuple.fromResponses(responses);
}

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(x, st);

proc findNonZero(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
nTasks = here.maxTaskPar;

// count the number of non-zero elements in a chunk of the input array owned by each task
var nnzPerTask: [0..<numLocales] [0..<nTasks] int;
coforall loc in Locales with (ref nnzPerTask) do on loc {
const locDom = eIn.a.localSubdomain();
coforall tid in 0..<nTasks with (ref nnzPerTask) {
var nnzTask = 0;
// TODO: evaluate whether 'subDomChunk' chunking along the largest dimension
// is the best choice. Perhaps it would be better to always chunk along the
// zeroth dimension for best cache locality (or to use some other technique to
// split work among tasks).
for idx in subDomChunk(locDom, tid, nTasks) do
if eIn.a[idx] != 0:t then nnzTask += 1;
nnzPerTask[loc.id][tid] = nnzTask;
}
// simple and efficient 'nonzero' implementation for 1D arrays
proc nonzero(
cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab,
type array_dtype,
param array_nd: int
): MsgTuple throws
where array_dtype == bigint
{
return MsgTuple.error("nonzero is not supported for bigint arrays");
}

proc nonzero1D(x: [?d] ?t): [] int throws {
const nTasksPerLoc = here.maxTaskPar;
var nnzPerTask: [0..<numLocales] [0..<nTasksPerLoc] int;

coforall loc in Locales with (ref nnzPerTask) do on loc {
const locDom = x.localSubdomain();
coforall tid in 0..<nTasksPerLoc with (ref nnzPerTask) {
var nnzTaskCount = 0;
for idx in subDomChunk(locDom, tid, nTasksPerLoc) do
if x[idx] != 0 then nnzTaskCount += 1;
nnzPerTask[loc.id][tid] = nnzTaskCount;
}
}

// calculate the total number of non-zero elements and the starting index of each locale
const nnzPerLocale = [locTasks in nnzPerTask] + reduce locTasks,
numNonZero = + reduce nnzPerLocale,
locStarts = (+ scan nnzPerLocale) - nnzPerLocale;

// create an index array for each dimension of the input array
var eOuts = for rn in rnames do st.addEntry(rn, numNonZero, int);

// populate the arrays with the indices of the non-zero elements
// TODO: refactor to use aggregation or bulk assignment to avoid fine-grained communication
coforall loc in Locales with (const ref nnzPerTask, const ref locStarts) do on loc {
const taskStarts = ((+ scan nnzPerTask[loc.id]) - nnzPerTask[loc.id]) + locStarts[loc.id],
locDom = eIn.a.localSubdomain();
coforall tid in 0..<nTasks {
var i = taskStarts[tid];
for idx in subDomChunk(locDom, tid, nTasks) {
if eIn.a[idx] != 0:t {
for d in 0..<nd do
eOuts[d].a[i] = if nd == 1 then idx else idx[d];
i += 1;
}
const nnzPerLoc = [locCounts in nnzPerTask] + reduce locCounts,
nnzTotalCount = + reduce nnzPerLoc,
locStarts = (+ scan nnzPerLoc) - nnzPerLoc;

var nnzIndices = makeDistArray(nnzTotalCount, int);

coforall loc in Locales with (ref nnzIndices) do on loc {
const taskStarts = ((+ scan nnzPerTask[loc.id]) - nnzPerTask[loc.id]) + locStarts[loc.id],
locDom = x.localSubdomain();

coforall tid in 0..<nTasksPerLoc with (ref nnzIndices) {
var i = taskStarts[tid];
for idx in subDomChunk(locDom, tid, nTasksPerLoc) {
if x[idx] != 0 then {
nnzIndices[i] = idx;
i += 1;
}
}
}

const repMsg = try! '+'.join([rn in rnames] "created " + st.attrib(rn));
rmLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return findNonZero(int);
when DType.UInt64 do return findNonZero(uint);
when DType.Float64 do return findNonZero(real);
when DType.Bool do return findNonZero(bool);
otherwise {
var errorMsg = notImplementedError(pn,dtype2str(gEnt.dtype));
rmLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg,MsgType.ERROR);
}
}
return nnzIndices;
}

private module SliceReductionOps {
Expand Down
26 changes: 26 additions & 0 deletions src/registry/Commands.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -1773,6 +1773,32 @@ proc ark_shuffle_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrow
return RandMsg.shuffle(cmd, msgArgs, st, array_dtype=bigint, array_nd=1);
registerFunction('shuffle<bigint,1>', ark_shuffle_bigint_1, 'RandMsg', 696);

import ReductionMsg;

proc ark_nonzero_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=int, array_nd=1);
registerFunction('nonzero<int64,1>', ark_nonzero_int_1, 'ReductionMsg', 325);

proc ark_nonzero_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=uint, array_nd=1);
registerFunction('nonzero<uint64,1>', ark_nonzero_uint_1, 'ReductionMsg', 325);

proc ark_nonzero_uint8_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=uint(8), array_nd=1);
registerFunction('nonzero<uint8,1>', ark_nonzero_uint8_1, 'ReductionMsg', 325);

proc ark_nonzero_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=real, array_nd=1);
registerFunction('nonzero<float64,1>', ark_nonzero_real_1, 'ReductionMsg', 325);

proc ark_nonzero_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=bool, array_nd=1);
registerFunction('nonzero<bool,1>', ark_nonzero_bool_1, 'ReductionMsg', 325);

proc ark_nonzero_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=bigint, array_nd=1);
registerFunction('nonzero<bigint,1>', ark_nonzero_bigint_1, 'ReductionMsg', 325);

import StatsMsg;

proc ark_reg_mean_generic(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_0, param array_nd_0: int): MsgTuple throws {
Expand Down
27 changes: 22 additions & 5 deletions tests/array_api/searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,36 @@ def test_argmin(self):

@pytest.mark.skip_if_max_rank_less_than(3)
def test_nonzero(self):
a = xp.zeros((4, 5, 6), dtype=ak.int64)
a = xp.zeros((40, 15, 16), dtype=ak.int64)
a[0, 1, 0] = 1
a[1, 2, 3] = 1
a[2, 2, 2] = 1
a[3, 2, 1] = 1
a[10, 10, 10] = 1
a[30, 12, 11] = 1
a[2, 13, 14] = 1
a[3, 14, 15] = 1

nz = xp.nonzero(a)

print(nz)
a_np = a.to_ndarray()
nz_np = np.nonzero(a_np)

assert nz[0].tolist() == nz_np[0].tolist()
assert nz[1].tolist() == nz_np[1].tolist()
assert nz[2].tolist() == nz_np[2].tolist()

def test_nonzero_1d(self):
b = xp.zeros(500, dtype=ak.int64)
b[0] = 1
b[12] = 1
b[100] = 1
b[205] = 1
b[490] = 1

nz = xp.nonzero(b)

assert sorted(nz[0].tolist()) == sorted([0, 1, 2, 3])
assert sorted(nz[1].tolist()) == sorted([1, 2, 2, 2])
assert sorted(nz[2].tolist()) == sorted([0, 3, 2, 1])
assert nz[0].tolist() == [0, 12, 100, 205, 490]

@pytest.mark.skip_if_max_rank_less_than(3)
def test_where(self):
Expand Down

0 comments on commit eb66ea7

Please sign in to comment.