Skip to content

Commit

Permalink
Closes Bears-R-Us#3722: Remove @arkouda.registerND from SortMsg.chpl
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Sep 11, 2024
1 parent 95cbb1b commit 478275f
Showing 1 changed file with 108 additions and 45 deletions.
153 changes: 108 additions & 45 deletions src/SortMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -137,49 +137,28 @@ module SortMsg
}
}

// https://data-apis.org/array-api/latest/API_specification/generated/array_api.searchsorted.html#array_api.searchsorted
@arkouda.registerND
proc searchSortedMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
const x1 = msgArgs.getValueOf("x1"),
x2 = msgArgs.getValueOf("x2"),
side = msgArgs.getValueOf("side"),
rname = st.nextName();
// proc transpose(array: [?d] ?t): [] t throws
// where d.rank >= 2 {
// var outShape = array.shape;
// outShape[outShape.size-2] <=> outShape[outShape.size-1];
// var ret = makeDistArray((...outShape), t);

// // // TODO: performance improvements. Should use tiling to keep data local
// forall idx in d {
// var bIdx = idx;
// bIdx[d.rank-1] <=> bIdx[d.rank-2]; // bIdx is now the reverse of idx
// ret[bIdx] = array[idx]; // making B the transpose of A
// }

// return ret;
// }

var gEntX1: borrowed GenSymEntry = getGenericTypedArrayEntry(x1, st),
gEntX2: borrowed GenSymEntry = getGenericTypedArrayEntry(x2, st);

if side != "left" && side != "right" {
throw getErrorWithContext(
msg="Unrecognized side: %s".format(side),
lineNumber=getLineNumber(),
pn,
moduleName=getModuleName(),
errorClass="NotImplementedError"
);
}

// TODO: add support for Float32
if gEntX1.dtype != DType.Float64 || gEntX2.dtype != DType.Float64 {
throw getErrorWithContext(
msg="searchsorted only supports Float64 arrays",
lineNumber=getLineNumber(),
pn,
moduleName=getModuleName(),
errorClass="NotImplementedError"
);
}

sortLogger.debug(
getModuleName(),pn,getLineNumber(),
"cmd: %s, x1: %s, x2: %s, side: %s, rname: %s, dtype: %?, nd: %i".format(
cmd, x1, x2, side, rname, gEntX1.dtype, nd
)
);
// https://data-apis.org/array-api/latest/API_specification/generated/array_api.searchsorted.html#array_api.searchsorted
@arkouda.registerCommand
proc searchSorted(x1: [?d] ?t, x2: [d] t, side: string): [] t
where ((side == "left") || (side == "right")) && (t == real){

const e1 = toSymEntry(gEntX1, real, 1),
e2 = toSymEntry(gEntX2, real, nd);
var ret = makeDistArray((...e2.a.domain.shape), int);
var ret = makeDistArray((...x2.shape), int);

proc doSearch(const ref a1: [] real, const ref a2: [?d] real, cmp) {
forall idx in ret.domain {
Expand All @@ -194,11 +173,18 @@ module SortMsg
otherwise do halt("unreachable");
}

st.addEntry(rname, createSymEntry(ret));
const repMsg = "created " + st.attrib(rname);
sortLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);
return ret;
}

proc searchSorted(x1: [?d] ?t, x2: [d] t, side: string): [] t
where (side != "right") && (side != "left"){
throw new Error("Unrecognized side: %s".format(side));
}

return new MsgTuple(repMsg, MsgType.NORMAL);
proc searchSorted(x1: [?d] ?t, x2: [d] t, side: string): [] t
where (t != real){
// TODO: add support for Float32
throw new Error("searchsorted only supports Float64 arrays");
}

record leftCmp: relativeComparator {
Expand All @@ -215,4 +201,81 @@ module SortMsg
}
}

// @arkouda.registerND
// proc searchSortedMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
// param pn = Reflection.getRoutineName();
// const x1 = msgArgs.getValueOf("x1"),
// x2 = msgArgs.getValueOf("x2"),
// side = msgArgs.getValueOf("side"),
// rname = st.nextName();

// var gEntX1: borrowed GenSymEntry = getGenericTypedArrayEntry(x1, st),
// gEntX2: borrowed GenSymEntry = getGenericTypedArrayEntry(x2, st);

// if side != "left" && side != "right" {
// throw getErrorWithContext(
// msg="Unrecognized side: %s".format(side),
// lineNumber=getLineNumber(),
// pn,
// moduleName=getModuleName(),
// errorClass="NotImplementedError"
// );
// }

// // TODO: add support for Float32
// if gEntX1.dtype != DType.Float64 || gEntX2.dtype != DType.Float64 {
// throw getErrorWithContext(
// msg="searchsorted only supports Float64 arrays",
// lineNumber=getLineNumber(),
// pn,
// moduleName=getModuleName(),
// errorClass="NotImplementedError"
// );
// }

// sortLogger.debug(
// getModuleName(),pn,getLineNumber(),
// "cmd: %s, x1: %s, x2: %s, side: %s, rname: %s, dtype: %?, nd: %i".format(
// cmd, x1, x2, side, rname, gEntX1.dtype, nd
// )
// );

// const e1 = toSymEntry(gEntX1, real, 1),
// e2 = toSymEntry(gEntX2, real, nd);
// var ret = makeDistArray((...e2.a.domain.shape), int);

// proc doSearch(const ref a1: [] real, const ref a2: [?d] real, cmp) {
// forall idx in ret.domain {
// const (_, i) = Search.binarySearch(a1, a2[idx], cmp);
// ret[idx] = i;
// }
// }

// select side {
// when "left" do doSearch(e1.a, e2.a, new leftCmp());
// when "right" do doSearch(e1.a, e2.a, new rightCmp());
// otherwise do halt("unreachable");
// }

// st.addEntry(rname, createSymEntry(ret));
// const repMsg = "created " + st.attrib(rname);
// sortLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);

// return new MsgTuple(repMsg, MsgType.NORMAL);
// }

// record leftCmp: relativeComparator {
// proc compare(a: real, b: real): int {
// if a < b then return -1;
// else return 1;
// }
// }

// record rightCmp: relativeComparator {
// proc compare(a: real, b: real): int {
// if a <= b then return -1;
// else return 1;
// }
// }

}// end module SortMsg

0 comments on commit 478275f

Please sign in to comment.