Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MixedFunctionSpace: interpolate and restrict #3868

Merged
merged 14 commits into from
Dec 6, 2024
10 changes: 7 additions & 3 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,10 @@ def __init__(self, V, g, sub_domain, method=None):
warnings.simplefilter('always', DeprecationWarning)
warnings.warn("Selecting a bcs method is deprecated. Only topological association is supported",
DeprecationWarning)
if len(V.boundary_set) and sub_domain not in V.boundary_set:
raise ValueError(f"Sub-domain {sub_domain} not in the boundary set of the restricted space.")
if len(V.boundary_set):
subs = [sub_domain] if type(sub_domain) in {int, str} else sub_domain
if any(sub not in V.boundary_set for sub in subs):
raise ValueError(f"Sub-domain {sub_domain} not in the boundary set of the restricted space.")
super().__init__(V, sub_domain)
if len(V) > 1:
raise ValueError("Cannot apply boundary conditions on mixed spaces directly.\n"
Expand All @@ -311,10 +313,12 @@ def function_arg(self):
return self._function_arg

@PETSc.Log.EventDecorator()
def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=False):
def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=False, indices=()):
fs = self.function_space()
if V is None:
V = fs
for index in indices:
V = V.sub(index)
if g is None:
g = self._original_arg
if sub_domain is None:
Expand Down
5 changes: 2 additions & 3 deletions firedrake/eigensolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Specify and solve finite element eigenproblems."""
from firedrake.assemble import assemble
from firedrake.bcs import DirichletBC
from firedrake.function import Function
from firedrake.functionspace import RestrictedFunctionSpace
from firedrake.ufl_expr import TrialFunction, TestFunction
Expand Down Expand Up @@ -71,12 +70,12 @@ def __init__(self, A, M=None, bcs=None, bc_shift=0.0, restrict=True):
M = inner(u, v) * dx

if restrict and bcs: # assumed u and v are in the same space here
V_res = RestrictedFunctionSpace(self.output_space, boundary_set=set([bc.sub_domain for bc in bcs]))
V_res = RestrictedFunctionSpace(self.output_space, bcs)
u_res = TrialFunction(V_res)
v_res = TestFunction(V_res)
self.M = replace(M, {u: u_res, v: v_res})
self.A = replace(A, {u: u_res, v: v_res})
self.bcs = [DirichletBC(V_res, bc.function_arg, bc.sub_domain) for bc in bcs]
self.bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs]
self.restricted_space = V_res
else:
self.A = A # LHS
Expand Down
25 changes: 20 additions & 5 deletions firedrake/functionspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,20 +308,35 @@ def rec(eles):


@PETSc.Log.EventDecorator("CreateFunctionSpace")
def RestrictedFunctionSpace(function_space, name=None, boundary_set=[]):
def RestrictedFunctionSpace(function_space, boundary_set=[], name=None):
"""Create a :class:`.RestrictedFunctionSpace`.

Parameters
----------
function_space :
FunctionSpace object to restrict
name :
An optional name for the function space.
boundary_set :
A set of subdomains of the mesh in which Dirichlet boundary conditions
will be applied.
name :
An optional name for the function space.

"""
return impl.WithGeometry.create(impl.RestrictedFunctionSpace(function_space, name=name,
boundary_set=boundary_set),
if len(function_space) > 1:
return MixedFunctionSpace([RestrictedFunctionSpace(Vsub, boundary_set=boundary_set)
for Vsub in function_space], name=name)

if len(boundary_set) > 0 and all(hasattr(bc, "sub_domain") for bc in boundary_set):
bcs = boundary_set
boundary_set = []
for bc in bcs:
if bc.function_space() == function_space:
if type(bc.sub_domain) in {str, int}:
boundary_set.append(bc.sub_domain)
else:
boundary_set.extend(bc.sub_domain)
dham marked this conversation as resolved.
Show resolved Hide resolved

return impl.WithGeometry.create(impl.RestrictedFunctionSpace(function_space,
boundary_set=boundary_set,
name=name),
function_space.mesh())
8 changes: 3 additions & 5 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,17 +860,16 @@ class RestrictedFunctionSpace(FunctionSpace):
output of the solver.

:arg function_space: The :class:`FunctionSpace` to restrict.
:kwarg boundary_set: A set of subdomains on which a DirichletBC will be applied.
:kwarg name: An optional name for this :class:`RestrictedFunctionSpace`,
useful for later identification.
:kwarg boundary_set: A set of subdomains on which a DirichletBC will be
applied.

Notes
-----
If using this class to solve or similar, a list of DirichletBCs will still
need to be specified on this space and passed into the function.
"""
def __init__(self, function_space, name=None, boundary_set=frozenset()):
def __init__(self, function_space, boundary_set=frozenset(), name=None):
label = ""
for boundary_domain in boundary_set:
label += str(boundary_domain)
Expand All @@ -884,8 +883,7 @@ def __init__(self, function_space, name=None, boundary_set=frozenset()):
label=self._label)
self.function_space = function_space
self.name = name or (function_space.name or "Restricted" + "_"
+ "_".join(sorted(
[str(i) for i in self.boundary_set])))
+ "_".join(sorted(map(str, self.boundary_set))))

def set_shared_data(self):
sdata = get_shared_data(self._mesh, self.ufl_element(), self.boundary_set)
Expand Down
25 changes: 21 additions & 4 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,10 +992,27 @@ def callable():
if numpy.prod(expr.ufl_shape, dtype=int) != V.value_size:
raise RuntimeError('Expression of length %d required, got length %d'
% (V.value_size, numpy.prod(expr.ufl_shape, dtype=int)))
if len(V) > 1:
raise NotImplementedError(
"UFL expressions for mixed functions are not yet supported.")
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs))

if len(V) == 1:
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs))
elif (hasattr(expr, "subfunctions") and len(expr.subfunctions) == len(V)
and all(sub_expr.ufl_shape == Vsub.value_shape for Vsub, sub_expr in zip(V, expr.subfunctions))):
for Vsub, sub_tensor, sub_expr in zip(V, tensor, expr.subfunctions):
loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs))
elif len(arguments) == 0:
# Unflatten the expression into each of the mixed components
offset = 0
for Vsub, sub_tensor in zip(V, tensor):
if len(Vsub.value_shape) == 0:
sub_expr = expr[offset]
else:
components = [expr[offset + j] for j in range(Vsub.value_size)]
sub_expr = ufl.as_tensor(numpy.reshape(components, Vsub.value_shape))
loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs))
offset += Vsub.value_size
else:
raise NotImplementedError("Cannot interpolate a mixed expression with %d arguments" % len(arguments))

if bcs and len(arguments) == 0:
loops.extend(partial(bc.apply, f) for bc in bcs)

Expand Down
11 changes: 5 additions & 6 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,18 @@ def __init__(self, F, u, bcs=None, J=None,
self.restrict = restrict

if restrict and bcs:
V_res = RestrictedFunctionSpace(V, boundary_set=set([bc.sub_domain for bc in bcs]))
bcs = [DirichletBC(V_res, bc.function_arg, bc.sub_domain) for bc in bcs]
V_res = RestrictedFunctionSpace(V, bcs)
bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs]
self.u_restrict = Function(V_res).interpolate(u)
v_res, u_res = TestFunction(V_res), TrialFunction(V_res)
F_arg, = F.arguments()
replace_dict = {F_arg: v_res}
replace_dict[self.u] = self.u_restrict
replace_dict = {F_arg: v_res, self.u: self.u_restrict}
self.F = replace(F, replace_dict)
v_arg, u_arg = self.J.arguments()
self.J = replace(self.J, {v_arg: v_res, u_arg: u_res})
self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict})
if self.Jp:
v_arg, u_arg = self.Jp.arguments()
self.Jp = replace(self.Jp, {v_arg: v_res, u_arg: u_res})
self.Jp = replace(self.Jp, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict})
self.restricted_space = V_res
else:
self.u_restrict = u
Expand Down
43 changes: 40 additions & 3 deletions tests/regression/test_restricted_function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,54 @@ def test_restricted_function_space_coord_change(j):
new_mesh = Mesh(Function(V).interpolate(as_vector([x, y])))
new_V = FunctionSpace(new_mesh, "CG", j)
bc = DirichletBC(new_V, 0, 1)
new_V_restricted = RestrictedFunctionSpace(new_V, name="Restricted", boundary_set=[1])
new_V_restricted = RestrictedFunctionSpace(new_V, boundary_set=[1], name="Restricted")

compare_function_space_assembly(new_V, new_V_restricted, [bc])


def test_restricted_mixed_space():
mesh = UnitSquareMesh(1, 1)
V = FunctionSpace(mesh, "RT", 1)
Q = FunctionSpace(mesh, "DG", 0)
Z = V * Q
bcs = [DirichletBC(Z.sub(0), 0, [1])]
Z_restricted = RestrictedFunctionSpace(Z, bcs)
compare_function_space_assembly(Z, Z_restricted, bcs)


def test_poisson_restricted_mixed_space():
mesh = UnitSquareMesh(1, 1)
V = FunctionSpace(mesh, "RT", 1)
Q = FunctionSpace(mesh, "DG", 0)
Z = V*Q

u, p = TrialFunctions(Z)
v, q = TestFunctions(Z)
a = inner(u, v)*dx + inner(p, div(v))*dx + inner(div(u), q)*dx
L = inner(1, q)*dx

bcs = [DirichletBC(Z.sub(0), 0, [1])]

w = Function(Z)
problem = LinearVariationalProblem(a, L, w, bcs=bcs, restrict=False)
solver = LinearVariationalSolver(problem)
solver.solve()

w2 = Function(Z)
problem = LinearVariationalProblem(a, L, w2, bcs=bcs, restrict=True)
solver = LinearVariationalSolver(problem)
solver.solve()

assert errornorm(w.subfunctions[0], w2.subfunctions[0]) < 1.e-12
assert errornorm(w.subfunctions[1], w2.subfunctions[1]) < 1.e-12


@pytest.mark.parametrize(["i", "j"], [(1, 0), (2, 0), (2, 1)])
def test_restricted_mixed_spaces(i, j):
def test_poisson_mixed_restricted_spaces(i, j):
mesh = UnitSquareMesh(1, 1)
DG = FunctionSpace(mesh, "DG", j)
CG = VectorFunctionSpace(mesh, "CG", i)
CG_res = RestrictedFunctionSpace(CG, "Restricted", boundary_set=[4])
CG_res = RestrictedFunctionSpace(CG, boundary_set=[4], name="Restricted")
W = CG * DG
W_res = CG_res * DG
bc = DirichletBC(W.sub(0), 0, 4)
Expand Down
Loading