diff --git a/mpi4py_fft/distarray.py b/mpi4py_fft/distarray.py index 9416856..a8815a5 100644 --- a/mpi4py_fft/distarray.py +++ b/mpi4py_fft/distarray.py @@ -50,17 +50,17 @@ def alignment(self): @property def global_shape(self): """Return global shape of ``self``""" - return self.shape[: self.rank] + self._p0.shape + return self.shape[:self.rank] + self._p0.shape @property def substart(self): """Return starting indices of local ``self`` array""" - return (0,) * self.rank + self._p0.substart + return (0,)*self.rank + self._p0.substart @property def subcomm(self): """Return tuple of subcommunicators for all axes of ``self``""" - return (MPI.COMM_SELF,) * self.rank + self._p0.subcomm + return (MPI.COMM_SELF,)*self.rank + self._p0.subcomm @property def commsizes(self): @@ -83,7 +83,7 @@ def dimensions(self): return len(self._p0.shape) @staticmethod - def getSubcomm(subcomm, global_shape, rank, alignment): + def get_subcomm(subcomm, global_shape, rank, alignment): if isinstance(subcomm, Subcomm): pass else: @@ -104,7 +104,7 @@ def getSubcomm(subcomm, global_shape, rank, alignment): return subcomm @classmethod - def getPencil(cls, subcomm, rank, global_shape, alignment): + def setup_pencil(cls, subcomm, rank, global_shape, alignment): sizes = [s.Get_size() for s in subcomm] if alignment is not None: assert isinstance(alignment, (int, cls.xp.integer)) @@ -185,20 +185,15 @@ def get(self, gslice): # global MPI. We create a global file with MPI, but then open it without # MPI and only on rank 0. import h5py - # TODO: can we use h5py to communicate the data without copying to cpu first when using cupy? - f = h5py.File("tmp.h5", "w", driver="mpio", comm=comm) + f = h5py.File('tmp.h5', 'w', driver="mpio", comm=comm) s = self.local_slice() sp = np.nonzero([isinstance(x, slice) for x in gslice])[0] sf = tuple(np.take(s, sp)) - f.require_dataset( - "data", shape=tuple(np.take(self.global_shape, sp)), dtype=self.dtype - ) + f.require_dataset('data', shape=tuple(np.take(self.global_shape, sp)), dtype=self.dtype) gslice = list(gslice) # We are required to check if the indices in si are on this processor - si = np.nonzero( - [isinstance(x, int) and not z == slice(None) for x, z in zip(gslice, s)] - )[0] + si = np.nonzero([isinstance(x, int) and not z == slice(None) for x, z in zip(gslice, s)])[0] on_this_proc = True for i in si: if gslice[i] >= s[i].start and gslice[i] < s[i].stop: @@ -206,15 +201,15 @@ def get(self, gslice): else: on_this_proc = False if on_this_proc: - data = self.asnumpy + data = self.asnumpy() f["data"][sf] = data[tuple(gslice)] f.close() c = None if comm.Get_rank() == 0: - h = h5py.File("tmp.h5", "r") - c = h["data"].__array__() + h = h5py.File('tmp.h5', 'r') + c = h['data'].__array__() h.close() - os.remove("tmp.h5") + os.remove('tmp.h5') return c def local_slice(self): @@ -250,11 +245,9 @@ def local_slice(self): (slice(0, 16, None), slice(7, 14, None), slice(0, 6, None)) (slice(0, 16, None), slice(7, 14, None), slice(6, 12, None)) """ - v = [ - slice(start, start + shape) - for start, shape in zip(self._p0.substart, self._p0.subshape) - ] - return tuple([slice(0, s) for s in self.shape[: self.rank]] + v) + v = [slice(start, start+shape) for start, shape in zip(self._p0.substart, + self._p0.subshape)] + return tuple([slice(0, s) for s in self.shape[:self.rank]] + v) def redistribute(self, axis=None, out=None): """Global redistribution of local ``self`` array @@ -283,7 +276,7 @@ def redistribute(self, axis=None, out=None): # Check if self is already aligned along axis. In that case just switch # axis of pencil (both axes are undivided) and return if axis is not None: - if self.commsizes[self.rank + axis] == 1: + if self.commsizes[self.rank+axis] == 1: self.pencil.axis = axis return self @@ -304,13 +297,11 @@ def redistribute(self, axis=None, out=None): p1, transfer = self.get_pencil_and_transfer(axis) if out is None: - out = type(self)( - self.global_shape, + out = type(self)(self.global_shape, subcomm=p1.subcomm, dtype=self.dtype, alignment=axis, - rank=self.rank, - ) + rank=self.rank) if self.rank == 0: transfer.forward(self, out) @@ -343,15 +334,8 @@ def get_pencil_and_transfer(self, axis): p1 = self._p0.pencil(axis) return p1, self._p0.transfer(p1, self.dtype) - def write( - self, - filename, - name="darray", - step=0, - global_slice=None, - domain=None, - as_scalar=False, - ): + def write(self, filename, name='darray', step=0, global_slice=None, + domain=None, as_scalar=False): """Write snapshot ``step`` of ``self`` to file ``filename`` Parameters @@ -384,14 +368,14 @@ def write( >>> u.write('h5file.h5', 'u', (slice(None), 4)) """ if isinstance(filename, str): - writer = HDF5File if filename.endswith(".h5") else NCFile - f = writer(filename, domain=domain, mode="a") + writer = HDF5File if filename.endswith('.h5') else NCFile + f = writer(filename, domain=domain, mode='a') elif isinstance(filename, FileBase): f = filename field = [self] if global_slice is None else [(self, global_slice)] f.write(step, {name: field}, as_scalar=as_scalar) - def read(self, filename, name="darray", step=0): + def read(self, filename, name='darray', step=0): """Read data ``name`` at index ``step``from file ``filename`` into ``self`` @@ -420,8 +404,8 @@ def read(self, filename, name="darray", step=0): """ if isinstance(filename, str): - writer = HDF5File if filename.endswith(".h5") else NCFile - f = writer(filename, mode="r") + writer = HDF5File if filename.endswith('.h5') else NCFile + f = writer(filename, mode='r') elif isinstance(filename, FileBase): f = filename f.read(self, name, step=step) @@ -460,29 +444,18 @@ class DistArray(DistArrayBase, np.ndarray): xp = np - def __new__( - cls, - global_shape, - subcomm=None, - val=None, - dtype=float, - buffer=None, - strides=None, - alignment=None, - rank=0, - ): - if len(global_shape[rank:]) < 2: # 1D case - obj = cls.xp.ndarray.__new__( - cls, global_shape, dtype=dtype, buffer=buffer, strides=strides - ) + def __new__(cls, global_shape, subcomm=None, val=None, dtype=float, + buffer=None, strides=None, alignment=None, rank=0): + if len(global_shape[rank:]) < 2: # 1D case + obj = cls.xp.ndarray.__new__(cls, global_shape, dtype=dtype, buffer=buffer, strides=strides) if buffer is None and isinstance(val, Number): obj.fill(val) obj._rank = rank obj._p0 = None return obj - subcomm = cls.getSubcomm(subcomm, global_shape, rank, alignment) - p0, subshape = cls.getPencil(subcomm, rank, global_shape, alignment) + subcomm = cls.get_subcomm(subcomm, global_shape, rank, alignment) + p0, subshape = cls.setup_pencil(subcomm, rank, global_shape, alignment) obj = cls.xp.ndarray.__new__(cls, subshape, dtype=dtype, buffer=buffer) if buffer is None and isinstance(val, Number): @@ -496,11 +469,9 @@ def v(self): """Return local ``self`` array as an ``ndarray`` object""" return self.__array__() - @property def asnumpy(self): return self - def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False): """Return a new :class:`.DistArray` object for provided :class:`.PFFT` object @@ -541,29 +512,21 @@ def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False): dtype = pfft.forward.output_array.dtype else: dtype = pfft.forward.input_array.dtype - global_shape = (len(global_shape),) * rank + global_shape + global_shape = (len(global_shape),)*rank + global_shape if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]: from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls else: darraycls = DistArray - z = darraycls( - global_shape, - subcomm=p0.subcomm, - val=val, - dtype=dtype, - alignment=p0.axis, - rank=rank, - ) + z = darraycls(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype, + alignment=p0.axis, rank=rank) return z.v if view else z - -def Function(*args, **kwargs): # pragma: no cover +def Function(*args, **kwargs): #pragma: no cover import warnings - warnings.warn("Function() is deprecated; use newDistArray().", FutureWarning) - if "tensor" in kwargs: - kwargs["rank"] = 1 - del kwargs["tensor"] + if 'tensor' in kwargs: + kwargs['rank'] = 1 + del kwargs['tensor'] return newDistArray(*args, **kwargs) diff --git a/mpi4py_fft/distarrayCuPy.py b/mpi4py_fft/distarrayCuPy.py index d3ead89..7374043 100644 --- a/mpi4py_fft/distarrayCuPy.py +++ b/mpi4py_fft/distarrayCuPy.py @@ -59,8 +59,8 @@ def __new__( obj._p0 = None return obj - subcomm = cls.getSubcomm(subcomm, global_shape, rank, alignment) - p0, subshape = cls.getPencil(subcomm, rank, global_shape, alignment) + subcomm = cls.get_subcomm(subcomm, global_shape, rank, alignment) + p0, subshape = cls.setup_pencil(subcomm, rank, global_shape, alignment) obj = cls.xp.ndarray.__new__(cls, subshape, dtype=dtype, memptr=memptr) if memptr is None and isinstance(val, Number): @@ -76,7 +76,6 @@ def get(self, *args, **kwargs): else: return cp.ndarray.get(self, *args, **kwargs) - @property def asnumpy(self): """Copy the array to CPU""" return self.get() diff --git a/mpi4py_fft/pencil.py b/mpi4py_fft/pencil.py index 3a8a7b2..4cf3ca1 100644 --- a/mpi4py_fft/pencil.py +++ b/mpi4py_fft/pencil.py @@ -246,27 +246,27 @@ def Alltoallw(self, arrayA, subtypesA, arrayB, subtypesB): iscomplex = cp.iscomplexobj(arrayA) NCCL_dtype, real_dtype = self.get_nccl_and_real_dtypes(arrayA) - for i in range(size): - for j in range(size): + for recv_rank in range(size): + for send_rank in range(size): - if rank == i: - local_slice, shape = self.get_slice_and_shape(subtypesB[j]) + if rank == recv_rank: + local_slice, shape = self.get_slice_and_shape(subtypesB[send_rank]) buff = self.get_buffer(shape, iscomplex, real_dtype) - if i == j: - send_slice, _ = self.get_slice_and_shape(subtypesA[i]) + if recv_rank == send_rank: + send_slice, _ = self.get_slice_and_shape(subtypesA[recv_rank]) self.fill_buffer(buff, arrayA, send_slice, iscomplex) else: - comm.recv(buff.data.ptr, buff.size, NCCL_dtype, j, stream) + comm.recv(buff.data.ptr, buff.size, NCCL_dtype, send_rank, stream) self.unpack_buffer(buff, arrayB, local_slice, iscomplex) - elif rank == j: - local_slice, shape = self.get_slice_and_shape(subtypesA[i]) + elif rank == send_rank: + local_slice, shape = self.get_slice_and_shape(subtypesA[recv_rank]) buff = self.get_buffer(shape, iscomplex, real_dtype) self.fill_buffer(buff, arrayA, local_slice, iscomplex) - comm.send(buff.data.ptr, buff.size, NCCL_dtype, i, stream) + comm.send(buff.data.ptr, buff.size, NCCL_dtype, recv_rank, stream) @staticmethod