diff --git a/firedrake/bcs.py b/firedrake/bcs.py index e6b60a0e4c..1fc6aa5a6a 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -23,7 +23,7 @@ from firedrake.adjoint_utils.dirichletbc import DirichletBCMixin from firedrake.petsc import PETSc -__all__ = ['DirichletBC', 'homogenize', 'EquationBC', 'restricted_function_space'] +__all__ = ['DirichletBC', 'homogenize', 'EquationBC'] class BCBase(object): @@ -692,36 +692,57 @@ def homogenize(bc): raise TypeError("homogenize only takes a DirichletBC or a list/tuple of DirichletBCs") -@PETSc.Log.EventDecorator("CreateFunctionSpace") -def restricted_function_space(V, bcs, name=None): - """Create a :class:`.RestrictedFunctionSpace` from a list of boundary conditions. +def extract_subdomain_ids(bcs): + """Return a tuple of subdomain ids for each component of a MixedFunctionSpace. Parameters ---------- - V : - FunctionSpace object to restrict bcs : A list of boundary conditions. - name : - An optional name for the function space. + + Returns + ------- + A tuple of subdomain ids for each component of a MixedFunctionSpace. """ - if len(V) > 1: - spaces = [restricted_function_space(Vsub, bcs) for Vsub in V] - return firedrake.MixedFunctionSpace(spaces, name=name) - - if not isinstance(bcs, (tuple, list)): - bcs = (bcs,) - - boundary_set = [] - for bc in bcs: - if bc.function_space() != V: - continue - for dbc in bc.dirichlet_bcs(): - if isinstance(dbc.sub_domain, (str, int)): - boundary_set.append(dbc.sub_domain) - else: - boundary_set.extend(dbc.sub_domain) - if len(boundary_set) == 0: + if len(bcs) == 0: + return None + + V = bcs[0].function_space() + while V.parent: + V = V.parent + + _chain = itertools.chain.from_iterable + subdomain_ids = tuple(tuple(_chain(as_tuple(bc.sub_domain) + for bc in bcs if bc.function_space() == Vsub)) + for Vsub in V) + return subdomain_ids + + +def restricted_function_space(V, ids): + """Create a :class:`.RestrictedFunctionSpace` from a tuple of subdomain ids. + + Parameters + ---------- + V : + FunctionSpace object to restrict + ids : + A tuple of subdomain ids. + + Returns + ------- + The RestrictedFunctionSpace. + + """ + if not ids: return V - return firedrake.RestrictedFunctionSpace(V, boundary_set=boundary_set, name=name) + + assert len(ids) == len(V) + spaces = [Vsub if len(boundary_set) == 0 else + firedrake.RestrictedFunctionSpace(Vsub, boundary_set=boundary_set) + for Vsub, boundary_set in zip(V, ids)] + + if len(spaces) == 1: + return spaces[0] + else: + return firedrake.MixedFunctionSpace(spaces) diff --git a/firedrake/eigensolver.py b/firedrake/eigensolver.py index 10dc58abad..ad05815527 100644 --- a/firedrake/eigensolver.py +++ b/firedrake/eigensolver.py @@ -1,6 +1,6 @@ """Specify and solve finite element eigenproblems.""" from firedrake.assemble import assemble -from firedrake.bcs import restricted_function_space +from firedrake.bcs import extract_subdomain_ids, restricted_function_space from firedrake.function import Function from firedrake.ufl_expr import TrialFunction, TestFunction from firedrake import utils @@ -70,7 +70,7 @@ def __init__(self, A, M=None, bcs=None, bc_shift=0.0, restrict=True): M = inner(u, v) * dx if restrict and bcs: # assumed u and v are in the same space here - V_res = restricted_function_space(self.output_space, bcs) + V_res = restricted_function_space(self.output_space, extract_subdomain_ids(bcs)) u_res = TrialFunction(V_res) v_res = TestFunction(V_res) self.M = replace(M, {u: u_res, v: v_res}) diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 89a80ab593..4a1ac396c5 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -11,7 +11,7 @@ ) from firedrake.function import Function from firedrake.ufl_expr import TrialFunction, TestFunction -from firedrake.bcs import DirichletBC, EquationBC, restricted_function_space +from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin from ufl import replace @@ -87,7 +87,7 @@ def __init__(self, F, u, bcs=None, J=None, self.restrict = restrict if restrict and bcs: - V_res = restricted_function_space(V, bcs) + V_res = restricted_function_space(V, extract_subdomain_ids(bcs)) bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs] self.u_restrict = Function(V_res).interpolate(u) v_res, u_res = TestFunction(V_res), TrialFunction(V_res) diff --git a/tests/regression/test_restricted_function_space.py b/tests/regression/test_restricted_function_space.py index ae3ea28ed4..51e2139ae6 100644 --- a/tests/regression/test_restricted_function_space.py +++ b/tests/regression/test_restricted_function_space.py @@ -173,16 +173,6 @@ def test_restricted_function_space_coord_change(j): compare_function_space_assembly(new_V, new_V_restricted, [bc]) -def test_restricted_mixed_space(): - mesh = UnitSquareMesh(1, 1) - V = FunctionSpace(mesh, "RT", 1) - Q = FunctionSpace(mesh, "DG", 0) - Z = V * Q - bcs = [DirichletBC(Z.sub(0), 0, [1])] - Z_restricted = restricted_function_space(Z, bcs) - compare_function_space_assembly(Z, Z_restricted, bcs) - - def test_poisson_restricted_mixed_space(): mesh = UnitSquareMesh(1, 1) V = FunctionSpace(mesh, "RT", 1)