From 35638dfcbdf7ab2a4b1dcf575a75c76ba358435b Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 3 Jan 2025 11:26:48 -0600 Subject: [PATCH] ImplicitMatrixContext: handle empty action --- firedrake/assemble.py | 13 ++++++++----- firedrake/matrix_free/operators.py | 27 ++++++++++++++++----------- firedrake/slate/slate.py | 8 ++++++++ 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 60c934b6c7..88d00c6db8 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -311,7 +311,8 @@ def __init__(self, zero_bc_nodes=False, diagonal=False, weight=1.0, - allocation_integral_types=None): + allocation_integral_types=None, + needs_zeroing=False): super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters) self._mat_type = mat_type self._sub_mat_type = sub_mat_type @@ -321,6 +322,7 @@ def __init__(self, self._diagonal = diagonal self._weight = weight self._allocation_integral_types = allocation_integral_types + assert not needs_zeroing def allocate(self): rank = len(self._form.arguments()) @@ -1127,7 +1129,8 @@ def _apply_bc(self, tensor, bc): pass def _check_tensor(self, tensor): - pass + if not isinstance(tensor, op2.Global): + raise TypeError(f"Expecting a op2.Global, got {tensor!r}.") @staticmethod def _as_pyop2_type(tensor, indices=None): @@ -1143,7 +1146,7 @@ class OneFormAssembler(ParloopFormAssembler): Parameters ---------- - form : ufl.Form or slate.TensorBasehe + form : ufl.Form or slate.TensorBase 1-form. Notes @@ -1189,8 +1192,8 @@ def _apply_bc(self, tensor, bc): self._apply_dirichlet_bc(tensor, bc) elif isinstance(bc, EquationBCSplit): bc.zero(tensor) - type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + get_assembler(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) else: raise AssertionError diff --git a/firedrake/matrix_free/operators.py b/firedrake/matrix_free/operators.py index 3ee448730e..29abca838a 100644 --- a/firedrake/matrix_free/operators.py +++ b/firedrake/matrix_free/operators.py @@ -10,6 +10,8 @@ from firedrake.bcs import DirichletBC, EquationBCSplit from firedrake.petsc import PETSc from firedrake.utils import cached_property +from firedrake.function import Function +from firedrake.cofunction import Cofunction __all__ = ("ImplicitMatrixContext", ) @@ -107,23 +109,22 @@ def __init__(self, a, row_bcs=[], col_bcs=[], # create functions from test and trial space to help # with 1-form assembly - test_space, trial_space = [ - a.arguments()[i].function_space() for i in (0, 1) - ] - from firedrake import function, cofunction + test_space, trial_space = ( + arg.function_space() for arg in a.arguments() + ) # Need a cofunction since y receives the assembled result of Ax - self._ystar = cofunction.Cofunction(test_space.dual()) - self._y = function.Function(test_space) - self._x = function.Function(trial_space) - self._xstar = cofunction.Cofunction(trial_space.dual()) + self._ystar = Cofunction(test_space.dual()) + self._y = Function(test_space) + self._x = Function(trial_space) + self._xstar = Cofunction(trial_space.dual()) # These are temporary storage for holding the BC # values during matvec application. _xbc is for # the action and ._ybc is for transpose. if len(self.bcs) > 0: - self._xbc = cofunction.Cofunction(trial_space.dual()) + self._xbc = Cofunction(trial_space.dual()) if len(self.col_bcs) > 0: - self._ybc = cofunction.Cofunction(test_space.dual()) + self._ybc = Cofunction(test_space.dual()) # Get size information from template vecs on test and trial spaces trial_vec = trial_space.dof_dset.layout_vec @@ -135,6 +136,11 @@ def __init__(self, a, row_bcs=[], col_bcs=[], self.action = action(self.a, self._x) self.actionT = action(self.aT, self._y) + # TODO prevent action from returning empty Forms + if self.action.empty(): + self.action = Cofunction(test_space.dual()) + if self.actionT.empty(): + self.action = Cofunction(trial_space.dual()) # For assembling action(f, self._x) self.bcs_action = [] @@ -170,7 +176,6 @@ def __init__(self, a, row_bcs=[], col_bcs=[], @cached_property def _diagonal(self): - from firedrake import Cofunction assert self.on_diag return Cofunction(self._x.function_space().dual()) diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index fd9535c31a..fd7411b72f 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -293,6 +293,10 @@ def solve(self, B, decomposition=None): """ return Solve(self, B, decomposition=decomposition) + def empty(self): + """Returns whether the form associated with the tensor is empty.""" + return False + @cached_property def blocks(self): """Returns an object containing the blocks of the tensor defined @@ -938,6 +942,10 @@ def subdomain_data(self): """ return self.form.subdomain_data() + def empty(self): + """Returns whether the form associated with the tensor is empty.""" + return self.form.empty() + def _output_string(self, prec=None): """Creates a string representation of the tensor.""" return ["S", "V", "M"][self.rank] + "_%d" % self.id