Skip to content

Commit

Permalink
Merge pull request #69 from pyroteus/68_go_test
Browse files Browse the repository at this point in the history
Testing for goal-oriented error estimation code
  • Loading branch information
jwallwork23 authored Dec 8, 2023
2 parents 1c1ed07 + da89b50 commit ea8787f
Show file tree
Hide file tree
Showing 10 changed files with 536 additions and 388 deletions.
50 changes: 0 additions & 50 deletions goalie/error_estimation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""
Tools to automate goal-oriented error estimation.
"""
from .time_partition import TimePartition
import firedrake
from firedrake import Function, FunctionSpace
from firedrake.functionspaceimpl import WithGeometry
from firedrake.petsc import PETSc
import ufl
from collections.abc import Iterable
from typing import Dict, Optional, Union


Expand Down Expand Up @@ -60,54 +58,6 @@ def form2indicator(F: ufl.form.Form) -> Function:
return indicator


@PETSc.Log.EventDecorator()
def indicators2estimator(
indicators: Iterable, time_partition: TimePartition, absolute_value: bool = False
) -> float:
r"""
Deduce the error estimator value associated with error indicator fields defined over
a :class:`~.MeshSeq`.
:arg indicators: the list of list of error indicator
:class:`firedrake.function.Function`\s
:arg time_partition: the :class:`~.TimePartition` instance for the problem being
solved
:kwarg absolute_value: toggle whether to take the modulus on each element
"""
if not isinstance(indicators, dict):
raise TypeError(
f"Expected 'indicators' to be a dict, not '{type(indicators)}'."
)
if not isinstance(time_partition, TimePartition):
raise TypeError(
f"Expected 'time_partition' to be a TimePartition, not '{type(time_partition)}'."
)
if not isinstance(absolute_value, bool):
raise TypeError(
f"Expected 'absolute_value' to be a bool, not '{type(absolute_value)}'."
)
estimator = 0
for field, by_field in indicators.items():
if field not in time_partition.fields:
raise ValueError(
f"Key '{field}' does not exist in the TimePartition provided."
)
if isinstance(by_field, Function) or not isinstance(by_field, Iterable):
raise TypeError(
f"Expected values of 'indicators' to be iterables, not '{type(by_field)}'."
)
for by_mesh, dt in zip(by_field, time_partition.timesteps):
if isinstance(by_mesh, Function) or not isinstance(by_mesh, Iterable):
raise TypeError(
f"Expected entries of 'indicators' to be iterables, not '{type(by_mesh)}'."
)
for indicator in by_mesh:
if absolute_value:
indicator.interpolate(abs(indicator))
estimator += dt * indicator.vector().gather().sum()
return estimator


@PETSc.Log.EventDecorator()
def get_dwr_indicator(
F, adjoint_error: Function, test_space: Optional[Union[WithGeometry, Dict]] = None
Expand Down
135 changes: 73 additions & 62 deletions goalie/go_mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
Drivers for goal-oriented error estimation on sequences of meshes.
"""
from .adjoint import AdjointMeshSeq
from .error_estimation import get_dwr_indicator, indicators2estimator
from .error_estimation import get_dwr_indicator
from .log import pyrint
from .utility import AttrDict
from firedrake import Function, FunctionSpace, MeshHierarchy, TransferManager, project
from firedrake.petsc import PETSc
from collections.abc import Callable
from collections.abc import Callable, Iterable
import numpy as np
from typing import Tuple
import ufl
Expand All @@ -18,33 +18,31 @@

class GoalOrientedMeshSeq(AdjointMeshSeq):
"""
An extension of :class:`~.AdjointMeshSeq` to account for
goal-oriented problems.
An extension of :class:`~.AdjointMeshSeq` to account for goal-oriented problems.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.estimator_values = []

@PETSc.Log.EventDecorator("goalie.GoalOrientedMeshSeq.get_enriched_mesh_seq")
@PETSc.Log.EventDecorator()
def get_enriched_mesh_seq(
self, enrichment_method: str = "p", num_enrichments: int = 1
) -> AdjointMeshSeq:
"""
Construct a sequence of globally enriched spaces.
Currently, global enrichment may be
achieved using one of:
Currently, global enrichment may be achieved using one of:
* h-refinement (``enrichment_method = 'h'``);
* p-refinement (``enrichment_method = 'p'``).
The number of refinements may be controlled by
the keyword argument ``num_enrichments``.
The number of refinements may be controlled by the keyword argument
``num_enrichments``.
"""
if enrichment_method not in ("h", "p"):
raise ValueError(f"Enrichment method {enrichment_method} not supported")
raise ValueError(f"Enrichment method '{enrichment_method}' not supported.")
if num_enrichments <= 0:
raise ValueError("A positive number of enrichments is required")
raise ValueError("A positive number of enrichments is required.")

# Apply h-refinement
if enrichment_method == "h":
Expand Down Expand Up @@ -79,63 +77,36 @@ def get_enriched_mesh_seq(

return mesh_seq_e

@PETSc.Log.EventDecorator("goalie.GoalOrientedMeshSeq.global_enrichment")
def global_enrichment(
self, enrichment_method: str = "p", num_enrichments: int = 1, **kwargs
) -> dict:
"""
Solve the forward and adjoint problems
associated with
:meth:`~.GoalOrientedMeshSeq.solver` in a
sequence of globally enriched spaces.
Currently, global enrichment may be
achieved using one of:
* h-refinement (``enrichment_method = 'h'``);
* p-refinement (``enrichment_method = 'p'``).
The number of refinements may be controlled by
the keyword argument ``num_enrichments``.
:kwarg kwargs: keyword arguments to pass to the
:meth:`~.AdjointMeshSeq.solve_adjoint` method
"""
mesh_seq = self.get_enriched_mesh_seq(
enrichment_method=enrichment_method,
num_enrichments=num_enrichments,
)
return mesh_seq.solve_adjoint(**kwargs)
@staticmethod
def _get_transfer_function(enrichment_method):
if enrichment_method == "h":
return TransferManager().prolong
else:
return lambda source, target: target.interpolate(source)

@PETSc.Log.EventDecorator("goalie.GoalOrientedMeshSeq.indicate_errors")
@PETSc.Log.EventDecorator()
def indicate_errors(
self,
enrichment_kwargs: dict = {},
adj_kwargs: dict = {},
indicator_fn: Callable = get_dwr_indicator,
) -> Tuple[dict, AttrDict]:
"""
Compute goal-oriented error indicators for each
subinterval based on solving the adjoint problem
in a globally enriched space.
:kwarg enrichment_kwargs: keyword arguments to pass
to the global enrichment method
:kwarg adj_kwargs: keyword arguments to pass to the
adjoint solver
:kwarg indicator_fn: function for error indication,
which takes the form, adjoint error and enriched
space(s) as arguments
"""
enrichment_method = enrichment_kwargs.get("enrichment_method", "p")
if enrichment_method == "h":
tm = TransferManager()
transfer = tm.prolong
else:

def transfer(source, target):
target.interpolate(source)
Compute goal-oriented error indicators for each subinterval based on solving the
adjoint problem in a globally enriched space.
:kwarg enrichment_kwargs: keyword arguments to pass to the global enrichment
method
:kwarg adj_kwargs: keyword arguments to pass to the adjoint solver
:kwarg indicator_fn: function for error indication, which takes the form,
adjoint error and enriched space(s) as arguments
"""
enrichment_kwargs.setdefault("enrichment_method", "p")
enrichment_kwargs.setdefault("num_enrichments", 1)
mesh_seq_e = self.get_enriched_mesh_seq(**enrichment_kwargs)
transfer = self._get_transfer_function(enrichment_kwargs["enrichment_method"])

# Solve the forward and adjoint problems on the MeshSeq and its enriched version
sols = self.solve_adjoint(**adj_kwargs)
sols_e = mesh_seq_e.solve_adjoint(**adj_kwargs)

Expand Down Expand Up @@ -184,8 +155,8 @@ def transfer(source, target):
forms = mesh_seq_e.form(i, mapping)
if not isinstance(forms, dict):
raise TypeError(
"The function defined by get_form should return a dictionary, not a"
f" {type(forms)}."
"The function defined by get_form should return a dictionary"
f", not type '{type(forms)}'."
)

# Loop over each strongly coupled field
Expand Down Expand Up @@ -215,6 +186,47 @@ def transfer(source, target):

return sols, indicators

@PETSc.Log.EventDecorator()
def indicators2estimator(
self, indicators: Iterable, absolute_value: bool = False
) -> float:
r"""
Deduce the error estimator value associated with error indicator fields defined over
a :class:`~.MeshSeq`.
:arg indicators: the list of list of error indicator
:class:`firedrake.function.Function`\s
:kwarg absolute_value: toggle whether to take the modulus on each element
"""
if not isinstance(indicators, dict):
raise TypeError(
f"Expected 'indicators' to be a dict, not '{type(indicators)}'."
)
if not isinstance(absolute_value, bool):
raise TypeError(
f"Expected 'absolute_value' to be a bool, not '{type(absolute_value)}'."
)
estimator = 0
for field, by_field in indicators.items():
if field not in self.time_partition.fields:
raise ValueError(
f"Key '{field}' does not exist in the TimePartition provided."
)
if isinstance(by_field, Function) or not isinstance(by_field, Iterable):
raise TypeError(
f"Expected values of 'indicators' to be iterables, not '{type(by_field)}'."
)
for by_mesh, dt in zip(by_field, self.time_partition.timesteps):
if isinstance(by_mesh, Function) or not isinstance(by_mesh, Iterable):
raise TypeError(
f"Expected entries of 'indicators' to be iterables, not '{type(by_mesh)}'."
)
for indicator in by_mesh:
if absolute_value:
indicator.interpolate(abs(indicator))
estimator += dt * indicator.vector().gather().sum()
return estimator

def check_estimator_convergence(self):
"""
Check for convergence of the fixed point iteration due to the relative
Expand Down Expand Up @@ -294,8 +306,7 @@ def fixed_point_iteration(
break

# Check for error estimator convergence
ee = indicators2estimator(indicators, self.time_partition)
self.estimator_values.append(ee)
self.estimator_values.append(self.indicators2estimator(indicators))
ee_converged = self.check_estimator_convergence()
if self.params.convergence_criteria == "any" and ee_converged:
self.converged[:] = True
Expand Down
Loading

0 comments on commit ea8787f

Please sign in to comment.