From e651e6b3f64d0bb9ee47b9f56d8f341509b98ce8 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Thu, 28 Sep 2023 08:07:24 +0100 Subject: [PATCH 01/23] Fix imports --- goalie/adjoint.py | 2 +- goalie/mesh_seq.py | 2 +- goalie/options.py | 2 +- test/test_options.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/goalie/adjoint.py b/goalie/adjoint.py index f83a5153..2cf40617 100644 --- a/goalie/adjoint.py +++ b/goalie/adjoint.py @@ -3,7 +3,7 @@ """ import firedrake from firedrake.petsc import PETSc -from firedrake_adjoint import pyadjoint +from firedrake.adjoint import pyadjoint from .interpolation import project from .mesh_seq import MeshSeq from .options import GoalOrientedParameters diff --git a/goalie/mesh_seq.py b/goalie/mesh_seq.py index d868e1e3..927f8e4d 100644 --- a/goalie/mesh_seq.py +++ b/goalie/mesh_seq.py @@ -3,7 +3,7 @@ """ import firedrake from firedrake.petsc import PETSc -from firedrake.adjoint.solving import get_solve_blocks +from firedrake.adjoint_utils.solving import get_solve_blocks from pyadjoint import get_working_tape, Block from .interpolation import project from .log import pyrint, debug, warning, info, logger, DEBUG diff --git a/goalie/options.py b/goalie/options.py index a611fbc2..1450d546 100644 --- a/goalie/options.py +++ b/goalie/options.py @@ -1,5 +1,5 @@ from .utility import AttrDict -from firedrake.meshadapt import RiemannianMetric +from animate.adapt import RiemannianMetric __all__ = [ diff --git a/test/test_options.py b/test/test_options.py index f8ef7d74..4543537d 100644 --- a/test/test_options.py +++ b/test/test_options.py @@ -1,5 +1,5 @@ from firedrake import TensorFunctionSpace -from firedrake.meshadapt import RiemannianMetric +from animate.adapt import RiemannianMetric from goalie.options import * from utility import uniform_mesh import unittest From 7ae42e101488da7595ca7b52aa5de2057c3df609 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Thu, 28 Sep 2023 08:10:29 +0100 Subject: [PATCH 02/23] Fix imports for TestBlockLogic mocking --- test/test_mesh_seq.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/test_mesh_seq.py b/test/test_mesh_seq.py index e0f1d03d..981f9561 100644 --- a/test/test_mesh_seq.py +++ b/test/test_mesh_seq.py @@ -159,7 +159,7 @@ def setUp(self): self.time_interval, self.mesh, get_function_spaces=self.get_p0_spaces ) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_not_function(self, MockSolveBlock): solve_block = MockSolveBlock() block_variable = BlockVariable(1) @@ -169,7 +169,7 @@ def test_output_not_function(self, MockSolveBlock): msg = "Solve block for field 'field' on subinterval 0 has no outputs." self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_wrong_function_space(self, MockSolveBlock): solve_block = MockSolveBlock() block_variable = BlockVariable(Function(FunctionSpace(self.mesh, "CG", 1))) @@ -179,7 +179,7 @@ def test_output_wrong_function_space(self, MockSolveBlock): msg = "Solve block for field 'field' on subinterval 0 has no outputs." self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_wrong_name(self, MockSolveBlock): solve_block = MockSolveBlock() function_space = FunctionSpace(self.mesh, "DG", 0) @@ -190,7 +190,7 @@ def test_output_wrong_name(self, MockSolveBlock): msg = "Solve block for field 'field' on subinterval 0 has no outputs." self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_valid(self, MockSolveBlock): solve_block = MockSolveBlock() function_space = FunctionSpace(self.mesh, "DG", 0) @@ -198,7 +198,7 @@ def test_output_valid(self, MockSolveBlock): solve_block._outputs = [block_variable] self.assertIsNotNone(self.mesh_seq._output("field", 0, solve_block)) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_multiple_valid_error(self, MockSolveBlock): solve_block = MockSolveBlock() function_space = FunctionSpace(self.mesh, "DG", 0) @@ -212,7 +212,7 @@ def test_output_multiple_valid_error(self, MockSolveBlock): ) self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_not_function(self, MockSolveBlock): solve_block = MockSolveBlock() block_variable = BlockVariable(1) @@ -222,7 +222,7 @@ def test_dependency_not_function(self, MockSolveBlock): msg = "Solve block for field 'field' on subinterval 0 has no dependencies." self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_wrong_function_space(self, MockSolveBlock): solve_block = MockSolveBlock() block_variable = BlockVariable(Function(FunctionSpace(self.mesh, "CG", 1))) @@ -232,7 +232,7 @@ def test_dependency_wrong_function_space(self, MockSolveBlock): msg = "Solve block for field 'field' on subinterval 0 has no dependencies." self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_wrong_name(self, MockSolveBlock): solve_block = MockSolveBlock() function_space = FunctionSpace(self.mesh, "DG", 0) @@ -243,7 +243,7 @@ def test_dependency_wrong_name(self, MockSolveBlock): msg = "Solve block for field 'field' on subinterval 0 has no dependencies." self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_valid(self, MockSolveBlock): solve_block = MockSolveBlock() function_space = FunctionSpace(self.mesh, "DG", 0) @@ -251,7 +251,7 @@ def test_dependency_valid(self, MockSolveBlock): solve_block._dependencies = [block_variable] self.assertIsNotNone(self.mesh_seq._dependency("field", 0, solve_block)) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_multiple_valid_error(self, MockSolveBlock): solve_block = MockSolveBlock() function_space = FunctionSpace(self.mesh, "DG", 0) @@ -265,7 +265,7 @@ def test_dependency_multiple_valid_error(self, MockSolveBlock): ) self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_steady(self, MockSolveBlock): time_interval = TimeInterval(1.0, 0.5, "field", field_types="steady") mesh_seq = MeshSeq( From 193e10f0ed29ca285231edc07bf8ff4aac293ea4 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Thu, 28 Sep 2023 08:32:42 +0100 Subject: [PATCH 03/23] Move adjoint test setup to separate file --- test_adjoint/conftest.py | 31 ----------------------------- test_adjoint/setup_adjoint_tests.py | 29 +++++++++++++++++++++++++++ test_adjoint/test_mesh_seq.py | 3 +-- test_adjoint/test_utils.py | 1 + 4 files changed, 31 insertions(+), 33 deletions(-) create mode 100644 test_adjoint/setup_adjoint_tests.py diff --git a/test_adjoint/conftest.py b/test_adjoint/conftest.py index e2059478..aca2a714 100644 --- a/test_adjoint/conftest.py +++ b/test_adjoint/conftest.py @@ -128,37 +128,6 @@ def fin(): request.addfinalizer(fin) -@pytest.fixture(autouse=True) -def handle_taping(): - """ - **Disclaimer: copied from - firedrake/tests/regression/test_adjoint_operators.py - """ - yield - import pyadjoint - - tape = pyadjoint.get_working_tape() - tape.clear_tape() - if not pyadjoint.annotate_tape(): - pyadjoint.continue_annotation() - - -@pytest.fixture(autouse=True, scope="module") -def handle_exit_annotation(): - """ - Since importing firedrake_adjoint modifies a global variable, we need to - pause annotations at the end of the module. - - **Disclaimer: copied from - firedrake/tests/regression/test_adjoint_operators.py - """ - yield - import pyadjoint - - if pyadjoint.annotate_tape(): - pyadjoint.pause_annotation() - - def pytest_runtest_teardown(item, nextitem): """ Clear caches after running a test diff --git a/test_adjoint/setup_adjoint_tests.py b/test_adjoint/setup_adjoint_tests.py new file mode 100644 index 00000000..52b59470 --- /dev/null +++ b/test_adjoint/setup_adjoint_tests.py @@ -0,0 +1,29 @@ +import pyadjoint +import pytest + + +@pytest.fixture(autouse=True) +def handle_taping(): + """ + **Disclaimer: copied from + firedrake/tests/regression/test_adjoint_operators.py + """ + yield + tape = pyadjoint.get_working_tape() + tape.clear_tape() + + +@pytest.fixture(autouse=True, scope="module") +def handle_annotation(): + """ + Since importing firedrake-adjoint modifies a global variable, we need to + pause annotations at the end of the module. + + **Disclaimer: copied from + firedrake/tests/regression/test_adjoint_operators.py + """ + if not pyadjoint.annotate_tape(): + pyadjoint.continue_annotation() + yield + if pyadjoint.annotate_tape(): + pyadjoint.pause_annotation() diff --git a/test_adjoint/test_mesh_seq.py b/test_adjoint/test_mesh_seq.py index 6b500168..093c4fae 100644 --- a/test_adjoint/test_mesh_seq.py +++ b/test_adjoint/test_mesh_seq.py @@ -6,10 +6,9 @@ from goalie.log import * from goalie.mesh_seq import MeshSeq from goalie.time_partition import TimeInterval -import pyadjoint import logging -import pytest import unittest +from setup_adjoint_tests import * class TestGetSolveBlocks(unittest.TestCase): diff --git a/test_adjoint/test_utils.py b/test_adjoint/test_utils.py index 58a81d81..60aa1d6c 100644 --- a/test_adjoint/test_utils.py +++ b/test_adjoint/test_utils.py @@ -3,6 +3,7 @@ from goalie.adjoint import annotate_qoi import numpy as np import unittest +from setup_adjoint_tests import * class TestAdjointUtils(unittest.TestCase): From b1bc1964403b9687fd8a9226eacfef21b7f678d8 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Thu, 28 Sep 2023 08:51:23 +0100 Subject: [PATCH 04/23] Further import fixes --- goalie/mesh_seq.py | 19 ++++++++----------- test_adjoint/test_adjoint.py | 4 +--- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/goalie/mesh_seq.py b/goalie/mesh_seq.py index 927f8e4d..2a29fb47 100644 --- a/goalie/mesh_seq.py +++ b/goalie/mesh_seq.py @@ -4,7 +4,7 @@ import firedrake from firedrake.petsc import PETSc from firedrake.adjoint_utils.solving import get_solve_blocks -from pyadjoint import get_working_tape, Block +import pyadjoint from .interpolation import project from .log import pyrint, debug, warning, info, logger, DEBUG from .options import AdaptParameters @@ -328,9 +328,12 @@ def get_solve_blocks(self, field: str, subinterval: int) -> list: :arg field: name of the prognostic solution field :arg subinterval: subinterval index """ + tape = pyadjoint.get_working_tape() + if tape is None: + self.warning("Tape does not exist!") + return [] - # Get all blocks - blocks = get_working_tape().get_blocks() + blocks = tape.get_blocks() if len(blocks) == 0: self.warning("Tape has no blocks!") return blocks @@ -385,9 +388,7 @@ def get_solve_blocks(self, field: str, subinterval: int) -> list: ) return solve_blocks - def _output( - self, field: str, subinterval: int, solve_block: Block - ) -> firedrake.Function: + def _output(self, field, subinterval, solve_block): """ For a given solve block and solution field, get the block's outputs which corresponds to the solution from the current timestep. @@ -430,9 +431,7 @@ def _output( " outputs." ) - def _dependency( - self, field: str, subinterval: int, solve_block: Block - ) -> firedrake.Function: + def _dependency(self, field, subinterval, solve_block): """ For a given solve block and solution field, get the block's dependency which corresponds to the solution from the previous timestep. @@ -500,8 +499,6 @@ def solve_forward(self, solver_kwargs: dict = {}) -> AttrDict: :return solution: an :class:`~.AttrDict` containing solution fields and their lagged versions. """ - from firedrake_adjoint import pyadjoint - num_subintervals = len(self) function_spaces = self.function_spaces P = self.time_partition diff --git a/test_adjoint/test_adjoint.py b/test_adjoint/test_adjoint.py index a538a839..27b22641 100644 --- a/test_adjoint/test_adjoint.py +++ b/test_adjoint/test_adjoint.py @@ -3,6 +3,7 @@ """ from firedrake import * from goalie_adjoint import * +import pyadjoint import pytest import importlib import os @@ -96,9 +97,6 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): or as a time integral? :kwarg debug: toggle debugging mode """ - from firedrake_adjoint import pyadjoint - - # Debugging if debug: set_log_level(DEBUG) From f1f8bbc9f59e4841852ba9bb170536e4ef47256c Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Thu, 28 Sep 2023 09:52:04 +0100 Subject: [PATCH 05/23] Fix tape handling --- goalie/adjoint.py | 12 ++++++++++-- goalie/mesh_seq.py | 14 +++++++++----- test_adjoint/test_adjoint.py | 15 +++++++++++---- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/goalie/adjoint.py b/goalie/adjoint.py index 2cf40617..d94dd00e 100644 --- a/goalie/adjoint.py +++ b/goalie/adjoint.py @@ -276,8 +276,12 @@ def wrapped_solver(subinterval, ic, **kwargs): return solver(subinterval, init, **kwargs) # Clear tape - tape = pyadjoint.get_working_tape() - tape.clear_tape() + if pyadjoint.annotate_tape(): + tape = pyadjoint.get_working_tape() + if tape is not None: + tape.clear_tape() + else: + pyadjoint.continue_annotation() # Loop over subintervals in reverse seeds = None @@ -308,6 +312,7 @@ def wrapped_solver(subinterval, ic, **kwargs): block.adj_kwargs.update(adj_solver_kwargs) # Solve adjoint problem + tape = pyadjoint.get_working_tape() with PETSc.Log.Event("goalie.AdjointMeshSeq.solve_adjoint.evaluate_adj"): m = pyadjoint.enlisting.Enlist(self.controls) with pyadjoint.stop_annotating(): @@ -424,6 +429,9 @@ def wrapped_solver(subinterval, ic, **kwargs): "QoI values computed during checkpointing and annotated" f" run do not match ({J_chk} vs. {self.J})" ) + + if pyadjoint.annotate_tape(): + pyadjoint.pause_annotation() return solutions @staticmethod diff --git a/goalie/mesh_seq.py b/goalie/mesh_seq.py index 2a29fb47..f6d95e70 100644 --- a/goalie/mesh_seq.py +++ b/goalie/mesh_seq.py @@ -4,7 +4,7 @@ import firedrake from firedrake.petsc import PETSc from firedrake.adjoint_utils.solving import get_solve_blocks -import pyadjoint +from firedrake.adjoint import pyadjoint from .interpolation import project from .log import pyrint, debug, warning, info, logger, DEBUG from .options import AdaptParameters @@ -525,9 +525,13 @@ def solve_forward(self, solver_kwargs: dict = {}) -> AttrDict: } ) - # Clear tape - tape = pyadjoint.get_working_tape() - tape.clear_tape() + # Start annotating + if pyadjoint.annotate_tape(): + tape = pyadjoint.get_working_tape() + if tape is not None: + tape.clear_tape() + else: + pyadjoint.continue_annotation() # Loop over the subintervals checkpoint = self.initial_condition @@ -585,7 +589,7 @@ def solve_forward(self, solver_kwargs: dict = {}) -> AttrDict: ) # Clear the tape to reduce the memory footprint - tape.clear_tape() + pyadjoint.get_working_tape().clear_tape() return solutions diff --git a/test_adjoint/test_adjoint.py b/test_adjoint/test_adjoint.py index 27b22641..57420e86 100644 --- a/test_adjoint/test_adjoint.py +++ b/test_adjoint/test_adjoint.py @@ -135,16 +135,20 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): # Solve forward and adjoint without solve_adjoint pyrint("\n--- Adjoint solve on 1 subinterval using pyadjoint\n") + if not pyadjoint.annotate_tape(): + pyadjoint.continue_annotation() + tape = pyadjoint.get_working_tape() + tape.clear_tape() ic = mesh_seq.initial_condition controls = [pyadjoint.Control(value) for key, value in ic.items()] sols = mesh_seq.solver(0, ic) qoi = mesh_seq.get_qoi(sols, 0) J = mesh_seq.J if qoi_type == "time_integrated" else qoi() m = pyadjoint.enlisting.Enlist(controls) - tape = pyadjoint.get_working_tape() - with pyadjoint.stop_annotating(): - with tape.marked_nodes(m): - tape.evaluate_adj(markings=True) + assert pyadjoint.annotate_tape() + pyadjoint.pause_annotation() + with tape.marked_nodes(m): + tape.evaluate_adj(markings=True) # FIXME: Using mixed Functions as Controls not correct J_expected = float(J) @@ -216,6 +220,9 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): f"Adjoint values do not match at t=0 (error {err:.4e}.)" ) + tape = pyadjoint.get_working_tape() + tape.clear_tape() + def plot_solutions(problem, qoi_type, debug=True): """ From 339232161d66f02b597c20d3983704ee43dffafe Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Thu, 28 Sep 2023 09:52:32 +0100 Subject: [PATCH 06/23] Apply Black formatter --- goalie/go_mesh_seq.py | 4 +--- test_adjoint/test_fp_iteration.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/goalie/go_mesh_seq.py b/goalie/go_mesh_seq.py index d26538d6..d2e2273f 100644 --- a/goalie/go_mesh_seq.py +++ b/goalie/go_mesh_seq.py @@ -327,9 +327,7 @@ def fixed_point_iteration( if self.params.convergence_criteria == "all": if not converged: self.converged[:] = False - pyrint( - f"Failed to converge in {self.params.maxiter} iterations." - ) + pyrint(f"Failed to converge in {self.params.maxiter} iterations.") else: for i, conv in enumerate(self.converged): if not conv: diff --git a/test_adjoint/test_fp_iteration.py b/test_adjoint/test_fp_iteration.py index ff0e5f29..c8115cb0 100644 --- a/test_adjoint/test_fp_iteration.py +++ b/test_adjoint/test_fp_iteration.py @@ -106,7 +106,9 @@ def adaptor(mesh_seq, sols): expected = [[1, 1], [1, 2], [1, 1], [1, 2], [1, 1], [1, 2]] self.assertEqual(mesh_seq.element_counts, expected) self.assertTrue(np.allclose(mesh_seq.converged, [True, False])) - self.assertTrue(np.allclose(mesh_seq.check_convergence, [not drop_out_converged, True])) + self.assertTrue( + np.allclose(mesh_seq.check_convergence, [not drop_out_converged, True]) + ) def test_no_late_convergence(self): mesh1 = UnitSquareMesh(1, 1) @@ -186,7 +188,9 @@ def test_dropout(self, drop_out_converged): time_partition = TimePartition(1.0, 2, [0.5, 0.5], []) ap = GoalOrientedParameters(self.parameters) ap.update({"drop_out_converged": drop_out_converged}) - mesh_seq = self.mesh_seq(time_partition, mesh2, parameters=ap, qoi_type="end_time") + mesh_seq = self.mesh_seq( + time_partition, mesh2, parameters=ap, qoi_type="end_time" + ) def adaptor(mesh_seq, sols, indicators): mesh_seq[1] = mesh1 if mesh_seq.fp_iteration % 2 == 0 else mesh2 @@ -196,7 +200,9 @@ def adaptor(mesh_seq, sols, indicators): expected = [[1, 1], [1, 2], [1, 1], [1, 2], [1, 1], [1, 2]] self.assertEqual(mesh_seq.element_counts, expected) self.assertTrue(np.allclose(mesh_seq.converged, [True, False])) - self.assertTrue(np.allclose(mesh_seq.check_convergence, [not drop_out_converged, True])) + self.assertTrue( + np.allclose(mesh_seq.check_convergence, [not drop_out_converged, True]) + ) def test_no_late_convergence(self): mesh1 = UnitSquareMesh(1, 1) @@ -204,7 +210,9 @@ def test_no_late_convergence(self): time_partition = TimePartition(1.0, 2, [0.5, 0.5], []) ap = GoalOrientedParameters(self.parameters) ap.update({"drop_out_converged": True}) - mesh_seq = self.mesh_seq(time_partition, mesh2, parameters=ap, qoi_type="end_time") + mesh_seq = self.mesh_seq( + time_partition, mesh2, parameters=ap, qoi_type="end_time" + ) def adaptor(mesh_seq, sols, indicators): mesh_seq[0] = mesh1 if mesh_seq.fp_iteration % 2 == 0 else mesh2 @@ -221,7 +229,9 @@ def test_convergence_criteria_all(self): time_partition = TimePartition(1.0, 1, 0.5, []) ap = GoalOrientedParameters(self.parameters) ap.update({"convergence_criteria": "all"}) - mesh_seq = self.mesh_seq(time_partition, mesh, parameters=ap, qoi_type="end_time") + mesh_seq = self.mesh_seq( + time_partition, mesh, parameters=ap, qoi_type="end_time" + ) def adaptor(mesh_seq, sols, indicators): return [False] From ebb33b0b73f4fa03616144abd80ec1148a029d6a Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Thu, 28 Sep 2023 13:14:26 +0100 Subject: [PATCH 07/23] Animate is already installed in Docker image --- .github/workflows/build.yml | 7 ------- 1 file changed, 7 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 935a270d..9ae9809b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -27,13 +27,6 @@ jobs: uses: actions/setup-python@v2 with: python-version: 3.8 - - name: Install Animate - run: | - . /home/firedrake/firedrake/bin/activate - cd .. - git clone https://github.com/pyroteus/animate.git - cd animate - python -m pip install -e . - name: Install Goalie run: | . /home/firedrake/firedrake/bin/activate From dd42702e46f0dea3aa49dcccf868225e9f6f0f7c Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Fri, 6 Oct 2023 09:19:10 +0100 Subject: [PATCH 08/23] Various adjoint updates --- goalie/adjoint.py | 76 +++++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/goalie/adjoint.py b/goalie/adjoint.py index d94dd00e..b075821c 100644 --- a/goalie/adjoint.py +++ b/goalie/adjoint.py @@ -8,7 +8,7 @@ from .mesh_seq import MeshSeq from .options import GoalOrientedParameters from .time_partition import TimePartition -from .utility import AttrDict +from .utility import AttrDict, norm from .log import pyrint from collections.abc import Callable from functools import wraps @@ -238,8 +238,6 @@ def solve_adjoint( labels = ("forward", "forward_old", "adjoint") if not self.steady: labels += ("adjoint_next",) - if get_adj_values: - labels += ("adj_value",) solutions = AttrDict( { field: AttrDict( @@ -257,6 +255,16 @@ def solve_adjoint( for field in self.fields } ) + if get_adj_values: + for field in self.fields: + solutions[field]["adj_value"] = [] + for i, fs in enumerate(function_spaces[field]): + solutions[field]["adj_value"].append( + [ + firedrake.Cofunction(fs.dual(), name=f"{field}_adj_value") + for j in range(P.num_exports_per_subinterval[i] - 1) + ] + ) @PETSc.Log.EventDecorator("goalie.AdjointMeshSeq.solve_adjoint.evaluate_fwd") @wraps(solver) @@ -275,36 +283,37 @@ def wrapped_solver(subinterval, ic, **kwargs): self.controls = [pyadjoint.Control(init[field]) for field in self.fields] return solver(subinterval, init, **kwargs) - # Clear tape - if pyadjoint.annotate_tape(): - tape = pyadjoint.get_working_tape() - if tape is not None: - tape.clear_tape() - else: - pyadjoint.continue_annotation() - # Loop over subintervals in reverse - seeds = None + seeds = {} for i in reversed(range(num_subintervals)): stride = P.num_timesteps_per_export[i] num_exports = P.num_exports_per_subinterval[i] + # Clear tape and start annotation + if not pyadjoint.annotate_tape(): + pyadjoint.continue_annotation() + tape = pyadjoint.get_working_tape() + if tape is not None: + tape.clear_tape() + # Annotate tape on current subinterval checkpoint = wrapped_solver(i, checkpoints[i], **solver_kwargs) + pyadjoint.pause_annotation() # Get seed vector for reverse propagation if i == num_subintervals - 1: if self.qoi_type in ["end_time", "steady"]: + pyadjoint.continue_annotation() qoi = self.get_qoi(checkpoint, i) self.J = qoi(**qoi_kwargs) if np.isclose(float(self.J), 0.0): self.warning("Zero QoI. Is it implemented as intended?") + pyadjoint.pause_annotation() else: - with pyadjoint.stop_annotating(): - for field, fs in function_spaces.items(): - checkpoint[field].block_variable.adj_value = project( - seeds[field], fs[i], adjoint=True - ) + for field, fs in function_spaces.items(): + checkpoint[field].block_variable.adj_value = project( + seeds[field], fs[i], adjoint=True + ) # Update adjoint solver kwargs for field in self.fields: @@ -368,7 +377,7 @@ def wrapped_solver(subinterval, ic, **kwargs): # Adjoint action also comes from dependencies if get_adj_values and dep is not None: - sols.adj_value[i][j].assign(dep.adj_value.function) + sols.adj_value[i][j].assign(dep.adj_value) # The adjoint solution at the 'next' timestep is determined from the # adj_sol attribute of the next solve block @@ -392,32 +401,30 @@ def wrapped_solver(subinterval, ic, **kwargs): ) # Check non-zero adjoint solution/value - if np.isclose(firedrake.norm(solutions[field].adjoint[i][0]), 0.0): + if np.isclose(norm(solutions[field].adjoint[i][0]), 0.0): self.warning( f"Adjoint solution for field '{field}' on {self.th(i)}" " subinterval is zero." ) - if get_adj_values and np.isclose( - firedrake.norm(sols.adj_value[i][0]), 0.0 - ): + if get_adj_values and np.isclose(norm(sols.adj_value[i][0]), 0.0): self.warning( f"Adjoint action for field '{field}' on {self.th(i)}" " subinterval is zero." ) # Get adjoint action on each subinterval - seeds = { - field: firedrake.Function( - function_spaces[field][i], val=control.block_variable.adj_value - ) - for field, control in zip(self.fields, self.controls) - } - for field, seed in seeds.items(): - if not self.steady and np.isclose(firedrake.norm(seed), 0.0): - self.warning( - f"Adjoint action for field '{field}' on {self.th(i)}" - " subinterval is zero." + with pyadjoint.stop_annotating(): + for field, control in zip(self.fields, self.controls): + seeds[field] = firedrake.Cofunction( + function_spaces[field][i].dual() ) + if control.block_variable.adj_value is not None: + seeds[field].assign(control.block_variable.adj_value) + if not self.steady and np.isclose(norm(seeds[field]), 0.0): + self.warning( + f"Adjoint action for field '{field}' on {self.th(i)}" + " subinterval is zero." + ) # Clear the tape to reduce the memory footprint tape.clear_tape() @@ -430,8 +437,7 @@ def wrapped_solver(subinterval, ic, **kwargs): f" run do not match ({J_chk} vs. {self.J})" ) - if pyadjoint.annotate_tape(): - pyadjoint.pause_annotation() + tape.clear_tape() return solutions @staticmethod From c907eead22e9f30a1dd6dd1a37047ebc3f4bc341 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Fri, 13 Oct 2023 07:48:11 +0100 Subject: [PATCH 09/23] Proposed new way to compute cell contributions to error estimator --- goalie/error_estimation.py | 13 ++++++++++--- test/test_error_estimation.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/goalie/error_estimation.py b/goalie/error_estimation.py index e48e5197..654b138d 100644 --- a/goalie/error_estimation.py +++ b/goalie/error_estimation.py @@ -33,6 +33,7 @@ def form2indicator(F: ufl.form.Form) -> Function: P0 = FunctionSpace(mesh, "DG", 0) p0test = firedrake.TestFunction(P0) indicator = Function(P0) + mass_term = firedrake.TrialFunction(P0) * p0test * firedrake.dx # Contributions from surface integrals flux_terms = 0 @@ -44,8 +45,6 @@ def form2indicator(F: ufl.form.Form) -> Function: flux_terms += p0test("+") * integral.integrand() * dS flux_terms += p0test("-") * integral.integrand() * dS if flux_terms != 0: - dx = firedrake.dx - mass_term = firedrake.TrialFunction(P0) * p0test * dx sp = { "snes_type": "ksponly", "ksp_type": "preonly", @@ -59,7 +58,15 @@ def form2indicator(F: ufl.form.Form) -> Function: dx = firedrake.dx(integral.subdomain_id()) cell_terms += p0test * integral.integrand() * dx if cell_terms != 0: - indicator += firedrake.assemble(cell_terms) + cell_contrib = Function(P0) + sp = { + "snes_type": "ksponly", + "ksp_type": "preonly", + "pc_type": "lu", + "pc_factor_mat_solver_type": "mumps", + } + firedrake.solve(mass_term == cell_terms, cell_contrib, solver_parameters=sp) + indicator += cell_contrib return indicator diff --git a/test/test_error_estimation.py b/test/test_error_estimation.py index a6c4dc10..601967a8 100644 --- a/test/test_error_estimation.py +++ b/test/test_error_estimation.py @@ -51,7 +51,7 @@ def test_cell_integral(self): F = conditional(x + y < 1, 1, 0) * dx indicator = form2indicator(F) self.assertAlmostEqual(indicator.dat.data[0], 0) - self.assertAlmostEqual(indicator.dat.data[1], 0.5) + self.assertAlmostEqual(indicator.dat.data[1], 1) class TestIndicators2Estimator(ErrorEstimationTestCase): From ece851777e52b7edadb4236d25f3d0bfcd2c15eb Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Fri, 13 Oct 2023 08:04:20 +0100 Subject: [PATCH 10/23] Fix error estimation test values --- test/test_error_estimation.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/test/test_error_estimation.py b/test/test_error_estimation.py index 601967a8..465f1598 100644 --- a/test/test_error_estimation.py +++ b/test/test_error_estimation.py @@ -129,7 +129,7 @@ def test_unit_time_instant(self): time_instant = TimeInstant("field", time=1.0) indicator = form2indicator(self.one * dx) estimator = indicators2estimator({"field": [[indicator]]}, time_instant) - self.assertAlmostEqual(estimator, 1.0) + self.assertAlmostEqual(estimator, 2.0) # 1 * (1 + 1) @parameterized.expand([[False], [True]]) def test_unit_time_instant_abs(self, absolute_value): @@ -138,19 +138,21 @@ def test_unit_time_instant_abs(self, absolute_value): estimator = indicators2estimator( {"field": [[indicator]]}, time_instant, absolute_value=absolute_value ) - self.assertAlmostEqual(estimator, 1.0 if absolute_value else -1.0) + self.assertAlmostEqual( + estimator, 2.0 if absolute_value else -2.0 + ) # (-)1 * (1 + 1) def test_half_time_instant(self): time_instant = TimeInstant("field", time=0.5) indicator = form2indicator(self.one * dx) estimator = indicators2estimator({"field": [[indicator]]}, time_instant) - self.assertAlmostEqual(estimator, 0.5) + self.assertAlmostEqual(estimator, 1.0) # 0.5 * (1 + 1) def test_time_partition_same_timestep(self): time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"]) indicator = form2indicator(self.one * dx) estimator = indicators2estimator({"field": [2 * [indicator]]}, time_partition) - self.assertAlmostEqual(estimator, 1.0) + self.assertAlmostEqual(estimator, 2.0) # 0.5 * (1 + 1) + 0.5 * (1 + 1) def test_time_partition_different_timesteps(self): time_partition = TimePartition(1.0, 2, [0.5, 0.25], ["field"]) @@ -158,7 +160,7 @@ def test_time_partition_different_timesteps(self): estimator = indicators2estimator( {"field": [[indicator], 2 * [indicator]]}, time_partition ) - self.assertAlmostEqual(estimator, 1.0) + self.assertAlmostEqual(estimator, 2.0) # 0.5 * (1 + 1) + 0.25 * 2 * (1 + 1) def test_time_instant_multiple_fields(self): time_instant = TimeInstant(["field1", "field2"], time=1.0) @@ -166,7 +168,7 @@ def test_time_instant_multiple_fields(self): estimator = indicators2estimator( {"field1": [[indicator]], "field2": [[indicator]]}, time_instant ) - self.assertAlmostEqual(estimator, 2.0) + self.assertAlmostEqual(estimator, 4.0) # 2 * (1 * (1 + 1)) class TestGetDWRIndicator(ErrorEstimationTestCase): @@ -245,32 +247,32 @@ def test_convert_neither(self): adjoint_error = {"field": self.two} test_space = {"field": self.one.function_space()} indicator = get_dwr_indicator(self.F, adjoint_error, test_space=test_space) - self.assertAlmostEqual(indicator.dat.data[0], 1) - self.assertAlmostEqual(indicator.dat.data[1], 1) + self.assertAlmostEqual(indicator.dat.data[0], 2) + self.assertAlmostEqual(indicator.dat.data[1], 2) def test_convert_both(self): test_space = self.one.function_space() indicator = get_dwr_indicator(self.F, self.two, test_space=test_space) - self.assertAlmostEqual(indicator.dat.data[0], 1) - self.assertAlmostEqual(indicator.dat.data[1], 1) + self.assertAlmostEqual(indicator.dat.data[0], 2) + self.assertAlmostEqual(indicator.dat.data[1], 2) def test_convert_test_space(self): adjoint_error = {"field": self.two} test_space = self.one.function_space() indicator = get_dwr_indicator(self.F, adjoint_error, test_space=test_space) - self.assertAlmostEqual(indicator.dat.data[0], 1) - self.assertAlmostEqual(indicator.dat.data[1], 1) + self.assertAlmostEqual(indicator.dat.data[0], 2) + self.assertAlmostEqual(indicator.dat.data[1], 2) def test_convert_adjoint_error(self): test_space = {"Dos": self.one.function_space()} indicator = get_dwr_indicator(self.F, self.two, test_space=test_space) - self.assertAlmostEqual(indicator.dat.data[0], 1) - self.assertAlmostEqual(indicator.dat.data[1], 1) + self.assertAlmostEqual(indicator.dat.data[0], 2) + self.assertAlmostEqual(indicator.dat.data[1], 2) def test_convert_adjoint_error_no_test_space(self): indicator = get_dwr_indicator(self.F, self.two) - self.assertAlmostEqual(indicator.dat.data[0], 1) - self.assertAlmostEqual(indicator.dat.data[1], 1) + self.assertAlmostEqual(indicator.dat.data[0], 2) + self.assertAlmostEqual(indicator.dat.data[1], 2) def test_convert_adjoint_error_mismatch(self): test_space = {"field": self.one.function_space()} From 93a759ebe89119745ce9235b0b621a3cdde03ce5 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Fri, 13 Oct 2023 08:25:24 +0100 Subject: [PATCH 11/23] Rework utility and test to avoid Constant --- goalie/utility.py | 10 +++++----- test/test_utility.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/goalie/utility.py b/goalie/utility.py index 4af8c7e4..262aa679 100644 --- a/goalie/utility.py +++ b/goalie/utility.py @@ -184,14 +184,14 @@ def errornorm(u, uh: firedrake.Function, norm_type: str = "L2", **kwargs) -> flo :kwarg boundary: should the norm be computed over the domain boundary? """ - if len(u.ufl_shape) != len(uh.ufl_shape): - raise RuntimeError("Mismatching rank between u and uh.") - if not isinstance(uh, firedrake.Function): - raise TypeError(f"uh should be a Function, is a {type(uh).__name__}.") + raise TypeError(f"uh should be a Function, is a '{type(uh)}'.") if norm_type[0] == "l": if not isinstance(u, firedrake.Function): - raise TypeError(f"u should be a Function, is a {type(u).__name__}.") + raise TypeError(f"u should be a Function, is a '{type(u)}'.") + + if len(u.ufl_shape) != len(uh.ufl_shape): + raise RuntimeError("Mismatching rank between u and uh.") if isinstance(u, firedrake.Function): degree_u = u.function_space().ufl_element().degree() diff --git a/test/test_utility.py b/test/test_utility.py index 1c4a45c5..233e6642 100644 --- a/test/test_utility.py +++ b/test/test_utility.py @@ -194,14 +194,14 @@ def test_shape_error(self): def test_not_function_error(self): with self.assertRaises(TypeError) as cm: - errornorm(self.f, Constant(1.0)) - msg = "uh should be a Function, is a Constant." + errornorm(self.f, 1.0) + msg = "uh should be a Function, is a ''." self.assertEqual(str(cm.exception), msg) def test_not_function_error_lp(self): with self.assertRaises(TypeError) as cm: - errornorm(Constant(1.0), self.f, norm_type="l1") - msg = "u should be a Function, is a Constant." + errornorm(1.0, self.f, norm_type="l1") + msg = "u should be a Function, is a ''." self.assertEqual(str(cm.exception), msg) def test_mixed_space_invalid_norm_error(self): From ea79a104004d4657bc1e776ac6af06f7a52ffe05 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Fri, 13 Oct 2023 09:23:21 +0100 Subject: [PATCH 12/23] Add functions to map between Functions and Cofunctions; use in norm --- goalie/utility.py | 78 ++++++++++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 28 deletions(-) diff --git a/goalie/utility.py b/goalie/utility.py index 262aa679..9b8193cc 100644 --- a/goalie/utility.py +++ b/goalie/utility.py @@ -105,27 +105,50 @@ def assemble_mass_matrix( return firedrake.assemble(lhs).petscmat -@PETSc.Log.EventDecorator("goalie.norm") -def norm(v: firedrake.Function, norm_type: str = "L2", **kwargs) -> float: +def cofunction2function(c): + """ + Map a :class:`Cofunction` to a :class:`Function`. + """ + f = firedrake.Function(c.function_space().dual()) + if isinstance(f.dat.data_with_halos, tuple): + for i, arr in enumerate(f.dat.data_with_halos): + arr[:] = c.dat.data_with_halos[i] + else: + f.dat.data_with_halos[:] = c.dat.data_with_halos + return f + + +def function2cofunction(f): + """ + Map a :class:`Function` to a :class:`Cofunction`. + """ + c = firedrake.Cofunction(f.function_space().dual()) + if isinstance(c.dat.data_with_halos, tuple): + for i, arr in enumerate(c.dat.data_with_halos): + arr[:] = f.dat.data_with_halos[i] + else: + c.dat.data_with_halos[:] = f.dat.data_with_halos + return c + + +@PETSc.Log.EventDecorator() +def norm(v, norm_type="L2", **kwargs): r""" - Overload :func:`firedrake.norms.norm` to - allow for :math:`\ell^p` norms. + Overload :func:`firedrake.norms.norm` to allow for :math:`\ell^p` norms. - Note that this version is case sensitive, - i.e. ``'l2'`` and ``'L2'`` will give + Note that this version is case sensitive, i.e. ``'l2'`` and ``'L2'`` will give different results in general. - :arg v: the :class:`firedrake.function.Function` - to take the norm of - :kwarg norm_type: choose from ``'l1'``, ``'l2'``, - ``'linf'``, ``'L2'``, ``'Linf'``, ``'H1'``, - ``'Hdiv'``, ``'Hcurl'``, or any ``'Lp'`` with - :math:`p >= 1`. - :kwarg condition: a UFL condition for specifying - a subdomain to compute the norm over - :kwarg boundary: should the norm be computed over - the domain boundary? + :arg v: the :class:`firedrake.function.Function` or + :class:`firedrake.cofunction.Cofunction` to take the norm of + :kwarg norm_type: choose from ``'l1'``, ``'l2'``, ``'linf'``, ``'L2'``, ``'Linf'``, + ``'H1'``, ``'Hdiv'``, ``'Hcurl'``, or any ``'Lp'`` with :math:`p >= 1`. + :kwarg condition: a UFL condition for specifying a subdomain to compute the norm + over + :kwarg boundary: should the norm be computed over the domain boundary? """ + if isinstance(v, firedrake.Cofunction): + v = cofunction2function(v) boundary = kwargs.get("boundary", False) condition = kwargs.get("condition", firedrake.Constant(1.0)) norm_codes = {"l1": 0, "l2": 2, "linf": 3} @@ -165,25 +188,24 @@ def norm(v: firedrake.Function, norm_type: str = "L2", **kwargs) -> float: return firedrake.assemble(condition * integrand ** (p / 2) * dX) ** (1 / p) -@PETSc.Log.EventDecorator("goalie.errornorm") -def errornorm(u, uh: firedrake.Function, norm_type: str = "L2", **kwargs) -> float: +@PETSc.Log.EventDecorator() +def errornorm(u, uh, norm_type="L2", **kwargs): r""" - Overload :func:`firedrake.norms.errornorm` - to allow for :math:`\ell^p` norms. + Overload :func:`firedrake.norms.errornorm` to allow for :math:`\ell^p` norms. - Note that this version is case sensitive, - i.e. ``'l2'`` and ``'L2'`` will give + Note that this version is case sensitive, i.e. ``'l2'`` and ``'L2'`` will give different results in general. :arg u: the 'true' value :arg uh: the approximation of the 'truth' - :kwarg norm_type: choose from ``'l1'``, ``'l2'``, - ``'linf'``, ``'L2'``, ``'Linf'``, ``'H1'``, - ``'Hdiv'``, ``'Hcurl'``, or any ``'Lp'`` with - :math:`p >= 1`. - :kwarg boundary: should the norm be computed over - the domain boundary? + :kwarg norm_type: choose from ``'l1'``, ``'l2'``, ``'linf'``, ``'L2'``, ``'Linf'``, + ``'H1'``, ``'Hdiv'``, ``'Hcurl'``, or any ``'Lp'`` with :math:`p >= 1`. + :kwarg boundary: should the norm be computed over the domain boundary? """ + if isinstance(u, firedrake.Cofunction): + u = cofunction2function(u) + if isinstance(uh, firedrake.Cofunction): + uh = cofunction2function(uh) if not isinstance(uh, firedrake.Function): raise TypeError(f"uh should be a Function, is a '{type(uh)}'.") if norm_type[0] == "l": From dea223c96572761d7b189e0ad222b3dee9491dd7 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Fri, 13 Oct 2023 09:23:53 +0100 Subject: [PATCH 13/23] Use Function-Cofunction mapping when projecting --- goalie/interpolation.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/goalie/interpolation.py b/goalie/interpolation.py index 2fbdebf9..167a0c04 100644 --- a/goalie/interpolation.py +++ b/goalie/interpolation.py @@ -1,7 +1,7 @@ """ Driver functions for mesh-to-mesh data transfer. """ -from .utility import assemble_mass_matrix +from .utility import assemble_mass_matrix, cofunction2function, function2cofunction import firedrake from firedrake.petsc import PETSc from petsc4py import PETSc as petsc4py @@ -29,8 +29,13 @@ def project( seek to project into :kwarg adjoint: apply the transposed projection operator? """ - if not isinstance(source, firedrake.Function): - raise NotImplementedError("Can only currently project Functions.") # TODO + if not isinstance(source, (firedrake.Function, firedrake.Cofunction)): + raise NotImplementedError( + "Can only currently project Functions and Cofunctions." + ) # TODO + adj_value = isinstance(source, firedrake.Cofunction) + if adj_value: + source = cofunction2function(source) Vs = source.function_space() if isinstance(target_space, firedrake.Function): target = target_space @@ -66,7 +71,10 @@ def project( ) # Apply projector - return (_project_adjoint if adjoint else _project)(source, target, **kwargs) + target = (_project_adjoint if adjoint else _project)(source, target, **kwargs) + if adj_value: + target = function2cofunction(target) + return target @PETSc.Log.EventDecorator("goalie.interpolation.project") From 73aa881a17808aa8e7814ff584e64a52f8c6778a Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Fri, 13 Oct 2023 09:24:15 +0100 Subject: [PATCH 14/23] Update error message in interpolation test --- test/test_interpolation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_interpolation.py b/test/test_interpolation.py index 1e0400fc..5d32bdc0 100644 --- a/test/test_interpolation.py +++ b/test/test_interpolation.py @@ -23,10 +23,9 @@ def sinusoid(self, source=True): def test_notimplemented_error(self): Vs = FunctionSpace(self.source_mesh, "CG", 1) Vt = FunctionSpace(self.target_mesh, "CG", 1) - source = Function(Vs) with self.assertRaises(NotImplementedError) as cm: - project(2 * source, Vt) - msg = "Can only currently project Functions." + project(2 * Function(Vs), Vt) + msg = "Can only currently project Functions and Cofunctions." self.assertEqual(str(cm.exception), msg) @parameterized.expand([[False], [True]]) From 6f276982bd8638ce48722dc29bb684e7ccd6e89d Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Fri, 13 Oct 2023 09:26:24 +0100 Subject: [PATCH 15/23] Use Cofunction in test_adjoint --- test_adjoint/test_adjoint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test_adjoint/test_adjoint.py b/test_adjoint/test_adjoint.py index 57420e86..dbd078ef 100644 --- a/test_adjoint/test_adjoint.py +++ b/test_adjoint/test_adjoint.py @@ -160,7 +160,8 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): adj_sols_expected[field] = solve_blocks[0].adj_sol.copy(deepcopy=True) if not steady: dep = mesh_seq._dependency(field, 0, solve_blocks[0]) - adj_values_expected[field] = Function(fs[0], val=dep.adj_value) + adj_values_expected[field] = Cofunction(fs[0].dual()) + adj_values_expected[field].assign(dep.adj_value) # Loop over having one or two subintervals for N in range(1, 2 if steady else 3): From 23b5e634b6a13997715d9bb34d41d4787d9c0b04 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Fri, 13 Oct 2023 09:35:13 +0100 Subject: [PATCH 16/23] Replace Constants with Functions from R-space --- demos/burgers-hessian.py | 7 ++++--- demos/burgers.py | 10 ++++++---- demos/burgers1.py | 7 ++++--- demos/burgers2.py | 7 ++++--- demos/burgers_ee.py | 7 ++++--- demos/burgers_oo.py | 10 ++++++---- demos/burgers_time_integrated.py | 12 ++++++----- demos/gray_scott.py | 11 +++++----- demos/gray_scott_split.py | 11 +++++----- demos/point_discharge2d-goal_oriented.py | 9 +++++++-- demos/point_discharge2d-hessian.py | 9 +++++++-- demos/point_discharge2d.py | 9 +++++++-- demos/solid_body_rotation.py | 8 ++++++-- test_adjoint/examples/burgers.py | 8 +++++--- test_adjoint/examples/point_discharge2d.py | 15 ++++++++------ test_adjoint/examples/point_discharge3d.py | 13 ++++++++---- test_adjoint/examples/steady_flow_past_cyl.py | 3 ++- test_adjoint/test_fp_iteration.py | 7 +++---- test_adjoint/test_utils.py | 20 ++++++++++++++----- 19 files changed, 117 insertions(+), 66 deletions(-) diff --git a/demos/burgers-hessian.py b/demos/burgers-hessian.py index b17874f1..32987c73 100644 --- a/demos/burgers-hessian.py +++ b/demos/burgers-hessian.py @@ -27,10 +27,11 @@ def get_form(mesh_seq): def form(index, solutions): u, u_ = solutions["u"] P = mesh_seq.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) diff --git a/demos/burgers.py b/demos/burgers.py index f0e8ca2f..6b8c88c8 100644 --- a/demos/burgers.py +++ b/demos/burgers.py @@ -50,17 +50,19 @@ def get_function_spaces(mesh): # # Timestepping information associated with a given subinterval # can be accessed via the :attr:`TimePartition` attribute of -# the :class:`MeshSeq`. :: +# the :class:`MeshSeq`. For technical reasons, we need to create a :class:`Function` +# in the `'R'` space (of real numbers) to hold constants. :: def get_form(mesh_seq): def form(index, solutions): u, u_ = solutions["u"] P = mesh_seq.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) diff --git a/demos/burgers1.py b/demos/burgers1.py index c66eba02..d077bf88 100644 --- a/demos/burgers1.py +++ b/demos/burgers1.py @@ -28,10 +28,11 @@ def get_form(mesh_seq): def form(index, solutions): u, u_ = solutions["u"] P = mesh_seq.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) diff --git a/demos/burgers2.py b/demos/burgers2.py index 9b8c298b..5215c9cd 100644 --- a/demos/burgers2.py +++ b/demos/burgers2.py @@ -28,10 +28,11 @@ def get_form(mesh_seq): def form(index, solutions): u, u_ = solutions["u"] P = mesh_seq.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) diff --git a/demos/burgers_ee.py b/demos/burgers_ee.py index c1903907..9cbf3724 100644 --- a/demos/burgers_ee.py +++ b/demos/burgers_ee.py @@ -43,10 +43,11 @@ def get_form(mesh_seq): def form(index, solutions): u, u_ = solutions["u"] P = mesh_seq.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) diff --git a/demos/burgers_oo.py b/demos/burgers_oo.py index f016b1fc..85aa21e8 100644 --- a/demos/burgers_oo.py +++ b/demos/burgers_oo.py @@ -33,10 +33,11 @@ def get_form(self): def form(index, solutions): u, u_ = solutions["u"] P = self.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) @@ -84,7 +85,8 @@ def get_initial_condition(self): @annotate_qoi def get_qoi(self, solutions, i): - dt = Constant(self.time_partition[i].timestep) + R = FunctionSpace(self[i], "R", 0) + dt = Function(R).assign(self.time_partition[i].timestep) def end_time_qoi(): u = solutions["u"] diff --git a/demos/burgers_time_integrated.py b/demos/burgers_time_integrated.py index 9dbecbc1..684460c5 100644 --- a/demos/burgers_time_integrated.py +++ b/demos/burgers_time_integrated.py @@ -23,10 +23,11 @@ def get_form(mesh_seq): def form(index, solutions): u, u_ = solutions["u"] P = mesh_seq.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) @@ -93,12 +94,13 @@ def solver(index, ic): # \;\mathrm dy\;\mathrm dt. # # Note that in this case we multiply by the timestep. -# It is wrapped in a :class:`Constant` to avoid +# It is wrapped in a :class:`Function` from `'R'` space to avoid # recompilation if the value is changed. :: def get_qoi(mesh_seq, solutions, i): - dt = Constant(mesh_seq.time_partition[i].timestep) + R = FunctionSpace(mesh_seq[i], "R", 0) + dt = Function(R).assign(mesh_seq.time_partition[i].timestep) def time_integrated_qoi(t): u = solutions["u"] diff --git a/demos/gray_scott.py b/demos/gray_scott.py index 4d4f344d..b9df29ea 100644 --- a/demos/gray_scott.py +++ b/demos/gray_scott.py @@ -57,11 +57,12 @@ def form(index, sols): psi_a, psi_b = TestFunctions(mesh_seq.function_spaces["ab"][index]) # Define constants - dt = Constant(mesh_seq.time_partition[index].timestep) - D_a = Constant(8.0e-05) - D_b = Constant(4.0e-05) - gamma = Constant(0.024) - kappa = Constant(0.06) + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(mesh_seq.time_partition[index].timestep) + D_a = Function(R).assign(8.0e-05) + D_b = Function(R).assign(4.0e-05) + gamma = Function(R).assign(0.024) + kappa = Function(R).assign(0.06) # Write the two equations in variational form F = ( diff --git a/demos/gray_scott_split.py b/demos/gray_scott_split.py index f76f2495..f028bfad 100644 --- a/demos/gray_scott_split.py +++ b/demos/gray_scott_split.py @@ -57,11 +57,12 @@ def form(index, sols): psi_b = TestFunction(mesh_seq.function_spaces["b"][index]) # Define constants - dt = Constant(mesh_seq.time_partition[index].timestep) - D_a = Constant(8.0e-05) - D_b = Constant(4.0e-05) - gamma = Constant(0.024) - kappa = Constant(0.06) + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(mesh_seq.time_partition[index].timestep) + D_a = Function(R).assign(8.0e-05) + D_b = Function(R).assign(4.0e-05) + gamma = Function(R).assign(0.024) + kappa = Function(R).assign(0.06) # Write the two equations in variational form F_a = ( diff --git a/demos/point_discharge2d-goal_oriented.py b/demos/point_discharge2d-goal_oriented.py index f6f942aa..f398c110 100644 --- a/demos/point_discharge2d-goal_oriented.py +++ b/demos/point_discharge2d-goal_oriented.py @@ -36,11 +36,16 @@ def get_form(mesh_seq): def form(index, sols): c, c_ = sols["c"] function_space = mesh_seq.function_spaces["c"][index] - D = Constant(0.1) - u = Constant(as_vector([1, 0])) h = CellSize(mesh_seq[index]) S = source(mesh_seq[index]) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + D = Function(R).assign(0.1) + u_x = Function(R).assign(1.0) + u_y = Function(R).assign(0.0) + u = as_vector([u_x, u_y]) + # SUPG stabilisation parameter unorm = sqrt(dot(u, u)) tau = 0.5 * h / unorm diff --git a/demos/point_discharge2d-hessian.py b/demos/point_discharge2d-hessian.py index 7b3ef3ba..ee0c0b9e 100644 --- a/demos/point_discharge2d-hessian.py +++ b/demos/point_discharge2d-hessian.py @@ -40,11 +40,16 @@ def get_form(mesh_seq): def form(index, sols): c, c_ = sols["c"] function_space = mesh_seq.function_spaces["c"][index] - D = Constant(0.1) - u = Constant(as_vector([1, 0])) h = CellSize(mesh_seq[index]) S = source(mesh_seq[index]) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + D = Function(R).assign(0.1) + u_x = Function(R).assign(1.0) + u_y = Function(R).assign(0.0) + u = as_vector([u_x, u_y]) + # SUPG stabilisation parameter unorm = sqrt(dot(u, u)) tau = 0.5 * h / unorm diff --git a/demos/point_discharge2d.py b/demos/point_discharge2d.py index d053d3e9..32e85378 100644 --- a/demos/point_discharge2d.py +++ b/demos/point_discharge2d.py @@ -71,11 +71,16 @@ def get_form(mesh_seq): def form(index, sols): c, c_ = sols["c"] function_space = mesh_seq.function_spaces["c"][index] - D = Constant(0.1) - u = Constant(as_vector([1, 0])) h = CellSize(mesh_seq[index]) S = source(mesh_seq[index]) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + D = Function(R).assign(0.1) + u_x = Function(R).assign(1.0) + u_y = Function(R).assign(0.0) + u = as_vector([u_x, u_y]) + # SUPG stabilisation parameter unorm = sqrt(dot(u, u)) tau = 0.5 * h / unorm diff --git a/demos/solid_body_rotation.py b/demos/solid_body_rotation.py index e8f656fc..52e15639 100644 --- a/demos/solid_body_rotation.py +++ b/demos/solid_body_rotation.py @@ -151,10 +151,14 @@ def form(index, sols, field="c"): V = mesh_seq.function_spaces[field][index] mesh = mesh_seq[index] + # Define velocity field x, y = SpatialCoordinate(mesh) u = as_vector([-y, x]) - dt = Constant(mesh_seq.time_partition[index].timestep) - theta = Constant(0.5) + + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(mesh_seq.time_partition[index].timestep) + theta = Function(R).assign(0.5) psi = TrialFunction(V) phi = TestFunction(V) diff --git a/test_adjoint/examples/burgers.py b/test_adjoint/examples/burgers.py index e9630a59..fea98b83 100644 --- a/test_adjoint/examples/burgers.py +++ b/test_adjoint/examples/burgers.py @@ -39,8 +39,9 @@ def form(i, sols): u, u_ = sols["uv_2d"] dt = self.time_partition[i].timestep fs = self.function_spaces["uv_2d"][i] - dtc = Constant(dt) - nu = Constant(0.0001) + R = FunctionSpace(self[i], "R", 0) + dtc = Function(R).assign(dt) + nu = Function(R).assign(0.0001) v = TestFunction(fs) F = ( inner((u - u_) / dtc, v) * dx @@ -104,7 +105,8 @@ def get_qoi(self, sol, i): norm over the right hand boundary. """ - dtc = Constant(self.time_partition[i].timestep) + R = FunctionSpace(self[i], "R", 0) + dtc = Function(R).assign(self.time_partition[i].timestep) def time_integrated_qoi(t): u = sol["uv_2d"] diff --git a/test_adjoint/examples/point_discharge2d.py b/test_adjoint/examples/point_discharge2d.py index e8492bbe..57d91098 100644 --- a/test_adjoint/examples/point_discharge2d.py +++ b/test_adjoint/examples/point_discharge2d.py @@ -52,15 +52,17 @@ def source(mesh): def get_form(self): """ - Advection-diffusion with SUPG - stabilisation. + Advection-diffusion with SUPG stabilisation. """ def form(i, sols): c, c_ = sols["tracer_2d"] fs = self.function_spaces["tracer_2d"][i] - D = Constant(0.1) - u = Constant(as_vector([1.0, 0.0])) + R = FunctionSpace(self[i], "R", 0) + D = Function(R).assign(0.1) + u_x = Function(R).assign(1.0) + u_y = Function(R).assign(0.0) + u = as_vector([u_x, u_y]) h = CellSize(self[i]) S = source(self[i]) @@ -151,8 +153,9 @@ def analytical_solution(mesh): a given mesh. See [Riadh et al. 2014]. """ x, y = SpatialCoordinate(mesh) - u = Constant(1.0) - D = Constant(0.1) + R = FunctionSpace(mesh, "R", 0) + u = Function(R).assign(1.0) + D = Function(R).assign(0.1) Pe = 0.5 * u / D r = max_value(sqrt((x - src_x) ** 2 + (y - src_y) ** 2), src_r) return 0.5 / (pi * D) * exp(Pe * (x - src_x)) * bessk0(Pe * r) diff --git a/test_adjoint/examples/point_discharge3d.py b/test_adjoint/examples/point_discharge3d.py index 20f893bd..e9124f31 100644 --- a/test_adjoint/examples/point_discharge3d.py +++ b/test_adjoint/examples/point_discharge3d.py @@ -70,8 +70,12 @@ def get_form(self): def form(i, sols): c, c_ = sols["tracer_3d"] fs = self.function_spaces["tracer_3d"][i] - D = Constant(0.1) - u = Constant(as_vector([1.0, 0.0, 0.0])) + R = FunctionSpace(self[i], "R", 0) + D = Function(R).assign(0.1) + u_x = Function(R).assign(1.0) + u_y = Function(R).assign(0.0) + u_z = Function(R).assign(0.0) + u = as_vector([u_x, u_y, u_z]) h = CellSize(self[i]) S = source(self[i]) @@ -164,8 +168,9 @@ def analytical_solution(mesh): a given mesh. """ x, y, z = SpatialCoordinate(mesh) - u = Constant(1.0) - D = Constant(0.1) + R = FunctionSpace(mesh, "R", 0) + u = Function(R).assign(1.0) + D = Function(R).assign(0.1) Pe = 0.5 * u / D r = max_value(sqrt((x - src_x) ** 2 + (y - src_y) ** 2 + (z - src_z) ** 2), src_r) return 0.5 / (pi * D) * exp(Pe * (x - src_x)) * bessk0(Pe * r) diff --git a/test_adjoint/examples/steady_flow_past_cyl.py b/test_adjoint/examples/steady_flow_past_cyl.py index 927793a3..4f77ec59 100644 --- a/test_adjoint/examples/steady_flow_past_cyl.py +++ b/test_adjoint/examples/steady_flow_past_cyl.py @@ -38,7 +38,8 @@ def get_form(self): def form(i, sols): up, up_ = sols["up"] W = self.function_spaces["up"][i] - nu = Constant(1.0) + R = FunctionSpace(self[i], "R", 0) + nu = Function(R).assign(1.0) u, p = split(up) v, q = TestFunctions(W) F = ( diff --git a/test_adjoint/test_fp_iteration.py b/test_adjoint/test_fp_iteration.py index c8115cb0..f3147722 100644 --- a/test_adjoint/test_fp_iteration.py +++ b/test_adjoint/test_fp_iteration.py @@ -30,11 +30,10 @@ def solver(index, ic): def get_qoi(mesh_seq, solutions, index): + R = FunctionSpace(mesh_seq[index], "R", 0) + def qoi(): - if mesh_seq.fp_iteration % 2 == 0: - return Constant(1, domain=mesh_seq[index]) * dx - else: - return Constant(2, domain=mesh_seq[index]) * dx + return Function(R).assign(1 if mesh_seq.fp_iteration % 2 == 0 else 2) * dx return qoi diff --git a/test_adjoint/test_utils.py b/test_adjoint/test_utils.py index 60aa1d6c..bbe65431 100644 --- a/test_adjoint/test_utils.py +++ b/test_adjoint/test_utils.py @@ -21,8 +21,10 @@ def mesh_seq(self, qoi_type="end_time"): def test_annotate_qoi_0args(self): @annotate_qoi def get_qoi(mesh_seq, solution_map, i): + R = FunctionSpace(mesh_seq[i], "R", 0) + def qoi(): - return Constant(1.0, domain=mesh_seq[i]) * dx + return Function(R).assign(1) * dx return qoi @@ -31,8 +33,10 @@ def qoi(): def test_annotate_qoi_1arg(self): @annotate_qoi def get_qoi(mesh_seq, solution_map, i): + R = FunctionSpace(mesh_seq[i], "R", 0) + def qoi(t): - return Constant(1.0, domain=mesh_seq[i]) * dx + return Function(R).assign(1) * dx return qoi @@ -41,8 +45,10 @@ def qoi(t): def test_annotate_qoi_0args_error(self): @annotate_qoi def get_qoi(mesh_seq, solution_map, i): + R = FunctionSpace(mesh_seq[i], "R", 0) + def qoi(): - return Constant(1.0, domain=mesh_seq[i]) * dx + return Function(R).assign(1) * dx return qoi @@ -54,8 +60,10 @@ def qoi(): def test_annotate_qoi_1arg_error(self): @annotate_qoi def get_qoi(mesh_seq, solution_map, i): + R = FunctionSpace(mesh_seq[i], "R", 0) + def qoi(t): - return Constant(1.0, domain=mesh_seq[i]) * dx + return Function(R).assign(1) * dx return qoi @@ -67,8 +75,10 @@ def qoi(t): def test_annotate_qoi_2args_error(self): @annotate_qoi def get_qoi(mesh_seq, solution_map, i): + R = FunctionSpace(mesh_seq[i], "R", 0) + def qoi(t, r): - return Constant(1.0, domain=mesh_seq[i]) * dx + return Function(R).assign(1) * dx return qoi From 63a620ebe4dddb28a559aab270b0cb7ad6dbb07d Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Sat, 14 Oct 2023 09:05:22 +0100 Subject: [PATCH 17/23] Rework project for using Cofunction --- goalie/interpolation.py | 72 +++++++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/goalie/interpolation.py b/goalie/interpolation.py index 167a0c04..77486e96 100644 --- a/goalie/interpolation.py +++ b/goalie/interpolation.py @@ -3,6 +3,7 @@ """ from .utility import assemble_mass_matrix, cofunction2function, function2cofunction import firedrake +from firedrake.functionspaceimpl import WithGeometry from firedrake.petsc import PETSc from petsc4py import PETSc as petsc4py from typing import Union @@ -12,66 +13,73 @@ __all__ = ["project"] -def project( - source: firedrake.Function, - target_space: Union[firedrake.FunctionSpace, firedrake.Function], - adjoint: bool = False, - **kwargs, -) -> firedrake.Function: +def project(source, target_space, adjoint=False, **kwargs): """ Overload :func:`firedrake.projection.project` to account for the case of two mixed function spaces defined on different meshes and for the adjoint projection operator. Extra keyword arguments are passed to :func:`firedrake.projection.project`. - :arg source: the :class:`firedrake.function.Function` to be projected + :arg source: the :class:`firedrake.function.Function` or + :class:`firedrake.cofunction.Cofunction` to be projected :arg target_space: the :class:`firedrake.functionspaceimpl.FunctionSpace` which we - seek to project into + seek to project into, or the :class:`firedrake.function.Function` or + :class:`firedrake.cofunction.Cofunction` to use as the target :kwarg adjoint: apply the transposed projection operator? """ if not isinstance(source, (firedrake.Function, firedrake.Cofunction)): raise NotImplementedError( "Can only currently project Functions and Cofunctions." - ) # TODO + ) + + # If the input is a Cofunction then record this, map to a Function for the + # projection and then map back to a Cofunction afterwards adj_value = isinstance(source, firedrake.Cofunction) if adj_value: source = cofunction2function(source) Vs = source.function_space() - if isinstance(target_space, firedrake.Function): + + # Account for cases where target_space is not a FunctionSpace + if isinstance(target_space, firedrake.WithGeometry): + target = firedrake.Function(target_space) + elif isinstance(target_space, firedrake.Function): target = target_space - Vt = target.function_space() + elif isinstance(target_space, firedrake.Cofunction): + target = cofunction2function(target_space) else: - Vt = target_space - target = firedrake.Function(Vt) + raise TypeError( + "Second argument must be a FunctionSpace, Function, or Cofunction." + ) + Vt = target.function_space() # Account for the case where the meshes match source_mesh = ufl.domain.extract_unique_domain(source) target_mesh = ufl.domain.extract_unique_domain(target) if source_mesh == target_mesh: if Vs == Vt: - return target.assign(source) + target.assign(source) elif not adjoint: - return target.project(source, **kwargs) - - # Check validity of function spaces - space1, space2 = ("target", "source") if adjoint else ("source", "target") - if hasattr(Vs, "num_sub_spaces"): - if not hasattr(Vt, "num_sub_spaces"): - raise ValueError( - f"{space1} space has multiple components but {space2} space does not.".capitalize() - ) - if Vs.num_sub_spaces() != Vt.num_sub_spaces(): + target.project(source, **kwargs) + else: + space1, space2 = ("target", "source") if adjoint else ("source", "target") + if hasattr(Vs, "num_sub_spaces"): + if not hasattr(Vt, "num_sub_spaces"): + raise ValueError( + f"{space1.capitalize()} space has multiple components but {space2}" + " space does not." + ) + if Vs.num_sub_spaces() != Vt.num_sub_spaces(): + raise ValueError( + f"Inconsistent numbers of components in {space1} and {space2}" + f" spaces: {Vs.num_sub_spaces()} vs. {Vt.num_sub_spaces()}." + ) + elif hasattr(Vt, "num_sub_spaces"): raise ValueError( - f"Inconsistent numbers of components in {space1} and {space2} spaces:" - f" {Vs.num_sub_spaces()} vs. {Vt.num_sub_spaces()}." + f"{space2.capitalize()} space has multiple components but {space1}" + " space does not." ) - elif hasattr(Vt, "num_sub_spaces"): - raise ValueError( - f"{space2} space has multiple components but {space1} space does not.".capitalize() - ) + target = (_project_adjoint if adjoint else _project)(source, target, **kwargs) - # Apply projector - target = (_project_adjoint if adjoint else _project)(source, target, **kwargs) if adj_value: target = function2cofunction(target) return target From 779c5cd74f6b2257908d38fa9ba7fd1c4a401b5a Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Sat, 14 Oct 2023 09:37:30 +0100 Subject: [PATCH 18/23] Remove adjoint kwarg of project; deduce from input --- goalie/interpolation.py | 81 ++++++++++++++------------- test/test_interpolation.py | 110 +++++++++++++++++++++++++++---------- 2 files changed, 126 insertions(+), 65 deletions(-) diff --git a/goalie/interpolation.py b/goalie/interpolation.py index 77486e96..33daca38 100644 --- a/goalie/interpolation.py +++ b/goalie/interpolation.py @@ -13,10 +13,10 @@ __all__ = ["project"] -def project(source, target_space, adjoint=False, **kwargs): +def project(source, target_space, **kwargs): """ Overload :func:`firedrake.projection.project` to account for the case of two mixed - function spaces defined on different meshes and for the adjoint projection operator. + function spaces defined on different meshes and for the adjoint projection operator when applied to :class:`firedrake.cofunction.Cofunction`\s. Extra keyword arguments are passed to :func:`firedrake.projection.project`. @@ -25,7 +25,6 @@ def project(source, target_space, adjoint=False, **kwargs): :arg target_space: the :class:`firedrake.functionspaceimpl.FunctionSpace` which we seek to project into, or the :class:`firedrake.function.Function` or :class:`firedrake.cofunction.Cofunction` to use as the target - :kwarg adjoint: apply the transposed projection operator? """ if not isinstance(source, (firedrake.Function, firedrake.Cofunction)): raise NotImplementedError( @@ -34,13 +33,12 @@ def project(source, target_space, adjoint=False, **kwargs): # If the input is a Cofunction then record this, map to a Function for the # projection and then map back to a Cofunction afterwards - adj_value = isinstance(source, firedrake.Cofunction) - if adj_value: + adjoint = isinstance(source, firedrake.Cofunction) + if adjoint: source = cofunction2function(source) - Vs = source.function_space() # Account for cases where target_space is not a FunctionSpace - if isinstance(target_space, firedrake.WithGeometry): + if isinstance(target_space, WithGeometry): target = firedrake.Function(target_space) elif isinstance(target_space, firedrake.Function): target = target_space @@ -50,37 +48,17 @@ def project(source, target_space, adjoint=False, **kwargs): raise TypeError( "Second argument must be a FunctionSpace, Function, or Cofunction." ) - Vt = target.function_space() - # Account for the case where the meshes match - source_mesh = ufl.domain.extract_unique_domain(source) - target_mesh = ufl.domain.extract_unique_domain(target) - if source_mesh == target_mesh: - if Vs == Vt: - target.assign(source) - elif not adjoint: - target.project(source, **kwargs) + # Choose appropriate transfer method + if source.function_space() == target.function_space(): + target.assign(source) + elif adjoint: + target = _project_adjoint(source, target, **kwargs) else: - space1, space2 = ("target", "source") if adjoint else ("source", "target") - if hasattr(Vs, "num_sub_spaces"): - if not hasattr(Vt, "num_sub_spaces"): - raise ValueError( - f"{space1.capitalize()} space has multiple components but {space2}" - " space does not." - ) - if Vs.num_sub_spaces() != Vt.num_sub_spaces(): - raise ValueError( - f"Inconsistent numbers of components in {space1} and {space2}" - f" spaces: {Vs.num_sub_spaces()} vs. {Vt.num_sub_spaces()}." - ) - elif hasattr(Vt, "num_sub_spaces"): - raise ValueError( - f"{space2.capitalize()} space has multiple components but {space1}" - " space does not." - ) - target = (_project_adjoint if adjoint else _project)(source, target, **kwargs) + target = _project(source, target, **kwargs) - if adj_value: + # Map back to Cofunction in the adjoint case + if adjoint: target = function2cofunction(target) return target @@ -102,9 +80,24 @@ def _project( :arg source: the `Function` to be projected :arg target: the `Function` which we seek to project onto """ + Vs = source.function_space() + Vt = target.function_space() + if hasattr(Vs, "num_sub_spaces"): + if not hasattr(Vt, "num_sub_spaces"): + raise ValueError( + "Source space has multiple components but target space does not." + ) + if Vs.num_sub_spaces() != Vt.num_sub_spaces(): + raise ValueError( + "Inconsistent numbers of components in source and target spaces:" + f" {Vs.num_sub_spaces()} vs. {Vt.num_sub_spaces()}." + ) + elif hasattr(Vt, "num_sub_spaces"): + raise ValueError( + "Target space has multiple components but source space does not." + ) assert isinstance(target, firedrake.Function) - if hasattr(target.function_space(), "num_sub_spaces"): - assert hasattr(source.function_space(), "num_sub_spaces") + if hasattr(Vt, "num_sub_spaces"): for s, t in zip(source.subfunctions, target.subfunctions): t.project(s, **kwargs) else: @@ -137,6 +130,20 @@ def _project_adjoint( Vt = target_b.function_space() assert isinstance(source_b, firedrake.Function) Vs = source_b.function_space() + if hasattr(Vs, "num_sub_spaces"): + if not hasattr(Vt, "num_sub_spaces"): + raise ValueError( + "Source space has multiple components but target space does not." + ) + if Vs.num_sub_spaces() != Vt.num_sub_spaces(): + raise ValueError( + "Inconsistent numbers of components in target and source spaces:" + f" {Vs.num_sub_spaces()} vs. {Vt.num_sub_spaces()}." + ) + elif hasattr(Vt, "num_sub_spaces"): + raise ValueError( + "Target space has multiple components but source space does not." + ) # Get subspaces if hasattr(Vs, "num_sub_spaces"): diff --git a/test/test_interpolation.py b/test/test_interpolation.py index 5d32bdc0..4ff91b8b 100644 --- a/test/test_interpolation.py +++ b/test/test_interpolation.py @@ -3,6 +3,7 @@ """ from firedrake import * from goalie import * +from goalie.utility import function2cofunction from parameterized import parameterized import unittest @@ -28,30 +29,40 @@ def test_notimplemented_error(self): msg = "Can only currently project Functions and Cofunctions." self.assertEqual(str(cm.exception), msg) - @parameterized.expand([[False], [True]]) - def test_no_sub_source_space(self, adjoint): + def test_no_sub_source_space(self): Vs = FunctionSpace(self.source_mesh, "CG", 1) Vt = FunctionSpace(self.target_mesh, "CG", 1) Vt = Vt * Vt with self.assertRaises(ValueError) as cm: - project(Function(Vs), Function(Vt), adjoint=adjoint) - if adjoint: - msg = "Source space has multiple components but target space does not." - else: - msg = "Target space has multiple components but source space does not." + project(Function(Vs), Function(Vt)) + msg = "Target space has multiple components but source space does not." + self.assertEqual(str(cm.exception), msg) + + def test_no_sub_source_space_adjoint(self): + Vs = FunctionSpace(self.source_mesh, "CG", 1) + Vt = FunctionSpace(self.target_mesh, "CG", 1) + Vt = Vt * Vt + with self.assertRaises(ValueError) as cm: + project(Cofunction(Vs.dual()), Cofunction(Vt.dual())) + msg = "Source space has multiple components but target space does not." + self.assertEqual(str(cm.exception), msg) + + def test_no_sub_target_space(self): + Vs = FunctionSpace(self.source_mesh, "CG", 1) + Vt = FunctionSpace(self.target_mesh, "CG", 1) + Vs = Vs * Vs + with self.assertRaises(ValueError) as cm: + project(Function(Vs), Function(Vt)) + msg = "Source space has multiple components but target space does not." self.assertEqual(str(cm.exception), msg) - @parameterized.expand([[False], [True]]) - def test_no_sub_target_space(self, adjoint): + def test_no_sub_target_space_adjoint(self): Vs = FunctionSpace(self.source_mesh, "CG", 1) Vt = FunctionSpace(self.target_mesh, "CG", 1) Vs = Vs * Vs with self.assertRaises(ValueError) as cm: - project(Function(Vs), Function(Vt), adjoint=adjoint) - if adjoint: - msg = "Target space has multiple components but source space does not." - else: - msg = "Source space has multiple components but target space does not." + project(Cofunction(Vs.dual()), Cofunction(Vt.dual())) + msg = "Target space has multiple components but source space does not." self.assertEqual(str(cm.exception), msg) def test_wrong_number_sub_spaces(self): @@ -64,17 +75,24 @@ def test_wrong_number_sub_spaces(self): msg = "Inconsistent numbers of components in source and target spaces: 3 vs. 2." self.assertEqual(str(cm.exception), msg) - @parameterized.expand([[False], [True]]) - def test_project_same_space(self, adjoint): + def test_project_same_space(self): Vs = FunctionSpace(self.source_mesh, "CG", 1) source = interpolate(self.sinusoid(), Vs) target = Function(Vs) - project(source, target, adjoint=adjoint) + project(source, target) + expected = source + self.assertAlmostEqual(errornorm(expected, target), 0) + + def test_project_same_space_adjoint(self): + Vs = FunctionSpace(self.source_mesh, "CG", 1) + source = interpolate(self.sinusoid(), Vs) + source = function2cofunction(source) + target = Cofunction(Vs.dual()) + project(source, target) expected = source self.assertAlmostEqual(errornorm(expected, target), 0) - @parameterized.expand([[False], [True]]) - def test_project_same_space_mixed(self, adjoint): + def test_project_same_space_mixed(self): P1 = FunctionSpace(self.source_mesh, "CG", 1) Vs = P1 * P1 source = Function(Vs) @@ -82,23 +100,42 @@ def test_project_same_space_mixed(self, adjoint): s1.interpolate(self.sinusoid()) s2.interpolate(-self.sinusoid()) target = Function(Vs) - project(source, target, adjoint=adjoint) + project(source, target) + expected = source + self.assertAlmostEqual(errornorm(expected, target), 0) + + def test_project_same_space_mixed_adjoint(self): + P1 = FunctionSpace(self.source_mesh, "CG", 1) + Vs = P1 * P1 + source = Function(Vs) + s1, s2 = source.subfunctions + s1.interpolate(self.sinusoid()) + s2.interpolate(-self.sinusoid()) + source = function2cofunction(source) + target = Cofunction(Vs.dual()) + project(source, target) expected = source self.assertAlmostEqual(errornorm(expected, target), 0) - @parameterized.expand([[False], [True]]) - def test_project_same_mesh(self, adjoint): + def test_project_same_mesh(self): Vs = FunctionSpace(self.source_mesh, "CG", 1) Vt = FunctionSpace(self.source_mesh, "DG", 0) source = interpolate(self.sinusoid(), Vs) target = Function(Vt) - project(source, target, adjoint=adjoint) - expected = Function(Vt) - expected.project(source) + project(source, target) + expected = Function(Vt).project(source) + self.assertAlmostEqual(errornorm(expected, target), 0) + + def test_project_same_mesh_adjoint(self): + Vs = FunctionSpace(self.source_mesh, "CG", 1) + Vt = FunctionSpace(self.source_mesh, "DG", 0) + source = interpolate(self.sinusoid(), Vs) + target = Cofunction(Vt.dual()) + project(function2cofunction(source), target) + expected = function2cofunction(Function(Vt).project(source)) self.assertAlmostEqual(errornorm(expected, target), 0) - @parameterized.expand([[False], [True]]) - def test_project_same_mesh_mixed(self, adjoint): + def test_project_same_mesh_mixed(self): P1 = FunctionSpace(self.source_mesh, "CG", 1) P0 = FunctionSpace(self.source_mesh, "DG", 0) Vs = P1 * P1 @@ -108,7 +145,24 @@ def test_project_same_mesh_mixed(self, adjoint): s1.interpolate(self.sinusoid()) s2.interpolate(-self.sinusoid()) target = Function(Vt) - project(source, target, adjoint=adjoint) + project(source, target) + expected = Function(Vt) + e1, e2 = expected.subfunctions + e1.project(s1) + e2.project(s2) + self.assertAlmostEqual(errornorm(expected, target), 0) + + def test_project_same_mesh_mixed_adjoint(self): + P1 = FunctionSpace(self.source_mesh, "CG", 1) + P0 = FunctionSpace(self.source_mesh, "DG", 0) + Vs = P1 * P1 + Vt = P0 * P0 + source = Function(Vs) + s1, s2 = source.subfunctions + s1.interpolate(self.sinusoid()) + s2.interpolate(-self.sinusoid()) + target = Cofunction(Vt.dual()) + project(function2cofunction(source), target) expected = Function(Vt) e1, e2 = expected.subfunctions e1.project(s1) From 096c4c651e6c7e65300ecc1de1e5c5c716603348 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Sat, 14 Oct 2023 09:42:50 +0100 Subject: [PATCH 19/23] Lint --- goalie/interpolation.py | 4 +--- test/test_interpolation.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/goalie/interpolation.py b/goalie/interpolation.py index 33daca38..8bd5e2f2 100644 --- a/goalie/interpolation.py +++ b/goalie/interpolation.py @@ -6,15 +6,13 @@ from firedrake.functionspaceimpl import WithGeometry from firedrake.petsc import PETSc from petsc4py import PETSc as petsc4py -from typing import Union -import ufl __all__ = ["project"] def project(source, target_space, **kwargs): - """ + r""" Overload :func:`firedrake.projection.project` to account for the case of two mixed function spaces defined on different meshes and for the adjoint projection operator when applied to :class:`firedrake.cofunction.Cofunction`\s. diff --git a/test/test_interpolation.py b/test/test_interpolation.py index 4ff91b8b..53473ad7 100644 --- a/test/test_interpolation.py +++ b/test/test_interpolation.py @@ -4,7 +4,6 @@ from firedrake import * from goalie import * from goalie.utility import function2cofunction -from parameterized import parameterized import unittest From 56f4515166177fb26e743eb97e257477eec25c37 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 23 Oct 2023 08:08:21 +0100 Subject: [PATCH 20/23] Only apply adjoint projection to Cofunctions --- goalie/interpolation.py | 91 +++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 53 deletions(-) diff --git a/goalie/interpolation.py b/goalie/interpolation.py index 8bd5e2f2..46c65bc2 100644 --- a/goalie/interpolation.py +++ b/goalie/interpolation.py @@ -14,7 +14,8 @@ def project(source, target_space, **kwargs): r""" Overload :func:`firedrake.projection.project` to account for the case of two mixed - function spaces defined on different meshes and for the adjoint projection operator when applied to :class:`firedrake.cofunction.Cofunction`\s. + function spaces defined on different meshes and for the adjoint projection operator + when applied to :class:`firedrake.cofunction.Cofunction`\s. Extra keyword arguments are passed to :func:`firedrake.projection.project`. @@ -28,43 +29,24 @@ def project(source, target_space, **kwargs): raise NotImplementedError( "Can only currently project Functions and Cofunctions." ) - - # If the input is a Cofunction then record this, map to a Function for the - # projection and then map back to a Cofunction afterwards - adjoint = isinstance(source, firedrake.Cofunction) - if adjoint: - source = cofunction2function(source) - - # Account for cases where target_space is not a FunctionSpace if isinstance(target_space, WithGeometry): target = firedrake.Function(target_space) - elif isinstance(target_space, firedrake.Function): + elif isinstance(target_space, (firedrake.Cofunction, firedrake.Function)): target = target_space - elif isinstance(target_space, firedrake.Cofunction): - target = cofunction2function(target_space) else: raise TypeError( "Second argument must be a FunctionSpace, Function, or Cofunction." ) - - # Choose appropriate transfer method - if source.function_space() == target.function_space(): - target.assign(source) - elif adjoint: - target = _project_adjoint(source, target, **kwargs) + if isinstance(source, firedrake.Cofunction): + return _project_adjoint(source, target, **kwargs) + elif source.function_space() == target.function_space(): + return target.assign(source) else: - target = _project(source, target, **kwargs) - - # Map back to Cofunction in the adjoint case - if adjoint: - target = function2cofunction(target) - return target + return _project(source, target, **kwargs) @PETSc.Log.EventDecorator("goalie.interpolation.project") -def _project( - source: firedrake.Function, target: firedrake.Function, **kwargs -) -> firedrake.Function: +def _project(source, target, **kwargs): """ Apply a mesh-to-mesh conservative projection to some source :class:`firedrake.function.Function`, mapping to a target @@ -104,13 +86,11 @@ def _project( @PETSc.Log.EventDecorator("goalie.interpolation.project_adjoint") -def _project_adjoint( - target_b: firedrake.Function, source_b: firedrake.Function, **kwargs -) -> firedrake.Function: +def _project_adjoint(target_b, source_b, **kwargs): """ Apply the adjoint of a mesh-to-mesh conservative projection to some seed - :class:`firedrake.function.Function`, mapping to an output - :class:`firedrake.function.Function`. + :class:`firedrake.cofunction.Cofunction`, mapping to an output + :class:`firedrake.cofunction.Cofunction`. The notation used here is in terms of the adjoint of standard projection. However, this function may also be interpreted as a projector in its own right, @@ -118,15 +98,20 @@ def _project_adjoint( Extra keyword arguments are passed to :func:`firedrake.projection.project`. - :arg target_b: seed :class:`firedrake.function.Function` from the target space of - the forward projection - :arg source_b: the :class:`firedrake.function.Function` from the source space of - the forward projection + :arg target_b: seed :class:`firedrake.cofunction.Cofunction` from the target space + of the forward projection + :arg source_b: the :class:`firedrake.cofunction.Cofunction` from the source space + of the forward projection """ from firedrake.supermeshing import assemble_mixed_mass_matrix + # Map to Functions to apply the adjoint projection + if not isinstance(target_b, firedrake.Function): + target_b = cofunction2function(target_b) + if not isinstance(source_b, firedrake.Function): + source_b = cofunction2function(source_b) + Vt = target_b.function_space() - assert isinstance(source_b, firedrake.Function) Vs = source_b.function_space() if hasattr(Vs, "num_sub_spaces"): if not hasattr(Vt, "num_sub_spaces"): @@ -138,28 +123,28 @@ def _project_adjoint( "Inconsistent numbers of components in target and source spaces:" f" {Vs.num_sub_spaces()} vs. {Vt.num_sub_spaces()}." ) + target_b_split = target_b.subfunctions + source_b_split = source_b.subfunctions elif hasattr(Vt, "num_sub_spaces"): raise ValueError( "Target space has multiple components but source space does not." ) - - # Get subspaces - if hasattr(Vs, "num_sub_spaces"): - assert hasattr(Vt, "num_sub_spaces") - target_b_split = target_b.subfunctions - source_b_split = source_b.subfunctions else: target_b_split = [target_b] source_b_split = [source_b] # Apply adjoint projection operator to each component - for i, (t_b, s_b) in enumerate(zip(target_b_split, source_b_split)): - ksp = petsc4py.KSP().create() - ksp.setOperators(assemble_mass_matrix(t_b.function_space())) - mixed_mass = assemble_mixed_mass_matrix(Vt[i], Vs[i]) - with t_b.dat.vec_ro as tb, s_b.dat.vec_wo as sb: - residual = tb.copy() - ksp.solveTranspose(tb, residual) - mixed_mass.mult(residual, sb) # NOTE: mixed mass already transposed - - return source_b + if Vs == Vt: + source_b.assign(target_b) + else: + for i, (t_b, s_b) in enumerate(zip(target_b_split, source_b_split)): + ksp = petsc4py.KSP().create() + ksp.setOperators(assemble_mass_matrix(t_b.function_space())) + mixed_mass = assemble_mixed_mass_matrix(Vt[i], Vs[i]) + with t_b.dat.vec_ro as tb, s_b.dat.vec_wo as sb: + residual = tb.copy() + ksp.solveTranspose(tb, residual) + mixed_mass.mult(residual, sb) # NOTE: already transposed above + + # Map back to a Cofunction + return function2cofunction(source_b) From 4168902ba24684ee3d120cd4d4254721dc57bceb Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 23 Oct 2023 08:42:52 +0100 Subject: [PATCH 21/23] Combine solves in error indicator --- goalie/error_estimation.py | 44 +++++++++++++++----------------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/goalie/error_estimation.py b/goalie/error_estimation.py index 654b138d..8c29dcc3 100644 --- a/goalie/error_estimation.py +++ b/goalie/error_estimation.py @@ -32,42 +32,32 @@ def form2indicator(F: ufl.form.Form) -> Function: mesh = F.ufl_domain() P0 = FunctionSpace(mesh, "DG", 0) p0test = firedrake.TestFunction(P0) - indicator = Function(P0) - mass_term = firedrake.TrialFunction(P0) * p0test * firedrake.dx + f = ufl.FacetArea(mesh) + h = ufl.CellVolume(mesh) - # Contributions from surface integrals - flux_terms = 0 + rhs = 0 for integral in F.integrals_by_type("exterior_facet"): ds = firedrake.ds(integral.subdomain_id()) - flux_terms += p0test * integral.integrand() * ds + rhs += p0test * integral.integrand() * ds for integral in F.integrals_by_type("interior_facet"): dS = firedrake.dS(integral.subdomain_id()) - flux_terms += p0test("+") * integral.integrand() * dS - flux_terms += p0test("-") * integral.integrand() * dS - if flux_terms != 0: - sp = { - "snes_type": "ksponly", - "ksp_type": "preonly", - "pc_type": "jacobi", - } - firedrake.solve(mass_term == flux_terms, indicator, solver_parameters=sp) - - # Contributions from volume integrals - cell_terms = 0 + rhs += p0test("+") * integral.integrand() * dS + rhs += p0test("-") * integral.integrand() * dS for integral in F.integrals_by_type("cell"): dx = firedrake.dx(integral.subdomain_id()) - cell_terms += p0test * integral.integrand() * dx - if cell_terms != 0: - cell_contrib = Function(P0) - sp = { + rhs += p0test * integral.integrand() * dx + + assert rhs != 0 + indicator = Function(P0) + firedrake.solve( + firedrake.TrialFunction(P0) * p0test * firedrake.dx == rhs, + indicator, + solver_parameters={ "snes_type": "ksponly", "ksp_type": "preonly", - "pc_type": "lu", - "pc_factor_mat_solver_type": "mumps", - } - firedrake.solve(mass_term == cell_terms, cell_contrib, solver_parameters=sp) - indicator += cell_contrib - + "pc_type": "jacobi", + }, + ) return indicator From 1b825f094e069c837968744c9142b4a0d070cbd0 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 23 Oct 2023 21:04:33 +0100 Subject: [PATCH 22/23] Multiply by CellVolume in error indicator --- goalie/error_estimation.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/goalie/error_estimation.py b/goalie/error_estimation.py index 8c29dcc3..ae87dd0f 100644 --- a/goalie/error_estimation.py +++ b/goalie/error_estimation.py @@ -32,20 +32,19 @@ def form2indicator(F: ufl.form.Form) -> Function: mesh = F.ufl_domain() P0 = FunctionSpace(mesh, "DG", 0) p0test = firedrake.TestFunction(P0) - f = ufl.FacetArea(mesh) h = ufl.CellVolume(mesh) rhs = 0 for integral in F.integrals_by_type("exterior_facet"): ds = firedrake.ds(integral.subdomain_id()) - rhs += p0test * integral.integrand() * ds + rhs += h * p0test * integral.integrand() * ds for integral in F.integrals_by_type("interior_facet"): dS = firedrake.dS(integral.subdomain_id()) - rhs += p0test("+") * integral.integrand() * dS - rhs += p0test("-") * integral.integrand() * dS + rhs += h("+") * p0test("+") * integral.integrand() * dS + rhs += h("-") * p0test("-") * integral.integrand() * dS for integral in F.integrals_by_type("cell"): dx = firedrake.dx(integral.subdomain_id()) - rhs += p0test * integral.integrand() * dx + rhs += h * p0test * integral.integrand() * dx assert rhs != 0 indicator = Function(P0) From f55dce1fa401b6a37f0bf957b15ae5033c0835e8 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 23 Oct 2023 21:04:56 +0100 Subject: [PATCH 23/23] Update tests to account for multiplication by CellVolume --- test/test_error_estimation.py | 44 +++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/test/test_error_estimation.py b/test/test_error_estimation.py index 465f1598..87a286ee 100644 --- a/test/test_error_estimation.py +++ b/test/test_error_estimation.py @@ -37,21 +37,21 @@ def test_form_type_error(self): def test_exterior_facet_integral(self): F = self.one * ds(1) - self.one * ds(2) indicator = form2indicator(F) - self.assertAlmostEqual(indicator.dat.data[0], -2) - self.assertAlmostEqual(indicator.dat.data[1], 2) + self.assertAlmostEqual(indicator.dat.data[0], -1) + self.assertAlmostEqual(indicator.dat.data[1], 1) def test_interior_facet_integral(self): F = avg(self.one) * dS indicator = form2indicator(F) - self.assertAlmostEqual(indicator.dat.data[0], 2 * sqrt(2)) - self.assertAlmostEqual(indicator.dat.data[1], 2 * sqrt(2)) + self.assertAlmostEqual(indicator.dat.data[0], sqrt(2)) + self.assertAlmostEqual(indicator.dat.data[1], sqrt(2)) def test_cell_integral(self): x, y = SpatialCoordinate(self.mesh) F = conditional(x + y < 1, 1, 0) * dx indicator = form2indicator(F) self.assertAlmostEqual(indicator.dat.data[0], 0) - self.assertAlmostEqual(indicator.dat.data[1], 1) + self.assertAlmostEqual(indicator.dat.data[1], 0.5) class TestIndicators2Estimator(ErrorEstimationTestCase): @@ -129,7 +129,7 @@ def test_unit_time_instant(self): time_instant = TimeInstant("field", time=1.0) indicator = form2indicator(self.one * dx) estimator = indicators2estimator({"field": [[indicator]]}, time_instant) - self.assertAlmostEqual(estimator, 2.0) # 1 * (1 + 1) + self.assertAlmostEqual(estimator, 1) # 1 * (0.5 + 0.5) @parameterized.expand([[False], [True]]) def test_unit_time_instant_abs(self, absolute_value): @@ -139,20 +139,20 @@ def test_unit_time_instant_abs(self, absolute_value): {"field": [[indicator]]}, time_instant, absolute_value=absolute_value ) self.assertAlmostEqual( - estimator, 2.0 if absolute_value else -2.0 - ) # (-)1 * (1 + 1) + estimator, 1 if absolute_value else -1 + ) # (-)1 * (0.5 + 0.5) def test_half_time_instant(self): time_instant = TimeInstant("field", time=0.5) indicator = form2indicator(self.one * dx) estimator = indicators2estimator({"field": [[indicator]]}, time_instant) - self.assertAlmostEqual(estimator, 1.0) # 0.5 * (1 + 1) + self.assertAlmostEqual(estimator, 0.5) # 0.5 * (0.5 + 0.5) def test_time_partition_same_timestep(self): time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"]) indicator = form2indicator(self.one * dx) estimator = indicators2estimator({"field": [2 * [indicator]]}, time_partition) - self.assertAlmostEqual(estimator, 2.0) # 0.5 * (1 + 1) + 0.5 * (1 + 1) + self.assertAlmostEqual(estimator, 1) # 2 * 0.5 * (0.5 + 0.5) def test_time_partition_different_timesteps(self): time_partition = TimePartition(1.0, 2, [0.5, 0.25], ["field"]) @@ -160,7 +160,7 @@ def test_time_partition_different_timesteps(self): estimator = indicators2estimator( {"field": [[indicator], 2 * [indicator]]}, time_partition ) - self.assertAlmostEqual(estimator, 2.0) # 0.5 * (1 + 1) + 0.25 * 2 * (1 + 1) + self.assertAlmostEqual(estimator, 1) # 0.5 * (0.5 + 0.5) + 0.25 * 2 * (0.5 + 0.5) def test_time_instant_multiple_fields(self): time_instant = TimeInstant(["field1", "field2"], time=1.0) @@ -168,7 +168,7 @@ def test_time_instant_multiple_fields(self): estimator = indicators2estimator( {"field1": [[indicator]], "field2": [[indicator]]}, time_instant ) - self.assertAlmostEqual(estimator, 4.0) # 2 * (1 * (1 + 1)) + self.assertAlmostEqual(estimator, 2) # 2 * (1 * (0.5 + 0.5)) class TestGetDWRIndicator(ErrorEstimationTestCase): @@ -247,32 +247,32 @@ def test_convert_neither(self): adjoint_error = {"field": self.two} test_space = {"field": self.one.function_space()} indicator = get_dwr_indicator(self.F, adjoint_error, test_space=test_space) - self.assertAlmostEqual(indicator.dat.data[0], 2) - self.assertAlmostEqual(indicator.dat.data[1], 2) + self.assertAlmostEqual(indicator.dat.data[0], 1) + self.assertAlmostEqual(indicator.dat.data[1], 1) def test_convert_both(self): test_space = self.one.function_space() indicator = get_dwr_indicator(self.F, self.two, test_space=test_space) - self.assertAlmostEqual(indicator.dat.data[0], 2) - self.assertAlmostEqual(indicator.dat.data[1], 2) + self.assertAlmostEqual(indicator.dat.data[0], 1) + self.assertAlmostEqual(indicator.dat.data[1], 1) def test_convert_test_space(self): adjoint_error = {"field": self.two} test_space = self.one.function_space() indicator = get_dwr_indicator(self.F, adjoint_error, test_space=test_space) - self.assertAlmostEqual(indicator.dat.data[0], 2) - self.assertAlmostEqual(indicator.dat.data[1], 2) + self.assertAlmostEqual(indicator.dat.data[0], 1) + self.assertAlmostEqual(indicator.dat.data[1], 1) def test_convert_adjoint_error(self): test_space = {"Dos": self.one.function_space()} indicator = get_dwr_indicator(self.F, self.two, test_space=test_space) - self.assertAlmostEqual(indicator.dat.data[0], 2) - self.assertAlmostEqual(indicator.dat.data[1], 2) + self.assertAlmostEqual(indicator.dat.data[0], 1) + self.assertAlmostEqual(indicator.dat.data[1], 1) def test_convert_adjoint_error_no_test_space(self): indicator = get_dwr_indicator(self.F, self.two) - self.assertAlmostEqual(indicator.dat.data[0], 2) - self.assertAlmostEqual(indicator.dat.data[1], 2) + self.assertAlmostEqual(indicator.dat.data[0], 1) + self.assertAlmostEqual(indicator.dat.data[1], 1) def test_convert_adjoint_error_mismatch(self): test_space = {"field": self.one.function_space()}