Skip to content

Commit

Permalink
Cosmetic changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas committed Feb 13, 2024
1 parent c3c20ff commit 405fb4d
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 25 deletions.
10 changes: 5 additions & 5 deletions mpi4py_fft/distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def local_slice(self):
(slice(0, 16, None), slice(7, 14, None), slice(6, 12, None))
"""
v = [slice(start, start+shape) for start, shape in zip(self._p0.substart,
self._p0.subshape)]
self._p0.subshape)]
return tuple([slice(0, s) for s in self.shape[:self.rank]] + v)

def redistribute(self, axis=None, out=None):
Expand Down Expand Up @@ -298,10 +298,10 @@ def redistribute(self, axis=None, out=None):
p1, transfer = self.get_pencil_and_transfer(axis)
if out is None:
out = type(self)(self.global_shape,
subcomm=p1.subcomm,
dtype=self.dtype,
alignment=axis,
rank=self.rank)
subcomm=p1.subcomm,
dtype=self.dtype,
alignment=axis,
rank=self.rank)

if self.rank == 0:
transfer.forward(self, out)
Expand Down
6 changes: 3 additions & 3 deletions mpi4py_fft/libfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _Xfftn_plan_mkl(shape, axes, dtype, transforms, options): #pragma: no cover

def _Xfftn_plan_cupyx_scipy(shape, axes, dtype, transforms, options):
import cupy as cp
import cupyx.scipy.fftpack as cufft
import cupyx.scipy.fft as cufft

transforms = {} if transforms is None else transforms
if tuple(axes) in transforms:
Expand All @@ -168,8 +168,8 @@ def _Xfftn_plan_cupyx_scipy(shape, axes, dtype, transforms, options):
V = plan_fwd(U, s=s, axes=axes)
V = cp.array(V)
M = np.prod(s)
return (_Yfftn_wrap(plan_fwd, U, V, 1, {'shape': s, 'axes': axes, 'overwrite_x': True}),
_Yfftn_wrap(plan_bck, V, U, M, {'shape': s, 'axes': axes, 'overwrite_x': True}))
return (_Yfftn_wrap(plan_fwd, U, V, 1, {'s': s, 'axes': axes, 'overwrite_x': True}),
_Yfftn_wrap(plan_bck, V, U, M, {'s': s, 'axes': axes, 'overwrite_x': True}))

def _Xfftn_plan_scipy(shape, axes, dtype, transforms, options):

Expand Down
2 changes: 0 additions & 2 deletions mpi4py_fft/pencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,6 @@ def transfer(self, pencil, dtype):
transfer_class = Transfer
elif self.backend == 'NCCL':
transfer_class = NCCLTransfer
elif self.backend == 'CUDAMemCpy':
transfer_class = CUDAMemCpy
elif self.backend == 'customMPI':
transfer_class = CustomMPITransfer
else:
Expand Down
41 changes: 26 additions & 15 deletions tests/test_transfer_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@
from mpi4py_fft.pencil import Transfer, CustomMPITransfer, Pencil, Subcomm
import numpy as np

transfer_classes = [CustomMPITransfer]
xps = {CustomMPITransfer: np}

try:
import cupy as cp
from mpi4py_fft.pencil import NCCLTransfer
transfer_classes += [NCCLTransfer]
xps[NCCLTransfer] = cp
except ModuleNotFoundError:
pass


def get_args(comm, shape, dtype):
subcomm = Subcomm(comm=comm, dims=None)
Expand All @@ -20,43 +31,43 @@ def get_args(comm, shape, dtype):
return kwargs


def get_arrays(kwargs):
arrayA = np.zeros(shape=kwargs['subshapeA'], dtype=kwargs['dtype'])
arrayB = np.zeros(shape=kwargs['subshapeB'], dtype=kwargs['dtype'])
def get_arrays(kwargs, xp):
arrayA = xp.zeros(shape=kwargs['subshapeA'], dtype=kwargs['dtype'])
arrayB = xp.zeros(shape=kwargs['subshapeB'], dtype=kwargs['dtype'])

arrayA[:] = np.random.random(arrayA.shape).astype(arrayA.dtype)
arrayA[:] = xp.random.random(arrayA.shape).astype(arrayA.dtype)
return arrayA, arrayB


def single_test_all_to_allw(transfer_class, shape, dtype, comm=None):
def single_test_all_to_allw(transfer_class, shape, dtype, comm=None, xp=None):
comm = comm if comm else MPI.COMM_WORLD
kwargs = get_args(comm, shape, dtype)
arrayA, arrayB = get_arrays(kwargs)
arrayA, arrayB = get_arrays(kwargs, xp)
arrayB_ref = arrayB.copy()

transfer = transfer_class(**kwargs)
reference_transfer = Transfer(**kwargs)

transfer.Alltoallw(arrayA, transfer._subtypesA, arrayB, transfer._subtypesB)
reference_transfer.Alltoallw(arrayA, transfer._subtypesA, arrayB_ref, transfer._subtypesB)
assert np.allclose(arrayB, arrayB_ref), f'Did not get the same result from `alltoallw` with {transfer_class.__name__} transfer class as MPI implementation on rank {comm.rank}!'
assert xp.allclose(arrayB, arrayB_ref), f'Did not get the same result from `alltoallw` with {transfer_class.__name__} transfer class as MPI implementation on rank {comm.rank}!'

comm.Barrier()
if comm.rank == 0:
print(f'{transfer_class.__name__} passed alltoallw test with shape {shape} and dtype {dtype}')


def single_test_forward_backward(transfer_class, shape, dtype, comm=None):
def single_test_forward_backward(transfer_class, shape, dtype, comm=None, xp=None):
comm = comm if comm else MPI.COMM_WORLD
kwargs = get_args(comm, shape, dtype)
arrayA, arrayB = get_arrays(kwargs)
arrayA, arrayB = get_arrays(kwargs, xp)
arrayA_ref = arrayA.copy()

transfer = transfer_class(**kwargs)

transfer.forward(arrayA, arrayB)
transfer.backward(arrayB, arrayA)
assert np.allclose(arrayA, arrayA_ref), f'Did not get the same result when transferring back and forth with {transfer_class.__name__} transfer class on rank {comm.rank}!'
assert xp.allclose(arrayA, arrayA_ref), f'Did not get the same result when transferring back and forth with {transfer_class.__name__} transfer class on rank {comm.rank}!'

comm.Barrier()
if comm.rank == 0:
Expand All @@ -67,14 +78,14 @@ def test_transfer_class():
dims = (2, 3)
sizes = (7, 8, 9, 128)
dtypes = 'fFdD'
transfer_class = CustomMPITransfer

shapes = [[size] * dim for size in sizes for dim in dims] + [[32, 256, 129]]

for shape in shapes:
for dtype in dtypes:
single_test_all_to_allw(transfer_class, shape, dtype)
single_test_forward_backward(transfer_class, shape, dtype)
for transfer_class in transfer_classes:
for shape in shapes:
for dtype in dtypes:
single_test_all_to_allw(transfer_class, shape, dtype, xp=xps[transfer_class])
single_test_forward_backward(transfer_class, shape, dtype, xp=xps[transfer_class])


if __name__ == '__main__':
Expand Down

0 comments on commit 405fb4d

Please sign in to comment.