Skip to content

Commit

Permalink
Fix specific boundary segments under MA (#115)
Browse files Browse the repository at this point in the history
Closes #112.

Note that this PR merges into #95 to further improve the testing of
boundary cases with Monge-Ampere.
  • Loading branch information
jwallwork23 authored Sep 3, 2024
1 parent 8112172 commit a823c26
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 71 deletions.
51 changes: 38 additions & 13 deletions movement/monge_ampere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""

import abc
from collections.abc import Iterable
from warnings import warn

import firedrake
import firedrake.exceptions as fexc
Expand Down Expand Up @@ -69,9 +71,11 @@ def MongeAmpereMover(mesh, monitor_function, method="relaxation", **kwargs):
:type dtol: :class:`float`
:kwarg pseudo_timestep: pseudo-timestep (only relevant to relaxation method)
:type pseudo_timestep: :class:`float`
:kwarg fix_boundary_nodes: should all boundary nodes remain fixed?
:type fix_boundary_nodes: :class:`bool`
:return: the Monge-Ampere Mover object
:kwarg fixed_boundary_segments: labels corresponding to boundary segments to be fixed
with a zero Dirichlet condition. The 'on_boundary' label indicates the whole
domain boundary
:type fixed_boundary_segments: :class:`list` of :class:`str` or :class:`int`
:return: the Monge-Ampère Mover object
:rtype: :class:`MongeAmpereMover_Relaxation` or
:class:`MongeAmpereMover_QuasiNewton`
"""
Expand Down Expand Up @@ -124,8 +128,10 @@ def __init__(self, mesh, monitor_function, **kwargs):
:type rtol: :class:`float`
:kwarg dtol: divergence tolerance for the residual
:type dtol: :class:`float`
:kwarg fix_boundary_nodes: should all boundary nodes remain fixed?
:type fix_boundary_nodes: :class:`bool`
:kwarg fixed_boundary_segments: labels corresponding to boundary segments to be
fixed with a zero Dirichlet condition. The 'on_boundary' label indicates
the whole domain boundary
:type fixed_boundary_segments: :class:`list` of :class:`str` or :class:`int`
"""
if monitor_function is None:
raise ValueError("Please supply a monitor function.")
Expand All @@ -138,10 +144,29 @@ def __init__(self, mesh, monitor_function, **kwargs):
self.maxiter = kwargs.pop("maxiter", 1000)
self.rtol = kwargs.pop("rtol", 1.0e-08)
self.dtol = kwargs.pop("dtol", 2.0)
self.fix_boundary_nodes = kwargs.pop("fix_boundary_nodes", False)
self.fixed_boundary_segments = kwargs.pop("fixed_boundary_segments", [])
super().__init__(mesh, monitor_function=monitor_function, **kwargs)
self.theta = firedrake.Constant(0.0)

# Handle boundary segments where zero Dirichlet conditions are applied
if self.fixed_boundary_segments == "on_boundary":
self.fixed_boundary_segments = self._all_boundary_segments
elif not isinstance(self.fixed_boundary_segments, Iterable):
self.fixed_boundary_segments = [self.fixed_boundary_segments]
if len(self._all_boundary_segments) == 0:
warn(
"Provided mesh has no boundary segments with Physical ID tags. If the "
"boundaries aren't fully periodic then this will likely cause errors."
)
elif (
len(self.fixed_boundary_segments) == 1
and self.fixed_boundary_segments[0] == "on_boundary"
):
self.fixed_boundary_segments = self._all_boundary_segments
for boundary_tag in self.fixed_boundary_segments:
if boundary_tag not in self._all_boundary_segments:
raise ValueError(f"Provided boundary_tag '{boundary_tag}' is invalid.")

def _create_function_spaces(self):
super()._create_function_spaces()
self.P1 = firedrake.FunctionSpace(self.mesh, "CG", 1)
Expand Down Expand Up @@ -232,7 +257,7 @@ def _l2_projector_bcs(self, boundary_tag):
:rtype: :class:`tuple` of :class:`~.DirichletBC`\s
"""
zero_bc = firedrake.DirichletBC(self.P1_vec, 0, boundary_tag)
if self.fix_boundary_nodes or self.dim == 1:
if (boundary_tag in self.fixed_boundary_segments) or self.dim == 1:
return (zero_bc,)

# If the boundary segment is axis-aligned, it is straightforward to avoid
Expand All @@ -257,7 +282,7 @@ def _l2_projector_bcs(self, boundary_tag):

# Determine the 'corner' vertices which are at the intersection of two boundary
# segments and create a Dirichlet condition for fixing them under mesh movement
facet_indices = set(self.mesh.exterior_facets.unique_markers)
facet_indices = set(self._all_boundary_segments)
ffacet_indices = [
(tag, boundary_tag) for tag in facet_indices.difference([boundary_tag])
]
Expand Down Expand Up @@ -306,11 +331,9 @@ def l2_projector(self):
# Enforce no movement normal to boundary
bcs = [
dirichlet_bc
for boundary_tag in self.mesh.exterior_facets.unique_markers
for boundary_tag in self._all_boundary_segments
for dirichlet_bc in self._l2_projector_bcs(boundary_tag)
]
if not bcs and self.fix_boundary_nodes:
raise ValueError("Cannot fix boundary nodes for periodic meshes.")

# Create solver
problem = firedrake.LinearVariationalProblem(a, L, self._grad_phi, bcs=bcs)
Expand Down Expand Up @@ -367,8 +390,10 @@ def __init__(self, mesh, monitor_function, phi_init=None, H_init=None, **kwargs)
:type rtol: :class:`float`
:kwarg dtol: divergence tolerance for the residual
:type dtol: :class:`float`
:kwarg fix_boundary_nodes: should all boundary nodes remain fixed?
:type fix_boundary_nodes: :class:`bool`
:kwarg fixed_boundary_segments: labels corresponding to boundary segments to be
fixed with a zero Dirichlet condition. The 'on_boundary' label indicates
the whole domain boundary
:type fixed_boundary_segments: :class:`list` of :class:`str` or :class:`int`
"""
self.pseudo_dt = firedrake.Constant(kwargs.pop("pseudo_timestep", 0.1))
super().__init__(mesh, monitor_function=monitor_function, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions movement/mover.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(

self._create_function_spaces()
self._create_functions()
self._all_boundary_segments = self.mesh.exterior_facets.unique_markers

# Utilities
if tangling_check is None:
Expand Down
150 changes: 92 additions & 58 deletions test/test_monge_ampere.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class BaseClasses:
class TestMongeAmpere(unittest.TestCase):
"""
Base class for Monge-ampere unit tests.
Base class for Monge-Ampère unit tests.
"""

def mesh(self, dim=1, n=10, periodic=False):
Expand All @@ -34,7 +34,7 @@ def dummy_monitor(self):

class TestExceptions(BaseClasses.TestMongeAmpere):
"""
Unit tests for exceptions raised by Monge-Ampere movers.
Unit tests for exceptions raised by Monge-Ampère movers.
"""

def test_method_valueerror(self):
Expand Down Expand Up @@ -109,14 +109,6 @@ def test_invalid_plane_valuerror(self):
msg = "Could not determine a plane for the provided points."
self.assertEqual(str(cm.exception), msg)

def test_periodic_fix_boundary_valueerror(self):
mesh = self.mesh(n=3, periodic=True)
mover = MongeAmpereMover_Relaxation(mesh, ring_monitor, fix_boundary_nodes=True)
with self.assertRaises(ValueError) as cm:
mover.move()
msg = "Cannot fix boundary nodes for periodic meshes."
self.assertEqual(str(cm.exception), msg)

def test_curved_notimplementederror(self):
coords = Function(VectorFunctionSpace(UnitTriangleMesh(), "CG", 2))
coords.interpolate(coords.function_space().mesh().coordinates)
Expand All @@ -139,10 +131,16 @@ def test_periodic_plex_valueerror(self):
msg = "Cannot update DMPlex coordinates for periodic meshes."
self.assertEqual(str(cm.exception), msg)

def test_fix_invalid_segment_valueerror(self):
with self.assertRaises(ValueError) as cm:
MongeAmpereMover(self.mesh(1), const_monitor, fixed_boundary_segments=[-1])
msg = "Provided boundary_tag '-1' is invalid."
self.assertEqual(str(cm.exception), msg)


class TestMonitor(BaseClasses.TestMongeAmpere):
"""
Unit tests for monitor functions used by Monge-Ampere movers.
Unit tests for monitor functions used by Monge-Ampère movers.
"""

@parameterized.expand(
Expand Down Expand Up @@ -231,30 +229,30 @@ def test_change_monitor(self, dim, method):

class TestBCs(BaseClasses.TestMongeAmpere):
"""
Unit tests for boundary conditions of Monge-Ampere movers.
Unit tests for boundary conditions of Monge-Ampère movers.
"""

def _test_boundary_preservation(self, mesh, method, fix_boundary):
def _test_boundary_preservation(self, mesh, method, fixed_boundaries):
bnd = assemble(Constant(1.0) * ds(domain=mesh))
bnodes = DirichletBC(mesh.coordinates.function_space(), 0, "on_boundary").nodes
coord_space = mesh.coordinates.function_space()
bnodes = DirichletBC(coord_space, 0, fixed_boundaries or "on_boundary").nodes
bnd_coords = mesh.coordinates.dat.data.copy()[bnodes]

# Adapt to a ring monitor
mover = MongeAmpereMover(
mesh,
ring_monitor,
method=method,
fix_boundary_nodes=fix_boundary,
fixed_boundary_segments=fixed_boundaries,
rtol=1e-3,
)
mover.move()

# Check boundary lengths are preserved
bnd_new = assemble(Constant(1.0) * ds(domain=mover.mesh))
self.assertAlmostEqual(bnd, bnd_new)
self.assertAlmostEqual(bnd, assemble(Constant(1.0) * mover.ds))

# Check boundaries are indeed fixed
if fix_boundary:
if fixed_boundaries:
bnd_coords_new = mover.mesh.coordinates.dat.data[bnodes]
self.assertTrue(np.allclose(bnd_coords, bnd_coords_new))
return mover
Expand All @@ -272,11 +270,11 @@ def _test_boundary_preservation(self, mesh, method, fix_boundary):
def test_periodic(self, dim, method):
"""
Test that periodic unit domains are not given boundary conditions by the
Monge-Ampere movers.
Monge-Ampère movers.
"""
mesh = self.mesh(dim=dim, periodic=True)
volume = assemble(Constant(1.0) * dx(domain=mesh))
mover = self._test_boundary_preservation(mesh, method, False)
mover = self._test_boundary_preservation(mesh, method, [])

# Check the volume of the domain is conserved
self.assertAlmostEqual(assemble(Constant(1.0) * dx(domain=mover.mesh)), volume)
Expand All @@ -293,27 +291,33 @@ def test_initial_guess_valueerror(self):

@parameterized.expand(
[
(1, "relaxation", True),
(1, "relaxation", False),
(1, "quasi_newton", True),
(1, "quasi_newton", False),
(2, "relaxation", True),
(2, "relaxation", False),
(2, "quasi_newton", True),
(2, "quasi_newton", False),
(3, "relaxation", True),
(3, "relaxation", False),
(3, "quasi_newton", True),
(3, "quasi_newton", False),
(1, "relaxation", "on_boundary"),
(1, "relaxation", [1]),
(1, "relaxation", []),
(1, "quasi_newton", "on_boundary"),
(1, "quasi_newton", [1]),
(1, "quasi_newton", []),
(2, "relaxation", "on_boundary"),
(2, "relaxation", [1]),
(2, "relaxation", []),
(2, "quasi_newton", "on_boundary"),
(2, "quasi_newton", [1]),
(2, "quasi_newton", []),
(3, "relaxation", "on_boundary"),
(3, "relaxation", [1]),
(3, "relaxation", []),
(3, "quasi_newton", "on_boundary"),
(3, "quasi_newton", [1]),
(3, "quasi_newton", []),
]
)
def test_boundary_preservation_axis_aligned(self, dim, method, fix_boundary):
def test_boundary_preservation_axis_aligned(self, dim, method, fixed_boundaries):
"""
Test that boundaries of unit domains are preserved by the Monge-Ampere movers.
Test that boundaries of unit domains are preserved by the Monge-Ampère movers.
"""
mesh = self.mesh(dim=dim)
volume = assemble(Constant(1.0) * dx(domain=mesh))
mover = self._test_boundary_preservation(mesh, method, fix_boundary)
mover = self._test_boundary_preservation(mesh, method, fixed_boundaries)

# Check the volume of the domain is conserved
self.assertAlmostEqual(assemble(Constant(1.0) * dx(domain=mover.mesh)), volume)
Expand All @@ -325,20 +329,26 @@ def test_boundary_preservation_axis_aligned(self, dim, method, fix_boundary):

@parameterized.expand(
[
(2, "relaxation", True),
(2, "relaxation", False),
(2, "quasi_newton", True),
(2, "quasi_newton", False),
(3, "relaxation", True),
(3, "relaxation", False),
(3, "quasi_newton", True),
(3, "quasi_newton", False),
(2, "relaxation", "on_boundary"),
(2, "relaxation", [1]),
(2, "relaxation", []),
(2, "quasi_newton", "on_boundary"),
(2, "quasi_newton", [1]),
(2, "quasi_newton", []),
(3, "relaxation", "on_boundary"),
(3, "relaxation", [1]),
(3, "relaxation", []),
(3, "quasi_newton", "on_boundary"),
(3, "quasi_newton", [1]),
(3, "quasi_newton", []),
]
)
def test_boundary_preservation_non_axis_aligned(self, dim, method, fix_boundary):
def test_boundary_preservation_non_axis_aligned(
self, dim, method, fixed_boundaries
):
"""
Test that boundaries of rotated unit domains are preserved by the
Monge-Ampere movers.
Monge-Ampère movers.
"""
mesh = self.mesh(dim=dim)
volume = assemble(Constant(1.0) * dx(domain=mesh))
Expand All @@ -354,29 +364,53 @@ def test_boundary_preservation_non_axis_aligned(self, dim, method, fix_boundary)
raise ValueError(f"Dimension {dim} not supported.")
coords = Function(mesh.coordinates.function_space())
coords.interpolate(ufl.dot(rotation_matrix, mesh.coordinates))
mover = self._test_boundary_preservation(Mesh(coords), method, fix_boundary)
mover = self._test_boundary_preservation(Mesh(coords), method, fixed_boundaries)

# Check the volume of the domain is conserved
self.assertAlmostEqual(assemble(Constant(1.0) * dx(domain=mover.mesh)), volume)

# If boundaries are not fixed then EquationBCs should be used for boundaries of
# the xy-plane
# If boundaries are fixed then they should have a single DirichletBC associated
# with them
# If boundaries are not fixed then they should have two EquationBCs associated
# with them
bcs = mover._l2_projector._problem.bcs
if fix_boundary:
self.assertTrue(len(bcs) == 2 * dim)
if fixed_boundaries == "on_boundary":
# All boundary segments are fixed => one DirichletBC per edge/face
self.assertEqual(len(bcs), 2 * dim)
self.assertTrue(all(isinstance(bc, DirichletBC) for bc in bcs))
elif dim == 2:
self.assertTrue(len(bcs) == 8)
self.assertTrue(all(isinstance(bc, EquationBC) for bc in bcs))
else:
self.assertTrue(len(bcs) == 10)
self.assertEqual(sum(isinstance(bc, EquationBC) for bc in bcs), 8)
self.assertEqual(sum(isinstance(bc, DirichletBC) for bc in bcs), 2)
elif fixed_boundaries == []:
if dim == 2:
# All four boundary edges have two EquationBCs each
self.assertEqual(len(bcs), 8)
self.assertTrue(all(isinstance(bc, EquationBC) for bc in bcs))
else:
# The four non-axis-aligned boundary faces have two EquationBCs each
# There are also two axis-aligned boundary faces, which have a
# DirichletBC each
self.assertEqual(len(bcs), 10)
self.assertEqual(sum(isinstance(bc, DirichletBC) for bc in bcs), 2)
elif fixed_boundaries == [1]:
if dim == 2:
# One of four non-axis-aligned boundary edges is fixed
# => one edge has a single DirichletBC and the other three have two
# EquationBCs each
self.assertEqual(len(bcs), 7)
self.assertEqual(sum(isinstance(bc, DirichletBC) for bc in bcs), 1)
self.assertEqual(sum(isinstance(bc, EquationBC) for bc in bcs), 6)
else:
# One of four non-axis-aligned boundary faces is fixed
# => one of these faces has a single DirichletBC and the other three have
# two EquationBCs each
# There are also two axis-aligned boundary faces, with a single
# DirichletBC each
self.assertEqual(len(bcs), 9)
self.assertEqual(sum(isinstance(bc, DirichletBC) for bc in bcs), 3)
self.assertEqual(sum(isinstance(bc, EquationBC) for bc in bcs), 6)


class TestMisc(BaseClasses.TestMongeAmpere):
"""
Unit tests for other misc. functionality of Monge-Ampere movers.
Unit tests for other misc. functionality of Monge-Ampère movers.
"""

@parameterized.expand(
Expand Down

0 comments on commit a823c26

Please sign in to comment.