Skip to content

Commit

Permalink
ImplicitMatrixContext: handle empty action
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 3, 2025
1 parent 7f40504 commit 35638df
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
13 changes: 8 additions & 5 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ def __init__(self,
zero_bc_nodes=False,
diagonal=False,
weight=1.0,
allocation_integral_types=None):
allocation_integral_types=None,
needs_zeroing=False):
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters)
self._mat_type = mat_type
self._sub_mat_type = sub_mat_type
Expand All @@ -321,6 +322,7 @@ def __init__(self,
self._diagonal = diagonal
self._weight = weight
self._allocation_integral_types = allocation_integral_types
assert not needs_zeroing

def allocate(self):
rank = len(self._form.arguments())
Expand Down Expand Up @@ -1127,7 +1129,8 @@ def _apply_bc(self, tensor, bc):
pass

def _check_tensor(self, tensor):
pass
if not isinstance(tensor, op2.Global):
raise TypeError(f"Expecting a op2.Global, got {tensor!r}.")

@staticmethod
def _as_pyop2_type(tensor, indices=None):
Expand All @@ -1143,7 +1146,7 @@ class OneFormAssembler(ParloopFormAssembler):
Parameters
----------
form : ufl.Form or slate.TensorBasehe
form : ufl.Form or slate.TensorBase
1-form.
Notes
Expand Down Expand Up @@ -1189,8 +1192,8 @@ def _apply_bc(self, tensor, bc):
self._apply_dirichlet_bc(tensor, bc)
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
get_assembler(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
else:
raise AssertionError

Expand Down
27 changes: 16 additions & 11 deletions firedrake/matrix_free/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from firedrake.bcs import DirichletBC, EquationBCSplit
from firedrake.petsc import PETSc
from firedrake.utils import cached_property
from firedrake.function import Function
from firedrake.cofunction import Cofunction


__all__ = ("ImplicitMatrixContext", )
Expand Down Expand Up @@ -107,23 +109,22 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

# create functions from test and trial space to help
# with 1-form assembly
test_space, trial_space = [
a.arguments()[i].function_space() for i in (0, 1)
]
from firedrake import function, cofunction
test_space, trial_space = (
arg.function_space() for arg in a.arguments()
)
# Need a cofunction since y receives the assembled result of Ax
self._ystar = cofunction.Cofunction(test_space.dual())
self._y = function.Function(test_space)
self._x = function.Function(trial_space)
self._xstar = cofunction.Cofunction(trial_space.dual())
self._ystar = Cofunction(test_space.dual())
self._y = Function(test_space)
self._x = Function(trial_space)
self._xstar = Cofunction(trial_space.dual())

# These are temporary storage for holding the BC
# values during matvec application. _xbc is for
# the action and ._ybc is for transpose.
if len(self.bcs) > 0:
self._xbc = cofunction.Cofunction(trial_space.dual())
self._xbc = Cofunction(trial_space.dual())
if len(self.col_bcs) > 0:
self._ybc = cofunction.Cofunction(test_space.dual())
self._ybc = Cofunction(test_space.dual())

# Get size information from template vecs on test and trial spaces
trial_vec = trial_space.dof_dset.layout_vec
Expand All @@ -135,6 +136,11 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

self.action = action(self.a, self._x)
self.actionT = action(self.aT, self._y)
# TODO prevent action from returning empty Forms
if self.action.empty():
self.action = Cofunction(test_space.dual())
if self.actionT.empty():
self.action = Cofunction(trial_space.dual())

# For assembling action(f, self._x)
self.bcs_action = []
Expand Down Expand Up @@ -170,7 +176,6 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

@cached_property
def _diagonal(self):
from firedrake import Cofunction
assert self.on_diag
return Cofunction(self._x.function_space().dual())

Expand Down
8 changes: 8 additions & 0 deletions firedrake/slate/slate.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ def solve(self, B, decomposition=None):
"""
return Solve(self, B, decomposition=decomposition)

def empty(self):
"""Returns whether the form associated with the tensor is empty."""
return False

@cached_property
def blocks(self):
"""Returns an object containing the blocks of the tensor defined
Expand Down Expand Up @@ -938,6 +942,10 @@ def subdomain_data(self):
"""
return self.form.subdomain_data()

def empty(self):
"""Returns whether the form associated with the tensor is empty."""
return self.form.empty()

def _output_string(self, prec=None):
"""Creates a string representation of the tensor."""
return ["S", "V", "M"][self.rank] + "_%d" % self.id
Expand Down

0 comments on commit 35638df

Please sign in to comment.