From b9ac0f71b336c7d421217de915e713276b9062c1 Mon Sep 17 00:00:00 2001 From: Thomas Date: Tue, 27 Feb 2024 16:23:54 +0100 Subject: [PATCH] Removed unnecessary rescaling in CuPy backend --- mpi4py_fft/libfft.py | 33 +++++++++++++++++---------------- mpi4py_fft/pencil.py | 2 +- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/mpi4py_fft/libfft.py b/mpi4py_fft/libfft.py index d271432..aca5265 100644 --- a/mpi4py_fft/libfft.py +++ b/mpi4py_fft/libfft.py @@ -95,14 +95,11 @@ def _Xfftn_plan_cupy(shape, axes, dtype, transforms, options): s = tuple(np.take(shape, axes)) U = cp.empty(shape=shape, dtype=dtype) - V = plan_fwd(U, s=s, axes=axes) - V = cp.array(V) - M = np.prod(s) + V = cp.empty_like(plan_fwd(U, s=s, axes=axes)) - # CuPy has forward transform unscaled and backward scaled with 1/N return ( - _Yfftn_wrap(plan_fwd, U, V, 1, {'s': s, 'axes': axes}, xp=cp), - _Yfftn_wrap(plan_bck, V, U, M, {'s': s, 'axes': axes}, xp=cp), + _Yfftn_wrap(plan_fwd, U, V, 1, {'s': s, 'axes': axes, 'norm': 'backward',}, xp=cp), + _Yfftn_wrap(plan_bck, V, U, 1, {'s': s, 'axes': axes, 'norm': 'forward',}, xp=cp), ) def _Xfftn_plan_numpy(shape, axes, dtype, transforms, options): @@ -189,6 +186,9 @@ def _Xfftn_plan_scipy(shape, axes, dtype, transforms, options): return (_Yfftn_wrap(plan_fwd, U, V, 1, {'shape': s, 'axes': axes}), _Yfftn_wrap(plan_bck, V, U, M, {'shape': s, 'axes': axes})) +def _copyto(dst, src, xp): + xp.copyto(dst, src, casting='unsafe') + class _Yfftn_wrap(object): #Wraps numpy/scipy/mkl transforms to FFTW style # pylint: disable=too-few-public-methods @@ -208,9 +208,6 @@ def __init__(self, xfftn_obj, input_array, output_array, M, opt, xp=np): def input_array(self): return object.__getattribute__(self, '_input_array') - def copyto(self, dst, src): - self.xp.copyto(dst, src, casting='unsafe') - @property def output_array(self): return object.__getattribute__(self, '_output_array') @@ -227,6 +224,9 @@ def opt(self): def M(self): return object.__getattribute__(self, '_M') + def copyto(self, dst, src): + _copyto(dst, src, self.xp) + def __call__(self, *args, **kwargs): self.opt.update(kwargs) self.copyto(self._output_array, self.xfftn(self.input_array, **self.opt)) @@ -234,7 +234,6 @@ def __call__(self, *args, **kwargs): self._output_array *= self.M return self.output_array - class _Xfftn_wrap(object): #Common interface for all serial transforms # pylint: disable=too-few-public-methods @@ -261,7 +260,7 @@ def xfftn(self): return object.__getattribute__(self, '_xfftn') def copyto(self, dst, src): - self.xp.copyto(dst, src, casting='unsafe') + _copyto(dst, src, self.xp) def __call__(self, input_array=None, output_array=None, **options): if input_array is not None: @@ -454,19 +453,21 @@ def __init__(self, shape, axes=None, dtype=float, padding=False, self.padding_factor = 1.0 if padding is not False: self.padding_factor = padding[self.axes[-1]] if np.ndim(padding) else padding + xp = np + if 'cupy' in self.backend: + import cupy as cp + xp = cp + if abs(self.padding_factor-1.0) > 1e-8: assert len(self.axes) == 1 trunc_array = self._get_truncarray(shape, V.dtype) - xp = np if 'cupy' in self.backend: # TODO: Skip detour via CPU - import cupy as cp trunc_array = cp.array(trunc_array) - xp = cp self.forward = _Xfftn_wrap(self._forward, U, trunc_array, xp=xp) self.backward = _Xfftn_wrap(self._backward, trunc_array, U, xp=xp) else: - self.forward = _Xfftn_wrap(self._forward, U, V) - self.backward = _Xfftn_wrap(self._backward, V, U) + self.forward = _Xfftn_wrap(self._forward, U, V, xp=xp) + self.backward = _Xfftn_wrap(self._backward, V, U, xp=xp) def _forward(self, **kw): normalize = kw.pop('normalize', True) diff --git a/mpi4py_fft/pencil.py b/mpi4py_fft/pencil.py index 8d225f6..2c4737e 100644 --- a/mpi4py_fft/pencil.py +++ b/mpi4py_fft/pencil.py @@ -321,7 +321,7 @@ def Alltoallw(self, arrayA, subtypesA, arrayB, subtypesB): events = {} for i in recvbufs.keys(): with streams[i]: - cp.copyto(arrayB[sliceBs[i]], recvbufs[i][:]) + cp.copyto(arrayB[sliceBs[i]], recvbufs[i]) events[i] = streams[i].record() for i in events.keys():