Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fieldsplit: replace empty Forms with ZeroBaseForm #3947

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ jobs:
--install defcon \
--install gadopt \
--install asQ \
--package-branch ufl pbrubeck/simplify-indexed \
|| (cat firedrake-install.log && /bin/false)
- name: Install test dependencies
run: |
Expand Down
13 changes: 8 additions & 5 deletions firedrake/adjoint_utils/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations
from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock
from firedrake.ufl_expr import derivative, adjoint
from ufl import replace


Expand All @@ -11,7 +12,6 @@ def _ad_annotate_init(init):
@no_annotations
@wraps(init)
def wrapper(self, *args, **kwargs):
from firedrake import derivative, adjoint, TrialFunction
init(self, *args, **kwargs)
self._ad_F = self.F
self._ad_u = self.u_restrict
Expand All @@ -20,10 +20,13 @@ def wrapper(self, *args, **kwargs):
try:
# Some forms (e.g. SLATE tensors) are not currently
# differentiable.
dFdu = derivative(self.F,
self.u_restrict,
TrialFunction(self.u_restrict.function_space()))
self._ad_adj_F = adjoint(dFdu)
dFdu = derivative(self.F, self.u_restrict)
try:
self._ad_adj_F = adjoint(dFdu)
except ValueError:
# Try again without expanding derivatives,
# as dFdu might have been simplied to an empty Form
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
except (TypeError, NotImplementedError):
self._ad_adj_F = None
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
Expand Down
22 changes: 13 additions & 9 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,10 +577,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
@staticmethod
def update_tensor(assembled_base_form, tensor):
if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)):
assembled_base_form.dat.copy(tensor.dat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.dat.zero()
else:
assembled_base_form.dat.copy(tensor.dat)
elif isinstance(tensor, matrix.MatrixBase):
# Uses the PETSc copy method.
assembled_base_form.petscmat.copy(tensor.petscmat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.petscmat.zeroEntries()
else:
assembled_base_form.petscmat.copy(tensor.petscmat)
else:
raise NotImplementedError("Cannot update tensor of type %s" % type(tensor))

Expand Down Expand Up @@ -1138,7 +1143,7 @@ class OneFormAssembler(ParloopFormAssembler):

Parameters
----------
form : ufl.Form or slate.TensorBasehe
form : ufl.Form or slate.TensorBase
1-form.

Notes
Expand Down Expand Up @@ -2127,14 +2132,13 @@ def iter_active_coefficients(form, kinfo):

@staticmethod
def iter_constants(form, kinfo):
"""Yield the form constants"""
"""Yield the form constants referenced in ``kinfo``."""
if isinstance(form, slate.TensorBase):
for const in form.constants():
yield const
all_constants = form.constants()
else:
all_constants = extract_firedrake_constants(form)
for constant_index in kinfo.constant_numbers:
yield all_constants[constant_index]
for constant_index in kinfo.constant_numbers:
yield all_constants[constant_index]

@staticmethod
def index_function_spaces(form, indices):
Expand Down
8 changes: 4 additions & 4 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,10 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
return
rank = len(self.f.arguments())
splitter = ExtractSubBlock()
if rank == 1:
form = splitter.split(self.f, argument_indices=(row_field, ))
elif rank == 2:
form = splitter.split(self.f, argument_indices=(row_field, col_field))
form = splitter.split(self.f, argument_indices=(row_field, col_field)[:rank])
if isinstance(form, ufl.ZeroBaseForm) or form.empty():
# form is empty, do nothing
return
if u is not None:
form = firedrake.replace(form, {self.u: u})
if action_x is not None:
Expand Down
123 changes: 48 additions & 75 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,30 @@
import numpy
import collections

from ufl import as_vector, FormSum, Form, split
from ufl import as_vector, split
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms import expand_derivatives
from ufl.corealg.map_dag import MultiFunction, map_expr_dags

from pyop2 import MixedDat
from pyop2.utils import as_tuple

from firedrake.petsc import PETSc
from firedrake.ufl_expr import Argument
from firedrake.cofunction import Cofunction
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace, DualSpace


def subspace(V, indices):
if len(indices) == 1:
W = V[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
else:
W = MixedFunctionSpace([V[i] for i in indices])
return W


class ExtractSubBlock(MultiFunction):

"""Extract a sub-block from a form."""
Expand All @@ -30,9 +41,11 @@ def indexed(self, o, child, multiindex):
indices = multiindex.indices()
if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices):
if len(indices) == 1:
return child.ufl_operands[indices[0]._value]
return child[indices[0]]
elif len(indices) == len(child.ufl_operands) and all(k == int(i) for k, i in enumerate(indices)):
return child
else:
return ListTensor(*(child.ufl_operands[i._value] for i in multiindex.indices()))
return ListTensor(*(child[i] for i in indices))
return self.expr(o, child, multiindex)

index_inliner = IndexInliner()
Expand All @@ -52,15 +65,22 @@ def split(self, form, argument_indices):
"""
args = form.arguments()
self._arg_cache = {}
self.blocks = dict(enumerate(argument_indices))
self.blocks = dict(enumerate(map(as_tuple, argument_indices)))
if len(args) == 0:
# Functional can't be split
return form
if all(len(a.function_space()) == 1 for a in args):
assert (len(idx) == 1 for idx in self.blocks.values())
assert (idx[0] == 0 for idx in self.blocks.values())
return form
# TODO find a way to distinguish empty Forms avoiding expand_derivatives
ksagiyam marked this conversation as resolved.
Show resolved Hide resolved
f = map_integrand_dags(self, form)
if expand_derivatives(f).empty():
# Get ZeroBaseForm with the right shape
f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(),
self.blocks[arg.number()]),
arg.number(), part=arg.part())
for arg in form.arguments()))
return f

expr = MultiFunction.reuse_if_untouched
Expand Down Expand Up @@ -98,76 +118,42 @@ def argument(self, o):
if o in self._arg_cache:
return self._arg_cache[o]

V_is = V.subfunctions
indices = self.blocks[o.number()]

# Only one index provided.
if isinstance(indices, int):
indices = (indices, )
W = subspace(V, indices)
a = Argument(W, o.number(), part=o.part())
a = (a, ) if len(W) == 1 else split(a)

if len(indices) == 1:
W = V_is[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
a = (Argument(W, o.number(), part=o.part()), )
else:
W = MixedFunctionSpace([V_is[i] for i in indices])
a = split(Argument(W, o.number(), part=o.part()))
args = []
for i in range(len(V_is)):
for i in range(len(V)):
if i in indices:
c = indices.index(i)
a_ = a[c]
if len(a_.ufl_shape) == 0:
args += [a_]
args.append(a_)
else:
args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)]
args.extend(a_[j] for j in numpy.ndindex(a_.ufl_shape))
else:
args += [Zero()
for j in numpy.ndindex(V_is[i].value_shape)]
args.extend(Zero() for j in numpy.ndindex(V[i].value_shape))
return self._arg_cache.setdefault(o, as_vector(args))

def cofunction(self, o):
V = o.function_space()

# Not on a mixed space, just return ourselves.
if len(V) == 1:
# Not on a mixed space, just return ourselves.
return o

# We only need the test space for Cofunction
# We only need the test space for Cofunction
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
indices = self.blocks[0]
V_is = V.subfunctions

# Only one index provided.
if isinstance(indices, int):
indices = (indices, )

# for two-forms, the cofunction should only
# be returned for the diagonal blocks, so
# if we are asked for an off-diagonal block
# then we return a zero form, analogously to
# the off components of arguments.
if len(self.blocks) == 2:
itest, itrial = self.blocks
on_diag = (itest == itrial)
else:
on_diag = True

# if we are on the diagonal, then return a Cofunction
# in the relevant subspace that points to the data in
# the full space. This means that the right hand side
# of the fieldsplit problem will be correct.
if on_diag:
if len(indices) == 1:
i = indices[0]
W = V_is[i]
W = DualSpace(W.mesh(), W.ufl_element())
c = Cofunction(W, val=o.subfunctions[i].dat)
else:
W = MixedFunctionSpace([V_is[i] for i in indices])
c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices))
if len(indices) == 1:
i = indices[0]
W = V[i]
W = DualSpace(W.mesh(), W.ufl_element())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the subspace function here too? It will need a switch adding for primal/dual but I think it would be nicer to use subspace for both.

I think it would also be nice to lift subspace out, possibly into one of the functionspace files. It's useful more generally - see for example line 678 in slate.py where we have the same if/else block to construct the subspace.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to extend FunctionSpace.__getitem__ to allow a list index return the subspace, since this is how sub-arrays are extracted in numpy.

Copy link
Contributor Author

@pbrubeck pbrubeck Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if statement can be avoided by using FunctionSpace.collapse()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually thinking exactly the same thing the other day.

It would be great to have it for Function too so that you can get a Mixed subfunction viewing the relevant subcomponents. ExtractSubBlock.cofunction would then basically just be calling cofunction[idxs].
There are definitely other places where that would be useful too, for example in the EnsembleFunction I've been working on here.

What does FunctionSpace.collapse() do? It just looks like it would return a copy of the FunctionSpace.

Copy link
Contributor Author

@pbrubeck pbrubeck Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ksagiyam what do you think about this __getitem__ idea? Currenty, indexing by a scalar returns IndexedProxyFunctionSpace, but indexing by a list of a single item would return the collapsed version of this space.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__getitem__ for Function is implemented in UFL so the changes would be needed there I think.

Copy link
Contributor Author

@pbrubeck pbrubeck Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not too worried for the Function case, that one seems more involved, as they are Expressions in the UFL sense, and extracting a component naturally means a scalar compoment, but here it seems that we want a MixedFunctionSpace (possibly vector-valued) component.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could also be added to sub or subfunctions, although they both also have issues.
sub explicitly says it is for bcs, and subfunctions is already a tuple not a method so it would probably have to become a local class with __getitem__ overridden so we could deal with single or multiple indices.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds slightly odd to me. Can you instead rewrite collapse() so that it would take a list of indices as an optional argument?

@JHopeCollins collapse() basically forgets parents. They do something nontrivial in ProxyFunctionSpace and in MixedFunctionSpace.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks.

It does strip the parent information, I don't think the name collapse would be where I would look for this functionality - it would be more intuitive in subfunctions or sub.

c = Cofunction(W, val=o.dat[i])
else:
c = ZeroBaseForm(o.arguments())

W = MixedFunctionSpace([V[i] for i in indices])
c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices))
return c


Expand Down Expand Up @@ -207,28 +193,15 @@ def split_form(form, diagonal=False):
args = form.arguments()
shape = tuple(len(a.function_space()) for a in args)
forms = []
rank = len(shape)
if diagonal:
assert len(shape) == 2
assert rank == 2
rank = 1
for idx in numpy.ndindex(shape):
if diagonal:
i, j = idx
if i != j:
continue
f = splitter.split(form, idx)

# does f actually contain anything?
if isinstance(f, Cofunction):
flen = 1
elif isinstance(f, FormSum):
flen = len(f.components())
elif isinstance(f, Form):
flen = len(f.integrals())
else:
raise ValueError(
"ExtractSubBlock.split should have returned an instance of "
"either Form, FormSum, or Cofunction")

if flen > 0:
if diagonal:
i, j = idx
if i != j:
continue
idx = (i, )
forms.append(SplitForm(indices=idx, form=f))
forms.append(SplitForm(indices=idx[:rank], form=f))
return tuple(forms)
28 changes: 17 additions & 11 deletions firedrake/matrix_free/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from firedrake.bcs import DirichletBC, EquationBCSplit
from firedrake.petsc import PETSc
from firedrake.utils import cached_property
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from ufl.form import ZeroBaseForm


__all__ = ("ImplicitMatrixContext", )
Expand Down Expand Up @@ -107,23 +110,22 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

# create functions from test and trial space to help
# with 1-form assembly
test_space, trial_space = [
a.arguments()[i].function_space() for i in (0, 1)
]
from firedrake import function, cofunction
test_space, trial_space = (
arg.function_space() for arg in a.arguments()
)
# Need a cofunction since y receives the assembled result of Ax
self._ystar = cofunction.Cofunction(test_space.dual())
self._y = function.Function(test_space)
self._x = function.Function(trial_space)
self._xstar = cofunction.Cofunction(trial_space.dual())
self._ystar = Cofunction(test_space.dual())
self._y = Function(test_space)
self._x = Function(trial_space)
self._xstar = Cofunction(trial_space.dual())

# These are temporary storage for holding the BC
# values during matvec application. _xbc is for
# the action and ._ybc is for transpose.
if len(self.bcs) > 0:
self._xbc = cofunction.Cofunction(trial_space.dual())
self._xbc = Cofunction(trial_space.dual())
if len(self.col_bcs) > 0:
self._ybc = cofunction.Cofunction(test_space.dual())
self._ybc = Cofunction(test_space.dual())

# Get size information from template vecs on test and trial spaces
trial_vec = trial_space.dof_dset.layout_vec
Expand All @@ -135,6 +137,11 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

self.action = action(self.a, self._x)
self.actionT = action(self.aT, self._y)
# TODO prevent action from returning empty Forms
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to do in this PR or at a later date? If it won't be part of this PR please can you open an issue for it to keep a record?

if self.action.empty():
self.action = ZeroBaseForm(self.a.arguments()[:-1])
if self.actionT.empty():
self.actionT = ZeroBaseForm(self.aT.arguments()[:-1])

# For assembling action(f, self._x)
self.bcs_action = []
Expand Down Expand Up @@ -170,7 +177,6 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

@cached_property
def _diagonal(self):
from firedrake import Cofunction
assert self.on_diag
return Cofunction(self._x.function_space().dual())

Expand Down
2 changes: 1 addition & 1 deletion firedrake/preconditioners/massinv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MassInvPC(AssembledPC):
context, keyed on ``"mu"``.
"""
def form(self, pc, test, trial):
_, bcs = super(MassInvPC, self).form(pc, test, trial)
_, bcs = super(MassInvPC, self).form(pc)

appctx = self.get_appctx(pc)
mu = appctx.get("mu", 1.0)
Expand Down
Loading
Loading