diff --git a/mpi4py_fft/distarray.py b/mpi4py_fft/distarray.py index 3f4ce5b..9416856 100644 --- a/mpi4py_fft/distarray.py +++ b/mpi4py_fft/distarray.py @@ -543,7 +543,7 @@ def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False): dtype = pfft.forward.input_array.dtype global_shape = (len(global_shape),) * rank + global_shape - if pfft.xfftn[0].backend in ["cupy"]: + if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]: from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls else: darraycls = DistArray diff --git a/mpi4py_fft/libfft.py b/mpi4py_fft/libfft.py index 3f74b4d..16feeda 100644 --- a/mpi4py_fft/libfft.py +++ b/mpi4py_fft/libfft.py @@ -151,6 +151,42 @@ def _Xfftn_plan_mkl(shape, axes, dtype, transforms, options): #pragma: no cover return (_Yfftn_wrap(plan_fwd, U, V, 1, {'s': s, 'axes': axes}), _Yfftn_wrap(plan_bck, V, U, M, {'s': s, 'axes': axes})) +def _Xfftn_plan_cupyx_scipy(shape, axes, dtype, transforms, options): + import cupy as cp + import cupyx.scipy.fft as fft_lib + + transforms = {} if transforms is None else transforms + if tuple(axes) in transforms: + _plan_fwd, _plan_bck = transforms[tuple(axes)] + else: + if cp.issubdtype(dtype, cp.floating): + _plan_fwd = fft_lib.rfftn + _plan_bck = fft_lib.irfftn + else: + _plan_fwd = fft_lib.fftn + _plan_bck = fft_lib.ifftn + + def swap_shape_for_s(kwargs): + _kwargs = { + 's': kwargs.pop('shape', None), + **kwargs, + } + return _kwargs + + def plan_fwd(*args, **kwargs): + return _plan_fwd(*args, **swap_shape_for_s(kwargs)) + + def plan_bck(*args, **kwargs): + return _plan_bck(*args, **swap_shape_for_s(kwargs)) + + s = tuple(np.take(shape, axes)) + U = cp.array(fftw.aligned(shape, dtype=dtype)) # TODO: Skip CPU detour + V = plan_fwd(U, s=s, axes=axes) + V = cp.array(fftw.aligned_like(V.get())) # TODO: skip CPU detour + M = np.prod(s) + return (_Yfftn_wrap(plan_fwd, U, V, 1, {'shape': s, 'axes': axes}), + _Yfftn_wrap(plan_bck, V, U, M, {'shape': s, 'axes': axes})) + def _Xfftn_plan_scipy(shape, axes, dtype, transforms, options): transforms = {} if transforms is None else transforms @@ -409,6 +445,7 @@ def __init__(self, shape, axes=None, dtype=float, padding=False, 'cupy': _Xfftn_plan_cupy, 'mkl_fft': _Xfftn_plan_mkl, 'scipy': _Xfftn_plan_scipy, + 'cupyx-scipy': _Xfftn_plan_cupyx_scipy, }[backend] self.backend = backend self.fwd, self.bck = plan(self.shape, self.axes, self.dtype, transforms, kw) @@ -427,7 +464,7 @@ def __init__(self, shape, axes=None, dtype=float, padding=False, if abs(self.padding_factor-1.0) > 1e-8: assert len(self.axes) == 1 trunc_array = self._get_truncarray(shape, V.dtype) - if self.backend in ['cupy']: # TODO: Skip detour via CPU + if 'cupy' in self.backend: # TODO: Skip detour via CPU import cupy as cp trunc_array = cp.array(trunc_array) self.forward = _Xfftn_wrap(self._forward, U, trunc_array) diff --git a/tests/test_libfft.py b/tests/test_libfft.py index d4523a6..180fc11 100644 --- a/tests/test_libfft.py +++ b/tests/test_libfft.py @@ -16,12 +16,22 @@ if has_backend['cupy']: import cupy as cp + has_backend['cupyx-scipy'] = True abstol = dict(f=5e-5, d=1e-14, g=1e-14) +def get_xpy(backend=None, array=None): + if backend in ['cupy', 'cupyx-scipy']: + return cp + if has_backend['cupy'] and array is not None: + if type(array) == cp.ndarray: + return cp + return np + + def allclose(a, b): atol = abstol[a.dtype.char.lower()] - xp = cp if type(a) == cp.ndarray else np + xp = get_xpy(array=a) return xp.allclose(a, b, rtol=0, atol=atol) def test_libfft(): @@ -37,6 +47,7 @@ def test_libfft(): for backend in has_backend.keys(): if has_backend[backend] is False: continue + xp = get_xpy(backend=backend) t0 = 0 for typecode in types: for dim in dims: @@ -51,7 +62,7 @@ def test_libfft(): A = fft.forward.input_array B = fft.forward.output_array - A[...] = (cp if backend in ['cupy'] else np).random.random(A.shape).astype(typecode) + A[...] = xp.random.random(A.shape).astype(typecode) X = A.copy() B.fill(0) @@ -71,7 +82,7 @@ def test_libfft(): for backend in has_backend.keys(): if has_backend[backend] is False: continue - xp = cp if backend in ['cupy'] else np + xp = get_xpy(backend=backend) for padding in (1.5, 2.0): for typecode in types: for dim in dims: @@ -105,7 +116,7 @@ def test_libfft(): for backend in has_backend.keys(): if has_backend[backend] is False: continue - xp = cp if backend in ['cupy'] else np + xp = get_xpy(backend=backend) if backend == 'fftw': dctn = functools.partial(fftw.dctn, type=3) @@ -130,6 +141,11 @@ def test_libfft(): from scipy.fftpack import fftn, ifftn transforms = {(1,): (fftn, ifftn), (0, 1): (fftn, ifftn)} + elif backend == 'cupyx-scipy': + from scipy.fftpack import fftn, ifftn + import cupyx.scipy.fft as fftlib + transforms = {(1,): (fftlib.fftn, fftlib.ifftn), + (0, 1): (fftlib.fftn, fftlib.ifftn)} for axis in ((1,), (0, 1)): fft = FFT(shape, axis, backend=backend, transforms=transforms)