Skip to content

Commit

Permalink
Array transfer perf fix (Bears-R-Us#3671)
Browse files Browse the repository at this point in the history
* use in-intent when assigning binary payload to MsgTuple and MsgArgs

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* address some of the array transfer performance hit from recent refactors. Convert array and tondarray commands to use new annotations

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* reenable metrics tracking without copying bytes payload

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* revert changes to run_benchmarks

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix flake8

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix broken arraySegString command

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix test_client_get_server_commands test

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix other test_client_get_server_commands test

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* fix other test_client_get_server_commands test

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* remove accidentally committed client_test file

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

---------

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>
  • Loading branch information
jeremiah-corrado authored Aug 19, 2024
1 parent 37df634 commit d0db02f
Show file tree
Hide file tree
Showing 10 changed files with 240 additions and 187 deletions.
10 changes: 7 additions & 3 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ def _to_pdarray(value: np.ndarray, dt=None) -> pdarray:
value_flat = value.flatten()
return create_pdarray(
generic_msg(
cmd=f"array{value.ndim}D",
args={"dtype": _dtype, "shape": np.shape(value), "seg_string": False},
cmd=f"array<{_dtype},{value.ndim}>",
args={"shape": np.shape(value)},
payload=_array_memview(value_flat),
send_binary=True,
)
Expand Down Expand Up @@ -1827,7 +1827,11 @@ def to_ndarray(self) -> np.ndarray:
# The reply from the server will be binary data
data = cast(
memoryview,
generic_msg(cmd=f"tondarray{self.ndim}D", args={"array": self}, recv_binary=True),
generic_msg(
cmd=f"tondarray<{self.dtype},{self.ndim}>",
args={"array": self},
recv_binary=True
),
)
# Make sure the received data has the expected length
if len(data) != self.size * self.dtype.itemsize:
Expand Down
8 changes: 4 additions & 4 deletions arkouda/pdarraycreation.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ def array(
)
encoded_np = np.array(encoded, dtype=np.uint8)
rep_msg = generic_msg(
cmd="array1D",
args={"dtype": encoded_np.dtype.name, "shape": encoded_np.size, "seg_string": True},
cmd=f"arraySegString<{encoded_np.dtype.name}>",
args={"size": encoded_np.size},
payload=_array_memview(encoded_np),
send_binary=True,
)
Expand Down Expand Up @@ -297,8 +297,8 @@ def array(
# native endian bytes
aview = _array_memview(a)
rep_msg = generic_msg(
cmd="array1D",
args={"dtype": a.dtype.name, "shape": size, "seg_string": False},
cmd=f"array<{a.dtype.name},1>",
args={"shape": size},
payload=aview,
send_binary=True,
)
Expand Down
10 changes: 2 additions & 8 deletions src/CommandMap.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,11 @@ module CommandMap {
}

proc executeCommand(cmd: string, msgArgs, st): MsgTuple throws {
var response: MsgTuple;
if commandMap.contains(cmd) {
if moduleMap.contains(cmd) then usedModules.add(moduleMap[cmd][0]);
try {
response = commandMap[cmd](cmd, msgArgs, st);
} catch e {
response = MsgTuple.error("Error executing command: %s".format(e.message()));
}
return commandMap[cmd](cmd, msgArgs, st);
} else {
response = MsgTuple.error("Unrecognized command: %s".format(cmd));
return MsgTuple.error("Unrecognized command: %s".format(cmd));
}
return response;
}
}
172 changes: 76 additions & 96 deletions src/GenSymIO.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ module GenSymIO {
use CTypes;
use CommAggregation;
use IOUtils;
use BigInteger;

private config const logLevel = ServerConfig.logLevel;
private config const logChannel = ServerConfig.logChannel;
Expand All @@ -30,74 +31,72 @@ module GenSymIO {
* Creates a pdarray server-side and returns the SymTab name used to
* retrieve the pdarray from the SymTab.
*/
@arkouda.registerND
proc arrayMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
const dtype = str2dtype(msgArgs.getValueOf("dtype")),
shape = msgArgs.get("shape").getTuple(nd),
asSegStr = msgArgs.get("seg_string").getBoolValue(),
rname = st.nextName();
@arkouda.instantiateAndRegister
proc array(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype != bigint
{
const shape = msgArgs["shape"].toScalarTuple(int, array_nd);

gsLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"dtype: %? shape: %?".format(array_dtype:string,shape));

return st.insert(new shared SymEntry(makeArrayFromBytes(msgArgs.payload, shape, array_dtype)));
}

proc array(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype == bigint
{
return MsgTuple.error("Array creation from binary payload is not supported for bigint arrays");
}

proc makeArrayFromBytes(ref payload: bytes, shape: ?N*int, type t): [] t throws {
var size = 1;
for s in shape do size *= s;
overMemLimit(2*size*dtypeSize(dtype));
overMemLimit(2*size*typeSize(t));

var ret = makeDistArray((...shape), t),
localA = makeArrayFromPtr(payload.c_str():c_ptr(void):c_ptr(t), num_elts=size:uint);
if N == 1 {
ret = localA;
} else {
forall (i, a) in zip(localA.domain, localA) with (var agg = newDstAggregator(t)) do
agg.copy(ret[ret.domain.orderToIndex(i)], a);
}

return ret;
}

@arkouda.instantiateAndRegister()
proc arraySegString(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype): MsgTuple throws {
const size = msgArgs["size"].toScalar(int),
rname = st.nextName();

gsLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"dtype: %? shape: %? size: %i".format(dtype,shape,size));
"dtype: %? size: %?".format(array_dtype:string,size));

proc bytesToSymEntry(type t) throws {
var entry = createSymEntry((...shape), t);
var localA = makeArrayFromPtr(msgArgs.payload.c_str():c_ptr(void):c_ptr(t), num_elts=size:uint);
if nd == 1 {
entry.a = localA;
const a = makeArrayFromBytes(msgArgs.payload, (size,), array_dtype);
st.addEntry(rname, createSymEntry(a));

try {
st.checkTable(rname, "arrayMsg");
var g = st.lookup(rname);
if g.isAssignableTo(SymbolEntryType.TypedArraySymEntry){
var values = toSymEntry( (g:GenSymEntry), uint(8) );
var offsets = segmentedCalcOffsets(values.a, values.a.domain);
var oname = st.nextName();
var offsetsEntry = createSymEntry(offsets);
st.addEntry(oname, offsetsEntry);
const msg = "created " + st.attrib(oname) + "+created " + st.attrib(rname);
gsLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),msg);
return new MsgTuple(msg, MsgType.NORMAL);
} else {
forall (i, a) in zip(localA.domain, localA) with (var agg = newDstAggregator(t)) do
agg.copy(entry.a[entry.a.domain.orderToIndex(i)], a);
throw new Error("Unsupported Type %s".format(g.entryType));
}
st.addEntry(rname, entry);
}

if dtype == DType.Int64 {
bytesToSymEntry(int);
} else if dtype == DType.UInt64 {
bytesToSymEntry(uint);
} else if dtype == DType.Float64 {
bytesToSymEntry(real);
} else if dtype == DType.Bool {
bytesToSymEntry(bool);
} else if dtype == DType.UInt8 {
bytesToSymEntry(uint(8));
} else {
const msg = "Unhandled data type %s".format(msgArgs.getValueOf("dtype"));
} catch e: Error {
const msg = "Error creating offsets for SegString";
gsLogger.error(getModuleName(),getRoutineName(),getLineNumber(),msg);
return new MsgTuple(msg, MsgType.ERROR);
}

if asSegStr {
try {
st.checkTable(rname, "arrayMsg");
var g = st.lookup(rname);
if g.isAssignableTo(SymbolEntryType.TypedArraySymEntry){
var values = toSymEntry( (g:GenSymEntry), uint(8) );
var offsets = segmentedCalcOffsets(values.a, values.a.domain);
var oname = st.nextName();
var offsetsEntry = createSymEntry(offsets);
st.addEntry(oname, offsetsEntry);
const msg = "created " + st.attrib(oname) + "+created " + st.attrib(rname);
gsLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),msg);
return new MsgTuple(msg, MsgType.NORMAL);
} else {
throw new Error("Unsupported Type %s".format(g.entryType));
}
} catch e: Error {
const msg = "Error creating offsets for SegString";
gsLogger.error(getModuleName(),getRoutineName(),getLineNumber(),msg);
return new MsgTuple(msg, MsgType.ERROR);
}
}

const msg = "created " + st.attrib(rname);
gsLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),msg);
return new MsgTuple(msg, MsgType.NORMAL);
}

/**
Expand All @@ -119,49 +118,30 @@ module GenSymIO {
* Outputs the pdarray as a Numpy ndarray in the form of a
* Chapel Bytes object
*/
@arkouda.registerND
proc tondarrayMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
var arrayBytes: bytes;
var abstractEntry = st.lookup(msgArgs.getValueOf("array"));
if !abstractEntry.isAssignableTo(SymbolEntryType.TypedArraySymEntry) {
var errorMsg = "Error: Unhandled SymbolEntryType %s".format(abstractEntry.entryType);
gsLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return MsgTuple.error(errorMsg);
}
var entry:borrowed GenSymEntry = abstractEntry: borrowed GenSymEntry;

overMemLimit(2 * entry.getSizeEstimate());

proc distArrToBytes(A: [?D] ?eltType) {
var ptr = allocate(eltType, D.size);
var localA = makeArrayFromPtr(ptr, D.size:uint);
if nd == 1 {
localA = A;
} else {
forall (i, a) in zip(localA.domain, localA) with (var agg = newSrcAggregator(eltType)) do
agg.copy(localA[i], A[D.orderToIndex(i)]);
}
const size = D.size*c_sizeof(eltType):int;
return bytes.createAdoptingBuffer(ptr:c_ptr(uint(8)), size, size);
}
@arkouda.instantiateAndRegister
proc tondarray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype != bigint
{
const array = st[msgArgs["array"]]: borrowed SymEntry(array_dtype, array_nd);

overMemLimit(2 * array.size * typeSize(array_dtype));

if entry.dtype == DType.Int64 {
arrayBytes = distArrToBytes(toSymEntry(entry, int, nd).a);
} else if entry.dtype == DType.UInt64 {
arrayBytes = distArrToBytes(toSymEntry(entry, uint, nd).a);
} else if entry.dtype == DType.Float64 {
arrayBytes = distArrToBytes(toSymEntry(entry, real, nd).a);
} else if entry.dtype == DType.Bool {
arrayBytes = distArrToBytes(toSymEntry(entry, bool, nd).a);
} else if entry.dtype == DType.UInt8 {
arrayBytes = distArrToBytes(toSymEntry(entry, uint(8), nd).a);
var ptr = allocate(array_dtype, array.size);
var localA = makeArrayFromPtr(ptr, array.size:uint);
if array_nd == 1 {
localA = array.a;
} else {
const errorMsg = "Error: Unhandled dtype %s".format(entry.dtype);
gsLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return MsgTuple.error(errorMsg);
forall (i, a) in zip(localA.domain, localA) with (var agg = newSrcAggregator(array_dtype)) do
agg.copy(localA[i], array.a[array.a.domain.orderToIndex(i)]);
}
const size = array.size*c_sizeof(array_dtype):int;
return MsgTuple.payload(bytes.createAdoptingBuffer(ptr:c_ptr(uint(8)), size, size));
}

return MsgTuple.payload(arrayBytes);
proc tondarray(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_dtype == bigint
{
return MsgTuple.error("cannot create ndarray from bigint array");
}

/*
Expand Down
12 changes: 6 additions & 6 deletions src/Message.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ module Message {
this.payload = b"";
}

proc MsgTuple.init(msg: string, msgType: MsgType, msgFormat: MsgFormat, user = "", payload = b"") {
proc MsgTuple.init(msg: string, msgType: MsgType, msgFormat: MsgFormat, user = "", in payload = b"") {
this.msg = msg;
this.msgType = msgType;
this.msgFormat = msgFormat;
Expand Down Expand Up @@ -135,7 +135,7 @@ module Message {
);
}

proc type MsgTuple.payload(data: bytes): MsgTuple {
proc type MsgTuple.payload(in data: bytes): MsgTuple {
return new MsgTuple(
msg = "",
msgType = MsgType.NORMAL,
Expand Down Expand Up @@ -458,7 +458,7 @@ module Message {
this.payload = b"";
}

proc init(param_list: list(ParameterObj, parSafe=true)) {
proc init(param_list: list(ParameterObj, parSafe=true), in payload: bytes) {
// Intentionally initializes the param_list with `parSafe=false`.
// It would be initialized that way anyways due to the field
// declaration relying on the default value, this just makes it
Expand All @@ -467,7 +467,7 @@ module Message {
this.size = param_list.size;

this.param_list = param_list;
this.payload = b"";
this.payload = payload;
}

/*
Expand Down Expand Up @@ -557,13 +557,13 @@ module Message {
/*
Parse arguments formatted as json string into objects
*/
proc parseMessageArgs(json_str: string, size: int) throws {
proc parseMessageArgs(json_str: string, size: int, in payload = b"") throws {
var pArr = jsonToArray(json_str, string, size);
var param_list = new list(ParameterObj, parSafe=true);
forall j_str in pArr with (ref param_list) {
param_list.pushBack(parseParameter(j_str));
}
return new owned MessageArgs(param_list);
return new owned MessageArgs(param_list, payload);
}

/*
Expand Down
Loading

0 comments on commit d0db02f

Please sign in to comment.