From 48667468e87fd1d484b1f9bbbe6ada616442dba0 Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Mon, 7 Oct 2024 16:33:50 +0100 Subject: [PATCH] io: fix saving/loading of HDiv/HCurl functions on a high-order mesh --- firedrake/checkpointing.py | 221 ++++++++++++++++----- tests/firedrake/output/test_io_function.py | 80 +++++++- 2 files changed, 246 insertions(+), 55 deletions(-) diff --git a/firedrake/checkpointing.py b/firedrake/checkpointing.py index 5070a96315..39d0d4f9dc 100644 --- a/firedrake/checkpointing.py +++ b/firedrake/checkpointing.py @@ -7,7 +7,7 @@ from firedrake.cython import hdf5interface as h5i from firedrake.cython import dmcommon from firedrake.petsc import PETSc, OptionsManager -from firedrake.mesh import MeshTopology, ExtrudedMeshTopology, DEFAULT_MESH_NAME, make_mesh_from_coordinates, DistributedMeshOverlapType +from firedrake.mesh import MeshGeometry, MeshTopology, ExtrudedMeshTopology, DEFAULT_MESH_NAME, make_mesh_from_coordinates, DistributedMeshOverlapType from firedrake.functionspace import FunctionSpace from firedrake import functionspaceimpl as impl from firedrake.functionspacedata import get_global_numbering, create_element @@ -20,6 +20,7 @@ import numpy as np import os import h5py +from typing import Optional, Union __all__ = ["DumbCheckpoint", "HDF5File", "FILE_READ", "FILE_CREATE", "FILE_UPDATE", "CheckpointFile"] @@ -896,25 +897,47 @@ def _save_function_space_topology(self, tV): topology_dm.setName(base_tmesh_name) @PETSc.Log.EventDecorator("SaveFunction") - def save_function(self, f, idx=None, name=None, timestepping_info={}): - r"""Save a :class:`~.Function`. + def save_function( + self, + f: Function, + idx: Optional[int] = None, + name: Optional[str] = None, + timestepping_info: Optional[dict] = {}, + affine_coordinates: Optional[Union[MeshGeometry, Function]] = None, + affine_quadrature_degree: Optional[int] = None, + ) -> None: + """Save a :class:`~.Function`. - :arg f: the :class:`~.Function` to save. - :kwarg idx: optional timestepping index. A function can + Parameters + ---------- + f + `Function` to save. + idx + Optional timestepping index. A function can either be saved in timestepping mode or in normal mode (non-timestepping); for each function of interest, this method must always be called with the idx parameter set or never be called with the idx parameter set. - :kwarg name: optional alternative name to save the function under. - :kwarg timestepping_info: optional (requires idx) additional information + name + Optional alternative name to save the function under. + timestepping_info + Optional (requires idx) additional information such as time, timestepping that can be stored along a function for each index. + affine_coordinates + Representation of a fictitious affine mesh onto which + the function is mapped before saving; only significant for + HDiv/HCurl functions defined on high-order mesh. + affine_quadrature_degree + Quadrature degree to be used when mapping onto the affine mesh; + only significant for HDiv/HCurl functions defined on high-order mesh. + """ V = f.function_space() mesh = V.mesh() if name: g = Function(V, val=f.dat, name=name) - return self.save_function(g, idx=idx, timestepping_info=timestepping_info) + return self.save_function(g, idx=idx, timestepping_info=timestepping_info, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree) # -- Save function space -- self._save_function_space(V) # -- Save function -- @@ -926,7 +949,7 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}): path = os.path.join(base_path, str(i)) self.require_group(path) self.set_attr(path, PREFIX + "_function", fsub.name()) - self.save_function(fsub, idx=idx, timestepping_info=timestepping_info) + self.save_function(fsub, idx=idx, timestepping_info=timestepping_info, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree) self._update_mixed_function_name_mixed_function_space_name_map(mesh.name, {f.name(): V_name}) else: tf = f.topological @@ -940,10 +963,32 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}): path = self._path_to_function_embedded(tmesh.name, mesh.name, V_name, f.name()) self.require_group(path) method = get_embedding_method_for_checkpointing(element) - _V = FunctionSpace(mesh, _element) + if mesh.coordinates.function_space().ufl_element().embedded_subdegree > 1: + # Handle non-affine mesh; this is only relevant when embedding into a DG space. + if affine_coordinates is None: + raise ValueError("Must provide affine_coordinates to save functions on high-order mesh") + if affine_quadrature_degree is None: + raise ValueError("Must provide affine_quadrature_degree to save functions on high-order mesh") + if isinstance(affine_coordinates, MeshGeometry): + affine_coordinates = affine_coordinates.coordinates + else: + if not isinstance(affine_coordinates, Function): + raise ValueError("affine_coordinates must be {MeshGeometry, Function}") + if affine_coordinates.function_space().mesh().topology is not tmesh: + raise ValueError(f"affine_coordinates.function_space().mesh().topology ({affine_coordinates.function_space().mesh().topology}) is not f.mesh().topology ({tmesh})") + if affine_coordinates.function_space().mesh() is not mesh: + affine_coordinate_V = FunctionSpace(mesh, affine_coordinates.function_space().ufl_element()) + affine_coordinates = Function(affine_coordinate_V, val=affine_coordinates.topological) + if affine_coordinates.topological.name() == mesh.coordinates.topological.name(): + raise ValueError(f"affine_coordinate.name() ({affine_coordinates.name()}) == mesh.coordinates.topological.name() ({mesh.coordinates.topological.name()})") + self._save_ufl_element(path, PREFIX_EMBEDDED + "_affine_coordinate_element", affine_coordinates.topological.function_space().ufl_element()) + self.set_attr(path, PREFIX_EMBEDDED + "_affine_coordinates", affine_coordinates.topological.name()) + self.set_attr(path, PREFIX_EMBEDDED + "_affine_quadrature_degree", affine_quadrature_degree) + self._save_function_topology(affine_coordinates.topological) _name = "_".join([PREFIX_EMBEDDED, f.name()]) + _V = FunctionSpace(mesh, _element) _f = Function(_V, name=_name) - self._project_function_for_checkpointing(_f, f, method) + self._project_function_for_checkpointing(_f, f, method, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree) self.save_function(_f, idx=idx, timestepping_info=timestepping_info) self.set_attr(path, PREFIX_EMBEDDED + "_function", _name) else: @@ -1045,35 +1090,41 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter path = self._path_to_topology_extruded(tmesh_name) if path in self.h5pyfile: # -- Load mesh topology -- - base_tmesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh") - base_tmesh = self._load_mesh_topology(base_tmesh_name, reorder, distribution_parameters) - base_tmesh.init() - periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False - variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers") - if variable_layers: - cell = base_tmesh.ufl_cell() - element = finat.ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2) - _ = self._load_function_space_topology(base_tmesh, element) - base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name, - base_tmesh._distribution_name, - base_tmesh._permutation_name) - sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element) - _, _, lsf = self._function_load_utils[base_tmesh_key + sd_key] - nroots, _, _ = lsf.getGraph() - layers_a = np.empty(nroots, dtype=utils.IntType) - layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm) - layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"])) - self.viewer.pushGroup(path) - layers_a_iset.load(self.viewer) - self.viewer.popGroup() - layers_a = layers_a_iset.getIndices() - layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType) - unit = MPI._typedict[np.dtype(utils.IntType).char] - lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE) - lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE) + if topology is None: + base_tmesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh") + base_tmesh = self._load_mesh_topology(base_tmesh_name, reorder, distribution_parameters) + base_tmesh.init() + periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False + variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers") + if variable_layers: + cell = base_tmesh.ufl_cell() + element = finat.ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2) + _ = self._load_function_space_topology(base_tmesh, element) + base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name, + base_tmesh._distribution_name, + base_tmesh._permutation_name) + sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element) + _, _, lsf = self._function_load_utils[base_tmesh_key + sd_key] + nroots, _, _ = lsf.getGraph() + layers_a = np.empty(nroots, dtype=utils.IntType) + layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm) + layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"])) + self.viewer.pushGroup(path) + layers_a_iset.load(self.viewer) + self.viewer.popGroup() + layers_a = layers_a_iset.getIndices() + layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType) + unit = MPI._typedict[np.dtype(utils.IntType).char] + lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE) + lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE) + else: + layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers") + tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name) else: - layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers") - tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name) + if topology.name != tmesh_name: + raise RuntimeError(f"Got wrong mesh topology (f{topology.name}): expecting f{tmesh_name}") + tmesh = topology + base_tmesh = topology._base_mesh # -- Load mesh -- path = self._path_to_mesh(tmesh_name, name) coord_element = self._load_ufl_element(path, PREFIX + "_coordinate_element") @@ -1301,14 +1352,29 @@ def _load_function_space_topology(self, tmesh, element): return impl.FunctionSpace(tmesh, element) @PETSc.Log.EventDecorator("LoadFunction") - def load_function(self, mesh, name, idx=None): - r"""Load a :class:`~.Function` defined on `mesh`. + def load_function( + self, + mesh: MeshGeometry, + name: str, + idx: Optional[int] = None + ) -> Function: + """Load a :class:`~.Function` defined on ``mesh``. - :arg mesh: the mesh on which the function is defined. - :arg name: the name of the :class:`~.Function` to load. - :kwarg idx: optional timestepping index. A function can + Parameters + ---------- + mesh + mesh on which the function is defined. + name + name of the `Function` to load. + idx + Optional timestepping index. A function can be loaded with idx only when it was saved with idx. - :returns: the loaded :class:`~.Function`. + + Returns + ------- + Function + Loaded `Function`. + """ tmesh = mesh.topology if name in self._get_mixed_function_name_mixed_function_space_name_map(mesh.name): @@ -1341,7 +1407,19 @@ def load_function(self, mesh, name, idx=None): method = get_embedding_method_for_checkpointing(element) assert _element == _f.function_space().ufl_element() f = Function(V, name=name) - self._project_function_for_checkpointing(f, _f, method) + if mesh.coordinates.function_space().ufl_element().embedded_subdegree > 1 and \ + self.has_attr(path, PREFIX_EMBEDDED + "_affine_coordinates"): + # Handle non-affine mesh; this is only relevant when embedding into a DG space. + affine_coord_element = self._load_ufl_element(path, PREFIX_EMBEDDED + "_affine_coordinate_element") + affine_coord_name = self.get_attr(path, PREFIX_EMBEDDED + "_affine_coordinates") + affine_quadrature_degree = self.get_attr(path, PREFIX_EMBEDDED + "_affine_quadrature_degree") + affine_coordinates = self._load_function_topology(tmesh, affine_coord_element, affine_coord_name) + affine_coordinate_V = FunctionSpace(mesh, affine_coordinates.function_space().ufl_element()) + affine_coordinates = Function(affine_coordinate_V, val=affine_coordinates.topological) + else: + affine_coordinates = None + affine_quadrature_degree = None + self._project_function_for_checkpointing(f, _f, method, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree) return f else: tf_name = self.get_attr(path, PREFIX + "_vec") @@ -1637,13 +1715,52 @@ def _is_mixed_function_space(self, mesh_name, V_name): return True return False - def _project_function_for_checkpointing(self, f, _f, method): - if method == "project": - getattr(f, method)(_f, solver_parameters={"ksp_rtol": 1.e-16}) - elif method == "interpolate": - getattr(f, method)(_f) + def _project_function_for_checkpointing(self, target, source, method, affine_coordinates=None, affine_quadrature_degree=None): + if affine_coordinates: + if affine_quadrature_degree is None: + raise ValueError("Need affine_quadrature_degree to save/load HDiv/HCurl functions on high-order mesh") + # Need to map to/from the representation on a fictitious + # affine mesh represented by affine_coordinates. + K = firedrake.grad(affine_coordinates) # K = (\partial X /\partial x) = F^-1. + from_elem = source.function_space().ufl_element() + to_elem = target.function_space().ufl_element() + if to_elem.mapping() == "identity": + if from_elem.mapping() == "covariant Piola": + source = firedrake.transpose(firedrake.inv(K)) * source + elif from_elem.mapping() == "contravariant Piola": + source = 1. / firedrake.det(K) * K * source + else: + raise NotImplementedError(f"Unsupported pair: ({from_elem.mapping()}, {to_elem.mapping()})") + V = target.function_space() + u = firedrake.TrialFunction(V) + v = firedrake.TestFunction(V) + # Solve projection problem on the fictitious affine mesh. + a = firedrake.inner(u, v) * firedrake.det(K) * firedrake.dx(degree=affine_quadrature_degree) + L = firedrake.inner(source, v) * firedrake.det(K) * firedrake.dx(degree=affine_quadrature_degree) + firedrake.solve(a == L, target, solver_parameters={"ksp_rtol": 1.e-16}) + elif from_elem.mapping() == "identity": + if to_elem.mapping() == "covariant Piola": + source = firedrake.transpose(K) * source + elif to_elem.mapping() == "contravariant Piola": + source = firedrake.det(K) * firedrake.inv(K) * source + else: + raise NotImplementedError(f"Unsupported pair: ({from_elem.mapping()}, {to_elem.mapping()})") + V = target.function_space() + u = firedrake.TrialFunction(V) + v = firedrake.TestFunction(V) + # Solve projection problem on the high-order mesh. + a = firedrake.inner(u, v) * firedrake.dx(degree=affine_quadrature_degree) + L = firedrake.inner(source, v) * firedrake.dx(degree=affine_quadrature_degree) + firedrake.solve(a == L, target, solver_parameters={"ksp_rtol": 1.e-16}) + else: + raise NotImplementedError(f"Unsupported pair: ({from_elem.mapping()}, {to_elem.mapping()})") else: - raise ValueError(f"Unknown method for projecting: {method}") + if method == "project": + getattr(target, method)(source, solver_parameters={"ksp_rtol": 1.e-16}) + elif method == "interpolate": + getattr(target, method)(source) + else: + raise ValueError(f"Unknown method for projecting: {method}") @property def h5pyfile(self): diff --git a/tests/firedrake/output/test_io_function.py b/tests/firedrake/output/test_io_function.py index 64be091afe..0a1cdcdfe3 100644 --- a/tests/firedrake/output/test_io_function.py +++ b/tests/firedrake/output/test_io_function.py @@ -14,6 +14,7 @@ mesh_name = "m" extruded_mesh_name = "m_extruded" func_name = "f" +affine_quadrature_degree = 4 def _initialise_function(f, _f, method): @@ -127,10 +128,14 @@ def _get_expr(V): raise ValueError(f"Invalid shape {shape}") -def _load_check_save_functions(filename, func_name, comm, method, mesh_name, variable_layers=False): +def _load_check_save_functions(filename, func_name, comm, method, mesh_name, variable_layers=False, high_order_mesh=False): # Load with CheckpointFile(filename, "r", comm=comm) as afile: - meshB = afile.load_mesh(mesh_name) + if high_order_mesh: + meshB = afile.load_mesh(mesh_name + "_ho") + flat = afile.load_mesh(mesh_name, topology=meshB.topology) + else: + meshB = afile.load_mesh(mesh_name) fB = afile.load_function(meshB, func_name) # Check if variable_layers: @@ -144,7 +149,11 @@ def _load_check_save_functions(filename, func_name, comm, method, mesh_name, var assert assemble(inner(fB - fBe, fB - fBe) * dx) < 5.e-12 # Save with CheckpointFile(filename, 'w', comm=comm) as afile: - afile.save_function(fB) + if high_order_mesh: + afile.save_function(fB, affine_coordinates=flat, affine_quadrature_degree=affine_quadrature_degree) + afile.save_mesh(flat) + else: + afile.save_function(fB) @pytest.mark.parallel(nprocs=2) @@ -597,6 +606,71 @@ def test_io_function_extrusion_periodic(tmpdir): _load_check_save_functions(filename, func_name, comm, method, extruded_mesh_name) +@pytest.mark.parallel(nprocs=2) +@pytest.mark.parametrize('quadrilateral_family_degree', [(False, "RTF", 2), + (False, "RTE", 2), + (True, "RTCF", 2), + (True, "RTCE", 2),]) +def test_io_function_high_order_coordinates_2d(quadrilateral_family_degree, tmpdir): + quadrilateral, family, degree = quadrilateral_family_degree + filename = join(str(tmpdir), "test_io_function_high_order_coordinates_dump.h5") + filename = COMM_WORLD.bcast(filename, root=0) + flat = UnitSquareMesh(4, 4, quadrilateral=quadrilateral, comm=COMM_WORLD, name=mesh_name) + x, y = SpatialCoordinate(flat) + R = sqrt(2.0) + r = sqrt(x**2 + y**2) + coordV = VectorFunctionSpace(flat, "CG", 2) + bc = DirichletBC(coordV, R * as_vector([x / r, y / r]), (2, 4)) + coords = Function(coordV, name=mesh_name + "_ho_coordinates").interpolate(flat.coordinates) + bc.apply(coords) + mesh = Mesh(coords, name=mesh_name + "_ho") + V = FunctionSpace(mesh, family, degree) + f = Function(V, name=func_name) + method = get_embedding_method_for_checkpointing(V.ufl_element()) + _initialise_function(f, _get_expr(V), method) + with CheckpointFile(filename, 'w', comm=COMM_WORLD) as afile: + afile.save_function(f, affine_coordinates=flat, affine_quadrature_degree=affine_quadrature_degree) + afile.save_mesh(flat) + ntimes = COMM_WORLD.size + for i in range(ntimes): + mycolor = (COMM_WORLD.rank > ntimes - 1 - i) + comm = COMM_WORLD.Split(color=mycolor, key=COMM_WORLD.rank) + if mycolor == 0: + _load_check_save_functions(filename, func_name, comm, method, mesh_name, high_order_mesh=True) + + +@pytest.mark.parallel(nprocs=2) +@pytest.mark.parametrize('family_degree', [("RTCE", 1), + ("RTCF", 1),]) +def test_io_function_high_order_coordinates_extrusion_periodic(family_degree, tmpdir): + family, degree = family_degree + filename = join(str(tmpdir), "test_io_function_high_order_coordinates_extrusion_periodic_dump.h5") + filename = COMM_WORLD.bcast(filename, root=0) + m = 1 # num. element in radial direction + n = 4 # num. element in circumferential direction + mesh = IntervalMesh(m, 1.0, 2.0, name=mesh_name) + flat = ExtrudedMesh(mesh, layers=n, layer_height=2 * pi / n, extrusion_type="uniform", periodic=True, name=extruded_mesh_name) + elem = flat.coordinates.ufl_element().reconstruct(degree=2) + coordV = FunctionSpace(flat, elem) + x, y = SpatialCoordinate(flat) + coord = Function(coordV, name=extruded_mesh_name + "_ho_coordinates").interpolate(as_vector([x * cos(y), x * sin(y)])) + extm = make_mesh_from_coordinates(coord.topological, name=extruded_mesh_name + "_ho") + extm._base_mesh = mesh + V = FunctionSpace(extm, family, degree) + method = get_embedding_method_for_checkpointing(V.ufl_element()) + f = Function(V, name=func_name) + _initialise_function(f, _get_expr(V), method) + with CheckpointFile(filename, 'w', comm=COMM_WORLD) as afile: + afile.save_function(f, affine_coordinates=flat, affine_quadrature_degree=affine_quadrature_degree) + afile.save_mesh(flat) + ntimes = COMM_WORLD.size + for i in range(ntimes): + mycolor = (COMM_WORLD.rank > ntimes - 1 - i) + comm = COMM_WORLD.Split(color=mycolor, key=COMM_WORLD.rank) + if mycolor == 0: + _load_check_save_functions(filename, func_name, comm, method, extruded_mesh_name, high_order_mesh=True) + + @pytest.mark.parallel(nprocs=2) @pytest.mark.parametrize("cell_family_degree", [("triangle", "P", 1), ("quadrilateral", "Q", 1)])