Skip to content

Commit

Permalink
Added cupyx-scipy backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas committed Dec 14, 2023
1 parent 90551b4 commit 202f30c
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 6 deletions.
2 changes: 1 addition & 1 deletion mpi4py_fft/distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False):
dtype = pfft.forward.input_array.dtype
global_shape = (len(global_shape),) * rank + global_shape

if pfft.xfftn[0].backend in ["cupy"]:
if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]:
from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls
else:
darraycls = DistArray
Expand Down
39 changes: 38 additions & 1 deletion mpi4py_fft/libfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,42 @@ def _Xfftn_plan_mkl(shape, axes, dtype, transforms, options): #pragma: no cover
return (_Yfftn_wrap(plan_fwd, U, V, 1, {'s': s, 'axes': axes}),
_Yfftn_wrap(plan_bck, V, U, M, {'s': s, 'axes': axes}))

def _Xfftn_plan_cupyx_scipy(shape, axes, dtype, transforms, options):
import cupy as cp
import cupyx.scipy.fft as fft_lib

transforms = {} if transforms is None else transforms
if tuple(axes) in transforms:
_plan_fwd, _plan_bck = transforms[tuple(axes)]
else:
if cp.issubdtype(dtype, cp.floating):
_plan_fwd = fft_lib.rfftn
_plan_bck = fft_lib.irfftn
else:
_plan_fwd = fft_lib.fftn
_plan_bck = fft_lib.ifftn

def swap_shape_for_s(kwargs):
_kwargs = {
's': kwargs.pop('shape', None),
**kwargs,
}
return _kwargs

def plan_fwd(*args, **kwargs):
return _plan_fwd(*args, **swap_shape_for_s(kwargs))

def plan_bck(*args, **kwargs):
return _plan_bck(*args, **swap_shape_for_s(kwargs))

s = tuple(np.take(shape, axes))
U = cp.array(fftw.aligned(shape, dtype=dtype)) # TODO: Skip CPU detour
V = plan_fwd(U, s=s, axes=axes)
V = cp.array(fftw.aligned_like(V.get())) # TODO: skip CPU detour
M = np.prod(s)
return (_Yfftn_wrap(plan_fwd, U, V, 1, {'shape': s, 'axes': axes}),
_Yfftn_wrap(plan_bck, V, U, M, {'shape': s, 'axes': axes}))

def _Xfftn_plan_scipy(shape, axes, dtype, transforms, options):

transforms = {} if transforms is None else transforms
Expand Down Expand Up @@ -409,6 +445,7 @@ def __init__(self, shape, axes=None, dtype=float, padding=False,
'cupy': _Xfftn_plan_cupy,
'mkl_fft': _Xfftn_plan_mkl,
'scipy': _Xfftn_plan_scipy,
'cupyx-scipy': _Xfftn_plan_cupyx_scipy,
}[backend]
self.backend = backend
self.fwd, self.bck = plan(self.shape, self.axes, self.dtype, transforms, kw)
Expand All @@ -427,7 +464,7 @@ def __init__(self, shape, axes=None, dtype=float, padding=False,
if abs(self.padding_factor-1.0) > 1e-8:
assert len(self.axes) == 1
trunc_array = self._get_truncarray(shape, V.dtype)
if self.backend in ['cupy']: # TODO: Skip detour via CPU
if 'cupy' in self.backend: # TODO: Skip detour via CPU
import cupy as cp
trunc_array = cp.array(trunc_array)
self.forward = _Xfftn_wrap(self._forward, U, trunc_array)
Expand Down
24 changes: 20 additions & 4 deletions tests/test_libfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,22 @@

if has_backend['cupy']:
import cupy as cp
has_backend['cupyx-scipy'] = True

abstol = dict(f=5e-5, d=1e-14, g=1e-14)

def get_xpy(backend=None, array=None):
if backend in ['cupy', 'cupyx-scipy']:
return cp
if has_backend['cupy'] and array is not None:
if type(array) == cp.ndarray:
return cp
return np


def allclose(a, b):
atol = abstol[a.dtype.char.lower()]
xp = cp if type(a) == cp.ndarray else np
xp = get_xpy(array=a)
return xp.allclose(a, b, rtol=0, atol=atol)

def test_libfft():
Expand All @@ -37,6 +47,7 @@ def test_libfft():
for backend in has_backend.keys():
if has_backend[backend] is False:
continue
xp = get_xpy(backend=backend)
t0 = 0
for typecode in types:
for dim in dims:
Expand All @@ -51,7 +62,7 @@ def test_libfft():
A = fft.forward.input_array
B = fft.forward.output_array

A[...] = (cp if backend in ['cupy'] else np).random.random(A.shape).astype(typecode)
A[...] = xp.random.random(A.shape).astype(typecode)
X = A.copy()

B.fill(0)
Expand All @@ -71,7 +82,7 @@ def test_libfft():
for backend in has_backend.keys():
if has_backend[backend] is False:
continue
xp = cp if backend in ['cupy'] else np
xp = get_xpy(backend=backend)
for padding in (1.5, 2.0):
for typecode in types:
for dim in dims:
Expand Down Expand Up @@ -105,7 +116,7 @@ def test_libfft():
for backend in has_backend.keys():
if has_backend[backend] is False:
continue
xp = cp if backend in ['cupy'] else np
xp = get_xpy(backend=backend)

if backend == 'fftw':
dctn = functools.partial(fftw.dctn, type=3)
Expand All @@ -130,6 +141,11 @@ def test_libfft():
from scipy.fftpack import fftn, ifftn
transforms = {(1,): (fftn, ifftn),
(0, 1): (fftn, ifftn)}
elif backend == 'cupyx-scipy':
from scipy.fftpack import fftn, ifftn
import cupyx.scipy.fft as fftlib
transforms = {(1,): (fftlib.fftn, fftlib.ifftn),
(0, 1): (fftlib.fftn, fftlib.ifftn)}

for axis in ((1,), (0, 1)):
fft = FFT(shape, axis, backend=backend, transforms=transforms)
Expand Down

0 comments on commit 202f30c

Please sign in to comment.