diff --git a/mpi4py_fft/pencil.py b/mpi4py_fft/pencil.py index 2c4737e..3cf5c16 100644 --- a/mpi4py_fft/pencil.py +++ b/mpi4py_fft/pencil.py @@ -269,6 +269,8 @@ class NCCLTransfer(Transfer): """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.forward_graph = None + self.backward_graph = None from cupy.cuda import nccl self.comm_nccl = nccl.NcclCommunicator(self.comm.size, self.comm.bcast(nccl.get_unique_id(), root=0), self.comm.rank) @@ -286,51 +288,79 @@ def __init__(self, *args, **kwargs): raise NotImplementedError(f'Don\'t know what NCCL dtype to use to send data of dtype {self.dtype}!') self.count_modifier = 2 if 'complex' in str(self.dtype) else 1 - def Alltoallw(self, arrayA, subtypesA, arrayB, subtypesB): + def backward(self, arrayB, arrayA): + """Global redistribution from arrayB to arrayA + + Parameters + ---------- + arrayB : array + Array of shape subshapeB, containing data to be redistributed + arrayA : array + Array of shape subshapeA, for receiving data + + """ + self.backward_graph = self.Alltoallw(arrayB, self._subtypesB, arrayA, self._subtypesA, graph=self.backward_graph) + + def forward(self, arrayA, arrayB): + """Global redistribution from arrayA to arrayB + + Parameters + ---------- + arrayA : array + Array of shape subshapeA, containing data to be redistributed + arrayB : array + Array of shape subshapeB, for receiving data + """ + self.forward_graph = self.Alltoallw(arrayA, self._subtypesA, arrayB, self._subtypesB, graph=self.forward_graph) + + def Alltoallw(self, arrayA, subtypesA, arrayB, subtypesB, graph=None): """ Redistribute arrayA to arrayB. """ import cupy as cp rank, size, comm = self.comm.rank, self.comm.size, self.comm_nccl - stream = cp.cuda.get_current_stream() - # initialize dictionaries required to overlap sends - recvbufs = {} - sendbufs = {} - sliceBs = {} + # record to a graph if we haven't already done so + if graph is None: + stream = cp.cuda.Stream(non_blocking=True) + with stream: + stream.begin_capture() - # perform all sends and receives in a single kernel to allow overlap - cp.cuda.nccl.groupStart() - for i in range(1, size + 1): + # initialize dictionaries required to overlap sends + recvbufs = {} + sendbufs = {} + sliceBs = {} - send_to = (rank + i) % size - recv_from = (rank -i + size) % size + # perform all sends and receives in a single kernel to allow overlap + cp.cuda.nccl.groupStart() + for i in range(1, size + 1): - sliceA = get_slice(subtypesA[send_to]) - sliceBs[i] = get_slice(subtypesB[recv_from]) + send_to = (rank + i) % size + recv_from = (rank -i + size) % size - recvbufs[i] = cp.ascontiguousarray(arrayB[sliceBs[i]]) - sendbufs[i] = cp.ascontiguousarray(arrayA[sliceA]) + sliceA = get_slice(subtypesA[send_to]) + sliceBs[i] = get_slice(subtypesB[recv_from]) - comm.recv(recvbufs[i].data.ptr, recvbufs[i].size * self.count_modifier, self.NCCL_dtype, recv_from, stream.ptr) - comm.send(sendbufs[i].data.ptr, sendbufs[i].size * self.count_modifier, self.NCCL_dtype, send_to, stream.ptr) - cp.cuda.nccl.groupEnd() + recvbufs[i] = cp.ascontiguousarray(arrayB[sliceBs[i]]) + sendbufs[i] = cp.ascontiguousarray(arrayA[sliceA]) - # unpack the buffers concurrently in different streams - streams = {key: cp.cuda.Stream() for key in recvbufs.keys()} - events = {} - for i in recvbufs.keys(): - with streams[i]: - cp.copyto(arrayB[sliceBs[i]], recvbufs[i]) - events[i] = streams[i].record() + comm.recv(recvbufs[i].data.ptr, recvbufs[i].size * self.count_modifier, self.NCCL_dtype, recv_from, stream.ptr) + comm.send(sendbufs[i].data.ptr, sendbufs[i].size * self.count_modifier, self.NCCL_dtype, send_to, stream.ptr) + cp.cuda.nccl.groupEnd() - for i in events.keys(): - stream.wait_event(events[i]) + for i in recvbufs.keys(): + cp.copyto(arrayB[sliceBs[i]], recvbufs[i]) + graph = stream.end_capture() + + graph.launch(stream=cp.cuda.get_current_stream()) + return graph def destroy(self): - super().destroy() + del self.forward_graph + del self.backward_graph self.comm_nccl.destroy() + super().destroy() class Pencil(object):