Skip to content

Commit

Permalink
TransferManager: allow partial custom transfers (#3796)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck authored Oct 16, 2024
1 parent cb93d10 commit e85bdb2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
14 changes: 6 additions & 8 deletions firedrake/mg/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import ufl
import finat.ufl
import weakref
from functools import reduce
from enum import IntEnum
from operator import and_
from firedrake.petsc import PETSc
from firedrake.embedding import get_embedding_dg_element

Expand Down Expand Up @@ -69,18 +67,18 @@ def __init__(self, *, native_transfers=None, use_averaging=True):
self.use_averaging = use_averaging
self.caches = {}

def is_native(self, element):
def is_native(self, element, op):
if element in self.native_transfers.keys():
return True
return self.native_transfers[element][op] is not None
if isinstance(element.cell, ufl.TensorProductCell) and len(element.sub_elements) > 0:
return reduce(and_, map(self.is_native, element.sub_elements))
return all(self.is_native(e, op) for e in element.sub_elements)
return (element.family() in native_families) and not (element.variant() in non_native_variants)

def _native_transfer(self, element, op):
try:
return self.native_transfers[element][op]
except KeyError:
if self.is_native(element):
if self.is_native(element, op):
ops = firedrake.prolong, firedrake.restrict, firedrake.inject
return self.native_transfers.setdefault(element, ops)[op]
return None
Expand Down Expand Up @@ -248,7 +246,7 @@ def op(self, source, target, transfer_op):
if not self.requires_transfer(source_element, transfer_op, source, target):
return

if self.is_native(source_element) and self.is_native(target_element):
if all(self.is_native(e, transfer_op) for e in (source_element, target_element)):
self._native_transfer(source_element, transfer_op)(source, target)
elif type(source_element) is finat.ufl.MixedElement:
assert type(target_element) is finat.ufl.MixedElement
Expand Down Expand Up @@ -313,7 +311,7 @@ def restrict(self, source, target):
if not self.requires_transfer(source_element, Op.RESTRICT, source, target):
return

if self.is_native(source_element) and self.is_native(target_element):
if all(self.is_native(e, Op.RESTRICT) for e in (source_element, target_element)):
self._native_transfer(source_element, Op.RESTRICT)(source, target)
elif type(source_element) is finat.ufl.MixedElement:
assert type(target_element) is finat.ufl.MixedElement
Expand Down
10 changes: 8 additions & 2 deletions tests/multigrid/test_custom_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def prolong_Q(fine, coarse):
assert count_Q == -2


def test_custom_transfer_setting():
@pytest.mark.parametrize("mode", ("full", pytest.param("partial", marks=pytest.mark.skipcomplexnoslate)))
def test_custom_transfer_setting(mode):
mesh = UnitIntervalMesh(2)
mh = MeshHierarchy(mesh, 1)
mesh = mh[-1]
Expand All @@ -194,7 +195,12 @@ def myprolong(coarse, fine):
options = {"ksp_type": "preonly",
"pc_type": "mg"}

transfer = TransferManager(native_transfers={V.ufl_element(): (myprolong, restrict, inject)})
if mode == "partial":
transfer_ops = (myprolong, None, None)
else:
transfer_ops = (myprolong, restrict, inject)

transfer = TransferManager(native_transfers={V.ufl_element(): transfer_ops})
problem = LinearVariationalProblem(a, L, uh)
solver = LinearVariationalSolver(problem, solver_parameters=options)
solver.set_transfer_manager(transfer)
Expand Down

0 comments on commit e85bdb2

Please sign in to comment.