Skip to content

Commit

Permalink
Use CUDA graphs in NCCL Alltoallw
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas committed Feb 29, 2024
1 parent b9ac0f7 commit 77b96f3
Showing 1 changed file with 58 additions and 28 deletions.
86 changes: 58 additions & 28 deletions mpi4py_fft/pencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ class NCCLTransfer(Transfer):
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.forward_graph = None
self.backward_graph = None

from cupy.cuda import nccl
self.comm_nccl = nccl.NcclCommunicator(self.comm.size, self.comm.bcast(nccl.get_unique_id(), root=0), self.comm.rank)
Expand All @@ -286,51 +288,79 @@ def __init__(self, *args, **kwargs):
raise NotImplementedError(f'Don\'t know what NCCL dtype to use to send data of dtype {self.dtype}!')
self.count_modifier = 2 if 'complex' in str(self.dtype) else 1

def Alltoallw(self, arrayA, subtypesA, arrayB, subtypesB):
def backward(self, arrayB, arrayA):
"""Global redistribution from arrayB to arrayA
Parameters
----------
arrayB : array
Array of shape subshapeB, containing data to be redistributed
arrayA : array
Array of shape subshapeA, for receiving data
"""
self.backward_graph = self.Alltoallw(arrayB, self._subtypesB, arrayA, self._subtypesA, graph=self.backward_graph)

def forward(self, arrayA, arrayB):
"""Global redistribution from arrayA to arrayB
Parameters
----------
arrayA : array
Array of shape subshapeA, containing data to be redistributed
arrayB : array
Array of shape subshapeB, for receiving data
"""
self.forward_graph = self.Alltoallw(arrayA, self._subtypesA, arrayB, self._subtypesB, graph=self.forward_graph)

def Alltoallw(self, arrayA, subtypesA, arrayB, subtypesB, graph=None):
"""
Redistribute arrayA to arrayB.
"""
import cupy as cp
rank, size, comm = self.comm.rank, self.comm.size, self.comm_nccl
stream = cp.cuda.get_current_stream()

# initialize dictionaries required to overlap sends
recvbufs = {}
sendbufs = {}
sliceBs = {}
# record to a graph if we haven't already done so
if graph is None:
stream = cp.cuda.Stream(non_blocking=True)
with stream:
stream.begin_capture()

# perform all sends and receives in a single kernel to allow overlap
cp.cuda.nccl.groupStart()
for i in range(1, size + 1):
# initialize dictionaries required to overlap sends
recvbufs = {}
sendbufs = {}
sliceBs = {}

send_to = (rank + i) % size
recv_from = (rank -i + size) % size
# perform all sends and receives in a single kernel to allow overlap
cp.cuda.nccl.groupStart()
for i in range(1, size + 1):

sliceA = get_slice(subtypesA[send_to])
sliceBs[i] = get_slice(subtypesB[recv_from])
send_to = (rank + i) % size
recv_from = (rank -i + size) % size

recvbufs[i] = cp.ascontiguousarray(arrayB[sliceBs[i]])
sendbufs[i] = cp.ascontiguousarray(arrayA[sliceA])
sliceA = get_slice(subtypesA[send_to])
sliceBs[i] = get_slice(subtypesB[recv_from])

comm.recv(recvbufs[i].data.ptr, recvbufs[i].size * self.count_modifier, self.NCCL_dtype, recv_from, stream.ptr)
comm.send(sendbufs[i].data.ptr, sendbufs[i].size * self.count_modifier, self.NCCL_dtype, send_to, stream.ptr)
cp.cuda.nccl.groupEnd()
recvbufs[i] = cp.ascontiguousarray(arrayB[sliceBs[i]])
sendbufs[i] = cp.ascontiguousarray(arrayA[sliceA])

# unpack the buffers concurrently in different streams
streams = {key: cp.cuda.Stream() for key in recvbufs.keys()}
events = {}
for i in recvbufs.keys():
with streams[i]:
cp.copyto(arrayB[sliceBs[i]], recvbufs[i])
events[i] = streams[i].record()
comm.recv(recvbufs[i].data.ptr, recvbufs[i].size * self.count_modifier, self.NCCL_dtype, recv_from, stream.ptr)
comm.send(sendbufs[i].data.ptr, sendbufs[i].size * self.count_modifier, self.NCCL_dtype, send_to, stream.ptr)
cp.cuda.nccl.groupEnd()

for i in events.keys():
stream.wait_event(events[i])
for i in recvbufs.keys():
cp.copyto(arrayB[sliceBs[i]], recvbufs[i])

graph = stream.end_capture()

graph.launch(stream=cp.cuda.get_current_stream())
return graph

def destroy(self):
super().destroy()
del self.forward_graph
del self.backward_graph
self.comm_nccl.destroy()
super().destroy()


class Pencil(object):
Expand Down

0 comments on commit 77b96f3

Please sign in to comment.