Skip to content

Commit

Permalink
#220 modify _fs directly, unit test for TimePartition drop_last_subin…
Browse files Browse the repository at this point in the history
…terval
  • Loading branch information
acse-ej321 committed Nov 9, 2024
1 parent 4c0bcb6 commit 1b35879
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 11 deletions.
16 changes: 14 additions & 2 deletions goalie/go_mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
Drivers for goal-oriented error estimation on sequences of meshes.
"""

import os
from collections.abc import Iterable
from copy import deepcopy

import numpy as np
import psutil
import ufl
from animate.interpolation import interpolate
from animate.utility import Mesh
Expand Down Expand Up @@ -168,6 +170,14 @@ def indicate_errors(
FWD_OLD = "forward" if self.steady else "forward_old"
ADJ_NEXT = "adjoint" if self.steady else "adjoint_next"
P0_spaces = [FunctionSpace(mesh, "DG", 0) for mesh in self]

process = psutil.Process(os.getpid()) # ej321

def mem(i):
print(
f"Memory usage {i}: {process.memory_full_info().uss / 1024**2:.0f} MB"
) # ej321

# Loop over each subinterval in reverse
for i in reversed(range(len(self))):
# Solve the adjoint problem on the current subinterval
Expand Down Expand Up @@ -242,14 +252,16 @@ def indicate_errors(
enriched_mesh_seq.solutions[f][ADJ_NEXT].pop(-1)

# delete current subinterval enriched mesh to reduce the memory footprint
for f in self.fields:
enriched_mesh_seq._fs[f].pop(-1)

enriched_mesh_seq.meshes.pop(-1)
enriched_mesh_seq.time_partition.drop_last_subinterval()
enriched_mesh_seq._update_function_spaces()

# clear empty labels
for f in self.fields:
if self.steady:
self.solutions.labels = "forward"
self.solutions.labels = ("forward",)
self.solutions[f].pop("forward_old", None)
else:
self.solutions.labels = ("forward", "adjoint")
Expand Down
20 changes: 11 additions & 9 deletions goalie/time_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,16 +294,18 @@ def __ne__(self, other):

def drop_last_subinterval(self):
"""
Drop the last subinterval and reset lists and counters appropriately.
Drop the last of multiple subinterval and reset lists and counters appropriately,
with a single subinterval always maintained for a valid :class:`~.TimePartition`.
"""
self.end_time = self.subintervals[-1][0]
self.interval = (self.start_time, self.end_time)
self.num_subintervals -= 1
self.timesteps.pop(-1)
self.subintervals.pop(-1)
self.num_timesteps_per_subinterval.pop(-1)
self.num_timesteps_per_export.pop(-1)
self.num_exports_per_subinterval.pop(-1)
if self.num_subintervals > 1:
self.end_time = self.subintervals[-1][0]
self.interval = (self.start_time, self.end_time)
self.num_subintervals -= 1
self.timesteps.pop(-1)
self.subintervals.pop(-1)
self.num_timesteps_per_subinterval.pop(-1)
self.num_timesteps_per_export.pop(-1)
self.num_exports_per_subinterval.pop(-1)


class TimeInterval(TimePartition):
Expand Down
65 changes: 65 additions & 0 deletions test/test_time_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,71 @@ def test_num_field_types_error(self):
self.assertEqual(str(cm.exception), msg)


class TestDropLastSubinterval(unittest.TestCase):
r"""
Tests related to :method:\~.drop_last_subinterval` of :class:`~.TimePartition`
"""

def setUp(self):
self.end_time = 1.0
self.field_names = ["field"]

def get_time_partition(self, n):
split = self.end_time / n
timesteps = [split if i % 2 else split / 2 for i in range(n)]
return TimePartition(self.end_time, n, timesteps, self.field_names)

def test_drop_last_subinterval(self):
n = 5
multi_subintervals = self.get_time_partition(n - 1)
expected_end_time = [1.0, 0.75, 0.5, 0.25, 0.25]
expected_interval = [
(0.0, 1.0),
(0.0, 0.75),
(0.0, 0.5),
(0.0, 0.25),
(0.0, 0.25),
]
expected_num_subintervals = [4, 3, 2, 1, 1]
expected_timesteps = [
[0.125, 0.25, 0.125, 0.25],
[0.125, 0.25, 0.125],
[0.125, 0.25],
[0.125],
[0.125],
]
expected_subintervals = [
[(0.0, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
[(0.0, 0.25), (0.25, 0.5), (0.5, 0.75)],
[(0.0, 0.25), (0.25, 0.5)],
[(0.0, 0.25)],
[(0.0, 0.25)],
]
expected_n_time_per_sub = [[2, 1, 2, 1], [2, 1, 2], [2, 1], [2], [2]]
expected_n_time_per_xpt = [[1, 1, 1, 1], [1, 1, 1], [1, 1], [1], [1]]
expected_n_xpt_per_sub = [[3, 2, 3, 2], [3, 2, 3], [3, 2], [3], [3]]
for i in range(n):
self.assertEqual(multi_subintervals.end_time, expected_end_time[i])
self.assertEqual(multi_subintervals.interval, expected_interval[i])
self.assertEqual(
multi_subintervals.num_subintervals, expected_num_subintervals[i]
)
self.assertEqual(multi_subintervals.timesteps, expected_timesteps[i])
self.assertEqual(multi_subintervals.subintervals, expected_subintervals[i])
self.assertEqual(
multi_subintervals.num_timesteps_per_subinterval,
expected_n_time_per_sub[i],
)
self.assertEqual(
multi_subintervals.num_timesteps_per_export, expected_n_time_per_xpt[i]
)
self.assertEqual(
multi_subintervals.num_exports_per_subinterval,
expected_n_xpt_per_sub[i],
)
multi_subintervals.drop_last_subinterval()


class TestStringFormatting(unittest.TestCase):
"""
Test that the :meth:`__str__`` and :meth:`__repr__`` methods work as intended for
Expand Down

0 comments on commit 1b35879

Please sign in to comment.