Skip to content

Commit

Permalink
MPI version of IMEX sweeper (#374)
Browse files Browse the repository at this point in the history
* MPI version of IMEX sweeper

* Tightened tolerances for test
  • Loading branch information
brownbaerchen authored Nov 7, 2023
1 parent 8987350 commit 36e06fe
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 8 deletions.
7 changes: 7 additions & 0 deletions pySDC/implementations/problem_classes/TestEquation_0D.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ class testequation0d(ptype):

def __init__(self, lambdas=None, u0=0.0):
"""Initialization routine"""
if lambdas is None:
re = np.linspace(-30, 19, 50)
im = np.linspace(-50, 49, 50)
lambdas = np.array([[complex(re[i], im[j]) for i in range(len(re))] for j in range(len(im))]).reshape(
(len(re) * len(im))
)

assert not any(isinstance(i, list) for i in lambdas), 'ERROR: expect flat list here, got %s' % lambdas
nvars = len(lambdas)
assert nvars > 0, 'ERROR: expect at least one lambda parameter here'
Expand Down
30 changes: 22 additions & 8 deletions pySDC/implementations/sweeper_classes/generic_implicit_MPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def compute_end_point(self):
# check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
if self.coll.right_is_node and not self.params.do_coll_update:
# a copy is sufficient
L.uend[:] = self.params.comm.bcast(L.u[self.rank + 1], root=self.params.comm.Get_size() - 1)
root = self.comm.Get_size() - 1
if self.comm.rank == root:
L.uend[:] = L.u[-1]
self.comm.Bcast(L.uend, root=root)
else:
raise NotImplementedError('require last node to be identical with right interval boundary')

Expand Down Expand Up @@ -88,7 +91,7 @@ def compute_residual(self, stage=None):
# compute the residual for each node

# build QF(u)
res = self.integrate()
res = self.integrate(last_only=L.params.residual_type[:4] == 'last')
res += L.u[0] - L.u[self.rank + 1]
# add tau if associated
if L.tau[self.rank] is not None:
Expand All @@ -97,7 +100,16 @@ def compute_residual(self, stage=None):
res_norm = abs(res)

# find maximal residual over the nodes
L.status.residual = self.params.comm.allreduce(res_norm, op=MPI.MAX)
if L.params.residual_type == 'full_abs':
L.status.residual = self.comm.allreduce(res_norm, op=MPI.MAX)
elif L.params.residual_type == 'last_abs':
L.status.residual = self.comm.bcast(res_norm, root=self.comm.size - 1)
elif L.params.residual_type == 'full_rel':
L.status.residual = self.comm.allreduce(res_norm / abs(L.u[0]), op=MPI.MAX)
elif L.params.residual_type == 'last_rel':
L.status.residual = self.comm.bcast(res_norm / abs(L.u[0]), root=self.comm.size - 1)
else:
raise NotImplementedError(f'residual type \"{L.params.residual_type}\" not implemented!')

# indicate that the residual has seen the new values
L.status.updated = False
Expand Down Expand Up @@ -139,21 +151,23 @@ class generic_implicit_MPI(SweeperMPI, generic_implicit):
rank (int): MPI rank
"""

def integrate(self):
def integrate(self, last_only=False):
"""
Integrates the right-hand side
Args:
last_only (bool): Integrate only the last node for the residual or all of them
Returns:
list of dtype_u: containing the integral as values
"""

L = self.level
P = L.prob

me = P.dtype_u(P.init, val=0.0)
for m in range(self.coll.num_nodes):
for m in [self.coll.num_nodes - 1] if last_only else range(self.coll.num_nodes):
recvBuf = me if m == self.rank else None
self.params.comm.Reduce(
self.comm.Reduce(
L.dt * self.coll.Qmat[m + 1, self.rank + 1] * L.f[self.rank + 1], recvBuf, root=m, op=MPI.SUM
)

Expand Down Expand Up @@ -225,7 +239,7 @@ def compute_end_point(self):
super().compute_end_point()
else:
L.uend = P.dtype_u(L.u[0])
self.params.comm.Allreduce(L.dt * self.coll.weights[self.rank] * L.f[self.rank + 1], L.uend, op=MPI.SUM)
self.comm.Allreduce(L.dt * self.coll.weights[self.rank] * L.f[self.rank + 1], L.uend, op=MPI.SUM)
L.uend += L.u[0]

# add up tau correction of the full interval (last entry)
Expand Down
112 changes: 112 additions & 0 deletions pySDC/implementations/sweeper_classes/imex_1st_order_MPI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from mpi4py import MPI
from pySDC.implementations.sweeper_classes.generic_implicit_MPI import SweeperMPI
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order


class imex_1st_order_MPI(SweeperMPI, imex_1st_order):
def __init__(self, params):
super().__init__(params)
assert (
self.params.QE == 'PIC'
), f"Only Picard is implemented for explicit precondioner so far in {type(self).__name__}! You chose \"{self.params.QE}\""

def integrate(self, last_only=False):
"""
Integrates the right-hand side (here impl + expl)
Args:
last_only (bool): Integrate only the last node for the residual or all of them
Returns:
list of dtype_u: containing the integral as values
"""

L = self.level
P = L.prob

me = P.dtype_u(P.init, val=0.0)
for m in [self.coll.num_nodes - 1] if last_only else range(self.coll.num_nodes):
recvBuf = me if m == self.rank else None
self.comm.Reduce(
L.dt * self.coll.Qmat[m + 1, self.rank + 1] * (L.f[self.rank + 1].impl + L.f[self.rank + 1].expl),
recvBuf,
root=m,
op=MPI.SUM,
)

return me

def update_nodes(self):
"""
Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes
Returns:
None
"""

L = self.level
P = L.prob

# only if the level has been touched before
assert L.status.unlocked

# get number of collocation nodes for easier access

# gather all terms which are known already (e.g. from the previous iteration)
# this corresponds to u0 + QF(u^k) - QdF(u^k) + tau

# get QF(u^k)
rhs = self.integrate()

# subtract QdF(u^k)
rhs -= L.dt * (self.QI[self.rank + 1, self.rank + 1] * L.f[self.rank + 1].impl)

# add initial conditions
rhs += L.u[0]
# add tau if associated
if L.tau[self.rank] is not None:
rhs += L.tau[self.rank]

# implicit solve with prefactor stemming from the diagonal of Qd
L.u[self.rank + 1] = P.solve_system(
rhs,
L.dt * self.QI[self.rank + 1, self.rank + 1],
L.u[self.rank + 1],
L.time + L.dt * self.coll.nodes[self.rank],
)
# update function values
L.f[self.rank + 1] = P.eval_f(L.u[self.rank + 1], L.time + L.dt * self.coll.nodes[self.rank])

# indicate presence of new values at this level
L.status.updated = True

return None

def compute_end_point(self):
"""
Compute u at the right point of the interval
Returns:
None
"""

L = self.level
P = L.prob
L.uend = P.dtype_u(P.init, val=0.0)

# check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
if self.coll.right_is_node and not self.params.do_coll_update:
super().compute_end_point()
else:
L.uend = P.dtype_u(L.u[0])
self.comm.Allreduce(
L.dt * self.coll.weights[self.rank] * (L.f[self.rank + 1].impl + L.f[self.rank + 1].expl),
L.uend,
op=MPI.SUM,
)
L.uend += L.u[0]

# add up tau correction of the full interval (last entry)
if L.tau[-1] is not None:
L.uend += L.tau[-1]
return None
108 changes: 108 additions & 0 deletions pySDC/tests/test_sweepers/test_MPI_sweeper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import pytest


def run(use_MPI, num_nodes, quad_type, residual_type, imex):
"""
Run a single sweep for a problem and compute the solution at the end point with a sweeper as specified.
Args:
use_MPI (bool): Use the MPI version of the sweeper or not
num_nodes (int): The number of nodes to use
quad_type (str): Type of nodes
residual_type (str): Type of residual computation
imex (bool): Use IMEX sweeper or not
Returns:
pySDC.Level.level: The level containing relevant data
"""
import numpy as np
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI

if not imex:
if use_MPI:
from pySDC.implementations.sweeper_classes.generic_implicit_MPI import generic_implicit_MPI as sweeper_class
else:
from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper_class

from pySDC.implementations.problem_classes.TestEquation_0D import testequation0d as problem_class
else:
if use_MPI:
from pySDC.implementations.sweeper_classes.imex_1st_order_MPI import imex_1st_order_MPI as sweeper_class
else:
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order as sweeper_class

from pySDC.implementations.problem_classes.HeatEquation_ND_FD import heatNd_forced as problem_class

dt = 1e-1
sweeper_params = {'num_nodes': num_nodes, 'quad_type': quad_type, 'QI': 'IEpar', 'QE': 'PIC'}
description = {}
description['problem_class'] = problem_class
description['sweeper_class'] = sweeper_class
description['sweeper_params'] = sweeper_params
description['level_params'] = {'dt': dt, 'residual_type': residual_type}
description['step_params'] = {'maxiter': 1}

controller = controller_nonMPI(1, {'logger_level': 30}, description)

if imex:
u0 = controller.MS[0].levels[0].prob.u_exact(0)
else:
u0 = np.ones_like(controller.MS[0].levels[0].prob.u_exact(0))
controller.run(u0, 0, dt)
controller.MS[0].levels[0].sweep.compute_end_point()
return controller.MS[0].levels[0]


@pytest.mark.mpi4py
@pytest.mark.parametrize("num_nodes", [2])
@pytest.mark.parametrize("quad_type", ['GAUSS', 'RADAU-RIGHT'])
@pytest.mark.parametrize("residual_type", ['last_abs', 'full_rel'])
@pytest.mark.parametrize("imex", [True, False])
def test_sweeper(num_nodes, quad_type, residual_type, imex, launch=True):
"""
Make a test if the result matches between the MPI and non-MPI versions of a sweeper.
Tests solution at the right end point and the residual.
Args:
num_nodes (int): The number of nodes to use
quad_type (str): Type of nodes
residual_type (str): Type of residual computation
imex (bool): Use IMEX sweeper or not
launch (bool): If yes, it will launch `mpirun` with the required number of processes
"""
if launch:
import os
import subprocess

# Set python path once
my_env = os.environ.copy()
my_env['PYTHONPATH'] = '../../..:.'
my_env['COVERAGE_PROCESS_START'] = 'pyproject.toml'

cmd = f"mpirun -np {num_nodes} python {__file__} --test_sweeper {num_nodes} {quad_type} {residual_type} {imex}".split()

p = subprocess.Popen(cmd, env=my_env, cwd=".")

p.wait()
assert p.returncode == 0, 'ERROR: did not get return code 0, got %s with %2i processes' % (
p.returncode,
num_nodes,
)
else:
import numpy as np

imex = False if imex == 'False' else True
MPI = run(use_MPI=True, num_nodes=int(num_nodes), quad_type=quad_type, residual_type=residual_type, imex=imex)
nonMPI = run(
use_MPI=False, num_nodes=int(num_nodes), quad_type=quad_type, residual_type=residual_type, imex=imex
)

assert np.allclose(MPI.uend, nonMPI.uend, atol=1e-14), 'Got different solutions at end point!'
assert np.allclose(MPI.status.residual, nonMPI.status.residual, atol=1e-14), 'Got different residuals!'


if __name__ == '__main__':
import sys

if '--test_sweeper' in sys.argv:
test_sweeper(sys.argv[-4], sys.argv[-3], sys.argv[-2], sys.argv[-1], launch=False)

0 comments on commit 36e06fe

Please sign in to comment.