Skip to content

Commit

Permalink
Split Cofunction
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 2, 2025
2 parents bb04bb0 + bfb7a19 commit d82039d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 28 deletions.
56 changes: 37 additions & 19 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@
import numpy
import collections

from ufl import as_vector, split, ZeroBaseForm
from ufl.classes import Zero, FixedIndex, ListTensor
from ufl import as_vector, split
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms import expand_derivatives
from ufl.corealg.map_dag import MultiFunction, map_expr_dags

from pyop2 import MixedDat
from pyop2.utils import as_tuple

from firedrake.petsc import PETSc
from firedrake.ufl_expr import Argument
from firedrake.functionspace import MixedFunctionSpace, FunctionSpace
from firedrake.cofunction import Cofunction
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace, DualSpace


def subspace(V, indices):
try:
indices = tuple(indices)
except TypeError:
# Only one index provided.
indices = (indices, )
if len(indices) == 1:
W = V[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
Expand Down Expand Up @@ -66,7 +65,7 @@ def split(self, form, argument_indices):
"""
args = form.arguments()
self._arg_cache = {}
self.blocks = dict(enumerate(argument_indices))
self.blocks = dict(enumerate(map(as_tuple, argument_indices)))
if len(args) == 0:
# Functional can't be split
return form
Expand All @@ -75,11 +74,13 @@ def split(self, form, argument_indices):
assert (idx[0] == 0 for idx in self.blocks.values())
return form
f = map_integrand_dags(self, form)
f = expand_derivatives(f)
if f.empty():
f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(), indices),
# TODO find a way to distinguish empty Forms avoiding expand_derivatives
if expand_derivatives(f).empty():
# Get ZeroBaseForm with the right shape
f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(),
self.blocks[arg.number()]),
arg.number(), part=arg.part())
for arg, indices in zip(form.arguments(), argument_indices)))
for arg in form.arguments()))
return f

expr = MultiFunction.reuse_if_untouched
Expand Down Expand Up @@ -109,6 +110,7 @@ def coefficient_derivative(self, o, expr, coefficients, arguments, cds):
@PETSc.Log.EventDecorator()
def argument(self, o):
V = o.function_space()

if len(V) == 1:
# Not on a mixed space, just return ourselves.
return o
Expand All @@ -118,12 +120,6 @@ def argument(self, o):

indices = self.blocks[o.number()]

try:
indices = tuple(indices)
except TypeError:
# Only one index provided.
indices = (indices, )

W = subspace(V, indices)
a = Argument(W, o.number(), part=o.part())
a = (a, ) if len(W) == 1 else split(a)
Expand All @@ -141,6 +137,28 @@ def argument(self, o):
args.extend(Zero() for j in numpy.ndindex(V[i].value_shape))
return self._arg_cache.setdefault(o, as_vector(args))

def cofunction(self, o):
V = o.function_space()

if len(V) == 1:
# Not on a mixed space, just return ourselves.
return o

try:
indices, = set(self.blocks.values())
except ValueError:
raise ValueError("Cofunction found on an off-diagonal block")

if len(indices) == 1:
i = indices[0]
W = V[i]
W = DualSpace(W.mesh(), W.ufl_element())
c = Cofunction(W, val=o.dat[i])
else:
W = MixedFunctionSpace([V[i] for i in indices])
c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices))
return c


SplitForm = collections.namedtuple("SplitForm", ["indices", "form"])

Expand Down
18 changes: 12 additions & 6 deletions scripts/firedrake-install
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,21 @@ from glob import iglob
from itertools import chain
import re
import importlib


class InstallError(Exception):
# Exception for generic install problems.
pass


try:
from pkg_resources.extern.packaging.version import Version, InvalidVersion
except ModuleNotFoundError:
from packaging.version import Version, InvalidVersion
try:
from packaging.version import Version, InvalidVersion
except ModuleNotFoundError:
raise InstallError("Neither setuptools or packaging found. Please "
"install one of these packages before trying again.")

osname = platform.uname().system
arch = platform.uname().machine
Expand Down Expand Up @@ -52,11 +63,6 @@ firedrake_apps = {
}


class InstallError(Exception):
# Exception for generic install problems.
pass


class FiredrakeConfiguration(dict):
"""A dictionary extended to facilitate the storage of Firedrake
configuration information."""
Expand Down
7 changes: 6 additions & 1 deletion tests/firedrake/regression/test_linesmoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def backend(request):
return request.param


def test_linesmoother(mesh, S1family, expected, backend):
@pytest.mark.parametrize("rhs", ["form_rhs", "cofunc_rhs"])
def test_linesmoother(mesh, S1family, expected, backend, rhs):
base_cell = mesh._base_mesh.ufl_cell()
S2family = "DG" if base_cell.is_simplex() else "DQ"
DGfamily = "DG" if mesh.ufl_cell().is_simplex() else "DQ"
Expand Down Expand Up @@ -86,6 +87,10 @@ def test_linesmoother(mesh, S1family, expected, backend):
f = exp(-rsq)

L = inner(f, q)*dx(degree=2*(degree+1))
if rhs == 'cofunc_rhs':
L = assemble(L)
elif rhs != 'form_rhs':
raise ValueError("Unknown right hand side type")

w0 = Function(W)
problem = LinearVariationalProblem(a, L, w0, bcs=bcs, aP=aP, form_compiler_parameters={"mode": "vanilla"})
Expand Down
7 changes: 6 additions & 1 deletion tests/firedrake/regression/test_matrix_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_matrixfree_action(a, V, bcs):

@pytest.mark.parametrize("preassembled", [False, True],
ids=["variational", "preassembled"])
@pytest.mark.parametrize("rhs", ["form_rhs", "cofunc_rhs"])
@pytest.mark.parametrize("parameters",
[{"ksp_type": "preonly",
"pc_type": "python",
Expand Down Expand Up @@ -168,7 +169,7 @@ def test_matrixfree_action(a, V, bcs):
"fieldsplit_1_fieldsplit_1_pc_type": "python",
"fieldsplit_1_fieldsplit_1_pc_python_type": "firedrake.AssembledPC",
"fieldsplit_1_fieldsplit_1_assembled_pc_type": "lu"}])
def test_fieldsplitting(mesh, preassembled, parameters):
def test_fieldsplitting(mesh, preassembled, parameters, rhs):
V = FunctionSpace(mesh, "CG", 1)
P = FunctionSpace(mesh, "DG", 0)
Q = VectorFunctionSpace(mesh, "DG", 1)
Expand All @@ -185,6 +186,10 @@ def test_fieldsplitting(mesh, preassembled, parameters):
a = inner(u, v)*dx

L = inner(expect, v)*dx
if rhs == 'cofunc_rhs':
L = assemble(L)
elif rhs != 'form_rhs':
raise ValueError("Unknown right hand side type")

f = Function(W)

Expand Down
7 changes: 6 additions & 1 deletion tests/firedrake/regression/test_nullspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ def test_nullspace_mixed_multiple_components():

@pytest.mark.parallel(nprocs=2)
@pytest.mark.parametrize("aux_pc", [False, True], ids=["PC(mu)", "PC(DG0-mu)"])
def test_near_nullspace_mixed(aux_pc):
@pytest.mark.parametrize("rhs", ["form_rhs", "cofunc_rhs"])
def test_near_nullspace_mixed(aux_pc, rhs):
# test nullspace and nearnullspace for a mixed Stokes system
# this is tested on the SINKER case of May and Moresi https://doi.org/10.1016/j.pepi.2008.07.036
# fails in parallel if nullspace is copied to fieldsplit_1_Mp_ksp solve (see PR #3488)
Expand Down Expand Up @@ -323,6 +324,10 @@ def test_near_nullspace_mixed(aux_pc):

f = as_vector((0, -9.8*conditional(inside_box, 2, 1)))
L = inner(f, v)*dx
if rhs == 'cofunc_rhs':
L = assemble(L)
elif rhs != 'form_rhs':
raise ValueError("Unknown right hand side type")

bcs = [DirichletBC(W[0].sub(0), 0, (1, 2)), DirichletBC(W[0].sub(1), 0, (3, 4))]

Expand Down

0 comments on commit d82039d

Please sign in to comment.