Skip to content

Commit

Permalink
Inexactness and some tests (#376)
Browse files Browse the repository at this point in the history
* Inexactness and some other stuff

* Fixes

* Fixes

* Hopefully fixed a flaky test

* IMEX version of the polynomial test problem
  • Loading branch information
brownbaerchen authored Nov 10, 2023
1 parent 39a208c commit 39903ea
Show file tree
Hide file tree
Showing 10 changed files with 281 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,8 @@ def setup(self, controller, params, description, **kwargs):
default_params = {
'Taylor_order': 2 * num_nodes,
'n': num_nodes,
'recompute_coefficients': False,
**params,
}

return {**super().setup(controller, params, description, **kwargs), **default_params}
Expand All @@ -485,18 +487,21 @@ def post_iteration_processing(self, controller, S, **kwargs):
t_eval = S.time + nodes_[-1]

dts = np.append(nodes_[0], nodes_[1:] - nodes_[:-1])
self.params.Taylor_order = 2 * len(nodes)
self.params.Taylor_order = len(nodes)
self.params.n = len(nodes)

# compute the extrapolation coefficients
# TODO: Maybe this can be reused
self.get_extrapolation_coefficients(nodes, dts, t_eval)
if None in self.coeff.u or self.params.recompute_coefficients:
self.get_extrapolation_coefficients(nodes, dts, t_eval)

# compute the extrapolated solution
if lvl.f[0] is None:
lvl.f[0] = lvl.prob.eval_f(lvl.u[0], lvl.time)

if type(lvl.f[0]) == imex_mesh:
f = [me.impl + me.expl for me in lvl.f]
f = [lvl.f[i].impl + lvl.f[i].expl if self.coeff.f[i] and lvl.f[i] else 0.0 for i in range(len(lvl.f) - 1)]
elif type(lvl.f[0]) == mesh:
f = lvl.f
f = [lvl.f[i] if self.coeff.f[i] else 0.0 for i in range(len(lvl.f) - 1)]
else:
raise DataError(
f"Unable to store f from datatype {type(lvl.f[0])}, extrapolation based error estimate only\
Expand All @@ -506,7 +511,7 @@ def post_iteration_processing(self, controller, S, **kwargs):
# compute the error with the weighted sum
if self.comm:
idx = (self.comm.rank + 1) % self.comm.size
sendbuf = self.coeff.u[idx] * lvl.u[idx] + self.coeff.f[idx] * lvl.f[idx]
sendbuf = self.coeff.u[idx] * lvl.u[idx] + self.coeff.f[idx] * f[idx]
u_ex = lvl.prob.dtype_u(lvl.prob.init, val=0.0) if self.comm.rank == self.comm.size - 1 else None
self.comm.Reduce(sendbuf, u_ex, op=self.sum, root=self.comm.size - 1)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,39 @@ def setup(self, controller, params, description, **kwargs):
"ratio": 1e-2,
"min_tol": 0,
"max_tol": 1e99,
"maxiter": None,
"use_e_tol": 'e_tol' in description['level_params'].keys(),
"initial_tol": 1e-3,
**super().setup(controller, params, description, **kwargs),
}
return {**defaults, **super().setup(controller, params, description, **kwargs)}
if defaults['maxiter']:
self.set_maxiter(description, defaults['maxiter'])
return defaults

def dependencies(self, controller, description, **kwargs):
"""
Load the embedded error estimator if needed.
Args:
controller (pySDC.Controller): The controller
description (dict): The description object used to instantiate the controller
Returns:
None
"""
super().dependencies(controller, description)

if self.params.use_e_tol:
from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import (
EstimateEmbeddedError,
)

controller.add_convergence_controller(
EstimateEmbeddedError,
description=description,
)

return None

def post_iteration_processing(self, controller, step, **kwargs):
"""
Expand All @@ -39,8 +70,18 @@ def post_iteration_processing(self, controller, step, **kwargs):
None
"""
for lvl in step.levels:
lvl.prob.newton_tol = max(
[min([lvl.status.residual * self.params.ratio, self.params.max_tol]), self.params.min_tol]
SDC_accuracy = (
lvl.status.get('error_embedded_estimate', lvl.status.residual)
if self.params.use_e_tol
else lvl.status.residual
)
SDC_accuracy = self.params.initial_tol if SDC_accuracy is None else SDC_accuracy
tol = max([min([SDC_accuracy * self.params.ratio, self.params.max_tol]), self.params.min_tol])
self.set_tolerance(lvl, tol)
self.log(f'Changed tolerance to {tol:.2e}', step)

def set_tolerance(self, lvl, tol):
lvl.prob.newton_tol = tol

self.log(f'Changed Newton tolerance to {lvl.prob.newton_tol:.2e}', step)
def set_maxiter(self, description, maxiter):
description['problem_params']['newton_maxiter'] = maxiter
77 changes: 65 additions & 12 deletions pySDC/implementations/problem_classes/AllenCahn_2D_FD.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from scipy.sparse.linalg import cg

from pySDC.core.Errors import ParameterError, ProblemError
from pySDC.core.Problem import ptype
from pySDC.core.Problem import ptype, WorkCounter
from pySDC.helpers import problem_helper
from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh, comp2_mesh

Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(
newton_tol=1e-12,
lin_tol=1e-8,
lin_maxiter=100,
inexact_linear_ratio=None,
radius=0.25,
order=2,
):
Expand All @@ -96,14 +97,19 @@ def __init__(
'nvars',
'nu',
'eps',
'radius',
'order',
localVars=locals(),
readOnly=True,
)
self._makeAttributeAndRegister(
'newton_maxiter',
'newton_tol',
'lin_tol',
'lin_maxiter',
'radius',
'order',
'inexact_linear_ratio',
localVars=locals(),
readOnly=True,
readOnly=False,
)

# compute dx and get discretization matrix A
Expand All @@ -124,6 +130,10 @@ def __init__(
self.newton_ncalls = 0
self.lin_ncalls = 0

self.work_counters['newton'] = WorkCounter()
self.work_counters['rhs'] = WorkCounter()
self.work_counters['linear'] = WorkCounter()

@staticmethod
def __get_A(N, dx):
"""
Expand Down Expand Up @@ -198,6 +208,10 @@ def solve_system(self, rhs, factor, u0, t):
# if g is close to 0, then we are done
res = np.linalg.norm(g, np.inf)

# do inexactness in the linear solver
if self.inexact_linear_ratio:
self.lin_tol = res * self.inexact_linear_ratio

if res < self.newton_tol:
break

Expand All @@ -206,11 +220,15 @@ def solve_system(self, rhs, factor, u0, t):

# newton update: u1 = u0 - g/dg
# u -= spsolve(dg, g)
u -= cg(dg, g, x0=z, tol=self.lin_tol, atol=0)[0]
u -= cg(
dg, g, x0=z, tol=self.lin_tol, maxiter=self.lin_maxiter, atol=0, callback=self.work_counters['linear']
)[0]
# increase iteration count
n += 1
# print(n, res)

self.work_counters['newton']()

# if n == self.newton_maxiter:
# raise ProblemError('Newton did not converge after %i iterations, error is %s' % (n, res))

Expand Down Expand Up @@ -242,9 +260,10 @@ def eval_f(self, u, t):
v = u.flatten()
f[:] = (self.A.dot(v) + 1.0 / self.eps**2 * v * (1.0 - v**self.nu)).reshape(self.nvars)

self.work_counters['rhs']()
return f

def u_exact(self, t):
def u_exact(self, t, u_init=None, t_init=None):
r"""
Routine to compute the exact solution at time :math:`t`.
Expand All @@ -258,13 +277,19 @@ def u_exact(self, t):
me : dtype_u
The exact solution.
"""

assert t == 0, 'ERROR: u_exact only valid for t=0'
me = self.dtype_u(self.init, val=0.0)
for i in range(self.nvars[0]):
for j in range(self.nvars[1]):
r2 = self.xvalues[i] ** 2 + self.xvalues[j] ** 2
me[i, j] = np.tanh((self.radius - np.sqrt(r2)) / (np.sqrt(2) * self.eps))
if t > 0:

def eval_rhs(t, u):
return self.eval_f(u.reshape(self.init[0]), t).flatten()

me[:] = self.generate_scipy_reference_solution(eval_rhs, t, u_init, t_init)

else:
for i in range(self.nvars[0]):
for j in range(self.nvars[1]):
r2 = self.xvalues[i] ** 2 + self.xvalues[j] ** 2
me[i, j] = np.tanh((self.radius - np.sqrt(r2)) / (np.sqrt(2) * self.eps))

return me

Expand Down Expand Up @@ -310,6 +335,7 @@ def eval_f(self, u, t):
f.impl[:] = self.A.dot(v).reshape(self.nvars)
f.expl[:] = (1.0 / self.eps**2 * v * (1.0 - v**self.nu)).reshape(self.nvars)

self.work_counters['rhs']()
return f

def solve_system(self, rhs, factor, u0, t):
Expand Down Expand Up @@ -338,6 +364,7 @@ class context:

def callback(xk):
context.num_iter += 1
self.work_counters['linear']()
return context.num_iter

me = self.dtype_u(self.init)
Expand All @@ -359,6 +386,32 @@ def callback(xk):

return me

def u_exact(self, t, u_init=None, t_init=None):
"""
Routine to compute the exact solution at time t.
Parameters
----------
t : float
Time of the exact solution.
Returns
-------
me : dtype_u
The exact solution.
"""
me = self.dtype_u(self.init, val=0.0)
if t > 0:

def eval_rhs(t, u):
f = self.eval_f(u.reshape(self.init[0]), t)
return (f.impl + f.expl).flatten()

me[:] = self.generate_scipy_reference_solution(eval_rhs, t, u_init, t_init)
else:
me[:] = super().u_exact(t, u_init, t_init)
return me


# noinspection PyUnusedLocal
class allencahn_semiimplicit_v2(allencahn_fullyimplicit):
Expand Down
12 changes: 9 additions & 3 deletions pySDC/implementations/problem_classes/Quench.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class Quench(ptype):
Maximum number of linear iterations inside the Newton solver.
direct_solver : bool, optional
Indicates if a direct solver should be used.
inexact_linear_ratio : float, optional
Ratio of tolerance of linear solver to the Newton residual, overrides `lintol`
min_lintol : float, optional
Minimal tolerance for the linear solver
reference_sol_type : str, optional
Indicates which method should be used to compute a reference solution.
Choose between ``'scipy'``, ``'SDC'``, or ``'DIRK'``.
Expand Down Expand Up @@ -106,6 +110,7 @@ def __init__(
liniter=99,
direct_solver=True,
inexact_linear_ratio=None,
min_lintol=1e-12,
reference_sol_type='scipy',
):
"""
Expand Down Expand Up @@ -142,6 +147,7 @@ def __init__(
'lintol',
'liniter',
'inexact_linear_ratio',
'min_lintol',
localVars=locals(),
readOnly=False,
)
Expand Down Expand Up @@ -307,11 +313,11 @@ def solve_system(self, rhs, factor, u0, t):
u = self.dtype_u(u0)
res = np.inf
delta = self.dtype_u(self.init, val=0.0)
z = self.dtype_u(self.init, val=0.0)

# construct a preconditioner for the space solver
if not self.direct_solver:
M = inv((self.Id - factor * self.A).toarray())
zero = self.dtype_u(self.init, val=0.0)

for n in range(0, self.newton_maxiter):
# assemble G such that G(u) = 0 at the solution of the step
Expand All @@ -325,7 +331,7 @@ def solve_system(self, rhs, factor, u0, t):
break

if self.inexact_linear_ratio:
self.lintol = max([res * self.inexact_linear_ratio, 1e-12])
self.lintol = max([res * self.inexact_linear_ratio, self.min_lintol])

# assemble Jacobian J of G
J = self.Id - factor * (self.A + self.get_non_linear_Jacobian(u))
Expand All @@ -337,7 +343,7 @@ def solve_system(self, rhs, factor, u0, t):
delta, info = gmres(
J,
G,
x0=z,
x0=zero,
M=M,
tol=self.lintol,
maxiter=self.liniter,
Expand Down
34 changes: 33 additions & 1 deletion pySDC/implementations/problem_classes/polynomial_test_problem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from pySDC.core.Problem import ptype
from pySDC.implementations.datatype_classes.mesh import mesh
from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh


class polynomial_testequation(ptype):
Expand Down Expand Up @@ -88,3 +88,35 @@ def u_exact(self, t, **kwargs):
me = self.dtype_u(self.init)
me[:] = self.poly(t)
return me


class polynomial_testequation_IMEX(polynomial_testequation):
"""
IMEX version of the polynomial test problem that assigns half the derivative to the implicit part and the other half to the explicit part.
Keep in mind that you still cannot Really perform any solves.
"""

dtype_f = imex_mesh

def eval_f(self, u, t):
"""
Derivative of the polynomial.
Parameters
----------
u : dtype_u
Current values of the numerical solution.
t : float
Current time of the numerical solution is computed.
Returns
-------
f : dtype_f
The right-hand side of the problem.
"""

f = self.dtype_f(self.init)
derivative = self.poly.deriv(m=1)(t)
f.impl[:] = derivative / 2
f.expl[:] = derivative / 2
return f
4 changes: 2 additions & 2 deletions pySDC/projects/Resilience/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,8 +1457,8 @@ def get_reference_value(self, problem, key, op, num_procs=1):
"""
if problem.__name__ == "run_vdp":
if key == 'work_newton' and op == sum:
return 2677
return 3443
elif key == 'e_global_post_run' and op == max:
return 4.375184403937471e-06
return 4.929282266752377e-06

raise NotImplementedError('The reference value you are looking for is not implemented for this strategy!')
Loading

0 comments on commit 39903ea

Please sign in to comment.