diff --git a/.github/workflows/test_suite.yml b/.github/workflows/test_suite.yml index a383d378..8b5265cd 100644 --- a/.github/workflows/test_suite.yml +++ b/.github/workflows/test_suite.yml @@ -22,10 +22,9 @@ jobs: python $(which firedrake-clean) export GITHUB_ACTIONS_TEST_RUN=1 python -m coverage erase - python -m coverage run -a --source=goalie -m pytest -v --durations=20 test - python -m coverage run -a --source=goalie -m pytest -v --durations=10 test_adjoint + python -m coverage run --source=goalie -m pytest -v --durations=20 test python -m coverage report changed-files-patterns: | **/*.py **/*.msh - **/*.geo \ No newline at end of file + **/*.geo diff --git a/Makefile b/Makefile index 2cf64768..b1955acb 100644 --- a/Makefile +++ b/Makefile @@ -18,14 +18,12 @@ lint: test: lint @echo "Running test suite..." @cd test && make - @cd test_adjoint && make @echo "PASS" coverage: @echo "Generating coverage report..." @python3 -m coverage erase - @python3 -m coverage run -a --source=goalie -m pytest -v test - @python3 -m coverage run -a --source=goalie -m pytest -v test_adjoint + @python3 -m coverage run --source=goalie -m pytest -v test @python3 -m coverage html demo: diff --git a/goalie/mesh_seq.py b/goalie/mesh_seq.py index 1a89eb46..9d9e275b 100644 --- a/goalie/mesh_seq.py +++ b/goalie/mesh_seq.py @@ -203,7 +203,7 @@ def plot(self, fig=None, axes=None, **kwargs): from matplotlib.pyplot import subplots if self.dim != 2: - raise ValueError("MeshSeq plotting only supported in 2D") + raise ValueError("MeshSeq plotting only supported in 2D.") # Process kwargs interior_kw = {"edgecolor": "k"} @@ -336,7 +336,7 @@ def _outputs_consistent(self): self.debug( "Current and lagged solutions are equal. Does the" " solver yield before updating lagged solutions?" - ) + ) # noqa break assert isinstance(method_map, dict), f"get_{method} should return a dict" mesh_seq_fields = set(self.fields) diff --git a/test/Makefile b/test/Makefile index c828879d..9b969384 100644 --- a/test/Makefile +++ b/test/Makefile @@ -1,8 +1,8 @@ all: run run: - @echo "Running all tests..." - @python3 -m pytest -v -n auto --durations=20 . + @echo "Running all non-adjoint tests..." + @python3 -m pytest -v -n auto --durations=20 test_*.py @echo "Done." clean: @@ -13,3 +13,4 @@ clean: @rm -rf *.jpg *.png @rm -rf outputs* @echo "Done." + @cd adjoint && make clean diff --git a/test_adjoint/Makefile b/test/adjoint/Makefile similarity index 100% rename from test_adjoint/Makefile rename to test/adjoint/Makefile diff --git a/test/adjoint/conftest.py b/test/adjoint/conftest.py new file mode 100644 index 00000000..92101a44 --- /dev/null +++ b/test/adjoint/conftest.py @@ -0,0 +1,51 @@ +""" +Global pytest configuration for adjoint tests. + +**Disclaimer: some functions copied from firedrake/src/tests/conftest.py +""" + +import pyadjoint +import pytest + + +@pytest.fixture(scope="module", autouse=True) +def check_empty_tape(request): + """ + Check that the tape is empty at the end of each module + + **Disclaimer: copied from firedrake/src/tests/conftest.py + """ + + def fin(): + tape = pyadjoint.get_working_tape() + if tape is not None: + assert len(tape.get_blocks()) == 0 + + request.addfinalizer(fin) + + +@pytest.fixture(autouse=True) +def handle_taping(): + """ + **Disclaimer: copied from + firedrake/tests/regression/test_adjoint_operators.py + """ + yield + tape = pyadjoint.get_working_tape() + tape.clear_tape() + + +@pytest.fixture(autouse=True, scope="module") +def handle_annotation(): + """ + Since importing firedrake-adjoint modifies a global variable, we need to + pause annotations at the end of the module. + + **Disclaimer: copied from + firedrake/tests/regression/test_adjoint_operators.py + """ + if not pyadjoint.annotate_tape(): + pyadjoint.continue_annotation() + yield + if pyadjoint.annotate_tape(): + pyadjoint.pause_annotation() diff --git a/test_adjoint/examples/burgers.py b/test/adjoint/examples/burgers.py similarity index 100% rename from test_adjoint/examples/burgers.py rename to test/adjoint/examples/burgers.py diff --git a/test_adjoint/examples/mesh-with-hole.geo b/test/adjoint/examples/mesh-with-hole.geo similarity index 100% rename from test_adjoint/examples/mesh-with-hole.geo rename to test/adjoint/examples/mesh-with-hole.geo diff --git a/test_adjoint/examples/mesh-with-hole.msh b/test/adjoint/examples/mesh-with-hole.msh similarity index 100% rename from test_adjoint/examples/mesh-with-hole.msh rename to test/adjoint/examples/mesh-with-hole.msh diff --git a/test_adjoint/examples/point_discharge2d.py b/test/adjoint/examples/point_discharge2d.py similarity index 100% rename from test_adjoint/examples/point_discharge2d.py rename to test/adjoint/examples/point_discharge2d.py diff --git a/test_adjoint/examples/point_discharge3d.py b/test/adjoint/examples/point_discharge3d.py similarity index 100% rename from test_adjoint/examples/point_discharge3d.py rename to test/adjoint/examples/point_discharge3d.py diff --git a/test_adjoint/examples/steady_flow_past_cyl.py b/test/adjoint/examples/steady_flow_past_cyl.py similarity index 100% rename from test_adjoint/examples/steady_flow_past_cyl.py rename to test/adjoint/examples/steady_flow_past_cyl.py diff --git a/test_adjoint/test_adjoint.py b/test/adjoint/test_adjoint.py similarity index 100% rename from test_adjoint/test_adjoint.py rename to test/adjoint/test_adjoint.py diff --git a/test_adjoint/test_mesh_seq.py b/test/adjoint/test_adjoint_mesh_seq.py similarity index 97% rename from test_adjoint/test_mesh_seq.py rename to test/adjoint/test_adjoint_mesh_seq.py index 58f74b85..9ad23659 100644 --- a/test_adjoint/test_mesh_seq.py +++ b/test/adjoint/test_adjoint_mesh_seq.py @@ -359,8 +359,7 @@ def test_enrichment_error(self): mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0)) with self.assertRaises(ValueError) as cm: mesh_seq.get_enriched_mesh_seq(enrichment_method="q") - msg = "Enrichment method 'q' not supported." - self.assertEqual(str(cm.exception), msg) + self.assertEqual(str(cm.exception), "Enrichment method 'q' not supported.") def test_num_enrichments_error(self): mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0)) @@ -369,6 +368,16 @@ def test_num_enrichments_error(self): msg = "A positive number of enrichments is required." self.assertEqual(str(cm.exception), msg) + def test_form_error(self): + mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0)) + with self.assertRaises(AttributeError) as cm: + mesh_seq.forms() + msg = ( + "Forms have not been read in. Use read_forms({'field_name': F}) in" + " get_solver to read in the forms." + ) + self.assertEqual(str(cm.exception), msg) + def test_h_enrichment_error(self): end_time = 1.0 num_subintervals = 2 diff --git a/test_adjoint/test_demos.py b/test/adjoint/test_demos.py similarity index 98% rename from test_adjoint/test_demos.py rename to test/adjoint/test_demos.py index b3e843d1..0a4236a9 100644 --- a/test_adjoint/test_demos.py +++ b/test/adjoint/test_demos.py @@ -13,7 +13,7 @@ from goalie.log import * cwd = os.path.abspath(os.path.dirname(__file__)) -demo_dir = os.path.abspath(os.path.join(cwd, "..", "demos")) +demo_dir = os.path.abspath(os.path.join(cwd, "..", "..", "demos")) all_demos = glob.glob(os.path.join(demo_dir, "*.py")) # Modifications dictionary to cut down run time of demos: diff --git a/test_adjoint/test_fp_iteration.py b/test/adjoint/test_fp_iteration.py similarity index 100% rename from test_adjoint/test_fp_iteration.py rename to test/adjoint/test_fp_iteration.py diff --git a/test_adjoint/test_utils.py b/test/adjoint/test_utils.py similarity index 100% rename from test_adjoint/test_utils.py rename to test/adjoint/test_utils.py diff --git a/test/test_mesh_seq.py b/test/test_mesh_seq.py index 7c2287e6..ccef4fdc 100644 --- a/test/test_mesh_seq.py +++ b/test/test_mesh_seq.py @@ -5,32 +5,51 @@ import re import unittest -from firedrake import Function, FunctionSpace, UnitCubeMesh, UnitSquareMesh +from firedrake import ( + Function, + FunctionSpace, + UnitCubeMesh, + UnitIntervalMesh, + UnitSquareMesh, +) from parameterized import parameterized from goalie.mesh_seq import MeshSeq from goalie.time_partition import TimeInterval, TimePartition -class TestGeneric(unittest.TestCase): +class BaseClasses: """ - Generic unit tests for :class:`MeshSeq`. + Base classes for mesh sequence unit testing. """ - def setUp(self): - self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"]) - self.time_interval = TimeInterval(1.0, [0.5], ["field"]) + class MeshSeqTestCase(unittest.TestCase): + """ + Test case with a simple setUp method and mesh constructor. + """ - def test_setitem(self): - mesh1 = UnitSquareMesh(1, 1, diagonal="left") - mesh2 = UnitSquareMesh(1, 1, diagonal="right") - mesh_seq = MeshSeq(self.time_interval, [mesh1]) - self.assertEqual(mesh_seq[0], mesh1) - mesh_seq[0] = mesh2 - self.assertEqual(mesh_seq[0], mesh2) + def setUp(self): + self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"]) + self.time_interval = TimeInterval(1.0, [0.5], ["field"]) + + def trivial_mesh(self, dim): + try: + return { + 1: UnitIntervalMesh(1), + 2: UnitSquareMesh(1, 1), + 3: UnitCubeMesh(1, 1, 1), + }[dim] + except KeyError: + raise ValueError(f"Dimension {dim} not supported.") from None - def test_inconsistent_dim(self): - meshes = [UnitSquareMesh(1, 1), UnitCubeMesh(1, 1, 1)] + +class TestExceptions(BaseClasses.MeshSeqTestCase): + """ + Unit tests for exceptions raised by :class:`MeshSeq`. + """ + + def test_inconsistent_dim_error(self): + meshes = [self.trivial_mesh(2), self.trivial_mesh(3)] with self.assertRaises(ValueError) as cm: MeshSeq(self.time_partition, meshes) msg = "Meshes must all have the same topological dimension." @@ -38,44 +57,39 @@ def test_inconsistent_dim(self): @parameterized.expand(["get_function_spaces", "get_solver"]) def test_notimplemented_error(self, function_name): - mesh_seq = MeshSeq(self.time_interval, UnitSquareMesh(1, 1)) + mesh_seq = MeshSeq(self.time_interval, self.trivial_mesh(2)) with self.assertRaises(NotImplementedError) as cm: if function_name == "get_function_spaces": getattr(mesh_seq, function_name)(mesh_seq[0]) else: getattr(mesh_seq, function_name)() - msg = f"'{function_name}' needs implementing." - self.assertEqual(str(cm.exception), msg) + self.assertEqual(str(cm.exception), f"'{function_name}' needs implementing.") @parameterized.expand(["get_function_spaces", "get_initial_condition"]) def test_return_dict_error(self, method): - mesh = UnitSquareMesh(1, 1) kwargs = {method: lambda _: 0} with self.assertRaises(AssertionError) as cm: - MeshSeq(self.time_interval, mesh, **kwargs) - msg = f"{method} should return a dict" - self.assertEqual(str(cm.exception), msg) + MeshSeq(self.time_interval, self.trivial_mesh(2), **kwargs) + self.assertEqual(str(cm.exception), f"{method} should return a dict") @parameterized.expand(["get_function_spaces", "get_initial_condition"]) def test_missing_field_error(self, method): - mesh = UnitSquareMesh(1, 1) kwargs = {method: lambda _: {}} with self.assertRaises(AssertionError) as cm: - MeshSeq(self.time_interval, mesh, **kwargs) + MeshSeq(self.time_interval, self.trivial_mesh(2), **kwargs) msg = "missing fields {'field'} in " + f"{method}" self.assertEqual(str(cm.exception), msg) @parameterized.expand(["get_function_spaces", "get_initial_condition"]) def test_unexpected_field_error(self, method): - mesh = UnitSquareMesh(1, 1) kwargs = {method: lambda _: {"field": None, "extra_field": None}} with self.assertRaises(AssertionError) as cm: - MeshSeq(self.time_interval, mesh, **kwargs) + MeshSeq(self.time_interval, self.trivial_mesh(2), **kwargs) msg = "unexpected fields {'extra_field'} in " + f"{method}" self.assertEqual(str(cm.exception), msg) def test_solver_generator_error(self): - mesh = UnitSquareMesh(1, 1) + mesh = self.trivial_mesh(2) f_space = FunctionSpace(mesh, "CG", 1) kwargs = { "get_function_spaces": lambda _: {"field": f_space}, @@ -84,8 +98,32 @@ def test_solver_generator_error(self): } with self.assertRaises(AssertionError) as cm: MeshSeq(self.time_interval, mesh, **kwargs) - msg = "solver should yield" - self.assertEqual(str(cm.exception), msg) + self.assertEqual(str(cm.exception), "solver should yield") + + @parameterized.expand([1, 3]) + def test_plot_dim_error(self, dim): + mesh_seq = MeshSeq(self.time_interval, self.trivial_mesh(dim)) + with self.assertRaises(ValueError) as cm: + mesh_seq.plot() + self.assertEqual(str(cm.exception), "MeshSeq plotting only supported in 2D.") + + +class TestGeneric(BaseClasses.MeshSeqTestCase): + """ + Generic unit tests for :class:`MeshSeq`. + """ + + def setUp(self): + self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"]) + self.time_interval = TimeInterval(1.0, [0.5], ["field"]) + + def test_setitem(self): + mesh1 = UnitSquareMesh(1, 1, diagonal="left") + mesh2 = UnitSquareMesh(1, 1, diagonal="right") + mesh_seq = MeshSeq(self.time_interval, [mesh1]) + self.assertEqual(mesh_seq[0], mesh1) + mesh_seq[0] = mesh2 + self.assertEqual(mesh_seq[0], mesh2) def test_counting_2d(self): mesh_seq = MeshSeq(self.time_interval, [UnitSquareMesh(3, 3)]) @@ -98,16 +136,12 @@ def test_counting_3d(self): self.assertEqual(mesh_seq.count_vertices(), [64]) -class TestStringFormatting(unittest.TestCase): +class TestStringFormatting(BaseClasses.MeshSeqTestCase): """ Test that the :meth:`__str__` and :meth:`__repr__` methods work as intended for Goalie's :class:`MeshSeq` object. """ - def setUp(self): - self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"]) - self.time_interval = TimeInterval(1.0, [0.5], ["field"]) - def test_mesh_seq_time_interval_str(self): mesh_seq = MeshSeq(self.time_interval, [UnitSquareMesh(1, 1)]) got = re.sub("#[0-9]*", "?", str(mesh_seq)) diff --git a/test_adjoint/conftest.py b/test_adjoint/conftest.py deleted file mode 100644 index daf0543e..00000000 --- a/test_adjoint/conftest.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -Global pytest configuration. - -**Disclaimer: some functions copied from firedrake/src/tests/conftest.py -""" - -from subprocess import check_call - -import pyadjoint -import pytest - - -def parallel(item): - """ - Run a test in parallel. - - **Disclaimer: copied from firedrake/src/tests/conftest.py - - :arg item: the test item to run. - """ - from mpi4py import MPI - - if MPI.COMM_WORLD.size > 1: - raise RuntimeError("Parallel test can't be run within parallel environment") - marker = item.get_closest_marker("parallel") - if marker is None: - raise RuntimeError("Parallel test doesn't have parallel marker") - nprocs = marker.kwargs.get("nprocs", 3) - if nprocs < 2: - raise RuntimeError("Need at least two processes to run parallel test") - - # Only spew tracebacks on rank 0. - # Run xfailing tests to ensure that errors are reported to calling process. - call = [ - "mpiexec", - "-n", - "1", - "python", - "-m", - "pytest", - "--runxfail", - "-s", - "-q", - "%s::%s" % (item.fspath, item.name), - ] - call.extend( - [ - ":", - "-n", - "%d" % (nprocs - 1), - "python", - "-m", - "pytest", - "--runxfail", - "--tb=no", - "-q", - "%s::%s" % (item.fspath, item.name), - ] - ) - check_call(call) - - -def pytest_configure(config): - """ - Register an additional marker. - - **Disclaimer: copied from firedrake/src/tests/conftest.py - """ - config.addinivalue_line( - "markers", - "parallel(nprocs): mark test to run in parallel on nprocs processors", - ) - config.addinivalue_line( - "markers", - "slow: mark test as slow to run", - ) - - -def pytest_runtest_setup(item): - """ - **Disclaimer: copied from firedrake/src/tests/conftest.py - """ - if item.get_closest_marker("parallel"): - from mpi4py import MPI - - if MPI.COMM_WORLD.size > 1: - # Turn on source hash checking - from functools import partial - - from firedrake import parameters - - def _reset(check): - parameters["pyop2_options"]["check_src_hashes"] = check - - # Reset to current value when test is cleaned up - item.addfinalizer( - partial(_reset, parameters["pyop2_options"]["check_src_hashes"]) - ) - - parameters["pyop2_options"]["check_src_hashes"] = True - else: - # Blow away function arg in "master" process, to ensure - # this test isn't run on only one process. - item.obj = lambda *args, **kwargs: True - - -def pytest_runtest_call(item): - """ - **Disclaimer: copied from firedrake/src/tests/conftest.py - """ - from mpi4py import MPI - - if item.get_closest_marker("parallel") and MPI.COMM_WORLD.size == 1: - # Spawn parallel processes to run test - parallel(item) - - -@pytest.fixture(scope="module", autouse=True) -def check_empty_tape(request): - """ - Check that the tape is empty at the end of each module - - **Disclaimer: copied from firedrake/src/tests/conftest.py - """ - - def fin(): - tape = pyadjoint.get_working_tape() - if tape is not None: - assert len(tape.get_blocks()) == 0 - - request.addfinalizer(fin) - - -def pytest_runtest_teardown(item, nextitem): - """ - Clear caches after running a test - """ - from firedrake.tsfc_interface import clear_cache - from pyop2.caching import clear_memory_cache - from pyop2.mpi import COMM_WORLD - - clear_cache() - clear_memory_cache(COMM_WORLD) - - -@pytest.fixture(autouse=True) -def handle_taping(): - """ - **Disclaimer: copied from - firedrake/tests/regression/test_adjoint_operators.py - """ - yield - tape = pyadjoint.get_working_tape() - tape.clear_tape() - - -@pytest.fixture(autouse=True, scope="module") -def handle_annotation(): - """ - Since importing firedrake-adjoint modifies a global variable, we need to - pause annotations at the end of the module. - - **Disclaimer: copied from - firedrake/tests/regression/test_adjoint_operators.py - """ - if not pyadjoint.annotate_tape(): - pyadjoint.continue_annotation() - yield - if pyadjoint.annotate_tape(): - pyadjoint.pause_annotation()