Skip to content

Commit

Permalink
Introduce SolutionData and IndicatorData classes (#114)
Browse files Browse the repository at this point in the history
* #26: SolutionData base class and forward case

* #26: Hook up ForwardSolutionData

* #26: Support indexing and add protection

* #26: AdjointSolutionData classes

* #26: Hook up AdjointSolutionData

* #26: More classes of ForwardSolutionData

* #26: IndicatorData class

* #26: Hook up IndicatorData

* #26: Fix error indicator tests

* #26: Introduce FunctionData base class and pull up methods

* #26: Rename as function_data; fix label bug

* #26: Refactor to consider field type inside FunctionData, rather than outside

* #26: Checks for steady case
  • Loading branch information
jwallwork23 authored Feb 28, 2024
1 parent 6c32ae9 commit 4d141be
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 72 deletions.
32 changes: 7 additions & 25 deletions goalie/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import firedrake
from firedrake.petsc import PETSc
from firedrake.adjoint import pyadjoint
from .function_data import AdjointSolutionData
from .interpolation import project
from .mesh_seq import MeshSeq
from .options import GoalOrientedParameters
Expand Down Expand Up @@ -175,27 +176,7 @@ def get_solve_blocks(
return solve_blocks

def _create_solutions(self):
P = self.time_partition
labels = ("forward", "forward_old", "adjoint")
if not self.steady:
labels += ("adjoint_next",)
self._solutions = AttrDict(
{
field: AttrDict(
{
label: [
[
firedrake.Function(fs, name=f"{field}_{label}")
for j in range(P.num_exports_per_subinterval[i] - 1)
]
for i, fs in enumerate(self.function_spaces[field])
]
for label in labels
}
)
for field in self.fields
}
)
self._solutions = AdjointSolutionData(self.time_partition, self.function_spaces)

@PETSc.Log.EventDecorator()
def solve_adjoint(
Expand All @@ -212,13 +193,14 @@ def solve_adjoint(
computed, the contents of which give values at all exported timesteps, indexed
first by the field label and then by type. The contents of these nested
dictionaries are lists which are indexed first by subinterval and then by
export. For a given exported timestep, the solution types are:
export. For a given exported timestep, the field types are:
* ``'forward'``: the forward solution after taking the timestep;
* ``'forward_old'``: the forward solution before taking the timestep
* ``'forward_old'``: the forward solution before taking the timestep (provided
the problem is not steady-state)
* ``'adjoint'``: the adjoint solution after taking the timestep;
* ``'adjoint_next'``: the adjoint solution before taking the timestep
(backwards).
backwards (provided the problem is not steady-state).
:kwarg solver_kwargs: a dictionary providing parameters to the solver. Any
keyword arguments for the QoI should be included as a subdict with label
Expand Down Expand Up @@ -366,7 +348,7 @@ def wrapped_solver(subinterval, ic, **kwargs):

# Lagged forward solution comes from dependencies
dep = self._dependency(field, i, block)
if dep is not None:
if not self.steady and dep is not None:
sols.forward_old[i][j].assign(dep.saved_output)

# Adjoint action also comes from dependencies
Expand Down
136 changes: 136 additions & 0 deletions goalie/function_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
r"""
Nested dictionaries of solution data :class:`~.Function`\s.
"""
import firedrake.function as ffunc
import firedrake.functionspace as ffs
from .utility import AttrDict
import abc

__all__ = [
"ForwardSolutionData",
"AdjointSolutionData",
"IndicatorData",
]


class FunctionData(abc.ABC):
"""
Abstract base class for classes holding field data.
"""

labels = {}

def __init__(self, time_partition, function_spaces):
r"""
:arg time_partition: the :class:`~.TimePartition` used to discretise the problem
in time
:arg function_spaces: the dictionary of :class:`~.FunctionSpace`\s used to
discretise the problem in space
"""
self.time_partition = time_partition
self.function_spaces = function_spaces
self._data = None

def _create_data(self):
assert self.labels
P = self.time_partition
self._data = AttrDict(
{
field: AttrDict(
{
label: [
[
ffunc.Function(fs, name=f"{field}_{label}")
for j in range(P.num_exports_per_subinterval[i] - 1)
]
for i, fs in enumerate(self.function_spaces[field])
]
for label in self.labels[field_type]
}
)
for field, field_type in zip(P.fields, P.field_types)
}
)

@property
def data(self):
if self._data is None:
self._create_data()
return self._data

def __getitem__(self, key):
return self.data[key]

def items(self):
return self.data.items()


class SolutionData(FunctionData, abc.ABC):
"""
Abstract base class that defines the API for solution data classes.
"""

@property
def solutions(self):
return self.data


class ForwardSolutionData(SolutionData):
"""
Class representing solution data for general forward problems.
"""

def __init__(self, *args, **kwargs):
self.labels = {"steady": ("forward",), "unsteady": ("forward", "forward_old")}
super().__init__(*args, **kwargs)


class AdjointSolutionData(SolutionData):
"""
Class representing solution data for general adjoint problems.
"""

def __init__(self, *args, **kwargs):
self.labels = {
"steady": ("forward", "adjoint"),
"unsteady": ("forward", "forward_old", "adjoint", "adjoint_next"),
}
super().__init__(*args, **kwargs)


class IndicatorData(FunctionData):
"""
Class representing error indicator data.
Note that this class has a single dictionary with the field name as the key, rather
than a doubly-nested dictionary.
"""

def __init__(self, time_partition, meshes):
"""
:arg time_partition: the :class:`~.TimePartition` used to discretise the problem
in time
:arg meshes: the list of meshes used to discretise the problem in space
"""
self.labels = {
field_type: ("error_indicator",) for field_type in ("steady", "unsteady")
}
P0_spaces = [ffs.FunctionSpace(mesh, "DG", 0) for mesh in meshes]
super().__init__(
time_partition, {key: P0_spaces for key in time_partition.fields}
)

def _create_data(self):
assert all(len(labels) == 1 for labels in self.labels.values())
super()._create_data()
P = self.time_partition
self._data = AttrDict(
{
field: self.data[field][self.labels[field_type][0]]
for field, field_type in zip(P.fields, P.field_types)
}
)

@property
def indicators(self):
return self.data
19 changes: 3 additions & 16 deletions goalie/go_mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .adjoint import AdjointMeshSeq
from .error_estimation import get_dwr_indicator
from .function_data import IndicatorData
from .log import pyrint
from .utility import AttrDict
from firedrake import Function, FunctionSpace, MeshHierarchy, TransferManager, project
Expand Down Expand Up @@ -88,21 +89,7 @@ def _get_transfer_function(enrichment_method):
return lambda source, target: target.interpolate(source)

def _create_indicators(self):
P0_spaces = [FunctionSpace(mesh, "DG", 0) for mesh in self]
self._indicators = AttrDict(
{
field: [
[
Function(fs, name=f"{field}_error_indicator")
for _ in range(
self.time_partition.num_exports_per_subinterval[i] - 1
)
]
for i, fs in enumerate(P0_spaces)
]
for field in self.fields
}
)
self._indicators = IndicatorData(self.time_partition, self.meshes)

@property
def indicators(self):
Expand Down Expand Up @@ -203,7 +190,7 @@ def error_estimate(self, absolute_value: bool = False) -> float:
:kwarg absolute_value: toggle whether to take the modulus on each element
"""
assert isinstance(self.indicators, dict)
assert isinstance(self.indicators, IndicatorData)
if not isinstance(absolute_value, bool):
raise TypeError(
f"Expected 'absolute_value' to be a bool, not '{type(absolute_value)}'."
Expand Down
28 changes: 6 additions & 22 deletions goalie/mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from firedrake.adjoint_utils.solving import get_solve_blocks
from firedrake.petsc import PETSc
from firedrake.pyplot import triplot
from .function_data import ForwardSolutionData
from .interpolation import project
from .log import pyrint, debug, warning, info, logger, DEBUG
from .options import AdaptParameters
Expand Down Expand Up @@ -489,25 +490,7 @@ def _dependency(self, field, subinterval, solve_block):
)

def _create_solutions(self):
P = self.time_partition
labels = ("forward", "forward_old")
self._solutions = AttrDict(
{
field: AttrDict(
{
label: [
[
firedrake.Function(fs, name=f"{field}_{label}")
for j in range(P.num_exports_per_subinterval[i] - 1)
]
for i, fs in enumerate(self.function_spaces[field])
]
for label in labels
}
)
for field in self.fields
}
)
self._solutions = ForwardSolutionData(self.time_partition, self.function_spaces)

@property
def solutions(self):
Expand All @@ -527,10 +510,11 @@ def solve_forward(self, solver_kwargs: dict = {}) -> AttrDict:
at all exported timesteps, indexed first by the field label and then by type.
The contents of these nested dictionaries are nested lists which are indexed
first by subinterval and then by export. For a given exported timestep, the
solution types are:
field types are:
* ``'forward'``: the forward solution after taking the timestep;
* ``'forward_old'``: the forward solution before taking the timestep.
* ``'forward_old'``: the forward solution before taking the timestep (provided
the problem is not steady-state).
:kwarg solver_kwargs: a dictionary providing parameters to the solver. Any
keyword arguments for the QoI should be included as a subdict with label
Expand Down Expand Up @@ -594,7 +578,7 @@ def solve_forward(self, solver_kwargs: dict = {}) -> AttrDict:

# Lagged solution comes from dependencies
dep = self._dependency(field, i, block)
if dep is not None:
if not self.steady and dep is not None:
sols.forward_old[i][j].assign(dep.saved_output)

# Transfer the checkpoint between subintervals
Expand Down
22 changes: 13 additions & 9 deletions test/test_error_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
form2indicator,
get_dwr_indicator,
)
from goalie.function_data import IndicatorData
from goalie.go_mesh_seq import GoalOrientedMeshSeq
from goalie.time_partition import TimeInstant, TimePartition
from parameterized import parameterized
Expand Down Expand Up @@ -68,47 +69,47 @@ def mesh_seq(self, time_partition=None):
)

def test_time_partition_wrong_field_error(self):
mesh_seq = self.mesh_seq()
mesh_seq._indicators = {"f": [[self.one]]}
mesh_seq = self.mesh_seq(TimeInstant("field"))
time_partition = TimeInstant("f")
mesh_seq._indicators = IndicatorData(time_partition, mesh_seq.meshes)
with self.assertRaises(ValueError) as cm:
mesh_seq.error_estimate()
msg = "Key 'f' does not exist in the TimePartition provided."
self.assertEqual(str(cm.exception), msg)

def test_absolute_value_type_error(self):
mesh_seq = self.mesh_seq()
mesh_seq._indicators = {"field": [[self.one]]}
with self.assertRaises(TypeError) as cm:
mesh_seq.error_estimate(absolute_value=0)
msg = "Expected 'absolute_value' to be a bool, not '<class 'int'>'."
self.assertEqual(str(cm.exception), msg)

def test_unit_time_instant(self):
mesh_seq = self.mesh_seq(time_partition=TimeInstant("field", time=1.0))
mesh_seq._indicators = {"field": [[form2indicator(self.one * dx)]]}
mesh_seq.indicators["field"][0][0].assign(form2indicator(self.one * dx))
estimator = mesh_seq.error_estimate()
self.assertAlmostEqual(estimator, 1) # 1 * (0.5 + 0.5)

@parameterized.expand([[False], [True]])
def test_unit_time_instant_abs(self, absolute_value):
mesh_seq = self.mesh_seq(time_partition=TimeInstant("field", time=1.0))
mesh_seq._indicators = {"field": [[form2indicator(-self.one * dx)]]}
mesh_seq.indicators["field"][0][0].assign(form2indicator(-self.one * dx))
estimator = mesh_seq.error_estimate(absolute_value=absolute_value)
self.assertAlmostEqual(
estimator, 1 if absolute_value else -1
) # (-)1 * (0.5 + 0.5)

def test_half_time_instant(self):
mesh_seq = self.mesh_seq(time_partition=TimeInstant("field", time=0.5))
mesh_seq._indicators = {"field": [[form2indicator(self.one * dx)]]}
mesh_seq.indicators["field"][0][0].assign(form2indicator(self.one * dx))
estimator = mesh_seq.error_estimate()
self.assertAlmostEqual(estimator, 0.5) # 0.5 * (0.5 + 0.5)

def test_time_partition_same_timestep(self):
mesh_seq = self.mesh_seq(
time_partition=TimePartition(1.0, 2, [0.5, 0.5], ["field"])
)
mesh_seq._indicators = {"field": [[form2indicator(2 * self.one * dx)]]}
mesh_seq.indicators["field"][0][0].assign(form2indicator(2 * self.one * dx))
estimator = mesh_seq.error_estimate()
self.assertAlmostEqual(estimator, 1) # 2 * 0.5 * (0.5 + 0.5)

Expand All @@ -117,7 +118,9 @@ def test_time_partition_different_timesteps(self):
time_partition=TimePartition(1.0, 2, [0.5, 0.25], ["field"])
)
indicator = form2indicator(self.one * dx)
mesh_seq._indicators = {"field": [[indicator], 2 * [indicator]]}
mesh_seq.indicators["field"][0][0].assign(indicator)
mesh_seq.indicators["field"][1][0].assign(indicator)
mesh_seq.indicators["field"][1][1].assign(indicator)
estimator = mesh_seq.error_estimate()
self.assertAlmostEqual(
estimator, 1
Expand All @@ -128,7 +131,8 @@ def test_time_instant_multiple_fields(self):
time_partition=TimeInstant(["field1", "field2"], time=1.0)
)
indicator = form2indicator(self.one * dx)
mesh_seq._indicators = {"field1": [[indicator]], "field2": [[indicator]]}
mesh_seq.indicators["field1"][0][0].assign(indicator)
mesh_seq.indicators["field2"][0][0].assign(indicator)
estimator = mesh_seq.error_estimate()
self.assertAlmostEqual(estimator, 2) # 2 * (1 * (0.5 + 0.5))

Expand Down

0 comments on commit 4d141be

Please sign in to comment.