diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 60c934b6c7..88d00c6db8 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -311,7 +311,8 @@ def __init__(self, zero_bc_nodes=False, diagonal=False, weight=1.0, - allocation_integral_types=None): + allocation_integral_types=None, + needs_zeroing=False): super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters) self._mat_type = mat_type self._sub_mat_type = sub_mat_type @@ -321,6 +322,7 @@ def __init__(self, self._diagonal = diagonal self._weight = weight self._allocation_integral_types = allocation_integral_types + assert not needs_zeroing def allocate(self): rank = len(self._form.arguments()) @@ -1127,7 +1129,8 @@ def _apply_bc(self, tensor, bc): pass def _check_tensor(self, tensor): - pass + if not isinstance(tensor, op2.Global): + raise TypeError(f"Expecting a op2.Global, got {tensor!r}.") @staticmethod def _as_pyop2_type(tensor, indices=None): @@ -1143,7 +1146,7 @@ class OneFormAssembler(ParloopFormAssembler): Parameters ---------- - form : ufl.Form or slate.TensorBasehe + form : ufl.Form or slate.TensorBase 1-form. Notes @@ -1189,8 +1192,8 @@ def _apply_bc(self, tensor, bc): self._apply_dirichlet_bc(tensor, bc) elif isinstance(bc, EquationBCSplit): bc.zero(tensor) - type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + get_assembler(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) else: raise AssertionError diff --git a/firedrake/matrix_free/operators.py b/firedrake/matrix_free/operators.py index 3ee448730e..29abca838a 100644 --- a/firedrake/matrix_free/operators.py +++ b/firedrake/matrix_free/operators.py @@ -10,6 +10,8 @@ from firedrake.bcs import DirichletBC, EquationBCSplit from firedrake.petsc import PETSc from firedrake.utils import cached_property +from firedrake.function import Function +from firedrake.cofunction import Cofunction __all__ = ("ImplicitMatrixContext", ) @@ -107,23 +109,22 @@ def __init__(self, a, row_bcs=[], col_bcs=[], # create functions from test and trial space to help # with 1-form assembly - test_space, trial_space = [ - a.arguments()[i].function_space() for i in (0, 1) - ] - from firedrake import function, cofunction + test_space, trial_space = ( + arg.function_space() for arg in a.arguments() + ) # Need a cofunction since y receives the assembled result of Ax - self._ystar = cofunction.Cofunction(test_space.dual()) - self._y = function.Function(test_space) - self._x = function.Function(trial_space) - self._xstar = cofunction.Cofunction(trial_space.dual()) + self._ystar = Cofunction(test_space.dual()) + self._y = Function(test_space) + self._x = Function(trial_space) + self._xstar = Cofunction(trial_space.dual()) # These are temporary storage for holding the BC # values during matvec application. _xbc is for # the action and ._ybc is for transpose. if len(self.bcs) > 0: - self._xbc = cofunction.Cofunction(trial_space.dual()) + self._xbc = Cofunction(trial_space.dual()) if len(self.col_bcs) > 0: - self._ybc = cofunction.Cofunction(test_space.dual()) + self._ybc = Cofunction(test_space.dual()) # Get size information from template vecs on test and trial spaces trial_vec = trial_space.dof_dset.layout_vec @@ -135,6 +136,11 @@ def __init__(self, a, row_bcs=[], col_bcs=[], self.action = action(self.a, self._x) self.actionT = action(self.aT, self._y) + # TODO prevent action from returning empty Forms + if self.action.empty(): + self.action = Cofunction(test_space.dual()) + if self.actionT.empty(): + self.action = Cofunction(trial_space.dual()) # For assembling action(f, self._x) self.bcs_action = [] @@ -170,7 +176,6 @@ def __init__(self, a, row_bcs=[], col_bcs=[], @cached_property def _diagonal(self): - from firedrake import Cofunction assert self.on_diag return Cofunction(self._x.function_space().dual())