Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing for goal-oriented error estimation code #69

Merged
merged 42 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
252adb4
Reformat docstrings in go_mesh_seq
jwallwork23 Nov 28, 2023
1c9f8a0
Simplify EventDecorator call setup in go_mesh_seq
jwallwork23 Nov 28, 2023
8fef1b4
Move empty setup functions
jwallwork23 Nov 28, 2023
85dfafc
Add tests for h-enrichment
jwallwork23 Nov 28, 2023
d81508a
Add tests for p-enrichment
jwallwork23 Nov 28, 2023
4125b44
Add tests for value checking
jwallwork23 Nov 28, 2023
6b8fdbb
Create TrivialGoalOrientedBaseClass
jwallwork23 Nov 30, 2023
79828d5
Add tests for check_estimator_convergence
jwallwork23 Nov 30, 2023
75e1bfd
Lint
jwallwork23 Nov 30, 2023
9eaebf6
Add test_update_params for MeshSeq
jwallwork23 Nov 30, 2023
f46bef5
Create empty_adaptor for testing
jwallwork23 Nov 30, 2023
2c285e1
Apply Black formatter
jwallwork23 Nov 30, 2023
d874dfb
Add test_estimator_convergence_check_false
jwallwork23 Nov 30, 2023
0386800
Introduce MeshSeqBaseClass; give it test_update_params
jwallwork23 Nov 30, 2023
55518a6
Define mesh_seq abstractmethod; move test_convergence_noop
jwallwork23 Nov 30, 2023
6b3c3e3
Move test_noconvergence
jwallwork23 Nov 30, 2023
e5f2dc2
Move test_no_late_convergence
jwallwork23 Nov 30, 2023
5946022
Move test_dropout
jwallwork23 Nov 30, 2023
308746d
Add time_partition and mesh kwargs to simplify
jwallwork23 Nov 30, 2023
e177041
Move estimator convergence tests to test_fp_iteration
jwallwork23 Nov 30, 2023
48ae871
Move element convergence tests to test_fp_iteration
jwallwork23 Nov 30, 2023
e214d8c
Introduce TestAdjointMeshSeq
jwallwork23 Nov 30, 2023
2d61d50
Move qoi convergence tests to test_fp_iteration
jwallwork23 Nov 30, 2023
0f07023
Add test_qoi_convergence_check_false
jwallwork23 Nov 30, 2023
5ce4ea9
Pull out set_values and check_convergence methods
jwallwork23 Nov 30, 2023
ec7c21a
Combine convergence check tests
jwallwork23 Nov 30, 2023
c1b100c
Combine convergence_check_false tests
jwallwork23 Nov 30, 2023
49cd31b
Lint
jwallwork23 Nov 30, 2023
4ec4d6b
Move indicators2estimator to GoalOrientedMeshSeq; update tests
jwallwork23 Dec 3, 2023
7f0b31f
Further test coverage for convergence checks
jwallwork23 Dec 3, 2023
14b6c03
Avoid passing qoi_type kwarg in tests
jwallwork23 Dec 3, 2023
62bde66
Have test classes inherit default kwargs and setup
jwallwork23 Dec 3, 2023
1842ce0
Pull out oscillating_adaptors
jwallwork23 Dec 3, 2023
6b1d2d5
Avoid unneccessary inits of default mesh
jwallwork23 Dec 3, 2023
0f59522
Inline TimePartition inits
jwallwork23 Dec 3, 2023
caec891
Add test for convergence criteria any
jwallwork23 Dec 3, 2023
1821abc
Drop unused global_enrichment method
jwallwork23 Dec 3, 2023
0b29b26
Extract get_transfer_function method and test
jwallwork23 Dec 3, 2023
69ec413
Avoid use of empty functions for init
jwallwork23 Dec 3, 2023
38e4a1c
Simplify mesh_seq setup for testing
jwallwork23 Dec 3, 2023
3216fde
Add test for get_form output
jwallwork23 Dec 3, 2023
da89b50
Fix test_p_enrichment_mesh
jwallwork23 Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 0 additions & 50 deletions goalie/error_estimation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""
Tools to automate goal-oriented error estimation.
"""
from .time_partition import TimePartition
import firedrake
from firedrake import Function, FunctionSpace
from firedrake.functionspaceimpl import WithGeometry
from firedrake.petsc import PETSc
import ufl
from collections.abc import Iterable
from typing import Dict, Optional, Union


Expand Down Expand Up @@ -60,54 +58,6 @@ def form2indicator(F: ufl.form.Form) -> Function:
return indicator


@PETSc.Log.EventDecorator()
def indicators2estimator(
indicators: Iterable, time_partition: TimePartition, absolute_value: bool = False
) -> float:
r"""
Deduce the error estimator value associated with error indicator fields defined over
a :class:`~.MeshSeq`.

:arg indicators: the list of list of error indicator
:class:`firedrake.function.Function`\s
:arg time_partition: the :class:`~.TimePartition` instance for the problem being
solved
:kwarg absolute_value: toggle whether to take the modulus on each element
"""
if not isinstance(indicators, dict):
raise TypeError(
f"Expected 'indicators' to be a dict, not '{type(indicators)}'."
)
if not isinstance(time_partition, TimePartition):
raise TypeError(
f"Expected 'time_partition' to be a TimePartition, not '{type(time_partition)}'."
)
if not isinstance(absolute_value, bool):
raise TypeError(
f"Expected 'absolute_value' to be a bool, not '{type(absolute_value)}'."
)
estimator = 0
for field, by_field in indicators.items():
if field not in time_partition.fields:
raise ValueError(
f"Key '{field}' does not exist in the TimePartition provided."
)
if isinstance(by_field, Function) or not isinstance(by_field, Iterable):
raise TypeError(
f"Expected values of 'indicators' to be iterables, not '{type(by_field)}'."
)
for by_mesh, dt in zip(by_field, time_partition.timesteps):
if isinstance(by_mesh, Function) or not isinstance(by_mesh, Iterable):
raise TypeError(
f"Expected entries of 'indicators' to be iterables, not '{type(by_mesh)}'."
)
for indicator in by_mesh:
if absolute_value:
indicator.interpolate(abs(indicator))
estimator += dt * indicator.vector().gather().sum()
return estimator


@PETSc.Log.EventDecorator()
def get_dwr_indicator(
F, adjoint_error: Function, test_space: Optional[Union[WithGeometry, Dict]] = None
Expand Down
135 changes: 73 additions & 62 deletions goalie/go_mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
Drivers for goal-oriented error estimation on sequences of meshes.
"""
from .adjoint import AdjointMeshSeq
from .error_estimation import get_dwr_indicator, indicators2estimator
from .error_estimation import get_dwr_indicator
from .log import pyrint
from .utility import AttrDict
from firedrake import Function, FunctionSpace, MeshHierarchy, TransferManager, project
from firedrake.petsc import PETSc
from collections.abc import Callable
from collections.abc import Callable, Iterable
import numpy as np
from typing import Tuple
import ufl
Expand All @@ -18,33 +18,31 @@

class GoalOrientedMeshSeq(AdjointMeshSeq):
"""
An extension of :class:`~.AdjointMeshSeq` to account for
goal-oriented problems.
An extension of :class:`~.AdjointMeshSeq` to account for goal-oriented problems.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.estimator_values = []

@PETSc.Log.EventDecorator("goalie.GoalOrientedMeshSeq.get_enriched_mesh_seq")
@PETSc.Log.EventDecorator()
def get_enriched_mesh_seq(
self, enrichment_method: str = "p", num_enrichments: int = 1
) -> AdjointMeshSeq:
"""
Construct a sequence of globally enriched spaces.

Currently, global enrichment may be
achieved using one of:
Currently, global enrichment may be achieved using one of:
* h-refinement (``enrichment_method = 'h'``);
* p-refinement (``enrichment_method = 'p'``).

The number of refinements may be controlled by
the keyword argument ``num_enrichments``.
The number of refinements may be controlled by the keyword argument
``num_enrichments``.
"""
if enrichment_method not in ("h", "p"):
raise ValueError(f"Enrichment method {enrichment_method} not supported")
raise ValueError(f"Enrichment method '{enrichment_method}' not supported.")
if num_enrichments <= 0:
raise ValueError("A positive number of enrichments is required")
raise ValueError("A positive number of enrichments is required.")

# Apply h-refinement
if enrichment_method == "h":
Expand Down Expand Up @@ -79,63 +77,36 @@ def get_enriched_mesh_seq(

return mesh_seq_e

@PETSc.Log.EventDecorator("goalie.GoalOrientedMeshSeq.global_enrichment")
def global_enrichment(
self, enrichment_method: str = "p", num_enrichments: int = 1, **kwargs
) -> dict:
"""
Solve the forward and adjoint problems
associated with
:meth:`~.GoalOrientedMeshSeq.solver` in a
sequence of globally enriched spaces.

Currently, global enrichment may be
achieved using one of:
* h-refinement (``enrichment_method = 'h'``);
* p-refinement (``enrichment_method = 'p'``).

The number of refinements may be controlled by
the keyword argument ``num_enrichments``.

:kwarg kwargs: keyword arguments to pass to the
:meth:`~.AdjointMeshSeq.solve_adjoint` method
"""
mesh_seq = self.get_enriched_mesh_seq(
enrichment_method=enrichment_method,
num_enrichments=num_enrichments,
)
return mesh_seq.solve_adjoint(**kwargs)
@staticmethod
def _get_transfer_function(enrichment_method):
if enrichment_method == "h":
return TransferManager().prolong
else:
return lambda source, target: target.interpolate(source)

@PETSc.Log.EventDecorator("goalie.GoalOrientedMeshSeq.indicate_errors")
@PETSc.Log.EventDecorator()
def indicate_errors(
self,
enrichment_kwargs: dict = {},
adj_kwargs: dict = {},
indicator_fn: Callable = get_dwr_indicator,
) -> Tuple[dict, AttrDict]:
"""
Compute goal-oriented error indicators for each
subinterval based on solving the adjoint problem
in a globally enriched space.

:kwarg enrichment_kwargs: keyword arguments to pass
to the global enrichment method
:kwarg adj_kwargs: keyword arguments to pass to the
adjoint solver
:kwarg indicator_fn: function for error indication,
which takes the form, adjoint error and enriched
space(s) as arguments
"""
enrichment_method = enrichment_kwargs.get("enrichment_method", "p")
if enrichment_method == "h":
tm = TransferManager()
transfer = tm.prolong
else:

def transfer(source, target):
target.interpolate(source)
Compute goal-oriented error indicators for each subinterval based on solving the
adjoint problem in a globally enriched space.

:kwarg enrichment_kwargs: keyword arguments to pass to the global enrichment
method
:kwarg adj_kwargs: keyword arguments to pass to the adjoint solver
:kwarg indicator_fn: function for error indication, which takes the form,
adjoint error and enriched space(s) as arguments
"""
enrichment_kwargs.setdefault("enrichment_method", "p")
enrichment_kwargs.setdefault("num_enrichments", 1)
mesh_seq_e = self.get_enriched_mesh_seq(**enrichment_kwargs)
transfer = self._get_transfer_function(enrichment_kwargs["enrichment_method"])

# Solve the forward and adjoint problems on the MeshSeq and its enriched version
sols = self.solve_adjoint(**adj_kwargs)
sols_e = mesh_seq_e.solve_adjoint(**adj_kwargs)

Expand Down Expand Up @@ -184,8 +155,8 @@ def transfer(source, target):
forms = mesh_seq_e.form(i, mapping)
if not isinstance(forms, dict):
raise TypeError(
"The function defined by get_form should return a dictionary, not a"
f" {type(forms)}."
"The function defined by get_form should return a dictionary"
f", not type '{type(forms)}'."
)

# Loop over each strongly coupled field
Expand Down Expand Up @@ -215,6 +186,47 @@ def transfer(source, target):

return sols, indicators

@PETSc.Log.EventDecorator()
def indicators2estimator(
self, indicators: Iterable, absolute_value: bool = False
) -> float:
r"""
Deduce the error estimator value associated with error indicator fields defined over
a :class:`~.MeshSeq`.

:arg indicators: the list of list of error indicator
:class:`firedrake.function.Function`\s
:kwarg absolute_value: toggle whether to take the modulus on each element
"""
if not isinstance(indicators, dict):
raise TypeError(
f"Expected 'indicators' to be a dict, not '{type(indicators)}'."
)
if not isinstance(absolute_value, bool):
raise TypeError(
f"Expected 'absolute_value' to be a bool, not '{type(absolute_value)}'."
)
estimator = 0
for field, by_field in indicators.items():
if field not in self.time_partition.fields:
raise ValueError(
f"Key '{field}' does not exist in the TimePartition provided."
)
if isinstance(by_field, Function) or not isinstance(by_field, Iterable):
raise TypeError(
f"Expected values of 'indicators' to be iterables, not '{type(by_field)}'."
)
for by_mesh, dt in zip(by_field, self.time_partition.timesteps):
if isinstance(by_mesh, Function) or not isinstance(by_mesh, Iterable):
raise TypeError(
f"Expected entries of 'indicators' to be iterables, not '{type(by_mesh)}'."
)
for indicator in by_mesh:
if absolute_value:
indicator.interpolate(abs(indicator))
estimator += dt * indicator.vector().gather().sum()
return estimator

def check_estimator_convergence(self):
"""
Check for convergence of the fixed point iteration due to the relative
Expand Down Expand Up @@ -294,8 +306,7 @@ def fixed_point_iteration(
break

# Check for error estimator convergence
ee = indicators2estimator(indicators, self.time_partition)
self.estimator_values.append(ee)
self.estimator_values.append(self.indicators2estimator(indicators))
ee_converged = self.check_estimator_convergence()
if self.params.convergence_criteria == "any" and ee_converged:
self.converged[:] = True
Expand Down
Loading
Loading