diff --git a/firedrake/bcs.py b/firedrake/bcs.py index 0b9b492bfc..311a5eb8bb 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -313,10 +313,12 @@ def function_arg(self): return self._function_arg @PETSc.Log.EventDecorator() - def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=False): + def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=False, indices=()): fs = self.function_space() if V is None: V = fs + for index in indices: + V = V.sub(index) if g is None: g = self._original_arg if sub_domain is None: diff --git a/firedrake/eigensolver.py b/firedrake/eigensolver.py index f33618f950..fab3766ae0 100644 --- a/firedrake/eigensolver.py +++ b/firedrake/eigensolver.py @@ -70,12 +70,12 @@ def __init__(self, A, M=None, bcs=None, bc_shift=0.0, restrict=True): M = inner(u, v) * dx if restrict and bcs: # assumed u and v are in the same space here - V_res = RestrictedFunctionSpace(self.output_space, boundary_set=set(bc.sub_domain for bc in bcs)) + V_res = RestrictedFunctionSpace(self.output_space, bcs) u_res = TrialFunction(V_res) v_res = TestFunction(V_res) self.M = replace(M, {u: u_res, v: v_res}) self.A = replace(A, {u: u_res, v: v_res}) - self.bcs = [bc.reconstruct(V=V_res) for bc in bcs] + self.bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs] self.restricted_space = V_res else: self.A = A # LHS diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 9de5f048e8..d02b4dcfc0 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -995,22 +995,23 @@ def callable(): if len(V) == 1: loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs)) - else: - if len(arguments) > 0: - raise NotImplementedError("Cannot interpolate a mixed expression with %d arguments" % len(arguments)) + elif (hasattr(expr, "subfunctions") and len(expr.subfunctions) == len(V) + and all(sub_expr.ufl_shape == Vsub.value_shape for Vsub, sub_expr in zip(V, expr.subfunctions))): + for Vsub, sub_tensor, sub_expr in zip(V, tensor, expr.subfunctions): + loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) + elif len(arguments) == 0: + # Unflatten the expression into each of the mixed components offset = 0 - for Vsub, usub in zip(V, tensor): - shape = Vsub.value_shape - rank = len(shape) - components = [expr[offset + j] for j in range(Vsub.value_size)] - if rank == 0: - Vexpr = components[0] - elif rank == 1: - Vexpr = ufl.as_vector(components) + for Vsub, sub_tensor in zip(V, tensor): + if len(Vsub.value_shape) == 0: + sub_expr = expr[offset] else: - Vexpr = ufl.as_tensor(numpy.reshape(components, Vsub.value_shape).tolist()) - loops.extend(_interpolator(Vsub, usub, Vexpr, subset, arguments, access, bcs=bcs)) + components = [expr[offset + j] for j in range(Vsub.value_size)] + sub_expr = ufl.as_tensor(numpy.reshape(components, Vsub.value_shape)) + loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs)) offset += Vsub.value_size + else: + raise NotImplementedError("Cannot interpolate a mixed expression with %d arguments" % len(arguments)) if bcs and len(arguments) == 0: loops.extend(partial(bc.apply, f) for bc in bcs) diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 65a7078d09..bdea03c2d8 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -22,12 +22,6 @@ "NonlinearVariationalSolver"] -def get_sub(V, indices): - for i in indices: - V = V.sub(i) - return V - - def check_pde_args(F, J, Jp): if not isinstance(F, (ufl.BaseForm, slate.slate.TensorBase)): raise TypeError("Provided residual is a '%s', not a BaseForm or Slate Tensor" % type(F).__name__) @@ -94,19 +88,17 @@ def __init__(self, F, u, bcs=None, J=None, self.restrict = restrict if restrict and bcs: - V_res = RestrictedFunctionSpace(V, boundary_set=bcs) - bcs = [bc.reconstruct(V=get_sub(V_res, bc._indices)) for bc in bcs] + V_res = RestrictedFunctionSpace(V, bcs) + bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs] self.u_restrict = Function(V_res).interpolate(u) v_res, u_res = TestFunction(V_res), TrialFunction(V_res) F_arg, = F.arguments() - replace_dict = {F_arg: v_res} - replace_dict[self.u] = self.u_restrict - self.F = replace(F, replace_dict) + self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict}) v_arg, u_arg = self.J.arguments() - self.J = replace(self.J, {v_arg: v_res, u_arg: u_res}) + self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict}) if self.Jp: v_arg, u_arg = self.Jp.arguments() - self.Jp = replace(self.Jp, {v_arg: v_res, u_arg: u_res}) + self.Jp = replace(self.Jp, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict}) self.restricted_space = V_res else: self.u_restrict = u