From 1dd4122d1de0d1fa9747e7e8e8d0ed0807ef2ffa Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 3 Dec 2024 15:07:18 +0000 Subject: [PATCH] Add cofunction handler to form splitter to enable fieldsplit with cofunctions. --- firedrake/formmanipulation.py | 67 ++++++++++++++++++- .../regression/test_fieldsplit_cofunction.py | 60 +++++++++++++++++ 2 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 tests/firedrake/regression/test_fieldsplit_cofunction.py diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 35a6789107..3bbb8d35ed 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -88,6 +88,7 @@ def argument(self, o): from ufl import split from firedrake import MixedFunctionSpace, FunctionSpace V = o.function_space() + if len(V) == 1: # Not on a mixed space, just return ourselves. return o @@ -127,6 +128,59 @@ def argument(self, o): for j in numpy.ndindex(V_is[i].value_shape)] return self._arg_cache.setdefault(o, as_vector(args)) + def cofunction(self, o): + from firedrake import Cofunction + from ufl.classes import ZeroBaseForm + + V = o.function_space() + + # Not on a mixed space, just return ourselves. + if len(V) == 1: + return o + + # We only need the test space for Cofunction + indices = self.blocks[0] + V_is = V.subfunctions + + try: + indices = tuple(indices) + nidx = len(indices) + except TypeError: # Only one index + indices = (indices, ) + nidx = 1 + + # the cofunction should only be returned on the + # diagonal elements, so if we are off-diagonal + # on a two-form then we return a zero form, + # analogously to the off components of arguments. + if len(self.blocks) == 2: + itest, itrial = self.blocks[0], self.blocks[1] + on_diag = (itest == itrial) + else: + on_diag = True + + # if we are on the diagonal, then return a Cofunction + # in the relevant subspace that points to the data in + # the full space. This means that the right hand side + # of the fieldsplit problem will be correct. + if on_diag: + if nidx == 1: + from firedrake import FunctionSpace + i = indices[0] + W = V_is[i] + W = FunctionSpace(W.mesh(), W.ufl_element()) + c = Cofunction(W.dual(), val=o.subfunctions[i].dat) + else: + from firedrake import MixedFunctionSpace + from pyop2 import MixedDat + W = MixedFunctionSpace([V_is[i] for i in indices]) + c = Cofunction(W.dual(), val=MixedDat(o.subfunctions[i].dat + for i in indices)) + else: + c = ZeroBaseForm(o.arguments()) + + return c + SplitForm = collections.namedtuple("SplitForm", ["indices", "form"]) @@ -160,6 +214,8 @@ def split_form(form, diagonal=False): compiler will remove these in its more complex simplification stages. """ + from firedrake import Cofunction + from ufl import FormSum splitter = ExtractSubBlock() args = form.arguments() shape = tuple(len(a.function_space()) for a in args) @@ -168,7 +224,16 @@ def split_form(form, diagonal=False): assert len(shape) == 2 for idx in numpy.ndindex(shape): f = splitter.split(form, idx) - if len(f.integrals()) > 0: + + # does f actually contain anything? + if isinstance(f, Cofunction): + flen = 1 + elif isinstance(f, FormSum): + flen = len(f.components()) + else: # Form + flen = len(f.integrals()) + + if flen > 0: if diagonal: i, j = idx if i != j: diff --git a/tests/firedrake/regression/test_fieldsplit_cofunction.py b/tests/firedrake/regression/test_fieldsplit_cofunction.py new file mode 100644 index 0000000000..4f61fb24d4 --- /dev/null +++ b/tests/firedrake/regression/test_fieldsplit_cofunction.py @@ -0,0 +1,60 @@ +import firedrake as fd + + +def test_fieldsplit_cofunction(): + """ + Test that fieldsplit preconditioners can be used + with a cofunction on the right hand side. + """ + mesh = fd.UnitSquareMesh(4, 4) + BDM = fd.FunctionSpace(mesh, "BDM", 1) + DG = fd.FunctionSpace(mesh, "DG", 0) + W = BDM*DG + + u, p = fd.TrialFunctions(W) + v, q = fd.TestFunctions(W) + + # simple wave equation scheme + a = (fd.dot(u, v) + fd.div(v)*p + - fd.div(u)*q + p*q)*fd.dx + + x, y = fd.SpatialCoordinate(mesh) + + f = fd.Function(W) + + f.subfunctions[0].project( + fd.as_vector([0.01*y, 0])) + f.subfunctions[1].interpolate( + -10*fd.exp(-(pow(x - 0.5, 2) + pow(y - 0.5, 2)) / 0.02)) + + # compare to plain 1-form + L_check = fd.inner(f, fd.TestFunction(W))*fd.dx + L_cofun = f.riesz_representation() + + # brute force schur complement solver + params = { + 'ksp_converged_reason': None, + 'ksp_type': 'preonly', + 'pc_type': 'fieldsplit', + 'pc_fieldsplit_type': 'schur', + 'pc_fieldsplit_schur_fact_type': 'full', + 'pc_fieldsplit_schur_precondition': 'full', + 'fieldsplit': { + 'ksp_type': 'preonly', + 'pc_type': 'lu' + } + } + + w_check = fd.Function(W) + problem_check = fd.LinearVariationalProblem(a, L_check, w_check) + solver_check = fd.LinearVariationalSolver(problem_check, + solver_parameters=params) + solver_check.solve() + + w_cofun = fd.Function(W) + problem_cofun = fd.LinearVariationalProblem(a, L_cofun, w_cofun) + solver_cofun = fd.LinearVariationalSolver(problem_cofun, + solver_parameters=params) + solver_cofun.solve() + + assert fd.errornorm(w_check, w_cofun) < 1e-14