Skip to content

Commit

Permalink
Removed unnecessary rescaling in CuPy backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas committed Feb 27, 2024
1 parent b69a303 commit b9ac0f7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
33 changes: 17 additions & 16 deletions mpi4py_fft/libfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,11 @@ def _Xfftn_plan_cupy(shape, axes, dtype, transforms, options):

s = tuple(np.take(shape, axes))
U = cp.empty(shape=shape, dtype=dtype)
V = plan_fwd(U, s=s, axes=axes)
V = cp.array(V)
M = np.prod(s)
V = cp.empty_like(plan_fwd(U, s=s, axes=axes))

# CuPy has forward transform unscaled and backward scaled with 1/N
return (
_Yfftn_wrap(plan_fwd, U, V, 1, {'s': s, 'axes': axes}, xp=cp),
_Yfftn_wrap(plan_bck, V, U, M, {'s': s, 'axes': axes}, xp=cp),
_Yfftn_wrap(plan_fwd, U, V, 1, {'s': s, 'axes': axes, 'norm': 'backward',}, xp=cp),
_Yfftn_wrap(plan_bck, V, U, 1, {'s': s, 'axes': axes, 'norm': 'forward',}, xp=cp),
)

def _Xfftn_plan_numpy(shape, axes, dtype, transforms, options):
Expand Down Expand Up @@ -189,6 +186,9 @@ def _Xfftn_plan_scipy(shape, axes, dtype, transforms, options):
return (_Yfftn_wrap(plan_fwd, U, V, 1, {'shape': s, 'axes': axes}),
_Yfftn_wrap(plan_bck, V, U, M, {'shape': s, 'axes': axes}))

def _copyto(dst, src, xp):
xp.copyto(dst, src, casting='unsafe')

class _Yfftn_wrap(object):
#Wraps numpy/scipy/mkl transforms to FFTW style
# pylint: disable=too-few-public-methods
Expand All @@ -208,9 +208,6 @@ def __init__(self, xfftn_obj, input_array, output_array, M, opt, xp=np):
def input_array(self):
return object.__getattribute__(self, '_input_array')

def copyto(self, dst, src):
self.xp.copyto(dst, src, casting='unsafe')

@property
def output_array(self):
return object.__getattribute__(self, '_output_array')
Expand All @@ -227,14 +224,16 @@ def opt(self):
def M(self):
return object.__getattribute__(self, '_M')

def copyto(self, dst, src):
_copyto(dst, src, self.xp)

def __call__(self, *args, **kwargs):
self.opt.update(kwargs)
self.copyto(self._output_array, self.xfftn(self.input_array, **self.opt))
if abs(self.M-1) > 1e-8:
self._output_array *= self.M
return self.output_array


class _Xfftn_wrap(object):
#Common interface for all serial transforms
# pylint: disable=too-few-public-methods
Expand All @@ -261,7 +260,7 @@ def xfftn(self):
return object.__getattribute__(self, '_xfftn')

def copyto(self, dst, src):
self.xp.copyto(dst, src, casting='unsafe')
_copyto(dst, src, self.xp)

def __call__(self, input_array=None, output_array=None, **options):
if input_array is not None:
Expand Down Expand Up @@ -454,19 +453,21 @@ def __init__(self, shape, axes=None, dtype=float, padding=False,
self.padding_factor = 1.0
if padding is not False:
self.padding_factor = padding[self.axes[-1]] if np.ndim(padding) else padding
xp = np
if 'cupy' in self.backend:
import cupy as cp
xp = cp

if abs(self.padding_factor-1.0) > 1e-8:
assert len(self.axes) == 1
trunc_array = self._get_truncarray(shape, V.dtype)
xp = np
if 'cupy' in self.backend: # TODO: Skip detour via CPU
import cupy as cp
trunc_array = cp.array(trunc_array)
xp = cp
self.forward = _Xfftn_wrap(self._forward, U, trunc_array, xp=xp)
self.backward = _Xfftn_wrap(self._backward, trunc_array, U, xp=xp)
else:
self.forward = _Xfftn_wrap(self._forward, U, V)
self.backward = _Xfftn_wrap(self._backward, V, U)
self.forward = _Xfftn_wrap(self._forward, U, V, xp=xp)
self.backward = _Xfftn_wrap(self._backward, V, U, xp=xp)

def _forward(self, **kw):
normalize = kw.pop('normalize', True)
Expand Down
2 changes: 1 addition & 1 deletion mpi4py_fft/pencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def Alltoallw(self, arrayA, subtypesA, arrayB, subtypesB):
events = {}
for i in recvbufs.keys():
with streams[i]:
cp.copyto(arrayB[sliceBs[i]], recvbufs[i][:])
cp.copyto(arrayB[sliceBs[i]], recvbufs[i])
events[i] = streams[i].record()

for i in events.keys():
Expand Down

0 comments on commit b9ac0f7

Please sign in to comment.