diff --git a/movement/laplacian_smoothing.py b/movement/laplacian_smoothing.py index 1ee752e..1ae619a 100644 --- a/movement/laplacian_smoothing.py +++ b/movement/laplacian_smoothing.py @@ -1,4 +1,5 @@ import firedrake +import firedrake.exceptions as fexc import numpy as np import ufl from firedrake.petsc import PETSc @@ -86,7 +87,10 @@ def move(self, time, update_boundary_velocity=None, boundary_conditions=None): # Solve on computational mesh self.mesh.coordinates.assign(self.xi) - self._solver.solve() + try: + self._solver.solve() + except fexc.ConvergenceError as conv_err: + self._convergence_error(exception=conv_err) # Update mesh coordinates self.displacement[:] = self.v.dat.data_with_halos * self.dt diff --git a/movement/monge_ampere.py b/movement/monge_ampere.py index 715cee2..520f9de 100644 --- a/movement/monge_ampere.py +++ b/movement/monge_ampere.py @@ -2,6 +2,7 @@ from warnings import warn import firedrake +import firedrake.exceptions as fexc import numpy as np import ufl from firedrake.petsc import PETSc @@ -95,7 +96,7 @@ def __init__(self, mesh, monitor_function, **kwargs): self.rtol = kwargs.pop("rtol", 1.0e-08) self.dtol = kwargs.pop("dtol", 2.0) self.fix_boundary_nodes = kwargs.pop("fix_boundary_nodes", False) - super().__init__(mesh, monitor_function=monitor_function) + super().__init__(mesh, monitor_function=monitor_function, **kwargs) # Create function spaces self.P0 = firedrake.FunctionSpace(self.mesh, "DG", 0) @@ -234,7 +235,8 @@ def l2_projector(self): bbc = None # Periodic case else: warn( - "Have you checked that all straight line segments are uniquely tagged?" + "Have you checked that all straight line segments are uniquely" + " tagged?" ) corners = [(i, j) for i in edges for j in edges.difference([i])] bbc = firedrake.DirichletBC(self.P1_vec, 0, corners) @@ -359,12 +361,13 @@ def equidistributor(self): @PETSc.Log.EventDecorator() def move(self): - """ + r""" Run the relaxation method to convergence and update the mesh. :return: the iteration count :rtype: :class:`int` """ + # Take iterations of the relaxed system until reaching convergence for i in range(self.maxiter): self.l2_projector.solve() self._update_coordinates() @@ -388,24 +391,21 @@ def move(self): f" Residual {residual:10.4e}" f" Variation (σ/μ) {cv:10.4e}" ) - plural = "s" if i != 0 else "" if residual < self.rtol: - PETSc.Sys.Print(f"Converged in {i+1} iteration{plural}.") + self._convergence_message(i + 1) break if residual > self.dtol * initial_norm: - raise firedrake.ConvergenceError( - f"Diverged after {i+1} iteration{plural}." - ) + self._divergence_error(i + 1) if i == self.maxiter - 1: - raise firedrake.ConvergenceError( - f"Failed to converge in {i+1} iteration{plural}." - ) + self._convergence_error(i + 1) # Apply pseudotimestepper and equidistributor self.pseudotimestepper.solve() self.equidistributor.solve() self.phi_old.assign(self.phi) self.sigma_old.assign(self.sigma) + + # Update mesh coordinates accordingly self._update_coordinates() return i @@ -552,22 +552,20 @@ def monitor(snes, i, rnorm): @PETSc.Log.EventDecorator() def move(self): - """ + r""" Run the quasi-Newton method to convergence and update the mesh. :return: the iteration count :rtype: :class:`int` """ + # Solve equidistribution problem, handling convergence errors according to + # desired behaviour try: self.equidistributor.solve() - i = self.snes.getIterationNumber() - plural = "s" if i != 1 else "" - PETSc.Sys.Print(f"Converged in {i} iteration{plural}.") - except firedrake.ConvergenceError: - i = self.snes.getIterationNumber() - plural = "s" if i != 1 else "" - raise firedrake.ConvergenceError( - f"Failed to converge in {i} iteration{plural}." - ) + self._convergence_message(self.snes.getIterationNumber()) + except fexc.ConvergenceError as conv_err: + self._convergence_error(self.snes.getIterationNumber(), exception=conv_err) + + # Update mesh coordinates accordingly self._update_coordinates() - return i + return self.snes.getIterationNumber() diff --git a/movement/mover.py b/movement/mover.py index cbd1ab9..3a80f98 100644 --- a/movement/mover.py +++ b/movement/mover.py @@ -1,18 +1,40 @@ +from warnings import warn + import firedrake +import firedrake.exceptions as fexc import numpy as np from firedrake.cython.dmcommon import create_section +from firedrake.petsc import PETSc __all__ = ["PrimeMover"] -class PrimeMover(object): +class PrimeMover: """ Base class for all mesh movers. """ - def __init__(self, mesh, monitor_function=None, **kwargs): + def __init__( + self, mesh, monitor_function=None, raise_convergence_errors=True, **kwargs + ): + r""" + :arg mesh: the physical mesh + :type mesh: :class:`firedrake.mesh.MeshGeometry` + :arg monitor_function: a Python function which takes a mesh as input + :type monitor_function: :class:`~.Callable` + :kwarg raise_convergence_errors: convergence error handling behaviour: if `True` + then :class:`~.ConvergenceError`\s are raised, else warnings are raised and + the program is allowed to continue + :kwarg raise_convergence_errors: :class:`bool` + """ self.mesh = firedrake.Mesh(mesh.coordinates.copy(deepcopy=True)) self.monitor_function = monitor_function + if not raise_convergence_errors: + warn( + f"{type(self)}.move called with raise_convergence_errors=False." + " Beware: this option can produce poor quality meshes!" + ) + self.raise_convergence_errors = raise_convergence_errors self.dim = self.mesh.topological_dimension() self.gdim = self.mesh.geometric_dimension() self.plex = self.mesh.topology_dm @@ -33,6 +55,66 @@ def __init__(self, mesh, monitor_function=None, **kwargs): ) self.v = firedrake.Function(self.coord_space, name="Mesh velocity") + def _convergence_message(self, iterations=None): + """ + Report solver convergence. + + :kwarg iterations: number of iterations before reaching convergence + :type iterations: :class:`int` + """ + msg = "Solver converged" + if iterations: + msg += f" in {iterations} iteration{plural(iterations)}" + PETSc.Sys.Print(f"{msg}.") + + def _exception(self, msg, exception=None, error_type=fexc.ConvergenceError): + """ + Raise an error or warning as indicated by the :attr:`raise_convergence_error` + option. + + :arg msg: message for the error/warning report + :type msg: :class:`str` + :kwarg exception: original exception that was triggered + :type exception: :class:`~.Exception` object + :kwarg error_type: error class to use + :type error_type: :class:`~.Exception` class + """ + exc_type = error_type if self.raise_convergence_errors else Warning + if exception: + raise exc_type(msg) from exception + else: + raise exc_type(msg) + + def _convergence_error(self, iterations=None, exception=None): + """ + Raise an error or warning for a solver fail as indicated by the + :attr:`raise_convergence_error` option. + + :kwarg iterations: number of iterations before failure + :type iterations: :class:`int` + :kwarg exception: original exception that was triggered + :type exception: :class:`~.Exception` + """ + msg = "Solver failed to converge" + if iterations: + msg += f" in {iterations} iteration{plural(iterations)}" + self._exception(f"{msg}.", exception=exception) + + def _divergence_error(self, iterations=None, exception=None): + """ + Raise an error or warning for a solver divergence as indicated by the + :attr:`raise_convergence_error` option. + + :kwarg iterations: number of iterations before failure + :type iterations: :class:`int` + :kwarg exception: original exception that was triggered + :type exception: :class:`~.Exception` + """ + msg = "Solver diverged" + if iterations: + msg += f" after {iterations} iteration{plural(iterations)}" + self._exception(f"{msg}.", exception=exception) + def _get_coordinate_section(self): entity_dofs = np.zeros(self.dim + 1, dtype=np.int32) entity_dofs[0] = self.gdim @@ -91,11 +173,13 @@ def adapt(self): """ Alias of `move`. """ - from warnings import warn - warn( "`adapt` is deprecated (use `move` instead)", DeprecationWarning, stacklevel=2, ) return self.move() + + +def plural(iterations): + return "s" if iterations != 1 else "" diff --git a/movement/spring.py b/movement/spring.py index ed2e24b..ce060d3 100644 --- a/movement/spring.py +++ b/movement/spring.py @@ -262,7 +262,10 @@ def move(self, time, update_boundary_displacement=None, boundary_conditions=None # Assemble and solve the linear system K = self.assemble_stiffness_matrix(boundary_conditions=boundary_conditions) - self.displacement = np.linalg.solve(K, self._forcing.flatten()) * self.dt + try: + self.displacement = np.linalg.solve(K, self._forcing.flatten()) * self.dt + except Exception as conv_err: + self._convergence_error(exception=conv_err) # Update mesh coordinates shape = self.mesh.coordinates.dat.data_with_halos.shape diff --git a/test/test_monge_ampere.py b/test/test_monge_ampere.py index 6f3506f..3fb8cca 100644 --- a/test/test_monge_ampere.py +++ b/test/test_monge_ampere.py @@ -165,17 +165,21 @@ def test_maxiter_convergenceerror(self, method): mover = MongeAmpereMover(mesh, ring_monitor, method=method, maxiter=1) with self.assertRaises(ConvergenceError) as cm: mover.move() - self.assertEqual(str(cm.exception), "Failed to converge in 1 iteration.") + self.assertEqual(str(cm.exception), "Solver failed to converge in 1 iteration.") - def test_divergence_convergenceerror(self): + @parameterized.expand([(True,), (False,)]) + def test_divergence_convergenceerror(self, raise_errors): """ - Test that the mesh mover raises a :class:`~.ConvergenceError` if it diverges. + Test that divergence of the mesh mover raises a :class:`~.ConvergenceError` if + `raise_errors=True` and a :class:`~.Warning` otherwise. """ mesh = self.mesh(2, n=4) - mover = MongeAmpereMover_Relaxation(mesh, ring_monitor, dtol=1.0e-08) - with self.assertRaises(ConvergenceError) as cm: + mover = MongeAmpereMover_Relaxation( + mesh, ring_monitor, dtol=1.0e-08, raise_convergence_errors=raise_errors + ) + with self.assertRaises(ConvergenceError if raise_errors else Warning) as cm: mover.move() - self.assertEqual(str(cm.exception), "Diverged after 1 iteration.") + self.assertEqual(str(cm.exception), "Solver diverged after 1 iteration.") def test_initial_guess_valueerror(self): mesh = self.mesh(2, n=2)