Skip to content

Commit

Permalink
Add cofunction handler to form splitter to enable fieldsplit with cof…
Browse files Browse the repository at this point in the history
…unctions.
  • Loading branch information
JHopeCollins committed Dec 16, 2024
1 parent 4c93354 commit 1dd4122
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 1 deletion.
67 changes: 66 additions & 1 deletion firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def argument(self, o):
from ufl import split
from firedrake import MixedFunctionSpace, FunctionSpace
V = o.function_space()

if len(V) == 1:
# Not on a mixed space, just return ourselves.
return o
Expand Down Expand Up @@ -127,6 +128,59 @@ def argument(self, o):
for j in numpy.ndindex(V_is[i].value_shape)]
return self._arg_cache.setdefault(o, as_vector(args))

def cofunction(self, o):
from firedrake import Cofunction
from ufl.classes import ZeroBaseForm

V = o.function_space()

# Not on a mixed space, just return ourselves.
if len(V) == 1:
return o

# We only need the test space for Cofunction
indices = self.blocks[0]
V_is = V.subfunctions

try:
indices = tuple(indices)
nidx = len(indices)
except TypeError: # Only one index
indices = (indices, )
nidx = 1

# the cofunction should only be returned on the
# diagonal elements, so if we are off-diagonal
# on a two-form then we return a zero form,
# analogously to the off components of arguments.
if len(self.blocks) == 2:
itest, itrial = self.blocks[0], self.blocks[1]
on_diag = (itest == itrial)
else:
on_diag = True

# if we are on the diagonal, then return a Cofunction
# in the relevant subspace that points to the data in
# the full space. This means that the right hand side
# of the fieldsplit problem will be correct.
if on_diag:
if nidx == 1:
from firedrake import FunctionSpace
i = indices[0]
W = V_is[i]
W = FunctionSpace(W.mesh(), W.ufl_element())
c = Cofunction(W.dual(), val=o.subfunctions[i].dat)
else:
from firedrake import MixedFunctionSpace
from pyop2 import MixedDat
W = MixedFunctionSpace([V_is[i] for i in indices])
c = Cofunction(W.dual(), val=MixedDat(o.subfunctions[i].dat
for i in indices))
else:
c = ZeroBaseForm(o.arguments())

return c


SplitForm = collections.namedtuple("SplitForm", ["indices", "form"])

Expand Down Expand Up @@ -160,6 +214,8 @@ def split_form(form, diagonal=False):
compiler will remove these in its more complex simplification
stages.
"""
from firedrake import Cofunction
from ufl import FormSum
splitter = ExtractSubBlock()
args = form.arguments()
shape = tuple(len(a.function_space()) for a in args)
Expand All @@ -168,7 +224,16 @@ def split_form(form, diagonal=False):
assert len(shape) == 2
for idx in numpy.ndindex(shape):
f = splitter.split(form, idx)
if len(f.integrals()) > 0:

# does f actually contain anything?
if isinstance(f, Cofunction):
flen = 1
elif isinstance(f, FormSum):
flen = len(f.components())
else: # Form
flen = len(f.integrals())

if flen > 0:
if diagonal:
i, j = idx
if i != j:
Expand Down
60 changes: 60 additions & 0 deletions tests/firedrake/regression/test_fieldsplit_cofunction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import firedrake as fd


def test_fieldsplit_cofunction():
"""
Test that fieldsplit preconditioners can be used
with a cofunction on the right hand side.
"""
mesh = fd.UnitSquareMesh(4, 4)
BDM = fd.FunctionSpace(mesh, "BDM", 1)
DG = fd.FunctionSpace(mesh, "DG", 0)
W = BDM*DG

u, p = fd.TrialFunctions(W)
v, q = fd.TestFunctions(W)

# simple wave equation scheme
a = (fd.dot(u, v) + fd.div(v)*p
- fd.div(u)*q + p*q)*fd.dx

x, y = fd.SpatialCoordinate(mesh)

f = fd.Function(W)

f.subfunctions[0].project(
fd.as_vector([0.01*y, 0]))
f.subfunctions[1].interpolate(
-10*fd.exp(-(pow(x - 0.5, 2) + pow(y - 0.5, 2)) / 0.02))

# compare to plain 1-form
L_check = fd.inner(f, fd.TestFunction(W))*fd.dx
L_cofun = f.riesz_representation()

# brute force schur complement solver
params = {
'ksp_converged_reason': None,
'ksp_type': 'preonly',
'pc_type': 'fieldsplit',
'pc_fieldsplit_type': 'schur',
'pc_fieldsplit_schur_fact_type': 'full',
'pc_fieldsplit_schur_precondition': 'full',
'fieldsplit': {
'ksp_type': 'preonly',
'pc_type': 'lu'
}
}

w_check = fd.Function(W)
problem_check = fd.LinearVariationalProblem(a, L_check, w_check)
solver_check = fd.LinearVariationalSolver(problem_check,

Check failure on line 50 in tests/firedrake/regression/test_fieldsplit_cofunction.py

View workflow job for this annotation

GitHub Actions / Firedrake complex

test_fieldsplit_cofunction

ufl.algorithms.check_arities.ArityMismatch: Failure to conjugate test function in complex Form
Raw output
def test_fieldsplit_cofunction():
        """
        Test that fieldsplit preconditioners can be used
        with a cofunction on the right hand side.
        """
        mesh = fd.UnitSquareMesh(4, 4)
        BDM = fd.FunctionSpace(mesh, "BDM", 1)
        DG = fd.FunctionSpace(mesh, "DG", 0)
        W = BDM*DG
    
        u, p = fd.TrialFunctions(W)
        v, q = fd.TestFunctions(W)
    
        # simple wave equation scheme
        a = (fd.dot(u, v) + fd.div(v)*p
             - fd.div(u)*q + p*q)*fd.dx
    
        x, y = fd.SpatialCoordinate(mesh)
    
        f = fd.Function(W)
    
        f.subfunctions[0].project(
            fd.as_vector([0.01*y, 0]))
        f.subfunctions[1].interpolate(
            -10*fd.exp(-(pow(x - 0.5, 2) + pow(y - 0.5, 2)) / 0.02))
    
        # compare to plain 1-form
        L_check = fd.inner(f, fd.TestFunction(W))*fd.dx
        L_cofun = f.riesz_representation()
    
        # brute force schur complement solver
        params = {
            'ksp_converged_reason': None,
            'ksp_type': 'preonly',
            'pc_type': 'fieldsplit',
            'pc_fieldsplit_type': 'schur',
            'pc_fieldsplit_schur_fact_type': 'full',
            'pc_fieldsplit_schur_precondition': 'full',
            'fieldsplit': {
                'ksp_type': 'preonly',
                'pc_type': 'lu'
            }
        }
    
        w_check = fd.Function(W)
        problem_check = fd.LinearVariationalProblem(a, L_check, w_check)
>       solver_check = fd.LinearVariationalSolver(problem_check,
                                                  solver_parameters=params)

tests/firedrake/regression/test_fieldsplit_cofunction.py:50: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
petsc4py/PETSc/Log.pyx:188: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
petsc4py/PETSc/Log.pyx:189: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
/usr/lib/python3.12/contextlib.py:81: in inner
    return func(*args, **kwds)
firedrake/adjoint_utils/variational_solver.py:44: in wrapper
    init(self, problem, *args, **kwargs)
firedrake/variational_solver.py:243: in __init__
    ctx.set_jacobian(self.snes)
firedrake/solving_utils.py:288: in set_jacobian
    snes.setJacobian(self.form_jacobian, J=self._jac.petscmat,
/usr/lib/python3.12/functools.py:995: in __get__
    val = self.func(instance)
firedrake/solving_utils.py:493: in _jac
    return self._assembler_jac.allocate()
firedrake/assemble.py:1316: in allocate
    self._make_maps_and_regions())
firedrake/assemble.py:1345: in _make_maps_and_regions
    elif any(local_kernel.indices == (None, None) for assembler in self._all_assemblers for local_kernel, _ in assembler.local_kernels):
firedrake/assemble.py:1345: in <genexpr>
    elif any(local_kernel.indices == (None, None) for assembler in self._all_assemblers for local_kernel, _ in assembler.local_kernels):
/usr/lib/python3.12/functools.py:995: in __get__
    val = self.func(instance)
firedrake/assemble.py:1054: in local_kernels
    kernels = tsfc_interface.compile_form(
petsc4py/PETSc/Log.pyx:188: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
petsc4py/PETSc/Log.pyx:189: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
pyop2/caching.py:550: in wrapper
    value = func(*args, **kwargs)
petsc4py/PETSc/Log.pyx:188: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
petsc4py/PETSc/Log.pyx:189: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
pyop2/caching.py:550: in wrapper
    value = func(*args, **kwargs)
petsc4py/PETSc/Log.pyx:188: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
petsc4py/PETSc/Log.pyx:189: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
firedrake/tsfc_interface.py:220: in compile_form
    tsfc_kernel = TSFCKernel(
firedrake/tsfc_interface.py:95: in __init__
    tree = tsfc_compile_form(form, prefix=name, parameters=parameters,
petsc4py/PETSc/Log.pyx:188: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
petsc4py/PETSc/Log.pyx:189: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
pyop2/caching.py:550: in wrapper
    value = func(*args, **kwargs)
petsc4py/PETSc/Log.pyx:188: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
petsc4py/PETSc/Log.pyx:189: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
pyop2/caching.py:550: in wrapper
    value = func(*args, **kwargs)
tsfc/driver.py:68: in compile_form
    fd = ufl_utils.compute_form_data(form, complex_mode=complex_mode)
tsfc/ufl_utils.py:56: in compute_form_data
    fd = ufl_compute_form_data(
../firedrake_venv/src/ufl/ufl/algorithms/compute_form_data.py:427: in compute_form_data
    check_form_arity(preprocessed_form, self.original_form.arguments(), complex_mode)
../firedrake_venv/src/ufl/ufl/algorithms/check_arities.py:213: in check_form_arity
    check_integrand_arity(itg.integrand(), arguments, complex_mode)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

expr = Product(Sum(Sum(Sum(Product(Indexed(ListTensor(Indexed(ComponentTensor(IndexSum(Product(Indexed(ComponentTensor(Produc...ectorElement(FiniteElement('Lagrange', triangle, 1), dim=2), 3106))), MultiIndex((FixedIndex(1), FixedIndex(0))))))))))
arguments = (Argument(WithGeometry(FunctionSpace(<firedrake.mesh.MeshTopology object at 0x7f6680cba000>, FiniteElement('Brezzi-Dou...Marini', triangle, 1), name=None), Mesh(VectorElement(FiniteElement('Lagrange', triangle, 1), dim=2), 3106)), 1, None))
complex_mode = True

    def check_integrand_arity(expr, arguments, complex_mode=False):
        """Check the arity of an integrand."""
        arguments = tuple(sorted(set(arguments), key=lambda x: (x.number(), x.part())))
        rules = ArityChecker(arguments)
        arg_tuples = map_expr_dag(rules, expr, compress=False)
        args = tuple(a[0] for a in arg_tuples)
        if args != arguments:
            raise ArityMismatch(f"Integrand arguments {args} differ from form arguments {arguments}.")
        if complex_mode:
            # Check that the test function is conjugated and that any
            # trial function is not conjugated. Further arguments are
            # treated as trial funtions (i.e. no conjugation) but this
            # might not be correct.
            for arg, conj in arg_tuples:
                if arg.number() == 0 and not conj:
>                   raise ArityMismatch("Failure to conjugate test function in complex Form")
E                   ufl.algorithms.check_arities.ArityMismatch: Failure to conjugate test function in complex Form

../firedrake_venv/src/ufl/ufl/algorithms/check_arities.py:205: ArityMismatch
solver_parameters=params)
solver_check.solve()

w_cofun = fd.Function(W)
problem_cofun = fd.LinearVariationalProblem(a, L_cofun, w_cofun)
solver_cofun = fd.LinearVariationalSolver(problem_cofun,
solver_parameters=params)
solver_cofun.solve()

assert fd.errornorm(w_check, w_cofun) < 1e-14

0 comments on commit 1dd4122

Please sign in to comment.