Skip to content

Commit

Permalink
Merge branch 'main' into 265_drop-ignores
Browse files Browse the repository at this point in the history
  • Loading branch information
jwallwork23 authored Dec 24, 2024
2 parents 4c42130 + 3f50662 commit 64f8939
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 10 deletions.
22 changes: 19 additions & 3 deletions goalie/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def _solve_adjoint(
adj_solver_kwargs=None,
get_adj_values=False,
test_checkpoint_qoi=False,
track_coefficients=False,
):
"""
A generator for solving an adjoint problem on a sequence of subintervals.
Expand All @@ -404,6 +405,10 @@ def _solve_adjoint(
:type get_adj_values: :class:`bool`
:kwarg test_checkpoint_qoi: solve over the final subinterval when checkpointing
so that the QoI value can be checked across runs
:kwarg: track_coefficients: if ``True``, coefficients in the variational form
will be stored whenever they change between export times. Only relevant for
goal-oriented error estimation on unsteady problems.
:type track_coefficients: :class:`bool`
:yields: the solution data of the forward and adjoint solves
:ytype: :class:`~.AdjointSolutionData`
"""
Expand Down Expand Up @@ -486,9 +491,20 @@ def wrapped_solver(subinterval, initial_condition_map, **kwargs):
# Initialise the solver generator
solver_gen = wrapped_solver(i, checkpoints[i], **solver_kwargs)

# Annotate tape on current subinterval
for _ in range(tp.num_timesteps_per_subinterval[i]):
next(solver_gen)
# Annotate tape on current subinterval.
# If we are using a goal-oriented approach on an unsteady problem, we need
# to keep track of the coefficients in the variational form to detect their
# changes between export times. In that case, we solve the forward problem
# sequentially between each export time and save changing coefficients.
# Otherwise, solve over the entire subinterval in one go.
if track_coefficients:
for j in range(tp.num_exports_per_subinterval[i] - 1):
for _ in range(tp.num_timesteps_per_export[i]):
next(solver_gen)
self._detect_changing_coefficients(j)
else:
for _ in range(tp.num_timesteps_per_subinterval[i]):
next(solver_gen)
pyadjoint.pause_annotation()

# Final solution is used as the initial condition for the next subinterval
Expand Down
63 changes: 62 additions & 1 deletion goalie/go_mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from collections.abc import Iterable
from copy import deepcopy

import numpy as np
import ufl
Expand All @@ -28,6 +29,8 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.estimator_values = []
self._forms = None
self._prev_form_coeffs = None
self._changed_form_coeffs = None

def read_forms(self, forms_dictionary):
"""
Expand Down Expand Up @@ -65,6 +68,52 @@ def forms(self):
)
return self._forms

@PETSc.Log.EventDecorator()
def _detect_changing_coefficients(self, export_idx):
"""
Detect whether coefficients other than the solution in the variational forms
change over time. If they do, store the changed coefficients so we can update
them in :meth:`~.GoalOrientedMeshSeq.indicate_errors`.
Changed coefficients are stored in a dictionary with the following structure:
``{field: {coeff_idx: {export_timestep_idx: coefficient}}}``, where
``coefficient=forms[field].coefficients()[coeff_idx]`` at export timestep
``export_timestep_idx``.
:arg export_idx: index of the current export timestep within the subinterval
:type export_idx: :class:`int`
"""
if export_idx == 0:
# Copy coefficients at subinterval's first export timestep
self._prev_form_coeffs = {
field: deepcopy(form.coefficients())
for field, form in self.forms.items()
}
self._changed_form_coeffs = {field: {} for field in self.fields}
else:
# Store coefficients that have changed since the previous export timestep
for field in self.fields:
# Coefficients at the current timestep
coeffs = self.forms[field].coefficients()
for coeff_idx, (coeff, init_coeff) in enumerate(
zip(coeffs, self._prev_form_coeffs[field])
):
# Skip solution fields since they are stored separately
if coeff.name().split("_old")[0] in self.time_partition.field_names:
continue
if not np.allclose(
coeff.vector().array(), init_coeff.vector().array()
):
if coeff_idx not in self._changed_form_coeffs[field]:
self._changed_form_coeffs[field][coeff_idx] = {
0: deepcopy(init_coeff)
}
self._changed_form_coeffs[field][coeff_idx][export_idx] = (
deepcopy(coeff)
)
# Use the current coeff for comparison in the next timestep
init_coeff.assign(coeff)

@PETSc.Log.EventDecorator()
def get_enriched_mesh_seq(self, enrichment_method="p", num_enrichments=1):
"""
Expand Down Expand Up @@ -193,7 +242,11 @@ def indicate_errors(

# Initialise adjoint solver generators on the MeshSeq and its enriched version
adj_sol_gen = self._solve_adjoint(**solver_kwargs)
adj_sol_gen_enriched = enriched_mesh_seq._solve_adjoint(**solver_kwargs)
# Track form coefficient changes in the enriched problem if the problem is unsteady
adj_sol_gen_enriched = enriched_mesh_seq._solve_adjoint(
track_coefficients=not self.steady,
**solver_kwargs,
)

FWD, ADJ = "forward", "adjoint"
FWD_OLD = "forward" if self.steady else "forward_old"
Expand Down Expand Up @@ -249,6 +302,14 @@ def indicate_errors(
)
u_star_e[f] -= u_star[f]

# Update other time-dependent form coefficients if they changed
# since the previous export timestep
emseq = enriched_mesh_seq
if not self.steady and emseq._changed_form_coeffs[f]:
for idx, coeffs in emseq._changed_form_coeffs[f].items():
if j in coeffs:
emseq.forms[f].coefficients()[idx].assign(coeffs[j])

# Evaluate error indicator
indi_e = indicator_fn(enriched_mesh_seq.forms[f], u_star_e[f])

Expand Down
82 changes: 82 additions & 0 deletions test/adjoint/test_adjoint_mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
UnitTriangleMesh,
VectorFunctionSpace,
dx,
inner,
solve,
)
from parameterized import parameterized
Expand Down Expand Up @@ -531,3 +532,84 @@ def test_enrichment_transfer(
target = Function(mesh_seq_e.function_spaces["field"][0])
transfer(source, target)
self.assertAlmostEqual(norm(source), norm(target))


class GoalOrientedBaseClass(unittest.TestCase):
"""
Base class for tests with a complete :class:`GoalOrientedMeshSeq`.
"""

def setUp(self):
self.field = "field"
self.time_partition = TimePartition(1.0, 1, 0.5, [self.field])
self.meshes = [UnitSquareMesh(1, 1)]

def go_mesh_seq(self, coeff_diff=0.0):
def get_function_spaces(mesh):
return {self.field: FunctionSpace(mesh, "R", 0)}

def get_initial_condition(mesh_seq):
return {self.field: Function(mesh_seq.function_spaces[self.field][0])}

def get_solver(mesh_seq):
def solver(index):
tp = mesh_seq.time_partition
R = FunctionSpace(mesh_seq[index], "R", 0)
dt = Function(R).assign(tp.timesteps[index])

u, u_ = mesh_seq.fields[self.field]
f = Function(R).assign(1.0001)
v = TestFunction(u.function_space())
F = (u - u_) / dt * v * dx - f * v * dx
mesh_seq.read_forms({self.field: F})

for _ in range(tp.num_timesteps_per_subinterval[index]):
solve(F == 0, u, ad_block_tag=self.field)
yield

u_.assign(u)
f += coeff_diff

return solver

def get_qoi(mesh_seq, i):
def end_time_qoi():
u = mesh_seq.fields[self.field][0]
return inner(u, u) * dx

return end_time_qoi

return GoalOrientedMeshSeq(
self.time_partition,
self.meshes,
get_initial_condition=get_initial_condition,
get_function_spaces=get_function_spaces,
get_solver=get_solver,
get_qoi=get_qoi,
qoi_type="end_time",
)


class TestDetectChangedCoefficients(GoalOrientedBaseClass):
"""
Unit tests for detecting changed coefficients using
:meth:`GoalOrientedMeshSeq._detect_changing_coefficients`.
"""

def test_constant_coefficients(self):
mesh_seq = self.go_mesh_seq()
# Solve over the first (only) subinterval
next(mesh_seq._solve_adjoint(track_coefficients=True))
# Check no coefficients have changed
self.assertEqual(mesh_seq._changed_form_coeffs, {self.field: {}})

def test_changed_coefficients(self):
# Change coefficient f by coeff_diff every timestep
coeff_diff = 1.1
mesh_seq = self.go_mesh_seq(coeff_diff=coeff_diff)
# Solve over the first (only) subinterval
next(mesh_seq._solve_adjoint(track_coefficients=True))
changed_coeffs_dict = mesh_seq._changed_form_coeffs[self.field]
coeff_idx = next(iter(changed_coeffs_dict))
for export_idx, f in changed_coeffs_dict[coeff_idx].items():
self.assertTrue(f.vector().gather() == [1.0001 + export_idx * coeff_diff])
37 changes: 31 additions & 6 deletions test/adjoint/test_fp_iteration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import unittest
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import numpy as np
import ufl
Expand Down Expand Up @@ -106,7 +106,12 @@ def test_noconvergence(self):
def test_no_late_convergence(self):
self.parameters.drop_out_converged = True
mesh_seq = self.mesh_seq(time_partition=TimePartition(1.0, 2, [0.5, 0.5], []))
mesh_seq.fixed_point_iteration(oscillating_adaptor0, parameters=self.parameters)
with patch("goalie.go_mesh_seq.GoalOrientedMeshSeq.forms") as mock_forms:
mock_forms.return_value = MagicMock()
mesh_seq.fixed_point_iteration(
oscillating_adaptor0,
parameters=self.parameters,
)
expected = [[1, 1], [2, 1], [1, 1], [2, 1], [1, 1], [2, 1]]
self.assertEqual(mesh_seq.element_counts, expected)
self.assertTrue(np.allclose(mesh_seq.converged, [False, False]))
Expand All @@ -116,7 +121,12 @@ def test_no_late_convergence(self):
def test_dropout(self, drop_out_converged):
self.parameters.drop_out_converged = drop_out_converged
mesh_seq = self.mesh_seq(time_partition=TimePartition(1.0, 2, [0.5, 0.5], []))
mesh_seq.fixed_point_iteration(oscillating_adaptor1, parameters=self.parameters)
with patch("goalie.go_mesh_seq.GoalOrientedMeshSeq.forms") as mock_forms:
mock_forms.return_value = MagicMock()
mesh_seq.fixed_point_iteration(
oscillating_adaptor1,
parameters=self.parameters,
)
expected = [[1, 1], [1, 2], [1, 1], [1, 2], [1, 1], [1, 2]]
self.assertEqual(mesh_seq.element_counts, expected)
self.assertTrue(np.allclose(mesh_seq.converged, [True, False]))
Expand Down Expand Up @@ -246,7 +256,12 @@ def check_convergence(self, mesh_seq):
def test_convergence_criteria_all_false(self):
self.parameters.convergence_criteria = "all"
mesh_seq = self.mesh_seq(time_partition=TimePartition(1.0, 1, 0.5, []))
mesh_seq.fixed_point_iteration(empty_adaptor, parameters=self.parameters)
with patch("goalie.go_mesh_seq.GoalOrientedMeshSeq.forms") as mock_forms:
mock_forms.return_value = MagicMock()
mesh_seq.fixed_point_iteration(
empty_adaptor,
parameters=self.parameters,
)
self.assertTrue(np.allclose(mesh_seq.element_counts, 1))
self.assertTrue(np.allclose(mesh_seq.converged, False))
self.assertTrue(np.allclose(mesh_seq.check_convergence, True))
Expand All @@ -258,7 +273,12 @@ def test_convergence_criteria_all_true(self):
get_qoi=constant_qoi,
)
mesh_seq.error_estimate = MagicMock(return_value=1)
mesh_seq.fixed_point_iteration(empty_adaptor, parameters=self.parameters)
with patch("goalie.go_mesh_seq.GoalOrientedMeshSeq.forms") as mock_forms:
mock_forms.return_value = MagicMock()
mesh_seq.fixed_point_iteration(
empty_adaptor,
parameters=self.parameters,
)
self.assertTrue(np.allclose(mesh_seq.element_counts, 1))
self.assertTrue(np.allclose(mesh_seq.converged, True))
self.assertTrue(np.allclose(mesh_seq.check_convergence, True))
Expand All @@ -272,5 +292,10 @@ def test_convergence_criteria_any(self, element, qoi, estimator):
mesh_seq.check_element_count_convergence = MagicMock(return_value=element)
mesh_seq.check_qoi_convergence = MagicMock(return_value=qoi)
mesh_seq.check_estimator_convergence = MagicMock(return_value=estimator)
mesh_seq.fixed_point_iteration(empty_adaptor, parameters=self.parameters)
with patch("goalie.go_mesh_seq.GoalOrientedMeshSeq.forms") as mock_forms:
mock_forms.return_value = MagicMock()
mesh_seq.fixed_point_iteration(
empty_adaptor,
parameters=self.parameters,
)
self.assertTrue(np.allclose(mesh_seq.check_convergence, True))

0 comments on commit 64f8939

Please sign in to comment.