Skip to content

Commit

Permalink
Changed Alltoallw communication pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas committed Dec 20, 2023
1 parent d0a493e commit 6b47e36
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 21 deletions.
21 changes: 17 additions & 4 deletions mpi4py_fft/libfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,24 @@ def _Xfftn_plan_cupy(shape, axes, dtype, transforms, options):
plan_fwd, plan_bck = transforms[tuple(axes)]
else:
if cp.issubdtype(dtype, cp.floating):
plan_fwd = cp.fft.rfftn
plan_bck = cp.fft.irfftn
_plan_fwd = cp.fft.rfftn
_plan_bck = cp.fft.irfftn
else:
plan_fwd = cp.fft.fftn
plan_bck = cp.fft.ifftn
_plan_fwd = cp.fft.fftn
_plan_bck = cp.fft.ifftn

stream = cp.cuda.stream.Stream()
def execute_in_stream(function, *args, **kwargs):
with stream:
result = function(*args, **kwargs)
stream.synchronize()
return result

def plan_fwd(*args, **kwargs):
return execute_in_stream(_plan_fwd, *args, **kwargs)

def plan_bck(*args, **kwargs):
return execute_in_stream(_plan_bck, *args, **kwargs)

s = tuple(np.take(shape, axes))
U = cp.array(fftw.aligned(shape, dtype=dtype)) # TODO: avoid going via CPU
Expand Down
46 changes: 29 additions & 17 deletions mpi4py_fft/pencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,32 +242,44 @@ def Alltoallw(self, arrayA, subtypesA, arrayB, subtypesB):
assert self.comm.rank == self.comm_nccl.rank_id(), f'The structure of the communicator has changed unexpectedly'

rank, size, comm = self.comm.rank, self.comm.size, self.comm_nccl
stream = cp.cuda.Stream.null.ptr
iscomplex = cp.iscomplexobj(arrayA)
NCCL_dtype, real_dtype = self.get_nccl_and_real_dtypes(arrayA)

for recv_rank in range(size):
for send_rank in range(size):
send_stream = cp.cuda.Stream(non_blocking=False)
recv_stream = cp.cuda.Stream(non_blocking=False)

if rank == recv_rank:
local_slice, shape = self.get_slice_and_shape(subtypesB[send_rank])
buff = self.get_buffer(shape, iscomplex, real_dtype)
def send(array, subtype, send_to, iscomplex, stream):
local_slice, shape = self.get_slice_and_shape(subtype)
buff = self.get_buffer(shape, iscomplex, real_dtype)
self.fill_buffer(buff, array, local_slice, iscomplex)
comm.send(buff.data.ptr, buff.size, NCCL_dtype, send_to, stream.ptr)

if recv_rank == send_rank:
send_slice, _ = self.get_slice_and_shape(subtypesA[recv_rank])
self.fill_buffer(buff, arrayA, send_slice, iscomplex)
else:
comm.recv(buff.data.ptr, buff.size, NCCL_dtype, send_rank, stream)
for i in range(size):
send_to = (rank + i) % size
recv_from = (rank -i + size) % size

self.unpack_buffer(buff, arrayB, local_slice, iscomplex)
if send_to > rank:
with send_stream:
send(arrayA, subtypesA[send_to], send_to, iscomplex, send_stream)

elif rank == send_rank:
local_slice, shape = self.get_slice_and_shape(subtypesA[recv_rank])
buff = self.get_buffer(shape, iscomplex, real_dtype)
self.fill_buffer(buff, arrayA, local_slice, iscomplex)
with recv_stream:
local_slice, shape = self.get_slice_and_shape(subtypesB[recv_from])
buff = self.get_buffer(shape, iscomplex, real_dtype)

comm.send(buff.data.ptr, buff.size, NCCL_dtype, recv_rank, stream)
if recv_from == rank:
send_slice, _ = self.get_slice_and_shape(subtypesA[send_to])
self.fill_buffer(buff, arrayA, send_slice, iscomplex)
else:
comm.recv(buff.data.ptr, buff.size, NCCL_dtype, recv_from, recv_stream.ptr)

self.unpack_buffer(buff, arrayB, local_slice, iscomplex)

if send_to < rank:
with send_stream:
send(arrayA, subtypesA[send_to], send_to, iscomplex, send_stream)

send_stream.synchronize()
recv_stream.synchronize()

@staticmethod
def get_slice_and_shape(subtype):
Expand Down

0 comments on commit 6b47e36

Please sign in to comment.