Skip to content

Commit

Permalink
Merge pull request #72 from pyroteus/70_sub_class
Browse files Browse the repository at this point in the history
Fix subclass issues with `GoalOrientedMeshSeq`
  • Loading branch information
jwallwork23 authored Dec 4, 2023
2 parents d9ddc0b + f57c890 commit eab630a
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 58 deletions.
40 changes: 24 additions & 16 deletions demos/burgers_oo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

# You may have noticed that the functions :func:`get_form`,
# :func:`get_solver`, :func:`get_initial_condition` and
# :func:`get_qoi` all take a :class:`MeshSeq` as input and return
# a function. If this all feels a lot like writing methods for a
# :class:`MeshSeq` subclass, that's because this is exactly what
# we are doing. The constructors for :class:`MeshSeq` and
# :class:`AdjointMeshSeq` simply take these functions and adopt
# :func:`get_qoi` all take a :class:`MeshSeq`, :class:`AdjointMeshSeq`
# or :class:`GoalOrientedMeshSeq` as input and return a function.
# If this all feels a lot like writing methods for a
# subclass, that's because this is exactly what we are doing.
# The constructors for :class:`MeshSeq`, :class:`AdjointMeshSeq` and
# :class:`GoalOrientedMeshSeq` simply take these functions and adopt
# them as methods. A more natural way to write the subclass yourself.
#
# In the following, we mostly copy the contents from the previous
Expand All @@ -24,7 +25,7 @@
set_log_level(DEBUG)


class BurgersMeshSeq(AdjointMeshSeq):
class BurgersMeshSeq(GoalOrientedMeshSeq):
@staticmethod
def get_function_spaces(mesh):
return {"u": VectorFunctionSpace(mesh, "CG", 2)}
Expand All @@ -35,7 +36,7 @@ def form(index, solutions):
P = self.time_partition

# Define constants
R = FunctionSpace(mesh_seq[index], "R", 0)
R = FunctionSpace(self[index], "R", 0)
dt = Function(R).assign(P.timesteps[index])
nu = Function(R).assign(0.0001)

Expand Down Expand Up @@ -114,20 +115,27 @@ def time_integrated_qoi(t):
end_time = 0.5
dt = 1 / n
num_subintervals = len(meshes)
P = TimePartition(
time_partition = TimePartition(
end_time, num_subintervals, dt, ["u"], num_timesteps_per_export=2
)
mesh_seq = BurgersMeshSeq(P, meshes, qoi_type="end_time")
solutions = mesh_seq.solve_adjoint()
mesh_seq = BurgersMeshSeq(time_partition, meshes, qoi_type="time_integrated")
solutions, indicators = mesh_seq.indicate_errors(
enrichment_kwargs={"enrichment_method": "h"}
)

# Plotting this, we find that the results are identical to those generated previously. ::
# Plotting this, we find that the results are consistent with those generated previously. ::

fig, axes, tcs = plot_snapshots(
solutions, P, "u", "adjoint", levels=np.linspace(0, 0.8, 9)
)
fig.savefig("burgers-oo.jpg")
fig, axes, tcs = plot_indicator_snapshots(indicators, time_partition, "u", levels=50)
fig.savefig("burgers-oo_ee.jpg")

# .. figure:: burgers-oo_ee.jpg
# :figwidth: 90%
# :align: center

fig, axes, tcs = plot_snapshots(solutions, time_partition, "u", "adjoint")
fig.savefig("burgers-oo-time_integrated.jpg")

# .. figure:: burgers-oo.jpg
# .. figure:: burgers-oo-time_integrated.jpg
# :figwidth: 90%
# :align: center

Expand Down
38 changes: 18 additions & 20 deletions goalie/go_mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ def get_enriched_mesh_seq(
self, enrichment_method: str = "p", num_enrichments: int = 1
) -> AdjointMeshSeq:
"""
Solve the forward and adjoint problems
associated with
:meth:`~.GoalOrientedMeshSeq.solver`
in a sequence of globally enriched spaces.
Construct a sequence of globally enriched spaces.
Currently, global enrichment may be
achieved using one of:
Expand All @@ -55,24 +52,11 @@ def get_enriched_mesh_seq(
else:
meshes = self.meshes

def get_function_spaces(mesh):
"""
Apply p-refinement, if requested.
"""
if enrichment_method == "h":
return self._get_function_spaces(mesh)
enriched_spaces = {}
for label, fs in self.function_spaces.items():
element = fs[0].ufl_element()
element = element.reconstruct(degree=element.degree() + num_enrichments)
enriched_spaces[label] = FunctionSpace(mesh, element)
return enriched_spaces

# Construct enriched AdjointMeshSeq
return AdjointMeshSeq(
# Construct object to hold enriched spaces
mesh_seq_e = self.__class__(
self.time_partition,
meshes,
get_function_spaces=get_function_spaces,
get_function_spaces=self._get_function_spaces,
get_initial_condition=self._get_initial_condition,
get_form=self._get_form,
get_solver=self._get_solver,
Expand All @@ -81,6 +65,20 @@ def get_function_spaces(mesh):
qoi_type=self.qoi_type,
)

# Apply p-refinement
if enrichment_method == "p":
for label, fs in mesh_seq_e.function_spaces.items():
for n, _space in enumerate(fs):
element = _space.ufl_element()
element = element.reconstruct(
degree=element.degree() + num_enrichments
)
mesh_seq_e._fs[label][n] = FunctionSpace(
mesh_seq_e.meshes[n], element
)

return mesh_seq_e

@PETSc.Log.EventDecorator("goalie.GoalOrientedMeshSeq.global_enrichment")
def global_enrichment(
self, enrichment_method: str = "p", num_enrichments: int = 1, **kwargs
Expand Down
49 changes: 29 additions & 20 deletions goalie/mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,7 @@ def __init__(self, time_partition: TimePartition, initial_meshes: list, **kwargs
}
self.subintervals = time_partition.subintervals
self.num_subintervals = time_partition.num_subintervals
self.meshes = initial_meshes
if not isinstance(self.meshes, Iterable):
self.meshes = [Mesh(initial_meshes) for subinterval in self.subintervals]
self.element_counts = [self.count_elements()]
self.vertex_counts = [self.count_vertices()]
dim = np.array([mesh.topological_dimension() for mesh in self.meshes])
if dim.min() != dim.max():
raise ValueError("Meshes must all have the same topological dimension.")
self.dim = dim.min()
if logger.level == DEBUG:
for i, mesh in enumerate(self.meshes):
nc = mesh.num_cells()
nv = mesh.num_vertices()
qm = QualityMeasure(mesh)
ar = qm("aspect_ratio")
mar = ar.vector().gather().max()
self.debug(
f"{i}: {nc:7d} cells, {nv:7d} vertices, max aspect ratio {mar:.2f}"
)
debug(100 * "-")
self.set_meshes(initial_meshes)
self._fs = None
self._get_function_spaces = kwargs.get("get_function_spaces")
self._get_initial_condition = kwargs.get("get_initial_condition")
Expand Down Expand Up @@ -136,6 +117,34 @@ def count_elements(self) -> list:
def count_vertices(self) -> list:
return [mesh.num_vertices() for mesh in self] # TODO: make parallel safe

def set_meshes(self, meshes):
"""
Update the meshes associated with the :class:`MeshSeq`, as well as the
associated attributes.
:arg meshes: mesh or list of meshes to use in the sequence
"""
if not isinstance(meshes, Iterable):
meshes = [Mesh(meshes) for subinterval in self.subintervals]
self.meshes = meshes
dim = np.array([mesh.topological_dimension() for mesh in meshes])
if dim.min() != dim.max():
raise ValueError("Meshes must all have the same topological dimension.")
self.dim = dim.min()
self.element_counts = [self.count_elements()]
self.vertex_counts = [self.count_vertices()]
if logger.level == DEBUG:
for i, mesh in enumerate(meshes):
nc = mesh.num_cells()
nv = mesh.num_vertices()
qm = QualityMeasure(mesh)
ar = qm("aspect_ratio")
mar = ar.vector().gather().max()
self.debug(
f"{i}: {nc:7d} cells, {nv:7d} vertices, max aspect ratio {mar:.2f}"
)
debug(100 * "-")

def plot(
self, **kwargs
) -> Tuple[matplotlib.figure.Figure, matplotlib.axes._axes.Axes]:
Expand Down
2 changes: 1 addition & 1 deletion goalie/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def Mesh(arg, **kwargs) -> firedrake.mesh.MeshGeometry:
mesh.boundary_area = bnd_len

# Cell size
if dim == 2 and mesh.coordinates.ufl_element().cell() == ufl.triangle:
if dim == 2 and mesh.coordinates.ufl_element().cell == ufl.triangle:
mesh.delta_x = firedrake.interpolate(ufl.CellDiameter(mesh), P0)

return mesh
Expand Down
4 changes: 3 additions & 1 deletion test/test_error_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def test_time_partition_different_timesteps(self):
estimator = indicators2estimator(
{"field": [[indicator], 2 * [indicator]]}, time_partition
)
self.assertAlmostEqual(estimator, 1) # 0.5 * (0.5 + 0.5) + 0.25 * 2 * (0.5 + 0.5)
self.assertAlmostEqual(
estimator, 1
) # 0.5 * (0.5 + 0.5) + 0.25 * 2 * (0.5 + 0.5)

def test_time_instant_multiple_fields(self):
time_instant = TimeInstant(["field1", "field2"], time=1.0)
Expand Down

0 comments on commit eab630a

Please sign in to comment.