Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

io: fix saving/loading of HDiv/HCurl functions on a high-order mesh #3838

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 169 additions & 52 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from firedrake.cython import hdf5interface as h5i
from firedrake.cython import dmcommon
from firedrake.petsc import PETSc, OptionsManager
from firedrake.mesh import MeshTopology, ExtrudedMeshTopology, DEFAULT_MESH_NAME, make_mesh_from_coordinates, DistributedMeshOverlapType
from firedrake.mesh import MeshGeometry, MeshTopology, ExtrudedMeshTopology, DEFAULT_MESH_NAME, make_mesh_from_coordinates, DistributedMeshOverlapType
from firedrake.functionspace import FunctionSpace
from firedrake import functionspaceimpl as impl
from firedrake.functionspacedata import get_global_numbering, create_element
Expand All @@ -20,6 +20,7 @@
import numpy as np
import os
import h5py
from typing import Optional, Union


__all__ = ["DumbCheckpoint", "HDF5File", "FILE_READ", "FILE_CREATE", "FILE_UPDATE", "CheckpointFile"]
Expand Down Expand Up @@ -896,25 +897,47 @@ def _save_function_space_topology(self, tV):
topology_dm.setName(base_tmesh_name)

@PETSc.Log.EventDecorator("SaveFunction")
def save_function(self, f, idx=None, name=None, timestepping_info={}):
r"""Save a :class:`~.Function`.
def save_function(
self,
f: Function,
idx: Optional[int] = None,
name: Optional[str] = None,
timestepping_info: Optional[dict] = {},
affine_coordinates: Optional[Union[MeshGeometry, Function]] = None,
affine_quadrature_degree: Optional[int] = None,
) -> None:
"""Save a :class:`~.Function`.

:arg f: the :class:`~.Function` to save.
:kwarg idx: optional timestepping index. A function can
Parameters
----------
f
`Function` to save.
idx
Optional timestepping index. A function can
either be saved in timestepping mode or in normal
mode (non-timestepping); for each function of interest,
this method must always be called with the idx parameter
set or never be called with the idx parameter set.
:kwarg name: optional alternative name to save the function under.
:kwarg timestepping_info: optional (requires idx) additional information
name
Optional alternative name to save the function under.
timestepping_info
Optional (requires idx) additional information
such as time, timestepping that can be stored along a function for
each index.
affine_coordinates
Representation of a fictitious affine mesh onto which
the function is mapped before saving; only significant for
HDiv/HCurl functions defined on high-order mesh.
affine_quadrature_degree
Quadrature degree to be used when mapping onto the affine mesh;
only significant for HDiv/HCurl functions defined on high-order mesh.

"""
V = f.function_space()
mesh = V.mesh()
if name:
g = Function(V, val=f.dat, name=name)
return self.save_function(g, idx=idx, timestepping_info=timestepping_info)
return self.save_function(g, idx=idx, timestepping_info=timestepping_info, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree)
# -- Save function space --
self._save_function_space(V)
# -- Save function --
Expand All @@ -926,7 +949,7 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
path = os.path.join(base_path, str(i))
self.require_group(path)
self.set_attr(path, PREFIX + "_function", fsub.name())
self.save_function(fsub, idx=idx, timestepping_info=timestepping_info)
self.save_function(fsub, idx=idx, timestepping_info=timestepping_info, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree)
self._update_mixed_function_name_mixed_function_space_name_map(mesh.name, {f.name(): V_name})
else:
tf = f.topological
Expand All @@ -940,10 +963,32 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
path = self._path_to_function_embedded(tmesh.name, mesh.name, V_name, f.name())
self.require_group(path)
method = get_embedding_method_for_checkpointing(element)
_V = FunctionSpace(mesh, _element)
if mesh.coordinates.function_space().ufl_element().embedded_subdegree > 1:
# Handle non-affine mesh; this is only relevant when embedding into a DG space.
if affine_coordinates is None:
raise ValueError("Must provide affine_coordinates to save functions on high-order mesh")
if affine_quadrature_degree is None:
raise ValueError("Must provide affine_quadrature_degree to save functions on high-order mesh")
if isinstance(affine_coordinates, MeshGeometry):
affine_coordinates = affine_coordinates.coordinates
else:
if not isinstance(affine_coordinates, Function):
raise ValueError("affine_coordinates must be {MeshGeometry, Function}")
if affine_coordinates.function_space().mesh().topology is not tmesh:
raise ValueError(f"affine_coordinates.function_space().mesh().topology ({affine_coordinates.function_space().mesh().topology}) is not f.mesh().topology ({tmesh})")
if affine_coordinates.function_space().mesh() is not mesh:
affine_coordinate_V = FunctionSpace(mesh, affine_coordinates.function_space().ufl_element())
affine_coordinates = Function(affine_coordinate_V, val=affine_coordinates.topological)
if affine_coordinates.topological.name() == mesh.coordinates.topological.name():
raise ValueError(f"affine_coordinate.name() ({affine_coordinates.name()}) == mesh.coordinates.topological.name() ({mesh.coordinates.topological.name()})")
self._save_ufl_element(path, PREFIX_EMBEDDED + "_affine_coordinate_element", affine_coordinates.topological.function_space().ufl_element())
self.set_attr(path, PREFIX_EMBEDDED + "_affine_coordinates", affine_coordinates.topological.name())
self.set_attr(path, PREFIX_EMBEDDED + "_affine_quadrature_degree", affine_quadrature_degree)
self._save_function_topology(affine_coordinates.topological)
_name = "_".join([PREFIX_EMBEDDED, f.name()])
_V = FunctionSpace(mesh, _element)
_f = Function(_V, name=_name)
self._project_function_for_checkpointing(_f, f, method)
self._project_function_for_checkpointing(_f, f, method, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree)
self.save_function(_f, idx=idx, timestepping_info=timestepping_info)
self.set_attr(path, PREFIX_EMBEDDED + "_function", _name)
else:
Expand Down Expand Up @@ -1045,35 +1090,41 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
path = self._path_to_topology_extruded(tmesh_name)
if path in self.h5pyfile:
# -- Load mesh topology --
base_tmesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh")
base_tmesh = self._load_mesh_topology(base_tmesh_name, reorder, distribution_parameters)
base_tmesh.init()
periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False
variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers")
if variable_layers:
cell = base_tmesh.ufl_cell()
element = finat.ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2)
_ = self._load_function_space_topology(base_tmesh, element)
base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element)
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
nroots, _, _ = lsf.getGraph()
layers_a = np.empty(nroots, dtype=utils.IntType)
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"]))
self.viewer.pushGroup(path)
layers_a_iset.load(self.viewer)
self.viewer.popGroup()
layers_a = layers_a_iset.getIndices()
layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType)
unit = MPI._typedict[np.dtype(utils.IntType).char]
lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE)
lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE)
if topology is None:
base_tmesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh")
base_tmesh = self._load_mesh_topology(base_tmesh_name, reorder, distribution_parameters)
base_tmesh.init()
periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False
variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers")
if variable_layers:
cell = base_tmesh.ufl_cell()
element = finat.ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2)
_ = self._load_function_space_topology(base_tmesh, element)
base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element)
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
nroots, _, _ = lsf.getGraph()
layers_a = np.empty(nroots, dtype=utils.IntType)
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"]))
self.viewer.pushGroup(path)
layers_a_iset.load(self.viewer)
self.viewer.popGroup()
layers_a = layers_a_iset.getIndices()
layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType)
unit = MPI._typedict[np.dtype(utils.IntType).char]
lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE)
lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE)
else:
layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers")
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
else:
layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers")
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
if topology.name != tmesh_name:
raise RuntimeError(f"Got wrong mesh topology (f{topology.name}): expecting f{tmesh_name}")
tmesh = topology
base_tmesh = topology._base_mesh
# -- Load mesh --
path = self._path_to_mesh(tmesh_name, name)
coord_element = self._load_ufl_element(path, PREFIX + "_coordinate_element")
Expand Down Expand Up @@ -1301,14 +1352,29 @@ def _load_function_space_topology(self, tmesh, element):
return impl.FunctionSpace(tmesh, element)

@PETSc.Log.EventDecorator("LoadFunction")
def load_function(self, mesh, name, idx=None):
r"""Load a :class:`~.Function` defined on `mesh`.
def load_function(
self,
mesh: MeshGeometry,
name: str,
idx: Optional[int] = None
) -> Function:
"""Load a :class:`~.Function` defined on ``mesh``.

:arg mesh: the mesh on which the function is defined.
:arg name: the name of the :class:`~.Function` to load.
:kwarg idx: optional timestepping index. A function can
Parameters
----------
mesh
mesh on which the function is defined.
name
name of the `Function` to load.
idx
Optional timestepping index. A function can
be loaded with idx only when it was saved with idx.
:returns: the loaded :class:`~.Function`.

Returns
-------
Function
Loaded `Function`.

"""
tmesh = mesh.topology
if name in self._get_mixed_function_name_mixed_function_space_name_map(mesh.name):
Expand Down Expand Up @@ -1341,7 +1407,19 @@ def load_function(self, mesh, name, idx=None):
method = get_embedding_method_for_checkpointing(element)
assert _element == _f.function_space().ufl_element()
f = Function(V, name=name)
self._project_function_for_checkpointing(f, _f, method)
if mesh.coordinates.function_space().ufl_element().embedded_subdegree > 1 and \
self.has_attr(path, PREFIX_EMBEDDED + "_affine_coordinates"):
# Handle non-affine mesh; this is only relevant when embedding into a DG space.
affine_coord_element = self._load_ufl_element(path, PREFIX_EMBEDDED + "_affine_coordinate_element")
affine_coord_name = self.get_attr(path, PREFIX_EMBEDDED + "_affine_coordinates")
affine_quadrature_degree = self.get_attr(path, PREFIX_EMBEDDED + "_affine_quadrature_degree")
affine_coordinates = self._load_function_topology(tmesh, affine_coord_element, affine_coord_name)
affine_coordinate_V = FunctionSpace(mesh, affine_coordinates.function_space().ufl_element())
affine_coordinates = Function(affine_coordinate_V, val=affine_coordinates.topological)
else:
affine_coordinates = None
affine_quadrature_degree = None
self._project_function_for_checkpointing(f, _f, method, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree)
return f
else:
tf_name = self.get_attr(path, PREFIX + "_vec")
Expand Down Expand Up @@ -1637,13 +1715,52 @@ def _is_mixed_function_space(self, mesh_name, V_name):
return True
return False

def _project_function_for_checkpointing(self, f, _f, method):
if method == "project":
getattr(f, method)(_f, solver_parameters={"ksp_rtol": 1.e-16})
elif method == "interpolate":
getattr(f, method)(_f)
def _project_function_for_checkpointing(self, target, source, method, affine_coordinates=None, affine_quadrature_degree=None):
if affine_coordinates:
if affine_quadrature_degree is None:
raise ValueError("Need affine_quadrature_degree to save/load HDiv/HCurl functions on high-order mesh")
# Need to map to/from the representation on a fictitious
# affine mesh represented by affine_coordinates.
K = firedrake.grad(affine_coordinates) # K = (\partial X /\partial x) = F^-1.
from_elem = source.function_space().ufl_element()
to_elem = target.function_space().ufl_element()
if to_elem.mapping() == "identity":
if from_elem.mapping() == "covariant Piola":
source = firedrake.transpose(firedrake.inv(K)) * source
elif from_elem.mapping() == "contravariant Piola":
source = 1. / firedrake.det(K) * K * source
else:
raise NotImplementedError(f"Unsupported pair: ({from_elem.mapping()}, {to_elem.mapping()})")
V = target.function_space()
u = firedrake.TrialFunction(V)
v = firedrake.TestFunction(V)
# Solve projection problem on the fictitious affine mesh.
a = firedrake.inner(u, v) * firedrake.det(K) * firedrake.dx(degree=affine_quadrature_degree)
L = firedrake.inner(source, v) * firedrake.det(K) * firedrake.dx(degree=affine_quadrature_degree)
firedrake.solve(a == L, target, solver_parameters={"ksp_rtol": 1.e-16})
elif from_elem.mapping() == "identity":
if to_elem.mapping() == "covariant Piola":
source = firedrake.transpose(K) * source
elif to_elem.mapping() == "contravariant Piola":
source = firedrake.det(K) * firedrake.inv(K) * source
else:
raise NotImplementedError(f"Unsupported pair: ({from_elem.mapping()}, {to_elem.mapping()})")
V = target.function_space()
u = firedrake.TrialFunction(V)
v = firedrake.TestFunction(V)
# Solve projection problem on the high-order mesh.
a = firedrake.inner(u, v) * firedrake.dx(degree=affine_quadrature_degree)
L = firedrake.inner(source, v) * firedrake.dx(degree=affine_quadrature_degree)
firedrake.solve(a == L, target, solver_parameters={"ksp_rtol": 1.e-16})
else:
raise NotImplementedError(f"Unsupported pair: ({from_elem.mapping()}, {to_elem.mapping()})")
else:
raise ValueError(f"Unknown method for projecting: {method}")
if method == "project":
getattr(target, method)(source, solver_parameters={"ksp_rtol": 1.e-16})
elif method == "interpolate":
getattr(target, method)(source)
else:
raise ValueError(f"Unknown method for projecting: {method}")

@property
def h5pyfile(self):
Expand Down
Loading
Loading