Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clear fields and meshes after indicate error calculated per segment #224

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
35 changes: 33 additions & 2 deletions goalie/go_mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"""

from collections.abc import Iterable
from copy import deepcopy

import numpy as np
import ufl
from animate.interpolation import interpolate
from animate.utility import Mesh
from firedrake import Function, FunctionSpace, MeshHierarchy, TransferManager
from firedrake.petsc import PETSc

Expand Down Expand Up @@ -59,11 +61,14 @@ def get_enriched_mesh_seq(self, enrichment_method="p", num_enrichments=1):
)
meshes = [MeshHierarchy(mesh, num_enrichments)[-1] for mesh in self.meshes]
else:
meshes = self.meshes
meshes = [Mesh(mesh) for mesh in self.meshes]

# Create copy of time_partition
time_partition = deepcopy(self.time_partition)

# Construct object to hold enriched spaces
enriched_mesh_seq = type(self)(
self.time_partition,
time_partition,
meshes,
get_function_spaces=self._get_function_spaces,
get_initial_condition=self._get_initial_condition,
Expand Down Expand Up @@ -163,6 +168,7 @@ def indicate_errors(
FWD_OLD = "forward" if self.steady else "forward_old"
ADJ_NEXT = "adjoint" if self.steady else "adjoint_next"
P0_spaces = [FunctionSpace(mesh, "DG", 0) for mesh in self]

# Loop over each subinterval in reverse
for i in reversed(range(len(self))):
# Solve the adjoint problem on the current subinterval
Expand Down Expand Up @@ -228,6 +234,31 @@ def indicate_errors(
indi.interpolate(abs(indi))
self.indicators[f][i][j].interpolate(ufl.max_value(indi, 1.0e-16))

# discard current subinterval duplicate solution fields
if not self.steady:
for f in self.fields:
self.solutions[f][FWD_OLD].pop(-1)
self.solutions[f][ADJ_NEXT].pop(-1)
enriched_mesh_seq.solutions[f][FWD_OLD].pop(-1)
enriched_mesh_seq.solutions[f][ADJ_NEXT].pop(-1)

# delete current subinterval enriched mesh to reduce the memory footprint
if len(enriched_mesh_seq.meshes) > 1:
for f in self.fields:
enriched_mesh_seq._fs[f].pop(-1)
enriched_mesh_seq.meshes.pop(-1)
enriched_mesh_seq.time_partition.drop_last_subinterval()

# clear empty labels
for f in self.fields:
if self.steady:
self.solutions.labels = ("forward",)
self.solutions[f].pop("forward_old", None)
else:
self.solutions.labels = ("forward", "adjoint")
self.solutions[f].pop("forward_old", None)
self.solutions[f].pop("adjoint_next", None)

Comment on lines +237 to +261
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indicate_errors is already quite a bloated method so it'd be nice if these were put in separate method(s), similarly to the one you introduced for TimePartition.

return self.solutions, self.indicators

@PETSc.Log.EventDecorator()
Expand Down
15 changes: 15 additions & 0 deletions goalie/time_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,21 @@ def __ne__(self, other):
or not self.field_types == other.field_types
)

def drop_last_subinterval(self):
"""
Drop the last of multiple subinterval and reset lists and counters appropriately,
with a single subinterval always maintained for a valid :class:`~.TimePartition`.
"""
if self.num_subintervals > 1:
self.end_time = self.subintervals[-1][0]
self.interval = (self.start_time, self.end_time)
self.num_subintervals -= 1
self.timesteps.pop(-1)
self.subintervals.pop(-1)
self.num_timesteps_per_subinterval.pop(-1)
self.num_timesteps_per_export.pop(-1)
self.num_exports_per_subinterval.pop(-1)


class TimeInterval(TimePartition):
"""
Expand Down
65 changes: 65 additions & 0 deletions test/test_time_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,71 @@ def test_num_field_types_error(self):
self.assertEqual(str(cm.exception), msg)


class TestDropLastSubinterval(unittest.TestCase):
r"""
Tests related to :method:\~.drop_last_subinterval` of :class:`~.TimePartition`
"""

def setUp(self):
self.end_time = 1.0
self.field_names = ["field"]

def get_time_partition(self, n):
split = self.end_time / n
timesteps = [split if i % 2 else split / 2 for i in range(n)]
return TimePartition(self.end_time, n, timesteps, self.field_names)

def test_drop_last_subinterval(self):
n = 5
multi_subintervals = self.get_time_partition(n - 1)
expected_end_time = [1.0, 0.75, 0.5, 0.25, 0.25]
expected_interval = [
(0.0, 1.0),
(0.0, 0.75),
(0.0, 0.5),
(0.0, 0.25),
(0.0, 0.25),
]
expected_num_subintervals = [4, 3, 2, 1, 1]
expected_timesteps = [
[0.125, 0.25, 0.125, 0.25],
[0.125, 0.25, 0.125],
[0.125, 0.25],
[0.125],
[0.125],
]
expected_subintervals = [
[(0.0, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
[(0.0, 0.25), (0.25, 0.5), (0.5, 0.75)],
[(0.0, 0.25), (0.25, 0.5)],
[(0.0, 0.25)],
[(0.0, 0.25)],
]
expected_n_time_per_sub = [[2, 1, 2, 1], [2, 1, 2], [2, 1], [2], [2]]
expected_n_time_per_xpt = [[1, 1, 1, 1], [1, 1, 1], [1, 1], [1], [1]]
expected_n_xpt_per_sub = [[3, 2, 3, 2], [3, 2, 3], [3, 2], [3], [3]]
for i in range(n):
self.assertEqual(multi_subintervals.end_time, expected_end_time[i])
self.assertEqual(multi_subintervals.interval, expected_interval[i])
self.assertEqual(
multi_subintervals.num_subintervals, expected_num_subintervals[i]
)
self.assertEqual(multi_subintervals.timesteps, expected_timesteps[i])
self.assertEqual(multi_subintervals.subintervals, expected_subintervals[i])
self.assertEqual(
multi_subintervals.num_timesteps_per_subinterval,
expected_n_time_per_sub[i],
)
self.assertEqual(
multi_subintervals.num_timesteps_per_export, expected_n_time_per_xpt[i]
)
self.assertEqual(
multi_subintervals.num_exports_per_subinterval,
expected_n_xpt_per_sub[i],
)
multi_subintervals.drop_last_subinterval()


class TestStringFormatting(unittest.TestCase):
"""
Test that the :meth:`__str__`` and :meth:`__repr__`` methods work as intended for
Expand Down
Loading