Skip to content

Commit

Permalink
Closes Bears-R-Us#3720: Update SetMsg to use the new message framework
Browse files Browse the repository at this point in the history
This PR (Closes Bears-R-Us#3720) updates `SetMsg` to use the new message framework and modifies the test to run with less than 3 dims
  • Loading branch information
stress-tess committed Sep 16, 2024
1 parent 8fc7bf8 commit 004d15d
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 182 deletions.
72 changes: 36 additions & 36 deletions arkouda/array_api/set_functions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from .array_object import Array

from typing import NamedTuple, cast

from arkouda.client import generic_msg
from arkouda.pdarrayclass import create_pdarray
from arkouda.pdarrayclass import create_pdarray, create_pdarrays

from .array_object import Array


class UniqueAllResult(NamedTuple):
Expand Down Expand Up @@ -33,21 +33,21 @@ def unique_all(x: Array, /) -> UniqueAllResult:
- the inverse indices that reconstruct `x` from the unique values
- the counts of each unique value
"""
resp = cast(
str,
generic_msg(
cmd=f"uniqueAll{x.ndim}D",
args={"name": x._array},
),
arrays = create_pdarrays(
cast(
str,
generic_msg(
cmd=f"uniqueAll<{x.dtype},{x.ndim}>",
args={"name": x._array},
),
)
)

arrays = [Array._new(create_pdarray(r)) for r in resp.split("+")]

return UniqueAllResult(
values=arrays[0],
indices=arrays[1],
inverse_indices=arrays[2],
counts=arrays[3],
values=Array._new(arrays[0]),
indices=Array._new(arrays[1]),
inverse_indices=Array._new(arrays[2]),
counts=Array._new(arrays[3]),
)


Expand All @@ -57,19 +57,19 @@ def unique_counts(x: Array, /) -> UniqueCountsResult:
- the unique values in `x`
- the counts of each unique value
"""
resp = cast(
str,
generic_msg(
cmd=f"uniqueCounts{x.ndim}D",
args={"name": x._array},
),
arrays = create_pdarrays(
cast(
str,
generic_msg(
cmd=f"uniqueCounts<{x.dtype},{x.ndim}>",
args={"name": x._array},
),
)
)

arrays = [Array._new(create_pdarray(r)) for r in resp.split("+")]

return UniqueCountsResult(
values=arrays[0],
counts=arrays[1],
values=Array._new(arrays[0]),
counts=Array._new(arrays[1]),
)


Expand All @@ -79,19 +79,19 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult:
- the unique values in `x`
- the inverse indices that reconstruct `x` from the unique values
"""
resp = cast(
str,
generic_msg(
cmd=f"uniqueInverse{x.ndim}D",
args={"name": x._array},
),
arrays = create_pdarrays(
cast(
str,
generic_msg(
cmd=f"uniqueInverse<{x.dtype},{x.ndim}>",
args={"name": x._array},
),
)
)

arrays = [Array._new(create_pdarray(r)) for r in resp.split("+")]

return UniqueInverseResult(
values=arrays[0],
inverse_indices=arrays[1],
values=Array._new(arrays[0]),
inverse_indices=Array._new(arrays[1]),
)


Expand All @@ -104,7 +104,7 @@ def unique_values(x: Array, /) -> Array:
cast(
str,
generic_msg(
cmd=f"uniqueValues{x.ndim}D",
cmd=f"uniqueValues<{x.dtype},{x.ndim}>",
args={"name": x._array},
),
)
Expand Down
211 changes: 72 additions & 139 deletions src/SetMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -11,155 +11,88 @@ module SetMsg {
use RadixSortLSD;
use Unique;
use Reflection;
use BigInteger;

private config const logLevel = ServerConfig.logLevel;
private config const logChannel = ServerConfig.logChannel;
const sLogger = new Logger(logLevel, logChannel);
@arkouda.instantiateAndRegister
proc uniqueValues(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype != BigInteger.bigint) && (array_dtype != uint(8))
{
const name = msgArgs["name"],
eIn = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
eFlat = if array_nd == 1 then eIn.a else flatten(eIn.a);

@arkouda.registerND
proc uniqueValuesMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
const name = msgArgs.getValueOf("name"),
rname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc getUniqueVals(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
eFlat = if nd == 1 then eIn.a else flatten(eIn.a);

const eSorted = radixSortLSD_keys(eFlat);
const eUnique = uniqueFromSorted(eSorted, needCounts=false);

st.addEntry(rname, createSymEntry(eUnique));

const repMsg = "created " + st.attrib(rname);
sLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return getUniqueVals(int);
// when DType.UInt8 do return getUniqueVals(uint(8));
when DType.UInt64 do return getUniqueVals(uint);
when DType.Float64 do return getUniqueVals(real);
when DType.Bool do return getUniqueVals(bool);
otherwise {
var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype);
sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
const eSorted = radixSortLSD_keys(eFlat);
const eUnique = uniqueFromSorted(eSorted, needCounts=false);

return st.insert(new shared SymEntry(eUnique));
}

proc uniqueValues(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8))
{
return MsgTuple.error("unique_values does not support the %s dtype".format(array_dtype:string));
}

@arkouda.registerND
proc uniqueCountsMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
@arkouda.instantiateAndRegister
proc uniqueCounts(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws {
const name = msgArgs.getValueOf("name"),
uname = st.nextName(),
cname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc getUniqueVals(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
eFlat = if nd == 1 then eIn.a else flatten(eIn.a);

const eSorted = radixSortLSD_keys(eFlat);
const (eUnique, eCounts) = uniqueFromSorted(eSorted);

st.addEntry(uname, createSymEntry(eUnique));
st.addEntry(cname, createSymEntry(eCounts));

const repMsg = "created " + st.attrib(uname) + "+created " + st.attrib(cname);
sLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return getUniqueVals(int);
// when DType.UInt8 do return getUniqueVals(uint(8));
when DType.UInt64 do return getUniqueVals(uint);
when DType.Float64 do return getUniqueVals(real);
when DType.Bool do return getUniqueVals(bool);
otherwise {
var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype);
sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
eIn = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
eFlat = if array_nd == 1 then eIn.a else flatten(eIn.a);

const eSorted = radixSortLSD_keys(eFlat);
const (eUnique, eCounts) = uniqueFromSorted(eSorted);

return MsgTuple.fromResponses([
st.insert(new shared SymEntry(eUnique)),
st.insert(new shared SymEntry(eCounts)),
]);
}

proc uniqueCounts(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8))
{
return MsgTuple.error("unique_counts does not support the %s dtype".format(array_dtype:string));
}

@arkouda.registerND
proc uniqueInverseMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
@arkouda.instantiateAndRegister
proc uniqueInverse(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws {
const name = msgArgs.getValueOf("name"),
uname = st.nextName(),
iname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc getUniqueVals(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
eFlat = if nd == 1 then eIn.a else flatten(eIn.a);

const (eUnique, _, inv) = uniqueSortWithInverse(eFlat);
st.addEntry(uname, createSymEntry(eUnique));
st.addEntry(iname, createSymEntry(if nd == 1 then inv else unflatten(inv, eIn.a.shape)));

const repMsg = "created " + st.attrib(uname) + "+created " + st.attrib(iname);
sLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return getUniqueVals(int);
// when DType.UInt8 do return getUniqueVals(uint(8));
when DType.UInt64 do return getUniqueVals(uint);
when DType.Float64 do return getUniqueVals(real);
when DType.Bool do return getUniqueVals(bool);
otherwise {
var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype);
sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
eIn = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
eFlat = if array_nd == 1 then eIn.a else flatten(eIn.a);

const (eUnique, _, inv) = uniqueSortWithInverse(eFlat);

return MsgTuple.fromResponses([
st.insert(new shared SymEntry(eUnique)),
st.insert(new shared SymEntry(if array_nd == 1 then inv else unflatten(inv, eIn.a.shape))),
]);
}

proc uniqueInverse(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8))
{
return MsgTuple.error("unique_inverse does not support the %s dtype".format(array_dtype:string));
}

@arkouda.registerND
proc uniqueAllMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
@arkouda.instantiateAndRegister
proc uniqueAll(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws {
const name = msgArgs.getValueOf("name"),
rnames = for 0..<4 do st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc getUniqueVals(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
eFlat = if nd == 1 then eIn.a else flatten(eIn.a);

const (eUnique, eCounts, inv, eIndices) = uniqueSortWithInverse(eFlat, needIndices=true);
st.addEntry(rnames[0], createSymEntry(eUnique));
st.addEntry(rnames[1], createSymEntry(eIndices));
st.addEntry(rnames[2], createSymEntry(if nd == 1 then inv else unflatten(inv, eIn.a.shape)));
st.addEntry(rnames[3], createSymEntry(eCounts));

const repMsg = try! "+".join([rn in rnames] "created " + st.attrib(rn));
sLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return getUniqueVals(int);
// when DType.UInt8 do return getUniqueVals(uint(8));
when DType.UInt64 do return getUniqueVals(uint);
when DType.Float64 do return getUniqueVals(real);
when DType.Bool do return getUniqueVals(bool);
otherwise {
var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype);
sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
eIn = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
eFlat = if array_nd == 1 then eIn.a else flatten(eIn.a);

const (eUnique, eCounts, inv, eIndices) = uniqueSortWithInverse(eFlat, needIndices=true);

return MsgTuple.fromResponses([
st.insert(new shared SymEntry(eUnique)),
st.insert(new shared SymEntry(eIndices)),
st.insert(new shared SymEntry(if array_nd == 1 then inv else unflatten(inv, eIn.a.shape))),
st.insert(new shared SymEntry(eCounts)),
]);
}

proc uniqueAll(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8))
{
return MsgTuple.error("unique_all does not support the %s dtype".format(array_dtype:string));
}
}
Loading

0 comments on commit 004d15d

Please sign in to comment.