Skip to content

Commit

Permalink
Increased performance of tau methods (#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
brownbaerchen authored Oct 8, 2024
1 parent 04853bb commit 3d59549
Showing 1 changed file with 39 additions and 29 deletions.
68 changes: 39 additions & 29 deletions pySDC/helpers/spectral_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,30 +1178,36 @@ def put_BCs_in_matrix(self, A):
def put_BCs_in_rhs_hat(self, rhs_hat):
"""
Put the BCs in the right hand side in spectral space for solving.
This function needs no transforms.
This function needs no transforms and caches a mask for faster subsequent use.
Args:
rhs_hat: Right hand side in spectral space
Returns:
rhs in spectral space with BCs
"""
ndim = self.ndim

for axis in range(ndim):
for bc in self.full_BCs:
slices = (
[slice(0, self.init[0][i + 1]) for i in range(axis)]
+ [bc['line']]
+ [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
)
if axis == bc['axis']:
_slice = [self.index(bc['equation'])] + slices
N = self.axes[axis].N
if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
_slice[axis + 1] -= self.local_slice[axis].start
rhs_hat[(*_slice,)] = 0

if not hasattr(self, '_rhs_hat_zero_mask'):
"""
Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
by the boundary conditions. The mask is then cached.
"""
self._rhs_hat_zero_mask = self.xp.zeros(shape=rhs_hat.shape, dtype=bool)

for axis in range(self.ndim):
for bc in self.full_BCs:
slices = (
[slice(0, self.init[0][i + 1]) for i in range(axis)]
+ [bc['line']]
+ [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
)
if axis == bc['axis']:
_slice = [self.index(bc['equation'])] + slices
N = self.axes[axis].N
if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
_slice[axis + 1] -= self.local_slice[axis].start
self._rhs_hat_zero_mask[(*_slice,)] = True

rhs_hat[self._rhs_hat_zero_mask] = 0
return rhs_hat + self.rhs_BCs_hat

def put_BCs_in_rhs(self, rhs):
Expand Down Expand Up @@ -1347,18 +1353,22 @@ def get_fft(self, axes=None, direction='object', padding=None, shape=None):
elif direction == 'object':
self.fft_cache[key] = None
else:
from mpi4py_fft import PFFT

_fft = PFFT(
comm=self.comm,
shape=shape,
axes=sorted(axes),
dtype='D',
collapse=False,
backend=self.fft_backend,
comm_backend=self.fft_comm_backend,
padding=padding,
)
if direction == 'object':
from mpi4py_fft import PFFT

_fft = PFFT(
comm=self.comm,
shape=shape,
axes=sorted(axes),
dtype='D',
collapse=False,
backend=self.fft_backend,
comm_backend=self.fft_comm_backend,
padding=padding,
)
else:
_fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape)

if direction == 'forward':
self.fft_cache[key] = _fft.forward
elif direction == 'backward':
Expand Down

0 comments on commit 3d59549

Please sign in to comment.