Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fieldsplit: replace empty Forms with ZeroBaseForm #3947

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ jobs:
--install defcon \
--install gadopt \
--install asQ \
--package-branch ufl pbrubeck/simplify-indexed \
|| (cat firedrake-install.log && /bin/false)
- name: Install test dependencies
run: |
Expand Down
13 changes: 8 additions & 5 deletions firedrake/adjoint_utils/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations
from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock
from firedrake.ufl_expr import derivative, adjoint
from ufl import replace


Expand All @@ -11,7 +12,6 @@ def _ad_annotate_init(init):
@no_annotations
@wraps(init)
def wrapper(self, *args, **kwargs):
from firedrake import derivative, adjoint, TrialFunction
init(self, *args, **kwargs)
self._ad_F = self.F
self._ad_u = self.u_restrict
Expand All @@ -20,10 +20,13 @@ def wrapper(self, *args, **kwargs):
try:
# Some forms (e.g. SLATE tensors) are not currently
# differentiable.
dFdu = derivative(self.F,
self.u_restrict,
TrialFunction(self.u_restrict.function_space()))
self._ad_adj_F = adjoint(dFdu)
dFdu = derivative(self.F, self.u_restrict)
try:
self._ad_adj_F = adjoint(dFdu)
except ValueError:
# Try again without expanding derivatives,
# as dFdu might have been simplied to an empty Form
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
except (TypeError, NotImplementedError):
self._ad_adj_F = None
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
Expand Down
22 changes: 13 additions & 9 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,10 +577,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
@staticmethod
def update_tensor(assembled_base_form, tensor):
if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)):
assembled_base_form.dat.copy(tensor.dat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.dat.zero()
else:
assembled_base_form.dat.copy(tensor.dat)
elif isinstance(tensor, matrix.MatrixBase):
# Uses the PETSc copy method.
assembled_base_form.petscmat.copy(tensor.petscmat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.petscmat.zeroEntries()
else:
assembled_base_form.petscmat.copy(tensor.petscmat)
else:
raise NotImplementedError("Cannot update tensor of type %s" % type(tensor))

Expand Down Expand Up @@ -1138,7 +1143,7 @@ class OneFormAssembler(ParloopFormAssembler):

Parameters
----------
form : ufl.Form or slate.TensorBasehe
form : ufl.Form or slate.TensorBase
1-form.

Notes
Expand Down Expand Up @@ -2127,14 +2132,13 @@ def iter_active_coefficients(form, kinfo):

@staticmethod
def iter_constants(form, kinfo):
"""Yield the form constants"""
"""Yield the form constants referenced in ``kinfo``."""
if isinstance(form, slate.TensorBase):
for const in form.constants():
yield const
all_constants = form.constants()
else:
all_constants = extract_firedrake_constants(form)
for constant_index in kinfo.constant_numbers:
yield all_constants[constant_index]
for constant_index in kinfo.constant_numbers:
yield all_constants[constant_index]

@staticmethod
def index_function_spaces(form, indices):
Expand Down
8 changes: 4 additions & 4 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,10 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
return
rank = len(self.f.arguments())
splitter = ExtractSubBlock()
if rank == 1:
form = splitter.split(self.f, argument_indices=(row_field, ))
elif rank == 2:
form = splitter.split(self.f, argument_indices=(row_field, col_field))
form = splitter.split(self.f, argument_indices=(row_field, col_field)[:rank])
if isinstance(form, ufl.ZeroBaseForm) or form.empty():
# form is empty, do nothing
return
if u is not None:
form = firedrake.replace(form, {self.u: u})
if action_x is not None:
Expand Down
116 changes: 35 additions & 81 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
import numpy
import collections

from ufl import as_vector, FormSum, Form, split
from ufl import as_vector, split
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms import expand_derivatives
from ufl.corealg.map_dag import MultiFunction, map_expr_dags

from pyop2 import MixedDat
from pyop2.utils import as_tuple

from firedrake.petsc import PETSc
from firedrake.ufl_expr import Argument
from firedrake.cofunction import Cofunction
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace, DualSpace


class ExtractSubBlock(MultiFunction):
Expand All @@ -30,13 +30,19 @@ def indexed(self, o, child, multiindex):
indices = multiindex.indices()
if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices):
if len(indices) == 1:
return child.ufl_operands[indices[0]._value]
return child[indices[0]]
elif len(indices) == len(child.ufl_operands) and all(k == int(i) for k, i in enumerate(indices)):
return child
else:
return ListTensor(*(child.ufl_operands[i._value] for i in multiindex.indices()))
return ListTensor(*(child[i] for i in indices))
return self.expr(o, child, multiindex)

index_inliner = IndexInliner()

def _subspace_argument(self, a):
return type(a)(a.function_space()[self.blocks[a.number()]].collapse(),
a.number(), part=a.part())

@PETSc.Log.EventDecorator()
def split(self, form, argument_indices):
"""Split a form.
Expand All @@ -52,15 +58,19 @@ def split(self, form, argument_indices):
"""
args = form.arguments()
self._arg_cache = {}
self.blocks = dict(enumerate(argument_indices))
self.blocks = dict(enumerate(map(as_tuple, argument_indices)))
if len(args) == 0:
# Functional can't be split
return form
if all(len(a.function_space()) == 1 for a in args):
assert (len(idx) == 1 for idx in self.blocks.values())
assert (idx[0] == 0 for idx in self.blocks.values())
return form
# TODO find a way to distinguish empty Forms avoiding expand_derivatives
ksagiyam marked this conversation as resolved.
Show resolved Hide resolved
f = map_integrand_dags(self, form)
if expand_derivatives(f).empty():
# Get ZeroBaseForm with the right shape
f = ZeroBaseForm(tuple(map(self._subspace_argument, form.arguments())))
return f

expr = MultiFunction.reuse_if_untouched
Expand Down Expand Up @@ -98,77 +108,34 @@ def argument(self, o):
if o in self._arg_cache:
return self._arg_cache[o]

V_is = V.subfunctions
indices = self.blocks[o.number()]

# Only one index provided.
if isinstance(indices, int):
indices = (indices, )
a = self._subspace_argument(o)
asplit = (a, ) if len(indices) == 1 else split(a)

if len(indices) == 1:
W = V_is[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
a = (Argument(W, o.number(), part=o.part()), )
else:
W = MixedFunctionSpace([V_is[i] for i in indices])
a = split(Argument(W, o.number(), part=o.part()))
args = []
for i in range(len(V_is)):
for i in range(len(V)):
if i in indices:
c = indices.index(i)
a_ = a[c]
if len(a_.ufl_shape) == 0:
args += [a_]
else:
args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)]
asub = asplit[indices.index(i)]
args.extend(asub[j] for j in numpy.ndindex(asub.ufl_shape))
else:
args += [Zero()
for j in numpy.ndindex(V_is[i].value_shape)]
args.extend(Zero() for j in numpy.ndindex(V[i].value_shape))
return self._arg_cache.setdefault(o, as_vector(args))

def cofunction(self, o):
V = o.function_space()

# Not on a mixed space, just return ourselves.
if len(V) == 1:
# Not on a mixed space, just return ourselves.
return o

# We only need the test space for Cofunction
indices = self.blocks[0]
V_is = V.subfunctions

# Only one index provided.
if isinstance(indices, int):
indices = (indices, )

# for two-forms, the cofunction should only
# be returned for the diagonal blocks, so
# if we are asked for an off-diagonal block
# then we return a zero form, analogously to
# the off components of arguments.
if len(self.blocks) == 2:
itest, itrial = self.blocks
on_diag = (itest == itrial)
W = V[indices].collapse()
if len(W) == 1:
return Cofunction(W, val=o.dat[indices[0]])
else:
on_diag = True

# if we are on the diagonal, then return a Cofunction
# in the relevant subspace that points to the data in
# the full space. This means that the right hand side
# of the fieldsplit problem will be correct.
if on_diag:
if len(indices) == 1:
i = indices[0]
W = V_is[i]
W = DualSpace(W.mesh(), W.ufl_element())
c = Cofunction(W, val=o.subfunctions[i].dat)
else:
W = MixedFunctionSpace([V_is[i] for i in indices])
c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices))
else:
c = ZeroBaseForm(o.arguments())

return c
return Cofunction(W, val=MixedDat(o.dat[i] for i in indices))


SplitForm = collections.namedtuple("SplitForm", ["indices", "form"])
Expand Down Expand Up @@ -207,28 +174,15 @@ def split_form(form, diagonal=False):
args = form.arguments()
shape = tuple(len(a.function_space()) for a in args)
forms = []
rank = len(shape)
if diagonal:
assert len(shape) == 2
assert rank == 2
rank = 1
for idx in numpy.ndindex(shape):
if diagonal:
i, j = idx
if i != j:
continue
f = splitter.split(form, idx)

# does f actually contain anything?
if isinstance(f, Cofunction):
flen = 1
elif isinstance(f, FormSum):
flen = len(f.components())
elif isinstance(f, Form):
flen = len(f.integrals())
else:
raise ValueError(
"ExtractSubBlock.split should have returned an instance of "
"either Form, FormSum, or Cofunction")

if flen > 0:
if diagonal:
i, j = idx
if i != j:
continue
idx = (i, )
forms.append(SplitForm(indices=idx, form=f))
forms.append(SplitForm(indices=idx[:rank], form=f))
return tuple(forms)
15 changes: 13 additions & 2 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,14 @@ def __iter__(self):
return iter(self.subfunctions)

def __getitem__(self, i):
from firedrake.functionspace import MixedFunctionSpace
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should cache the mixed subspaces like we do for the non-mixed subspaces. Otherwise we get a "different" space every time we call with the same indices.

if isinstance(i, (tuple, list)):
# Return a subspace
if len(i) == 1:
return self[i[0]]
else:
return MixedFunctionSpace([self[isub] for isub in i])

return self.subfunctions[i]

def __mul__(self, other):
Expand Down Expand Up @@ -944,6 +952,9 @@ def __hash__(self):
def local_to_global_map(self, bcs, lgmap=None):
return lgmap or self.dof_dset.lgmap

def collapse(self):
return type(self)(self.function_space.collapse(), boundary_set=self.boundary_set)


class MixedFunctionSpace(object):
r"""A function space on a mixed finite element.
Expand Down Expand Up @@ -1236,16 +1247,16 @@ class ProxyRestrictedFunctionSpace(RestrictedFunctionSpace):
r"""A :class:`RestrictedFunctionSpace` that one can attach extra properties to.

:arg function_space: The function space to be restricted.
:kwarg name: The name of the restricted function space.
:kwarg boundary_set: The boundary domains on which boundary conditions will
be specified
:kwarg name: The name of the restricted function space.

.. warning::

Users should not build a :class:`ProxyRestrictedFunctionSpace` directly,
it is mostly used as an internal implementation detail.
"""
def __new__(cls, function_space, name=None, boundary_set=frozenset()):
def __new__(cls, function_space, boundary_set=frozenset(), name=None):
topology = function_space._mesh.topology
self = super(ProxyRestrictedFunctionSpace, cls).__new__(cls)
if function_space._mesh is not topology:
Expand Down
Loading
Loading