diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 935a270d..9ae9809b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -27,13 +27,6 @@ jobs: uses: actions/setup-python@v2 with: python-version: 3.8 - - name: Install Animate - run: | - . /home/firedrake/firedrake/bin/activate - cd .. - git clone https://github.com/pyroteus/animate.git - cd animate - python -m pip install -e . - name: Install Goalie run: | . /home/firedrake/firedrake/bin/activate diff --git a/demos/burgers-hessian.py b/demos/burgers-hessian.py index b17874f1..32987c73 100644 --- a/demos/burgers-hessian.py +++ b/demos/burgers-hessian.py @@ -27,10 +27,11 @@ def get_form(mesh_seq): def form(index, solutions): u, u_ = solutions["u"] P = mesh_seq.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) diff --git a/demos/burgers.py b/demos/burgers.py index f0e8ca2f..6b8c88c8 100644 --- a/demos/burgers.py +++ b/demos/burgers.py @@ -50,17 +50,19 @@ def get_function_spaces(mesh): # # Timestepping information associated with a given subinterval # can be accessed via the :attr:`TimePartition` attribute of -# the :class:`MeshSeq`. :: +# the :class:`MeshSeq`. For technical reasons, we need to create a :class:`Function` +# in the `'R'` space (of real numbers) to hold constants. :: def get_form(mesh_seq): def form(index, solutions): u, u_ = solutions["u"] P = mesh_seq.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) diff --git a/demos/burgers1.py b/demos/burgers1.py index c66eba02..d077bf88 100644 --- a/demos/burgers1.py +++ b/demos/burgers1.py @@ -28,10 +28,11 @@ def get_form(mesh_seq): def form(index, solutions): u, u_ = solutions["u"] P = mesh_seq.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) diff --git a/demos/burgers2.py b/demos/burgers2.py index 9b8c298b..5215c9cd 100644 --- a/demos/burgers2.py +++ b/demos/burgers2.py @@ -28,10 +28,11 @@ def get_form(mesh_seq): def form(index, solutions): u, u_ = solutions["u"] P = mesh_seq.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) diff --git a/demos/burgers_ee.py b/demos/burgers_ee.py index c1903907..9cbf3724 100644 --- a/demos/burgers_ee.py +++ b/demos/burgers_ee.py @@ -43,10 +43,11 @@ def get_form(mesh_seq): def form(index, solutions): u, u_ = solutions["u"] P = mesh_seq.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) diff --git a/demos/burgers_oo.py b/demos/burgers_oo.py index f016b1fc..85aa21e8 100644 --- a/demos/burgers_oo.py +++ b/demos/burgers_oo.py @@ -33,10 +33,11 @@ def get_form(self): def form(index, solutions): u, u_ = solutions["u"] P = self.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) @@ -84,7 +85,8 @@ def get_initial_condition(self): @annotate_qoi def get_qoi(self, solutions, i): - dt = Constant(self.time_partition[i].timestep) + R = FunctionSpace(self[i], "R", 0) + dt = Function(R).assign(self.time_partition[i].timestep) def end_time_qoi(): u = solutions["u"] diff --git a/demos/burgers_time_integrated.py b/demos/burgers_time_integrated.py index 9dbecbc1..684460c5 100644 --- a/demos/burgers_time_integrated.py +++ b/demos/burgers_time_integrated.py @@ -23,10 +23,11 @@ def get_form(mesh_seq): def form(index, solutions): u, u_ = solutions["u"] P = mesh_seq.time_partition - dt = Constant(P.timesteps[index]) - # Specify viscosity coefficient - nu = Constant(0.0001) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(P.timesteps[index]) + nu = Function(R).assign(0.0001) # Setup variational problem v = TestFunction(u.function_space()) @@ -93,12 +94,13 @@ def solver(index, ic): # \;\mathrm dy\;\mathrm dt. # # Note that in this case we multiply by the timestep. -# It is wrapped in a :class:`Constant` to avoid +# It is wrapped in a :class:`Function` from `'R'` space to avoid # recompilation if the value is changed. :: def get_qoi(mesh_seq, solutions, i): - dt = Constant(mesh_seq.time_partition[i].timestep) + R = FunctionSpace(mesh_seq[i], "R", 0) + dt = Function(R).assign(mesh_seq.time_partition[i].timestep) def time_integrated_qoi(t): u = solutions["u"] diff --git a/demos/gray_scott.py b/demos/gray_scott.py index 4d4f344d..b9df29ea 100644 --- a/demos/gray_scott.py +++ b/demos/gray_scott.py @@ -57,11 +57,12 @@ def form(index, sols): psi_a, psi_b = TestFunctions(mesh_seq.function_spaces["ab"][index]) # Define constants - dt = Constant(mesh_seq.time_partition[index].timestep) - D_a = Constant(8.0e-05) - D_b = Constant(4.0e-05) - gamma = Constant(0.024) - kappa = Constant(0.06) + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(mesh_seq.time_partition[index].timestep) + D_a = Function(R).assign(8.0e-05) + D_b = Function(R).assign(4.0e-05) + gamma = Function(R).assign(0.024) + kappa = Function(R).assign(0.06) # Write the two equations in variational form F = ( diff --git a/demos/gray_scott_split.py b/demos/gray_scott_split.py index f76f2495..f028bfad 100644 --- a/demos/gray_scott_split.py +++ b/demos/gray_scott_split.py @@ -57,11 +57,12 @@ def form(index, sols): psi_b = TestFunction(mesh_seq.function_spaces["b"][index]) # Define constants - dt = Constant(mesh_seq.time_partition[index].timestep) - D_a = Constant(8.0e-05) - D_b = Constant(4.0e-05) - gamma = Constant(0.024) - kappa = Constant(0.06) + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(mesh_seq.time_partition[index].timestep) + D_a = Function(R).assign(8.0e-05) + D_b = Function(R).assign(4.0e-05) + gamma = Function(R).assign(0.024) + kappa = Function(R).assign(0.06) # Write the two equations in variational form F_a = ( diff --git a/demos/point_discharge2d-goal_oriented.py b/demos/point_discharge2d-goal_oriented.py index f6f942aa..f398c110 100644 --- a/demos/point_discharge2d-goal_oriented.py +++ b/demos/point_discharge2d-goal_oriented.py @@ -36,11 +36,16 @@ def get_form(mesh_seq): def form(index, sols): c, c_ = sols["c"] function_space = mesh_seq.function_spaces["c"][index] - D = Constant(0.1) - u = Constant(as_vector([1, 0])) h = CellSize(mesh_seq[index]) S = source(mesh_seq[index]) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + D = Function(R).assign(0.1) + u_x = Function(R).assign(1.0) + u_y = Function(R).assign(0.0) + u = as_vector([u_x, u_y]) + # SUPG stabilisation parameter unorm = sqrt(dot(u, u)) tau = 0.5 * h / unorm diff --git a/demos/point_discharge2d-hessian.py b/demos/point_discharge2d-hessian.py index 7b3ef3ba..ee0c0b9e 100644 --- a/demos/point_discharge2d-hessian.py +++ b/demos/point_discharge2d-hessian.py @@ -40,11 +40,16 @@ def get_form(mesh_seq): def form(index, sols): c, c_ = sols["c"] function_space = mesh_seq.function_spaces["c"][index] - D = Constant(0.1) - u = Constant(as_vector([1, 0])) h = CellSize(mesh_seq[index]) S = source(mesh_seq[index]) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + D = Function(R).assign(0.1) + u_x = Function(R).assign(1.0) + u_y = Function(R).assign(0.0) + u = as_vector([u_x, u_y]) + # SUPG stabilisation parameter unorm = sqrt(dot(u, u)) tau = 0.5 * h / unorm diff --git a/demos/point_discharge2d.py b/demos/point_discharge2d.py index d053d3e9..32e85378 100644 --- a/demos/point_discharge2d.py +++ b/demos/point_discharge2d.py @@ -71,11 +71,16 @@ def get_form(mesh_seq): def form(index, sols): c, c_ = sols["c"] function_space = mesh_seq.function_spaces["c"][index] - D = Constant(0.1) - u = Constant(as_vector([1, 0])) h = CellSize(mesh_seq[index]) S = source(mesh_seq[index]) + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + D = Function(R).assign(0.1) + u_x = Function(R).assign(1.0) + u_y = Function(R).assign(0.0) + u = as_vector([u_x, u_y]) + # SUPG stabilisation parameter unorm = sqrt(dot(u, u)) tau = 0.5 * h / unorm diff --git a/demos/solid_body_rotation.py b/demos/solid_body_rotation.py index e8f656fc..52e15639 100644 --- a/demos/solid_body_rotation.py +++ b/demos/solid_body_rotation.py @@ -151,10 +151,14 @@ def form(index, sols, field="c"): V = mesh_seq.function_spaces[field][index] mesh = mesh_seq[index] + # Define velocity field x, y = SpatialCoordinate(mesh) u = as_vector([-y, x]) - dt = Constant(mesh_seq.time_partition[index].timestep) - theta = Constant(0.5) + + # Define constants + R = FunctionSpace(mesh_seq[index], "R", 0) + dt = Function(R).assign(mesh_seq.time_partition[index].timestep) + theta = Function(R).assign(0.5) psi = TrialFunction(V) phi = TestFunction(V) diff --git a/goalie/adjoint.py b/goalie/adjoint.py index f83a5153..b075821c 100644 --- a/goalie/adjoint.py +++ b/goalie/adjoint.py @@ -3,12 +3,12 @@ """ import firedrake from firedrake.petsc import PETSc -from firedrake_adjoint import pyadjoint +from firedrake.adjoint import pyadjoint from .interpolation import project from .mesh_seq import MeshSeq from .options import GoalOrientedParameters from .time_partition import TimePartition -from .utility import AttrDict +from .utility import AttrDict, norm from .log import pyrint from collections.abc import Callable from functools import wraps @@ -238,8 +238,6 @@ def solve_adjoint( labels = ("forward", "forward_old", "adjoint") if not self.steady: labels += ("adjoint_next",) - if get_adj_values: - labels += ("adj_value",) solutions = AttrDict( { field: AttrDict( @@ -257,6 +255,16 @@ def solve_adjoint( for field in self.fields } ) + if get_adj_values: + for field in self.fields: + solutions[field]["adj_value"] = [] + for i, fs in enumerate(function_spaces[field]): + solutions[field]["adj_value"].append( + [ + firedrake.Cofunction(fs.dual(), name=f"{field}_adj_value") + for j in range(P.num_exports_per_subinterval[i] - 1) + ] + ) @PETSc.Log.EventDecorator("goalie.AdjointMeshSeq.solve_adjoint.evaluate_fwd") @wraps(solver) @@ -275,32 +283,37 @@ def wrapped_solver(subinterval, ic, **kwargs): self.controls = [pyadjoint.Control(init[field]) for field in self.fields] return solver(subinterval, init, **kwargs) - # Clear tape - tape = pyadjoint.get_working_tape() - tape.clear_tape() - # Loop over subintervals in reverse - seeds = None + seeds = {} for i in reversed(range(num_subintervals)): stride = P.num_timesteps_per_export[i] num_exports = P.num_exports_per_subinterval[i] + # Clear tape and start annotation + if not pyadjoint.annotate_tape(): + pyadjoint.continue_annotation() + tape = pyadjoint.get_working_tape() + if tape is not None: + tape.clear_tape() + # Annotate tape on current subinterval checkpoint = wrapped_solver(i, checkpoints[i], **solver_kwargs) + pyadjoint.pause_annotation() # Get seed vector for reverse propagation if i == num_subintervals - 1: if self.qoi_type in ["end_time", "steady"]: + pyadjoint.continue_annotation() qoi = self.get_qoi(checkpoint, i) self.J = qoi(**qoi_kwargs) if np.isclose(float(self.J), 0.0): self.warning("Zero QoI. Is it implemented as intended?") + pyadjoint.pause_annotation() else: - with pyadjoint.stop_annotating(): - for field, fs in function_spaces.items(): - checkpoint[field].block_variable.adj_value = project( - seeds[field], fs[i], adjoint=True - ) + for field, fs in function_spaces.items(): + checkpoint[field].block_variable.adj_value = project( + seeds[field], fs[i], adjoint=True + ) # Update adjoint solver kwargs for field in self.fields: @@ -308,6 +321,7 @@ def wrapped_solver(subinterval, ic, **kwargs): block.adj_kwargs.update(adj_solver_kwargs) # Solve adjoint problem + tape = pyadjoint.get_working_tape() with PETSc.Log.Event("goalie.AdjointMeshSeq.solve_adjoint.evaluate_adj"): m = pyadjoint.enlisting.Enlist(self.controls) with pyadjoint.stop_annotating(): @@ -363,7 +377,7 @@ def wrapped_solver(subinterval, ic, **kwargs): # Adjoint action also comes from dependencies if get_adj_values and dep is not None: - sols.adj_value[i][j].assign(dep.adj_value.function) + sols.adj_value[i][j].assign(dep.adj_value) # The adjoint solution at the 'next' timestep is determined from the # adj_sol attribute of the next solve block @@ -387,32 +401,30 @@ def wrapped_solver(subinterval, ic, **kwargs): ) # Check non-zero adjoint solution/value - if np.isclose(firedrake.norm(solutions[field].adjoint[i][0]), 0.0): + if np.isclose(norm(solutions[field].adjoint[i][0]), 0.0): self.warning( f"Adjoint solution for field '{field}' on {self.th(i)}" " subinterval is zero." ) - if get_adj_values and np.isclose( - firedrake.norm(sols.adj_value[i][0]), 0.0 - ): + if get_adj_values and np.isclose(norm(sols.adj_value[i][0]), 0.0): self.warning( f"Adjoint action for field '{field}' on {self.th(i)}" " subinterval is zero." ) # Get adjoint action on each subinterval - seeds = { - field: firedrake.Function( - function_spaces[field][i], val=control.block_variable.adj_value - ) - for field, control in zip(self.fields, self.controls) - } - for field, seed in seeds.items(): - if not self.steady and np.isclose(firedrake.norm(seed), 0.0): - self.warning( - f"Adjoint action for field '{field}' on {self.th(i)}" - " subinterval is zero." + with pyadjoint.stop_annotating(): + for field, control in zip(self.fields, self.controls): + seeds[field] = firedrake.Cofunction( + function_spaces[field][i].dual() ) + if control.block_variable.adj_value is not None: + seeds[field].assign(control.block_variable.adj_value) + if not self.steady and np.isclose(norm(seeds[field]), 0.0): + self.warning( + f"Adjoint action for field '{field}' on {self.th(i)}" + " subinterval is zero." + ) # Clear the tape to reduce the memory footprint tape.clear_tape() @@ -424,6 +436,8 @@ def wrapped_solver(subinterval, ic, **kwargs): "QoI values computed during checkpointing and annotated" f" run do not match ({J_chk} vs. {self.J})" ) + + tape.clear_tape() return solutions @staticmethod diff --git a/goalie/error_estimation.py b/goalie/error_estimation.py index e48e5197..654b138d 100644 --- a/goalie/error_estimation.py +++ b/goalie/error_estimation.py @@ -33,6 +33,7 @@ def form2indicator(F: ufl.form.Form) -> Function: P0 = FunctionSpace(mesh, "DG", 0) p0test = firedrake.TestFunction(P0) indicator = Function(P0) + mass_term = firedrake.TrialFunction(P0) * p0test * firedrake.dx # Contributions from surface integrals flux_terms = 0 @@ -44,8 +45,6 @@ def form2indicator(F: ufl.form.Form) -> Function: flux_terms += p0test("+") * integral.integrand() * dS flux_terms += p0test("-") * integral.integrand() * dS if flux_terms != 0: - dx = firedrake.dx - mass_term = firedrake.TrialFunction(P0) * p0test * dx sp = { "snes_type": "ksponly", "ksp_type": "preonly", @@ -59,7 +58,15 @@ def form2indicator(F: ufl.form.Form) -> Function: dx = firedrake.dx(integral.subdomain_id()) cell_terms += p0test * integral.integrand() * dx if cell_terms != 0: - indicator += firedrake.assemble(cell_terms) + cell_contrib = Function(P0) + sp = { + "snes_type": "ksponly", + "ksp_type": "preonly", + "pc_type": "lu", + "pc_factor_mat_solver_type": "mumps", + } + firedrake.solve(mass_term == cell_terms, cell_contrib, solver_parameters=sp) + indicator += cell_contrib return indicator diff --git a/goalie/go_mesh_seq.py b/goalie/go_mesh_seq.py index d26538d6..d2e2273f 100644 --- a/goalie/go_mesh_seq.py +++ b/goalie/go_mesh_seq.py @@ -327,9 +327,7 @@ def fixed_point_iteration( if self.params.convergence_criteria == "all": if not converged: self.converged[:] = False - pyrint( - f"Failed to converge in {self.params.maxiter} iterations." - ) + pyrint(f"Failed to converge in {self.params.maxiter} iterations.") else: for i, conv in enumerate(self.converged): if not conv: diff --git a/goalie/interpolation.py b/goalie/interpolation.py index 2fbdebf9..167a0c04 100644 --- a/goalie/interpolation.py +++ b/goalie/interpolation.py @@ -1,7 +1,7 @@ """ Driver functions for mesh-to-mesh data transfer. """ -from .utility import assemble_mass_matrix +from .utility import assemble_mass_matrix, cofunction2function, function2cofunction import firedrake from firedrake.petsc import PETSc from petsc4py import PETSc as petsc4py @@ -29,8 +29,13 @@ def project( seek to project into :kwarg adjoint: apply the transposed projection operator? """ - if not isinstance(source, firedrake.Function): - raise NotImplementedError("Can only currently project Functions.") # TODO + if not isinstance(source, (firedrake.Function, firedrake.Cofunction)): + raise NotImplementedError( + "Can only currently project Functions and Cofunctions." + ) # TODO + adj_value = isinstance(source, firedrake.Cofunction) + if adj_value: + source = cofunction2function(source) Vs = source.function_space() if isinstance(target_space, firedrake.Function): target = target_space @@ -66,7 +71,10 @@ def project( ) # Apply projector - return (_project_adjoint if adjoint else _project)(source, target, **kwargs) + target = (_project_adjoint if adjoint else _project)(source, target, **kwargs) + if adj_value: + target = function2cofunction(target) + return target @PETSc.Log.EventDecorator("goalie.interpolation.project") diff --git a/goalie/mesh_seq.py b/goalie/mesh_seq.py index d868e1e3..f6d95e70 100644 --- a/goalie/mesh_seq.py +++ b/goalie/mesh_seq.py @@ -3,8 +3,8 @@ """ import firedrake from firedrake.petsc import PETSc -from firedrake.adjoint.solving import get_solve_blocks -from pyadjoint import get_working_tape, Block +from firedrake.adjoint_utils.solving import get_solve_blocks +from firedrake.adjoint import pyadjoint from .interpolation import project from .log import pyrint, debug, warning, info, logger, DEBUG from .options import AdaptParameters @@ -328,9 +328,12 @@ def get_solve_blocks(self, field: str, subinterval: int) -> list: :arg field: name of the prognostic solution field :arg subinterval: subinterval index """ + tape = pyadjoint.get_working_tape() + if tape is None: + self.warning("Tape does not exist!") + return [] - # Get all blocks - blocks = get_working_tape().get_blocks() + blocks = tape.get_blocks() if len(blocks) == 0: self.warning("Tape has no blocks!") return blocks @@ -385,9 +388,7 @@ def get_solve_blocks(self, field: str, subinterval: int) -> list: ) return solve_blocks - def _output( - self, field: str, subinterval: int, solve_block: Block - ) -> firedrake.Function: + def _output(self, field, subinterval, solve_block): """ For a given solve block and solution field, get the block's outputs which corresponds to the solution from the current timestep. @@ -430,9 +431,7 @@ def _output( " outputs." ) - def _dependency( - self, field: str, subinterval: int, solve_block: Block - ) -> firedrake.Function: + def _dependency(self, field, subinterval, solve_block): """ For a given solve block and solution field, get the block's dependency which corresponds to the solution from the previous timestep. @@ -500,8 +499,6 @@ def solve_forward(self, solver_kwargs: dict = {}) -> AttrDict: :return solution: an :class:`~.AttrDict` containing solution fields and their lagged versions. """ - from firedrake_adjoint import pyadjoint - num_subintervals = len(self) function_spaces = self.function_spaces P = self.time_partition @@ -528,9 +525,13 @@ def solve_forward(self, solver_kwargs: dict = {}) -> AttrDict: } ) - # Clear tape - tape = pyadjoint.get_working_tape() - tape.clear_tape() + # Start annotating + if pyadjoint.annotate_tape(): + tape = pyadjoint.get_working_tape() + if tape is not None: + tape.clear_tape() + else: + pyadjoint.continue_annotation() # Loop over the subintervals checkpoint = self.initial_condition @@ -588,7 +589,7 @@ def solve_forward(self, solver_kwargs: dict = {}) -> AttrDict: ) # Clear the tape to reduce the memory footprint - tape.clear_tape() + pyadjoint.get_working_tape().clear_tape() return solutions diff --git a/goalie/options.py b/goalie/options.py index a611fbc2..1450d546 100644 --- a/goalie/options.py +++ b/goalie/options.py @@ -1,5 +1,5 @@ from .utility import AttrDict -from firedrake.meshadapt import RiemannianMetric +from animate.adapt import RiemannianMetric __all__ = [ diff --git a/goalie/utility.py b/goalie/utility.py index 4af8c7e4..9b8193cc 100644 --- a/goalie/utility.py +++ b/goalie/utility.py @@ -105,27 +105,50 @@ def assemble_mass_matrix( return firedrake.assemble(lhs).petscmat -@PETSc.Log.EventDecorator("goalie.norm") -def norm(v: firedrake.Function, norm_type: str = "L2", **kwargs) -> float: +def cofunction2function(c): + """ + Map a :class:`Cofunction` to a :class:`Function`. + """ + f = firedrake.Function(c.function_space().dual()) + if isinstance(f.dat.data_with_halos, tuple): + for i, arr in enumerate(f.dat.data_with_halos): + arr[:] = c.dat.data_with_halos[i] + else: + f.dat.data_with_halos[:] = c.dat.data_with_halos + return f + + +def function2cofunction(f): + """ + Map a :class:`Function` to a :class:`Cofunction`. + """ + c = firedrake.Cofunction(f.function_space().dual()) + if isinstance(c.dat.data_with_halos, tuple): + for i, arr in enumerate(c.dat.data_with_halos): + arr[:] = f.dat.data_with_halos[i] + else: + c.dat.data_with_halos[:] = f.dat.data_with_halos + return c + + +@PETSc.Log.EventDecorator() +def norm(v, norm_type="L2", **kwargs): r""" - Overload :func:`firedrake.norms.norm` to - allow for :math:`\ell^p` norms. + Overload :func:`firedrake.norms.norm` to allow for :math:`\ell^p` norms. - Note that this version is case sensitive, - i.e. ``'l2'`` and ``'L2'`` will give + Note that this version is case sensitive, i.e. ``'l2'`` and ``'L2'`` will give different results in general. - :arg v: the :class:`firedrake.function.Function` - to take the norm of - :kwarg norm_type: choose from ``'l1'``, ``'l2'``, - ``'linf'``, ``'L2'``, ``'Linf'``, ``'H1'``, - ``'Hdiv'``, ``'Hcurl'``, or any ``'Lp'`` with - :math:`p >= 1`. - :kwarg condition: a UFL condition for specifying - a subdomain to compute the norm over - :kwarg boundary: should the norm be computed over - the domain boundary? + :arg v: the :class:`firedrake.function.Function` or + :class:`firedrake.cofunction.Cofunction` to take the norm of + :kwarg norm_type: choose from ``'l1'``, ``'l2'``, ``'linf'``, ``'L2'``, ``'Linf'``, + ``'H1'``, ``'Hdiv'``, ``'Hcurl'``, or any ``'Lp'`` with :math:`p >= 1`. + :kwarg condition: a UFL condition for specifying a subdomain to compute the norm + over + :kwarg boundary: should the norm be computed over the domain boundary? """ + if isinstance(v, firedrake.Cofunction): + v = cofunction2function(v) boundary = kwargs.get("boundary", False) condition = kwargs.get("condition", firedrake.Constant(1.0)) norm_codes = {"l1": 0, "l2": 2, "linf": 3} @@ -165,33 +188,32 @@ def norm(v: firedrake.Function, norm_type: str = "L2", **kwargs) -> float: return firedrake.assemble(condition * integrand ** (p / 2) * dX) ** (1 / p) -@PETSc.Log.EventDecorator("goalie.errornorm") -def errornorm(u, uh: firedrake.Function, norm_type: str = "L2", **kwargs) -> float: +@PETSc.Log.EventDecorator() +def errornorm(u, uh, norm_type="L2", **kwargs): r""" - Overload :func:`firedrake.norms.errornorm` - to allow for :math:`\ell^p` norms. + Overload :func:`firedrake.norms.errornorm` to allow for :math:`\ell^p` norms. - Note that this version is case sensitive, - i.e. ``'l2'`` and ``'L2'`` will give + Note that this version is case sensitive, i.e. ``'l2'`` and ``'L2'`` will give different results in general. :arg u: the 'true' value :arg uh: the approximation of the 'truth' - :kwarg norm_type: choose from ``'l1'``, ``'l2'``, - ``'linf'``, ``'L2'``, ``'Linf'``, ``'H1'``, - ``'Hdiv'``, ``'Hcurl'``, or any ``'Lp'`` with - :math:`p >= 1`. - :kwarg boundary: should the norm be computed over - the domain boundary? + :kwarg norm_type: choose from ``'l1'``, ``'l2'``, ``'linf'``, ``'L2'``, ``'Linf'``, + ``'H1'``, ``'Hdiv'``, ``'Hcurl'``, or any ``'Lp'`` with :math:`p >= 1`. + :kwarg boundary: should the norm be computed over the domain boundary? """ - if len(u.ufl_shape) != len(uh.ufl_shape): - raise RuntimeError("Mismatching rank between u and uh.") - + if isinstance(u, firedrake.Cofunction): + u = cofunction2function(u) + if isinstance(uh, firedrake.Cofunction): + uh = cofunction2function(uh) if not isinstance(uh, firedrake.Function): - raise TypeError(f"uh should be a Function, is a {type(uh).__name__}.") + raise TypeError(f"uh should be a Function, is a '{type(uh)}'.") if norm_type[0] == "l": if not isinstance(u, firedrake.Function): - raise TypeError(f"u should be a Function, is a {type(u).__name__}.") + raise TypeError(f"u should be a Function, is a '{type(u)}'.") + + if len(u.ufl_shape) != len(uh.ufl_shape): + raise RuntimeError("Mismatching rank between u and uh.") if isinstance(u, firedrake.Function): degree_u = u.function_space().ufl_element().degree() diff --git a/test/test_error_estimation.py b/test/test_error_estimation.py index a6c4dc10..465f1598 100644 --- a/test/test_error_estimation.py +++ b/test/test_error_estimation.py @@ -51,7 +51,7 @@ def test_cell_integral(self): F = conditional(x + y < 1, 1, 0) * dx indicator = form2indicator(F) self.assertAlmostEqual(indicator.dat.data[0], 0) - self.assertAlmostEqual(indicator.dat.data[1], 0.5) + self.assertAlmostEqual(indicator.dat.data[1], 1) class TestIndicators2Estimator(ErrorEstimationTestCase): @@ -129,7 +129,7 @@ def test_unit_time_instant(self): time_instant = TimeInstant("field", time=1.0) indicator = form2indicator(self.one * dx) estimator = indicators2estimator({"field": [[indicator]]}, time_instant) - self.assertAlmostEqual(estimator, 1.0) + self.assertAlmostEqual(estimator, 2.0) # 1 * (1 + 1) @parameterized.expand([[False], [True]]) def test_unit_time_instant_abs(self, absolute_value): @@ -138,19 +138,21 @@ def test_unit_time_instant_abs(self, absolute_value): estimator = indicators2estimator( {"field": [[indicator]]}, time_instant, absolute_value=absolute_value ) - self.assertAlmostEqual(estimator, 1.0 if absolute_value else -1.0) + self.assertAlmostEqual( + estimator, 2.0 if absolute_value else -2.0 + ) # (-)1 * (1 + 1) def test_half_time_instant(self): time_instant = TimeInstant("field", time=0.5) indicator = form2indicator(self.one * dx) estimator = indicators2estimator({"field": [[indicator]]}, time_instant) - self.assertAlmostEqual(estimator, 0.5) + self.assertAlmostEqual(estimator, 1.0) # 0.5 * (1 + 1) def test_time_partition_same_timestep(self): time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"]) indicator = form2indicator(self.one * dx) estimator = indicators2estimator({"field": [2 * [indicator]]}, time_partition) - self.assertAlmostEqual(estimator, 1.0) + self.assertAlmostEqual(estimator, 2.0) # 0.5 * (1 + 1) + 0.5 * (1 + 1) def test_time_partition_different_timesteps(self): time_partition = TimePartition(1.0, 2, [0.5, 0.25], ["field"]) @@ -158,7 +160,7 @@ def test_time_partition_different_timesteps(self): estimator = indicators2estimator( {"field": [[indicator], 2 * [indicator]]}, time_partition ) - self.assertAlmostEqual(estimator, 1.0) + self.assertAlmostEqual(estimator, 2.0) # 0.5 * (1 + 1) + 0.25 * 2 * (1 + 1) def test_time_instant_multiple_fields(self): time_instant = TimeInstant(["field1", "field2"], time=1.0) @@ -166,7 +168,7 @@ def test_time_instant_multiple_fields(self): estimator = indicators2estimator( {"field1": [[indicator]], "field2": [[indicator]]}, time_instant ) - self.assertAlmostEqual(estimator, 2.0) + self.assertAlmostEqual(estimator, 4.0) # 2 * (1 * (1 + 1)) class TestGetDWRIndicator(ErrorEstimationTestCase): @@ -245,32 +247,32 @@ def test_convert_neither(self): adjoint_error = {"field": self.two} test_space = {"field": self.one.function_space()} indicator = get_dwr_indicator(self.F, adjoint_error, test_space=test_space) - self.assertAlmostEqual(indicator.dat.data[0], 1) - self.assertAlmostEqual(indicator.dat.data[1], 1) + self.assertAlmostEqual(indicator.dat.data[0], 2) + self.assertAlmostEqual(indicator.dat.data[1], 2) def test_convert_both(self): test_space = self.one.function_space() indicator = get_dwr_indicator(self.F, self.two, test_space=test_space) - self.assertAlmostEqual(indicator.dat.data[0], 1) - self.assertAlmostEqual(indicator.dat.data[1], 1) + self.assertAlmostEqual(indicator.dat.data[0], 2) + self.assertAlmostEqual(indicator.dat.data[1], 2) def test_convert_test_space(self): adjoint_error = {"field": self.two} test_space = self.one.function_space() indicator = get_dwr_indicator(self.F, adjoint_error, test_space=test_space) - self.assertAlmostEqual(indicator.dat.data[0], 1) - self.assertAlmostEqual(indicator.dat.data[1], 1) + self.assertAlmostEqual(indicator.dat.data[0], 2) + self.assertAlmostEqual(indicator.dat.data[1], 2) def test_convert_adjoint_error(self): test_space = {"Dos": self.one.function_space()} indicator = get_dwr_indicator(self.F, self.two, test_space=test_space) - self.assertAlmostEqual(indicator.dat.data[0], 1) - self.assertAlmostEqual(indicator.dat.data[1], 1) + self.assertAlmostEqual(indicator.dat.data[0], 2) + self.assertAlmostEqual(indicator.dat.data[1], 2) def test_convert_adjoint_error_no_test_space(self): indicator = get_dwr_indicator(self.F, self.two) - self.assertAlmostEqual(indicator.dat.data[0], 1) - self.assertAlmostEqual(indicator.dat.data[1], 1) + self.assertAlmostEqual(indicator.dat.data[0], 2) + self.assertAlmostEqual(indicator.dat.data[1], 2) def test_convert_adjoint_error_mismatch(self): test_space = {"field": self.one.function_space()} diff --git a/test/test_interpolation.py b/test/test_interpolation.py index 1e0400fc..5d32bdc0 100644 --- a/test/test_interpolation.py +++ b/test/test_interpolation.py @@ -23,10 +23,9 @@ def sinusoid(self, source=True): def test_notimplemented_error(self): Vs = FunctionSpace(self.source_mesh, "CG", 1) Vt = FunctionSpace(self.target_mesh, "CG", 1) - source = Function(Vs) with self.assertRaises(NotImplementedError) as cm: - project(2 * source, Vt) - msg = "Can only currently project Functions." + project(2 * Function(Vs), Vt) + msg = "Can only currently project Functions and Cofunctions." self.assertEqual(str(cm.exception), msg) @parameterized.expand([[False], [True]]) diff --git a/test/test_mesh_seq.py b/test/test_mesh_seq.py index e0f1d03d..981f9561 100644 --- a/test/test_mesh_seq.py +++ b/test/test_mesh_seq.py @@ -159,7 +159,7 @@ def setUp(self): self.time_interval, self.mesh, get_function_spaces=self.get_p0_spaces ) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_not_function(self, MockSolveBlock): solve_block = MockSolveBlock() block_variable = BlockVariable(1) @@ -169,7 +169,7 @@ def test_output_not_function(self, MockSolveBlock): msg = "Solve block for field 'field' on subinterval 0 has no outputs." self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_wrong_function_space(self, MockSolveBlock): solve_block = MockSolveBlock() block_variable = BlockVariable(Function(FunctionSpace(self.mesh, "CG", 1))) @@ -179,7 +179,7 @@ def test_output_wrong_function_space(self, MockSolveBlock): msg = "Solve block for field 'field' on subinterval 0 has no outputs." self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_wrong_name(self, MockSolveBlock): solve_block = MockSolveBlock() function_space = FunctionSpace(self.mesh, "DG", 0) @@ -190,7 +190,7 @@ def test_output_wrong_name(self, MockSolveBlock): msg = "Solve block for field 'field' on subinterval 0 has no outputs." self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_valid(self, MockSolveBlock): solve_block = MockSolveBlock() function_space = FunctionSpace(self.mesh, "DG", 0) @@ -198,7 +198,7 @@ def test_output_valid(self, MockSolveBlock): solve_block._outputs = [block_variable] self.assertIsNotNone(self.mesh_seq._output("field", 0, solve_block)) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_multiple_valid_error(self, MockSolveBlock): solve_block = MockSolveBlock() function_space = FunctionSpace(self.mesh, "DG", 0) @@ -212,7 +212,7 @@ def test_output_multiple_valid_error(self, MockSolveBlock): ) self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_not_function(self, MockSolveBlock): solve_block = MockSolveBlock() block_variable = BlockVariable(1) @@ -222,7 +222,7 @@ def test_dependency_not_function(self, MockSolveBlock): msg = "Solve block for field 'field' on subinterval 0 has no dependencies." self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_wrong_function_space(self, MockSolveBlock): solve_block = MockSolveBlock() block_variable = BlockVariable(Function(FunctionSpace(self.mesh, "CG", 1))) @@ -232,7 +232,7 @@ def test_dependency_wrong_function_space(self, MockSolveBlock): msg = "Solve block for field 'field' on subinterval 0 has no dependencies." self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_wrong_name(self, MockSolveBlock): solve_block = MockSolveBlock() function_space = FunctionSpace(self.mesh, "DG", 0) @@ -243,7 +243,7 @@ def test_dependency_wrong_name(self, MockSolveBlock): msg = "Solve block for field 'field' on subinterval 0 has no dependencies." self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_valid(self, MockSolveBlock): solve_block = MockSolveBlock() function_space = FunctionSpace(self.mesh, "DG", 0) @@ -251,7 +251,7 @@ def test_dependency_valid(self, MockSolveBlock): solve_block._dependencies = [block_variable] self.assertIsNotNone(self.mesh_seq._dependency("field", 0, solve_block)) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_multiple_valid_error(self, MockSolveBlock): solve_block = MockSolveBlock() function_space = FunctionSpace(self.mesh, "DG", 0) @@ -265,7 +265,7 @@ def test_dependency_multiple_valid_error(self, MockSolveBlock): ) self.assertEqual(str(cm.exception), msg) - @patch("dolfin_adjoint_common.blocks.solving.GenericSolveBlock") + @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_steady(self, MockSolveBlock): time_interval = TimeInterval(1.0, 0.5, "field", field_types="steady") mesh_seq = MeshSeq( diff --git a/test/test_options.py b/test/test_options.py index f8ef7d74..4543537d 100644 --- a/test/test_options.py +++ b/test/test_options.py @@ -1,5 +1,5 @@ from firedrake import TensorFunctionSpace -from firedrake.meshadapt import RiemannianMetric +from animate.adapt import RiemannianMetric from goalie.options import * from utility import uniform_mesh import unittest diff --git a/test/test_utility.py b/test/test_utility.py index 1c4a45c5..233e6642 100644 --- a/test/test_utility.py +++ b/test/test_utility.py @@ -194,14 +194,14 @@ def test_shape_error(self): def test_not_function_error(self): with self.assertRaises(TypeError) as cm: - errornorm(self.f, Constant(1.0)) - msg = "uh should be a Function, is a Constant." + errornorm(self.f, 1.0) + msg = "uh should be a Function, is a ''." self.assertEqual(str(cm.exception), msg) def test_not_function_error_lp(self): with self.assertRaises(TypeError) as cm: - errornorm(Constant(1.0), self.f, norm_type="l1") - msg = "u should be a Function, is a Constant." + errornorm(1.0, self.f, norm_type="l1") + msg = "u should be a Function, is a ''." self.assertEqual(str(cm.exception), msg) def test_mixed_space_invalid_norm_error(self): diff --git a/test_adjoint/conftest.py b/test_adjoint/conftest.py index e2059478..aca2a714 100644 --- a/test_adjoint/conftest.py +++ b/test_adjoint/conftest.py @@ -128,37 +128,6 @@ def fin(): request.addfinalizer(fin) -@pytest.fixture(autouse=True) -def handle_taping(): - """ - **Disclaimer: copied from - firedrake/tests/regression/test_adjoint_operators.py - """ - yield - import pyadjoint - - tape = pyadjoint.get_working_tape() - tape.clear_tape() - if not pyadjoint.annotate_tape(): - pyadjoint.continue_annotation() - - -@pytest.fixture(autouse=True, scope="module") -def handle_exit_annotation(): - """ - Since importing firedrake_adjoint modifies a global variable, we need to - pause annotations at the end of the module. - - **Disclaimer: copied from - firedrake/tests/regression/test_adjoint_operators.py - """ - yield - import pyadjoint - - if pyadjoint.annotate_tape(): - pyadjoint.pause_annotation() - - def pytest_runtest_teardown(item, nextitem): """ Clear caches after running a test diff --git a/test_adjoint/examples/burgers.py b/test_adjoint/examples/burgers.py index e9630a59..fea98b83 100644 --- a/test_adjoint/examples/burgers.py +++ b/test_adjoint/examples/burgers.py @@ -39,8 +39,9 @@ def form(i, sols): u, u_ = sols["uv_2d"] dt = self.time_partition[i].timestep fs = self.function_spaces["uv_2d"][i] - dtc = Constant(dt) - nu = Constant(0.0001) + R = FunctionSpace(self[i], "R", 0) + dtc = Function(R).assign(dt) + nu = Function(R).assign(0.0001) v = TestFunction(fs) F = ( inner((u - u_) / dtc, v) * dx @@ -104,7 +105,8 @@ def get_qoi(self, sol, i): norm over the right hand boundary. """ - dtc = Constant(self.time_partition[i].timestep) + R = FunctionSpace(self[i], "R", 0) + dtc = Function(R).assign(self.time_partition[i].timestep) def time_integrated_qoi(t): u = sol["uv_2d"] diff --git a/test_adjoint/examples/point_discharge2d.py b/test_adjoint/examples/point_discharge2d.py index e8492bbe..57d91098 100644 --- a/test_adjoint/examples/point_discharge2d.py +++ b/test_adjoint/examples/point_discharge2d.py @@ -52,15 +52,17 @@ def source(mesh): def get_form(self): """ - Advection-diffusion with SUPG - stabilisation. + Advection-diffusion with SUPG stabilisation. """ def form(i, sols): c, c_ = sols["tracer_2d"] fs = self.function_spaces["tracer_2d"][i] - D = Constant(0.1) - u = Constant(as_vector([1.0, 0.0])) + R = FunctionSpace(self[i], "R", 0) + D = Function(R).assign(0.1) + u_x = Function(R).assign(1.0) + u_y = Function(R).assign(0.0) + u = as_vector([u_x, u_y]) h = CellSize(self[i]) S = source(self[i]) @@ -151,8 +153,9 @@ def analytical_solution(mesh): a given mesh. See [Riadh et al. 2014]. """ x, y = SpatialCoordinate(mesh) - u = Constant(1.0) - D = Constant(0.1) + R = FunctionSpace(mesh, "R", 0) + u = Function(R).assign(1.0) + D = Function(R).assign(0.1) Pe = 0.5 * u / D r = max_value(sqrt((x - src_x) ** 2 + (y - src_y) ** 2), src_r) return 0.5 / (pi * D) * exp(Pe * (x - src_x)) * bessk0(Pe * r) diff --git a/test_adjoint/examples/point_discharge3d.py b/test_adjoint/examples/point_discharge3d.py index 20f893bd..e9124f31 100644 --- a/test_adjoint/examples/point_discharge3d.py +++ b/test_adjoint/examples/point_discharge3d.py @@ -70,8 +70,12 @@ def get_form(self): def form(i, sols): c, c_ = sols["tracer_3d"] fs = self.function_spaces["tracer_3d"][i] - D = Constant(0.1) - u = Constant(as_vector([1.0, 0.0, 0.0])) + R = FunctionSpace(self[i], "R", 0) + D = Function(R).assign(0.1) + u_x = Function(R).assign(1.0) + u_y = Function(R).assign(0.0) + u_z = Function(R).assign(0.0) + u = as_vector([u_x, u_y, u_z]) h = CellSize(self[i]) S = source(self[i]) @@ -164,8 +168,9 @@ def analytical_solution(mesh): a given mesh. """ x, y, z = SpatialCoordinate(mesh) - u = Constant(1.0) - D = Constant(0.1) + R = FunctionSpace(mesh, "R", 0) + u = Function(R).assign(1.0) + D = Function(R).assign(0.1) Pe = 0.5 * u / D r = max_value(sqrt((x - src_x) ** 2 + (y - src_y) ** 2 + (z - src_z) ** 2), src_r) return 0.5 / (pi * D) * exp(Pe * (x - src_x)) * bessk0(Pe * r) diff --git a/test_adjoint/examples/steady_flow_past_cyl.py b/test_adjoint/examples/steady_flow_past_cyl.py index 927793a3..4f77ec59 100644 --- a/test_adjoint/examples/steady_flow_past_cyl.py +++ b/test_adjoint/examples/steady_flow_past_cyl.py @@ -38,7 +38,8 @@ def get_form(self): def form(i, sols): up, up_ = sols["up"] W = self.function_spaces["up"][i] - nu = Constant(1.0) + R = FunctionSpace(self[i], "R", 0) + nu = Function(R).assign(1.0) u, p = split(up) v, q = TestFunctions(W) F = ( diff --git a/test_adjoint/setup_adjoint_tests.py b/test_adjoint/setup_adjoint_tests.py new file mode 100644 index 00000000..52b59470 --- /dev/null +++ b/test_adjoint/setup_adjoint_tests.py @@ -0,0 +1,29 @@ +import pyadjoint +import pytest + + +@pytest.fixture(autouse=True) +def handle_taping(): + """ + **Disclaimer: copied from + firedrake/tests/regression/test_adjoint_operators.py + """ + yield + tape = pyadjoint.get_working_tape() + tape.clear_tape() + + +@pytest.fixture(autouse=True, scope="module") +def handle_annotation(): + """ + Since importing firedrake-adjoint modifies a global variable, we need to + pause annotations at the end of the module. + + **Disclaimer: copied from + firedrake/tests/regression/test_adjoint_operators.py + """ + if not pyadjoint.annotate_tape(): + pyadjoint.continue_annotation() + yield + if pyadjoint.annotate_tape(): + pyadjoint.pause_annotation() diff --git a/test_adjoint/test_adjoint.py b/test_adjoint/test_adjoint.py index a538a839..dbd078ef 100644 --- a/test_adjoint/test_adjoint.py +++ b/test_adjoint/test_adjoint.py @@ -3,6 +3,7 @@ """ from firedrake import * from goalie_adjoint import * +import pyadjoint import pytest import importlib import os @@ -96,9 +97,6 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): or as a time integral? :kwarg debug: toggle debugging mode """ - from firedrake_adjoint import pyadjoint - - # Debugging if debug: set_log_level(DEBUG) @@ -137,16 +135,20 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): # Solve forward and adjoint without solve_adjoint pyrint("\n--- Adjoint solve on 1 subinterval using pyadjoint\n") + if not pyadjoint.annotate_tape(): + pyadjoint.continue_annotation() + tape = pyadjoint.get_working_tape() + tape.clear_tape() ic = mesh_seq.initial_condition controls = [pyadjoint.Control(value) for key, value in ic.items()] sols = mesh_seq.solver(0, ic) qoi = mesh_seq.get_qoi(sols, 0) J = mesh_seq.J if qoi_type == "time_integrated" else qoi() m = pyadjoint.enlisting.Enlist(controls) - tape = pyadjoint.get_working_tape() - with pyadjoint.stop_annotating(): - with tape.marked_nodes(m): - tape.evaluate_adj(markings=True) + assert pyadjoint.annotate_tape() + pyadjoint.pause_annotation() + with tape.marked_nodes(m): + tape.evaluate_adj(markings=True) # FIXME: Using mixed Functions as Controls not correct J_expected = float(J) @@ -158,7 +160,8 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): adj_sols_expected[field] = solve_blocks[0].adj_sol.copy(deepcopy=True) if not steady: dep = mesh_seq._dependency(field, 0, solve_blocks[0]) - adj_values_expected[field] = Function(fs[0], val=dep.adj_value) + adj_values_expected[field] = Cofunction(fs[0].dual()) + adj_values_expected[field].assign(dep.adj_value) # Loop over having one or two subintervals for N in range(1, 2 if steady else 3): @@ -218,6 +221,9 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): f"Adjoint values do not match at t=0 (error {err:.4e}.)" ) + tape = pyadjoint.get_working_tape() + tape.clear_tape() + def plot_solutions(problem, qoi_type, debug=True): """ diff --git a/test_adjoint/test_fp_iteration.py b/test_adjoint/test_fp_iteration.py index ff0e5f29..f3147722 100644 --- a/test_adjoint/test_fp_iteration.py +++ b/test_adjoint/test_fp_iteration.py @@ -30,11 +30,10 @@ def solver(index, ic): def get_qoi(mesh_seq, solutions, index): + R = FunctionSpace(mesh_seq[index], "R", 0) + def qoi(): - if mesh_seq.fp_iteration % 2 == 0: - return Constant(1, domain=mesh_seq[index]) * dx - else: - return Constant(2, domain=mesh_seq[index]) * dx + return Function(R).assign(1 if mesh_seq.fp_iteration % 2 == 0 else 2) * dx return qoi @@ -106,7 +105,9 @@ def adaptor(mesh_seq, sols): expected = [[1, 1], [1, 2], [1, 1], [1, 2], [1, 1], [1, 2]] self.assertEqual(mesh_seq.element_counts, expected) self.assertTrue(np.allclose(mesh_seq.converged, [True, False])) - self.assertTrue(np.allclose(mesh_seq.check_convergence, [not drop_out_converged, True])) + self.assertTrue( + np.allclose(mesh_seq.check_convergence, [not drop_out_converged, True]) + ) def test_no_late_convergence(self): mesh1 = UnitSquareMesh(1, 1) @@ -186,7 +187,9 @@ def test_dropout(self, drop_out_converged): time_partition = TimePartition(1.0, 2, [0.5, 0.5], []) ap = GoalOrientedParameters(self.parameters) ap.update({"drop_out_converged": drop_out_converged}) - mesh_seq = self.mesh_seq(time_partition, mesh2, parameters=ap, qoi_type="end_time") + mesh_seq = self.mesh_seq( + time_partition, mesh2, parameters=ap, qoi_type="end_time" + ) def adaptor(mesh_seq, sols, indicators): mesh_seq[1] = mesh1 if mesh_seq.fp_iteration % 2 == 0 else mesh2 @@ -196,7 +199,9 @@ def adaptor(mesh_seq, sols, indicators): expected = [[1, 1], [1, 2], [1, 1], [1, 2], [1, 1], [1, 2]] self.assertEqual(mesh_seq.element_counts, expected) self.assertTrue(np.allclose(mesh_seq.converged, [True, False])) - self.assertTrue(np.allclose(mesh_seq.check_convergence, [not drop_out_converged, True])) + self.assertTrue( + np.allclose(mesh_seq.check_convergence, [not drop_out_converged, True]) + ) def test_no_late_convergence(self): mesh1 = UnitSquareMesh(1, 1) @@ -204,7 +209,9 @@ def test_no_late_convergence(self): time_partition = TimePartition(1.0, 2, [0.5, 0.5], []) ap = GoalOrientedParameters(self.parameters) ap.update({"drop_out_converged": True}) - mesh_seq = self.mesh_seq(time_partition, mesh2, parameters=ap, qoi_type="end_time") + mesh_seq = self.mesh_seq( + time_partition, mesh2, parameters=ap, qoi_type="end_time" + ) def adaptor(mesh_seq, sols, indicators): mesh_seq[0] = mesh1 if mesh_seq.fp_iteration % 2 == 0 else mesh2 @@ -221,7 +228,9 @@ def test_convergence_criteria_all(self): time_partition = TimePartition(1.0, 1, 0.5, []) ap = GoalOrientedParameters(self.parameters) ap.update({"convergence_criteria": "all"}) - mesh_seq = self.mesh_seq(time_partition, mesh, parameters=ap, qoi_type="end_time") + mesh_seq = self.mesh_seq( + time_partition, mesh, parameters=ap, qoi_type="end_time" + ) def adaptor(mesh_seq, sols, indicators): return [False] diff --git a/test_adjoint/test_mesh_seq.py b/test_adjoint/test_mesh_seq.py index 6b500168..093c4fae 100644 --- a/test_adjoint/test_mesh_seq.py +++ b/test_adjoint/test_mesh_seq.py @@ -6,10 +6,9 @@ from goalie.log import * from goalie.mesh_seq import MeshSeq from goalie.time_partition import TimeInterval -import pyadjoint import logging -import pytest import unittest +from setup_adjoint_tests import * class TestGetSolveBlocks(unittest.TestCase): diff --git a/test_adjoint/test_utils.py b/test_adjoint/test_utils.py index 58a81d81..bbe65431 100644 --- a/test_adjoint/test_utils.py +++ b/test_adjoint/test_utils.py @@ -3,6 +3,7 @@ from goalie.adjoint import annotate_qoi import numpy as np import unittest +from setup_adjoint_tests import * class TestAdjointUtils(unittest.TestCase): @@ -20,8 +21,10 @@ def mesh_seq(self, qoi_type="end_time"): def test_annotate_qoi_0args(self): @annotate_qoi def get_qoi(mesh_seq, solution_map, i): + R = FunctionSpace(mesh_seq[i], "R", 0) + def qoi(): - return Constant(1.0, domain=mesh_seq[i]) * dx + return Function(R).assign(1) * dx return qoi @@ -30,8 +33,10 @@ def qoi(): def test_annotate_qoi_1arg(self): @annotate_qoi def get_qoi(mesh_seq, solution_map, i): + R = FunctionSpace(mesh_seq[i], "R", 0) + def qoi(t): - return Constant(1.0, domain=mesh_seq[i]) * dx + return Function(R).assign(1) * dx return qoi @@ -40,8 +45,10 @@ def qoi(t): def test_annotate_qoi_0args_error(self): @annotate_qoi def get_qoi(mesh_seq, solution_map, i): + R = FunctionSpace(mesh_seq[i], "R", 0) + def qoi(): - return Constant(1.0, domain=mesh_seq[i]) * dx + return Function(R).assign(1) * dx return qoi @@ -53,8 +60,10 @@ def qoi(): def test_annotate_qoi_1arg_error(self): @annotate_qoi def get_qoi(mesh_seq, solution_map, i): + R = FunctionSpace(mesh_seq[i], "R", 0) + def qoi(t): - return Constant(1.0, domain=mesh_seq[i]) * dx + return Function(R).assign(1) * dx return qoi @@ -66,8 +75,10 @@ def qoi(t): def test_annotate_qoi_2args_error(self): @annotate_qoi def get_qoi(mesh_seq, solution_map, i): + R = FunctionSpace(mesh_seq[i], "R", 0) + def qoi(t, r): - return Constant(1.0, domain=mesh_seq[i]) * dx + return Function(R).assign(1) * dx return qoi