Skip to content

Commit

Permalink
Minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas committed Dec 18, 2023
1 parent e5d4b56 commit d0a493e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 88 deletions.
113 changes: 38 additions & 75 deletions mpi4py_fft/distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ def alignment(self):
@property
def global_shape(self):
"""Return global shape of ``self``"""
return self.shape[: self.rank] + self._p0.shape
return self.shape[:self.rank] + self._p0.shape

@property
def substart(self):
"""Return starting indices of local ``self`` array"""
return (0,) * self.rank + self._p0.substart
return (0,)*self.rank + self._p0.substart

@property
def subcomm(self):
"""Return tuple of subcommunicators for all axes of ``self``"""
return (MPI.COMM_SELF,) * self.rank + self._p0.subcomm
return (MPI.COMM_SELF,)*self.rank + self._p0.subcomm

@property
def commsizes(self):
Expand All @@ -83,7 +83,7 @@ def dimensions(self):
return len(self._p0.shape)

@staticmethod
def getSubcomm(subcomm, global_shape, rank, alignment):
def get_subcomm(subcomm, global_shape, rank, alignment):
if isinstance(subcomm, Subcomm):
pass
else:
Expand All @@ -104,7 +104,7 @@ def getSubcomm(subcomm, global_shape, rank, alignment):
return subcomm

@classmethod
def getPencil(cls, subcomm, rank, global_shape, alignment):
def setup_pencil(cls, subcomm, rank, global_shape, alignment):
sizes = [s.Get_size() for s in subcomm]
if alignment is not None:
assert isinstance(alignment, (int, cls.xp.integer))
Expand Down Expand Up @@ -185,36 +185,31 @@ def get(self, gslice):
# global MPI. We create a global file with MPI, but then open it without
# MPI and only on rank 0.
import h5py

# TODO: can we use h5py to communicate the data without copying to cpu first when using cupy?
f = h5py.File("tmp.h5", "w", driver="mpio", comm=comm)
f = h5py.File('tmp.h5', 'w', driver="mpio", comm=comm)
s = self.local_slice()
sp = np.nonzero([isinstance(x, slice) for x in gslice])[0]
sf = tuple(np.take(s, sp))
f.require_dataset(
"data", shape=tuple(np.take(self.global_shape, sp)), dtype=self.dtype
)
f.require_dataset('data', shape=tuple(np.take(self.global_shape, sp)), dtype=self.dtype)
gslice = list(gslice)
# We are required to check if the indices in si are on this processor
si = np.nonzero(
[isinstance(x, int) and not z == slice(None) for x, z in zip(gslice, s)]
)[0]
si = np.nonzero([isinstance(x, int) and not z == slice(None) for x, z in zip(gslice, s)])[0]
on_this_proc = True
for i in si:
if gslice[i] >= s[i].start and gslice[i] < s[i].stop:
gslice[i] -= s[i].start
else:
on_this_proc = False
if on_this_proc:
data = self.asnumpy
data = self.asnumpy()
f["data"][sf] = data[tuple(gslice)]
f.close()
c = None
if comm.Get_rank() == 0:
h = h5py.File("tmp.h5", "r")
c = h["data"].__array__()
h = h5py.File('tmp.h5', 'r')
c = h['data'].__array__()
h.close()
os.remove("tmp.h5")
os.remove('tmp.h5')
return c

def local_slice(self):
Expand Down Expand Up @@ -250,11 +245,9 @@ def local_slice(self):
(slice(0, 16, None), slice(7, 14, None), slice(0, 6, None))
(slice(0, 16, None), slice(7, 14, None), slice(6, 12, None))
"""
v = [
slice(start, start + shape)
for start, shape in zip(self._p0.substart, self._p0.subshape)
]
return tuple([slice(0, s) for s in self.shape[: self.rank]] + v)
v = [slice(start, start+shape) for start, shape in zip(self._p0.substart,
self._p0.subshape)]
return tuple([slice(0, s) for s in self.shape[:self.rank]] + v)

def redistribute(self, axis=None, out=None):
"""Global redistribution of local ``self`` array
Expand Down Expand Up @@ -283,7 +276,7 @@ def redistribute(self, axis=None, out=None):
# Check if self is already aligned along axis. In that case just switch
# axis of pencil (both axes are undivided) and return
if axis is not None:
if self.commsizes[self.rank + axis] == 1:
if self.commsizes[self.rank+axis] == 1:
self.pencil.axis = axis
return self

Expand All @@ -304,13 +297,11 @@ def redistribute(self, axis=None, out=None):

p1, transfer = self.get_pencil_and_transfer(axis)
if out is None:
out = type(self)(
self.global_shape,
out = type(self)(self.global_shape,
subcomm=p1.subcomm,
dtype=self.dtype,
alignment=axis,
rank=self.rank,
)
rank=self.rank)

if self.rank == 0:
transfer.forward(self, out)
Expand Down Expand Up @@ -343,15 +334,8 @@ def get_pencil_and_transfer(self, axis):
p1 = self._p0.pencil(axis)
return p1, self._p0.transfer(p1, self.dtype)

def write(
self,
filename,
name="darray",
step=0,
global_slice=None,
domain=None,
as_scalar=False,
):
def write(self, filename, name='darray', step=0, global_slice=None,
domain=None, as_scalar=False):
"""Write snapshot ``step`` of ``self`` to file ``filename``
Parameters
Expand Down Expand Up @@ -384,14 +368,14 @@ def write(
>>> u.write('h5file.h5', 'u', (slice(None), 4))
"""
if isinstance(filename, str):
writer = HDF5File if filename.endswith(".h5") else NCFile
f = writer(filename, domain=domain, mode="a")
writer = HDF5File if filename.endswith('.h5') else NCFile
f = writer(filename, domain=domain, mode='a')
elif isinstance(filename, FileBase):
f = filename
field = [self] if global_slice is None else [(self, global_slice)]
f.write(step, {name: field}, as_scalar=as_scalar)

def read(self, filename, name="darray", step=0):
def read(self, filename, name='darray', step=0):
"""Read data ``name`` at index ``step``from file ``filename`` into
``self``
Expand Down Expand Up @@ -420,8 +404,8 @@ def read(self, filename, name="darray", step=0):
"""
if isinstance(filename, str):
writer = HDF5File if filename.endswith(".h5") else NCFile
f = writer(filename, mode="r")
writer = HDF5File if filename.endswith('.h5') else NCFile
f = writer(filename, mode='r')
elif isinstance(filename, FileBase):
f = filename
f.read(self, name, step=step)
Expand Down Expand Up @@ -460,29 +444,18 @@ class DistArray(DistArrayBase, np.ndarray):

xp = np

def __new__(
cls,
global_shape,
subcomm=None,
val=None,
dtype=float,
buffer=None,
strides=None,
alignment=None,
rank=0,
):
if len(global_shape[rank:]) < 2: # 1D case
obj = cls.xp.ndarray.__new__(
cls, global_shape, dtype=dtype, buffer=buffer, strides=strides
)
def __new__(cls, global_shape, subcomm=None, val=None, dtype=float,
buffer=None, strides=None, alignment=None, rank=0):
if len(global_shape[rank:]) < 2: # 1D case
obj = cls.xp.ndarray.__new__(cls, global_shape, dtype=dtype, buffer=buffer, strides=strides)
if buffer is None and isinstance(val, Number):
obj.fill(val)
obj._rank = rank
obj._p0 = None
return obj

subcomm = cls.getSubcomm(subcomm, global_shape, rank, alignment)
p0, subshape = cls.getPencil(subcomm, rank, global_shape, alignment)
subcomm = cls.get_subcomm(subcomm, global_shape, rank, alignment)
p0, subshape = cls.setup_pencil(subcomm, rank, global_shape, alignment)

obj = cls.xp.ndarray.__new__(cls, subshape, dtype=dtype, buffer=buffer)
if buffer is None and isinstance(val, Number):
Expand All @@ -496,11 +469,9 @@ def v(self):
"""Return local ``self`` array as an ``ndarray`` object"""
return self.__array__()

@property
def asnumpy(self):
return self


def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False):
"""Return a new :class:`.DistArray` object for provided :class:`.PFFT` object
Expand Down Expand Up @@ -541,29 +512,21 @@ def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False):
dtype = pfft.forward.output_array.dtype
else:
dtype = pfft.forward.input_array.dtype
global_shape = (len(global_shape),) * rank + global_shape
global_shape = (len(global_shape),)*rank + global_shape

if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]:
from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls
else:
darraycls = DistArray

z = darraycls(
global_shape,
subcomm=p0.subcomm,
val=val,
dtype=dtype,
alignment=p0.axis,
rank=rank,
)
z = darraycls(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype,
alignment=p0.axis, rank=rank)
return z.v if view else z


def Function(*args, **kwargs): # pragma: no cover
def Function(*args, **kwargs): #pragma: no cover
import warnings

warnings.warn("Function() is deprecated; use newDistArray().", FutureWarning)
if "tensor" in kwargs:
kwargs["rank"] = 1
del kwargs["tensor"]
if 'tensor' in kwargs:
kwargs['rank'] = 1
del kwargs['tensor']
return newDistArray(*args, **kwargs)
5 changes: 2 additions & 3 deletions mpi4py_fft/distarrayCuPy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __new__(
obj._p0 = None
return obj

subcomm = cls.getSubcomm(subcomm, global_shape, rank, alignment)
p0, subshape = cls.getPencil(subcomm, rank, global_shape, alignment)
subcomm = cls.get_subcomm(subcomm, global_shape, rank, alignment)
p0, subshape = cls.setup_pencil(subcomm, rank, global_shape, alignment)

obj = cls.xp.ndarray.__new__(cls, subshape, dtype=dtype, memptr=memptr)
if memptr is None and isinstance(val, Number):
Expand All @@ -76,7 +76,6 @@ def get(self, *args, **kwargs):
else:
return cp.ndarray.get(self, *args, **kwargs)

@property
def asnumpy(self):
"""Copy the array to CPU"""
return self.get()
Expand Down
20 changes: 10 additions & 10 deletions mpi4py_fft/pencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,27 +246,27 @@ def Alltoallw(self, arrayA, subtypesA, arrayB, subtypesB):
iscomplex = cp.iscomplexobj(arrayA)
NCCL_dtype, real_dtype = self.get_nccl_and_real_dtypes(arrayA)

for i in range(size):
for j in range(size):
for recv_rank in range(size):
for send_rank in range(size):

if rank == i:
local_slice, shape = self.get_slice_and_shape(subtypesB[j])
if rank == recv_rank:
local_slice, shape = self.get_slice_and_shape(subtypesB[send_rank])
buff = self.get_buffer(shape, iscomplex, real_dtype)

if i == j:
send_slice, _ = self.get_slice_and_shape(subtypesA[i])
if recv_rank == send_rank:
send_slice, _ = self.get_slice_and_shape(subtypesA[recv_rank])
self.fill_buffer(buff, arrayA, send_slice, iscomplex)
else:
comm.recv(buff.data.ptr, buff.size, NCCL_dtype, j, stream)
comm.recv(buff.data.ptr, buff.size, NCCL_dtype, send_rank, stream)

self.unpack_buffer(buff, arrayB, local_slice, iscomplex)

elif rank == j:
local_slice, shape = self.get_slice_and_shape(subtypesA[i])
elif rank == send_rank:
local_slice, shape = self.get_slice_and_shape(subtypesA[recv_rank])
buff = self.get_buffer(shape, iscomplex, real_dtype)
self.fill_buffer(buff, arrayA, local_slice, iscomplex)

comm.send(buff.data.ptr, buff.size, NCCL_dtype, i, stream)
comm.send(buff.data.ptr, buff.size, NCCL_dtype, recv_rank, stream)


@staticmethod
Expand Down

0 comments on commit d0a493e

Please sign in to comment.