diff --git a/firedrake/adjoint/__init__.py b/firedrake/adjoint/__init__.py index 8373a4285a..c48b990420 100644 --- a/firedrake/adjoint/__init__.py +++ b/firedrake/adjoint/__init__.py @@ -21,10 +21,11 @@ pause_annotation, continue_annotation, \ stop_annotating, annotate_tape # noqa F401 from pyadjoint.reduced_functional import ReducedFunctional # noqa F401 +from pyadjoint.checkpointing import disk_checkpointing_callback # noqa F401 from firedrake.adjoint_utils.checkpointing import \ enable_disk_checkpointing, pause_disk_checkpointing, \ continue_disk_checkpointing, stop_disk_checkpointing, \ - checkpointable_mesh # noqa F401 + checkpointable_mesh # noqa F401 from firedrake.adjoint_utils import get_solve_blocks # noqa F401 from pyadjoint.verification import taylor_test, taylor_to_dict # noqa F401 diff --git a/firedrake/adjoint_utils/checkpointing.py b/firedrake/adjoint_utils/checkpointing.py index 583d7155eb..0ab337e17b 100644 --- a/firedrake/adjoint_utils/checkpointing.py +++ b/firedrake/adjoint_utils/checkpointing.py @@ -1,5 +1,5 @@ """A module providing support for disk checkpointing of the adjoint tape.""" -from pyadjoint import get_working_tape, OverloadedType +from pyadjoint import get_working_tape, OverloadedType, disk_checkpointing_callback from pyadjoint.tape import TapePackageData from pyop2.mpi import COMM_WORLD import tempfile @@ -10,6 +10,8 @@ from numbers import Number _enable_disk_checkpoint = False _checkpoint_init_data = False +disk_checkpointing_callback["firedrake"] = "Please call enable_disk_checkpointing() "\ + "before checkpointing on the disk." __all__ = ["enable_disk_checkpointing", "disk_checkpointing", "pause_disk_checkpointing", "continue_disk_checkpointing", @@ -204,6 +206,12 @@ def restore_from_checkpoint(self, state): self.init_checkpoint_file = state["init"] self.current_checkpoint_file = state["current"] + def continue_checkpointing(self): + continue_disk_checkpointing() + + def pause_checkpointing(self): + pause_disk_checkpointing() + def checkpointable_mesh(mesh): """Write a mesh to disk and read it back. @@ -251,7 +259,7 @@ def restore(self): pass -class CheckpointFunction(CheckpointBase): +class CheckpointFunction(CheckpointBase, OverloadedType): """Metadata for a Function checkpointed to disk. An object of this class replaces the :class:`~firedrake.Function` stored as @@ -304,6 +312,9 @@ def restore(self): return type(function)(function.function_space(), function.dat, name=self.name(), count=self.count) + def _ad_restore_at_checkpoint(self, checkpoint): + return checkpoint.restore() + def maybe_disk_checkpoint(function): """Checkpoint a Function to disk if disk checkpointing is active.""" diff --git a/firedrake/scripts/firedrake-install b/firedrake/scripts/firedrake-install index 8b0ad3c76d..ea4997852d 100755 --- a/firedrake/scripts/firedrake-install +++ b/firedrake/scripts/firedrake-install @@ -2064,6 +2064,7 @@ with environment(**compiler_env): if options["netgen"]: packages += ["ngsPETSc"] + run_pip(["install", "-U", "ngsPETSc"]) with pipargs("--no-deps"): if options["opencascade"]: