Skip to content

Commit

Permalink
Replace empty Jacobians with ZeroBaseForm
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 2, 2025
1 parent 2286596 commit e0c7ba1
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 46 deletions.
8 changes: 3 additions & 5 deletions firedrake/adjoint_utils/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations
from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock
from firedrake.ufl_expr import derivative, adjoint
from ufl import replace


Expand All @@ -11,7 +12,6 @@ def _ad_annotate_init(init):
@no_annotations
@wraps(init)
def wrapper(self, *args, **kwargs):
from firedrake import derivative, adjoint, TrialFunction
init(self, *args, **kwargs)
self._ad_F = self.F
self._ad_u = self.u_restrict
Expand All @@ -20,10 +20,8 @@ def wrapper(self, *args, **kwargs):
try:
# Some forms (e.g. SLATE tensors) are not currently
# differentiable.
dFdu = derivative(self.F,
self.u_restrict,
TrialFunction(self.u_restrict.function_space()))
self._ad_adj_F = adjoint(dFdu)
dFdu = derivative(self.F, self.u_restrict)
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
except (TypeError, NotImplementedError):
self._ad_adj_F = None
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
Expand Down
11 changes: 8 additions & 3 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,10 +577,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
@staticmethod
def update_tensor(assembled_base_form, tensor):
if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)):
assembled_base_form.dat.copy(tensor.dat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.dat.zero()
else:
assembled_base_form.dat.copy(tensor.dat)
elif isinstance(tensor, matrix.MatrixBase):
# Uses the PETSc copy method.
assembled_base_form.petscmat.copy(tensor.petscmat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.petscmat.zero()
else:
assembled_base_form.petscmat.copy(tensor.petscmat)
else:
raise NotImplementedError("Cannot update tensor of type %s" % type(tensor))

Expand Down
44 changes: 23 additions & 21 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@
import numpy
import collections

from ufl import as_vector
from ufl import as_vector, split
from ufl.classes import Zero, FixedIndex, ListTensor
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.corealg.map_dag import MultiFunction, map_expr_dags

from firedrake.petsc import PETSc
from firedrake.ufl_expr import Argument
from firedrake.functionspace import MixedFunctionSpace


def subspace(V, indices):
if len(indices) == 1:
W = V[indices[0]].collapse()
else:
W = MixedFunctionSpace([V[i] for i in indices])
return W


class ExtractSubBlock(MultiFunction):
Expand All @@ -26,9 +35,11 @@ def indexed(self, o, child, multiindex):
indices = multiindex.indices()
if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices):
if len(indices) == 1:
return child.ufl_operands[indices[0]._value]
return child[indices[0]]
elif len(indices) == len(child.ufl_operands) and all(k == int(i) for k, i in enumerate(indices)):
return child
else:
return ListTensor(*(child.ufl_operands[i._value] for i in multiindex.indices()))
return ListTensor(*(child[i] for i in indices))
return self.expr(o, child, multiindex)

index_inliner = IndexInliner()
Expand Down Expand Up @@ -85,8 +96,6 @@ def coefficient_derivative(self, o, expr, coefficients, arguments, cds):

@PETSc.Log.EventDecorator()
def argument(self, o):
from ufl import split
from firedrake import MixedFunctionSpace, FunctionSpace
V = o.function_space()
if len(V) == 1:
# Not on a mixed space, just return ourselves.
Expand All @@ -95,36 +104,29 @@ def argument(self, o):
if o in self._arg_cache:
return self._arg_cache[o]

V_is = V.subfunctions
indices = self.blocks[o.number()]

try:
indices = tuple(indices)
nidx = len(indices)
except TypeError:
# Only one index provided.
indices = (indices, )
nidx = 1

if nidx == 1:
W = V_is[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
a = (Argument(W, o.number(), part=o.part()), )
else:
W = MixedFunctionSpace([V_is[i] for i in indices])
a = split(Argument(W, o.number(), part=o.part()))
W = subspace(V, indices)
a = Argument(W, o.number(), part=o.part())
a = (a, ) if len(indices) == 1 else split(a)

args = []
for i in range(len(V_is)):
for i in range(len(V)):
if i in indices:
c = indices.index(i)
a_ = a[c]
if len(a_.ufl_shape) == 0:
args += [a_]
args.append(a_)
else:
args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)]
args.extend(a_[j] for j in numpy.ndindex(a_.ufl_shape))
else:
args += [Zero()
for j in numpy.ndindex(V_is[i].value_shape)]
args.extend(Zero() for j in numpy.ndindex(V[i].value_shape))
return self._arg_cache.setdefault(o, as_vector(args))


Expand Down Expand Up @@ -168,7 +170,7 @@ def split_form(form, diagonal=False):
assert len(shape) == 2
for idx in numpy.ndindex(shape):
f = splitter.split(form, idx)
if len(f.integrals()) > 0:
if not f.empty():
if diagonal:
i, j = idx
if i != j:
Expand Down
16 changes: 12 additions & 4 deletions firedrake/matrix_free/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import numpy

from pyop2.mpi import internal_comm, temp_internal_comm
from firedrake.ufl_expr import adjoint, action
from firedrake.formmanipulation import ExtractSubBlock
from firedrake.ufl_expr import adjoint, action, TestFunction, TrialFunction
from firedrake.formmanipulation import ExtractSubBlock, subspace
from firedrake.bcs import DirichletBC, EquationBCSplit
from firedrake.petsc import PETSc
from firedrake.utils import cached_property
from ufl.form import ZeroBaseForm
from ufl.algorithms import expand_derivatives


__all__ = ("ImplicitMatrixContext", )
Expand Down Expand Up @@ -383,8 +385,14 @@ def createSubMatrix(self, mat, row_is, col_is, target=None):
splitter = ExtractSubBlock()
asub = splitter.split(self.a,
argument_indices=(row_inds, col_inds))
Wrow = asub.arguments()[0].function_space()
Wcol = asub.arguments()[1].function_space()
asub = expand_derivatives(asub)
if asub.empty():
Wrow = subspace(self.a.arguments()[0].function_space(), row_inds)
Wcol = subspace(self.a.arguments()[1].function_space(), col_inds)
asub = ZeroBaseForm((TestFunction(Wrow), TrialFunction(Wcol)))
else:
Wrow = asub.arguments()[0].function_space()
Wcol = asub.arguments()[1].function_space()

row_bcs = []
col_bcs = []
Expand Down
2 changes: 1 addition & 1 deletion firedrake/preconditioners/massinv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MassInvPC(AssembledPC):
context, keyed on ``"mu"``.
"""
def form(self, pc, test, trial):
_, bcs = super(MassInvPC, self).form(pc, test, trial)
_, bcs = super(MassInvPC, self).form(pc)

appctx = self.get_appctx(pc)
mu = appctx.get("mu", 1.0)
Expand Down
14 changes: 9 additions & 5 deletions firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from firedrake.formmanipulation import ExtractSubBlock
from firedrake.utils import cached_property
from firedrake.logging import warning
from firedrake.ufl_expr import TestFunction, TrialFunction
from ufl.form import ZeroBaseForm


def _make_reasons(reasons):
return dict([(getattr(reasons, r), r)
for r in dir(reasons) if not r.startswith('_')])
return {getattr(reasons, r): r
for r in dir(reasons) if not r.startswith('_')}


KSPReasons = _make_reasons(PETSc.KSP.ConvergedReason())
Expand Down Expand Up @@ -333,7 +335,7 @@ def split(self, fields):
# Split it apart to shove in the form.
subsplit = split(subu)
# Permutation from field indexing to indexing of pieces
field_renumbering = dict([f, i] for i, f in enumerate(field))
field_renumbering = {f: i for i, f in enumerate(field)}
vec = []
for i, u in enumerate(us):
if i in field:
Expand All @@ -344,8 +346,7 @@ def split(self, fields):
if u.ufl_shape == ():
vec.append(u)
else:
for idx in numpy.ndindex(u.ufl_shape):
vec.append(u[idx])
vec.extend(u[idx] for idx in numpy.ndindex(u.ufl_shape))

# So now we have a new representation for the solution
# vector in the old problem. For the fields we're going
Expand All @@ -359,6 +360,9 @@ def split(self, fields):
u = as_vector(vec)
F = replace(F, {problem.u_restrict: u})
J = replace(J, {problem.u_restrict: u})
if J.empty():
# Handle zero Jacobian
J = ZeroBaseForm((TestFunction(V), TrialFunction(V)))
if problem.Jp is not None:
Jp = splitter.split(problem.Jp, argument_indices=(field, field))
Jp = replace(Jp, {problem.u_restrict: u})
Expand Down
11 changes: 4 additions & 7 deletions tests/firedrake/slate/test_assemble_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,10 @@ def test_matrix_subblocks(mesh):
refs = dict(split_form(A.form))
_A = A.blocks
for x, y in indices:
ref = assemble(refs[x, y]).M.values
block = _A[x, y]
assert np.allclose(assemble(block).M.values, ref, rtol=1e-14)
if not block.form.empty():
ref = assemble(refs[x, y]).M.values
assert np.allclose(assemble(block).M.values, ref, rtol=1e-14)

# Mixed blocks
A0101 = _A[:2, :2]
Expand All @@ -267,17 +268,13 @@ def test_matrix_subblocks(mesh):
A0101_10 = _A0101[1, 0]
A1212_00 = _A1212[0, 0]
A1212_11 = _A1212[1, 1]
A1212_01 = _A1212[0, 1]
A1212_10 = _A1212[1, 0]

items = [(A0101_00, refs[(0, 0)]),
(A0101_11, refs[(1, 1)]),
(A0101_01, refs[(0, 1)]),
(A0101_10, refs[(1, 0)]),
(A1212_00, refs[(1, 1)]),
(A1212_11, refs[(2, 2)]),
(A1212_01, refs[(1, 2)]),
(A1212_10, refs[(2, 1)])]
(A1212_11, refs[(2, 2)])]

# Test assembly of blocks of mixed blocks
for tensor, form in items:
Expand Down

0 comments on commit e0c7ba1

Please sign in to comment.