Skip to content

Commit

Permalink
Accelerated transposes by reusing transfer objects
Browse files Browse the repository at this point in the history
  • Loading branch information
brownbaerchen committed Sep 12, 2024
1 parent cc4297b commit d94ad53
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 25 deletions.
88 changes: 65 additions & 23 deletions pySDC/helpers/spectral_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,9 @@ def transform_single_component(self, u, axes=None, padding=None):

fft = self.get_fft(_axes, 'object', padding=padding, shape=shape)

_in = self.get_aligned(result, axis_in=alignment, axis_out=self.ndim + _axes[-1], forward=False, fft=fft)
_in = self.get_aligned(
result, axis_in=alignment, axis_out=self.ndim + _axes[-1], forward=False, fft=fft, shape=shape
)

alignment = self.ndim + _axes[-1]

Expand All @@ -1266,10 +1268,12 @@ def transform_single_component(self, u, axes=None, padding=None):

axes_next_base = axes_collapsed[(trf + 1) % len(axes_collapsed)]
alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[-1]
result = self.get_aligned(_out, axis_in=self.ndim + _axes[0], axis_out=alignment, fft=fft, forward=True)
result = self.get_aligned(
_out, axis_in=self.ndim + _axes[0], axis_out=alignment, fft=fft, forward=True, shape=shape
)

fft = self.get_fft(axes=axes, padding=padding)
return self.get_aligned(result, axis_in=alignment, axis_out=self.ndim - 1, fft=fft, forward=True)
return self.get_aligned(result, axis_in=alignment, axis_out=self.ndim - 1, fft=fft, forward=True, shape=shape)

def transform(self, u, axes=None, padding=None):
"""
Expand Down Expand Up @@ -1384,7 +1388,9 @@ def itransform_single_component(self, u, axes=None, padding=None):

fft = self.get_fft(_axes, 'object', padding=padding, shape=shape)

_in = self.get_aligned(result, axis_in=alignment, axis_out=self.ndim + _axes[0], forward=True, fft=fft)
_in = self.get_aligned(
result, axis_in=alignment, axis_out=self.ndim + _axes[0], forward=True, fft=fft, shape=shape
)
if self.comm is not None:
_in /= np.prod([self.axes[i].N for i in _axes])

Expand All @@ -1393,43 +1399,79 @@ def itransform_single_component(self, u, axes=None, padding=None):
_out = trfs[base](_in, axes=_axes, padding=padding, shape=shape)

for _ax in _axes:
shape[_ax] = _out.shape[_ax]
if fft:
shape[_ax] = fft._input_shape[_ax]
else:
shape[_ax] = _out.shape[_ax]

axes_next_base = axes_collapsed[(trf + 1) % len(axes_collapsed)]
alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[0]
result = self.get_aligned(_out, axis_in=self.ndim + _axes[-1], axis_out=alignment, fft=fft, forward=False)
result = self.get_aligned(
_out, axis_in=self.ndim + _axes[-1], axis_out=alignment, fft=fft, forward=False, shape=shape
)

fft = self.get_fft(axes=axes, padding=padding)
return self.get_aligned(result, axis_in=alignment, axis_out=self.ndim - 1, fft=fft)
return self.get_aligned(result, axis_in=alignment, axis_out=self.ndim - 1, fft=fft, shape=shape)

def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, fill=True, **kwargs):
def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs):
"""
Realign the data along the axis when using distributed FFTs
Args:
u: The solution
axis (int): New alignment
axis_in (int): Current alignment
axis_out (int): New alignment
fft (mpi4py_fft.PFFT), optional: parallel FFT object
forward (bool): Whether the input is in spectral space or not
Returns:
solution aligned on `axis`
solution aligned on `axis_in`
"""
if self.comm is None:
if fill:
return u
elif forward:
return self.u_init_forward
else:
return self.u_init

from mpi4py_fft import newDistArray
if self.comm is None or axis_in == axis_out:
return u.copy()

fft = self.get_fft(**kwargs) if fft is None else fft

_in = newDistArray(fft, forward).redistribute(axis_in)
if fill:
_in[...] = u
global_fft = self.get_fft(**kwargs)
axisA = [me.axisA for me in global_fft.transfer]
axisB = [me.axisB for me in global_fft.transfer]

current_axis = axis_in

if axis_in in axisA and axis_out in axisB:
while current_axis != axis_out:
transfer = global_fft.transfer[axisA.index(current_axis)]

arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype)
arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype)
arrayA[:] = u[:]

transfer.forward(arrayA=arrayA, arrayB=arrayB)

current_axis = transfer.axisB
u = arrayB
return u
elif axis_in in axisB and axis_out in axisA:
while current_axis != axis_out:
transfer = global_fft.transfer[axisB.index(current_axis)]

arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype)
arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype)
arrayB[:] = u[:]

transfer.backward(arrayA=arrayA, arrayB=arrayB)

current_axis = transfer.axisA
u = arrayA
return u
else: # go the potentially slower route of not reusing transfer classes
from mpi4py_fft import newDistArray

_in = newDistArray(fft, forward).redistribute(axis_in)
if fill:
_in[...] = u

return _in.redistribute(axis_out)
return _in.redistribute(axis_out)

def itransform(self, u, axes=None, padding=None):
axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
Expand Down
4 changes: 2 additions & 2 deletions pySDC/tests/test_problems/test_RayleighBenard.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ def test_Nyquist_mode_elimination():
if __name__ == '__main__':
# test_eval_f(2**0, 2**2, 'z', True)
# test_Poisson_problem(1, 'T')
# test_Poisson_problem_v()
test_Poisson_problem_v()
# test_Nusselt_numbers(1)
# test_buoyancy_computation()
# test_viscous_dissipation()
# test_CFL()
test_Nyquist_mode_elimination()
# test_Nyquist_mode_elimination()

0 comments on commit d94ad53

Please sign in to comment.