diff --git a/firedrake/mg/embedded.py b/firedrake/mg/embedded.py index 1ea27a6663..752ec6ff13 100644 --- a/firedrake/mg/embedded.py +++ b/firedrake/mg/embedded.py @@ -2,9 +2,7 @@ import ufl import finat.ufl import weakref -from functools import reduce from enum import IntEnum -from operator import and_ from firedrake.petsc import PETSc from firedrake.embedding import get_embedding_dg_element @@ -69,18 +67,18 @@ def __init__(self, *, native_transfers=None, use_averaging=True): self.use_averaging = use_averaging self.caches = {} - def is_native(self, element): + def is_native(self, element, op): if element in self.native_transfers.keys(): - return True + return self.native_transfers[element][op] is not None if isinstance(element.cell, ufl.TensorProductCell) and len(element.sub_elements) > 0: - return reduce(and_, map(self.is_native, element.sub_elements)) + return all(self.is_native(e, op) for e in element.sub_elements) return (element.family() in native_families) and not (element.variant() in non_native_variants) def _native_transfer(self, element, op): try: return self.native_transfers[element][op] except KeyError: - if self.is_native(element): + if self.is_native(element, op): ops = firedrake.prolong, firedrake.restrict, firedrake.inject return self.native_transfers.setdefault(element, ops)[op] return None @@ -248,7 +246,7 @@ def op(self, source, target, transfer_op): if not self.requires_transfer(source_element, transfer_op, source, target): return - if self.is_native(source_element) and self.is_native(target_element): + if all(self.is_native(e, transfer_op) for e in (source_element, target_element)): self._native_transfer(source_element, transfer_op)(source, target) elif type(source_element) is finat.ufl.MixedElement: assert type(target_element) is finat.ufl.MixedElement @@ -313,7 +311,7 @@ def restrict(self, source, target): if not self.requires_transfer(source_element, Op.RESTRICT, source, target): return - if self.is_native(source_element) and self.is_native(target_element): + if all(self.is_native(e, Op.RESTRICT) for e in (source_element, target_element)): self._native_transfer(source_element, Op.RESTRICT)(source, target) elif type(source_element) is finat.ufl.MixedElement: assert type(target_element) is finat.ufl.MixedElement diff --git a/tests/multigrid/test_custom_transfer.py b/tests/multigrid/test_custom_transfer.py index ecb4afc93f..b9760ac7c5 100644 --- a/tests/multigrid/test_custom_transfer.py +++ b/tests/multigrid/test_custom_transfer.py @@ -172,7 +172,8 @@ def prolong_Q(fine, coarse): assert count_Q == -2 -def test_custom_transfer_setting(): +@pytest.mark.parametrize("mode", ("full", pytest.param("partial", marks=pytest.mark.skipcomplexnoslate))) +def test_custom_transfer_setting(mode): mesh = UnitIntervalMesh(2) mh = MeshHierarchy(mesh, 1) mesh = mh[-1] @@ -194,7 +195,12 @@ def myprolong(coarse, fine): options = {"ksp_type": "preonly", "pc_type": "mg"} - transfer = TransferManager(native_transfers={V.ufl_element(): (myprolong, restrict, inject)}) + if mode == "partial": + transfer_ops = (myprolong, None, None) + else: + transfer_ops = (myprolong, restrict, inject) + + transfer = TransferManager(native_transfers={V.ufl_element(): transfer_ops}) problem = LinearVariationalProblem(a, L, uh) solver = LinearVariationalSolver(problem, solver_parameters=options) solver.set_transfer_manager(transfer)