diff --git a/pySDC/helpers/spectral_helper.py b/pySDC/helpers/spectral_helper.py index 2cf912fbf4..5c45b75871 100644 --- a/pySDC/helpers/spectral_helper.py +++ b/pySDC/helpers/spectral_helper.py @@ -1178,7 +1178,7 @@ def put_BCs_in_matrix(self, A): def put_BCs_in_rhs_hat(self, rhs_hat): """ Put the BCs in the right hand side in spectral space for solving. - This function needs no transforms. + This function needs no transforms and caches a mask for faster subsequent use. Args: rhs_hat: Right hand side in spectral space @@ -1186,22 +1186,28 @@ def put_BCs_in_rhs_hat(self, rhs_hat): Returns: rhs in spectral space with BCs """ - ndim = self.ndim - - for axis in range(ndim): - for bc in self.full_BCs: - slices = ( - [slice(0, self.init[0][i + 1]) for i in range(axis)] - + [bc['line']] - + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] - ) - if axis == bc['axis']: - _slice = [self.index(bc['equation'])] + slices - N = self.axes[axis].N - if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]: - _slice[axis + 1] -= self.local_slice[axis].start - rhs_hat[(*_slice,)] = 0 - + if not hasattr(self, '_rhs_hat_zero_mask'): + """ + Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them + by the boundary conditions. The mask is then cached. + """ + self._rhs_hat_zero_mask = self.xp.zeros(shape=rhs_hat.shape, dtype=bool) + + for axis in range(self.ndim): + for bc in self.full_BCs: + slices = ( + [slice(0, self.init[0][i + 1]) for i in range(axis)] + + [bc['line']] + + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] + ) + if axis == bc['axis']: + _slice = [self.index(bc['equation'])] + slices + N = self.axes[axis].N + if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]: + _slice[axis + 1] -= self.local_slice[axis].start + self._rhs_hat_zero_mask[(*_slice,)] = True + + rhs_hat[self._rhs_hat_zero_mask] = 0 return rhs_hat + self.rhs_BCs_hat def put_BCs_in_rhs(self, rhs): @@ -1347,18 +1353,22 @@ def get_fft(self, axes=None, direction='object', padding=None, shape=None): elif direction == 'object': self.fft_cache[key] = None else: - from mpi4py_fft import PFFT - - _fft = PFFT( - comm=self.comm, - shape=shape, - axes=sorted(axes), - dtype='D', - collapse=False, - backend=self.fft_backend, - comm_backend=self.fft_comm_backend, - padding=padding, - ) + if direction == 'object': + from mpi4py_fft import PFFT + + _fft = PFFT( + comm=self.comm, + shape=shape, + axes=sorted(axes), + dtype='D', + collapse=False, + backend=self.fft_backend, + comm_backend=self.fft_comm_backend, + padding=padding, + ) + else: + _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape) + if direction == 'forward': self.fft_cache[key] = _fft.forward elif direction == 'backward':