Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Nov 19, 2024
1 parent 562852f commit d6ef63a
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 29 deletions.
4 changes: 3 additions & 1 deletion firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,12 @@ def function_arg(self):
return self._function_arg

@PETSc.Log.EventDecorator()
def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=False):
def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=False, indices=()):
fs = self.function_space()
if V is None:
V = fs
for index in indices:
V = V.sub(index)
if g is None:
g = self._original_arg
if sub_domain is None:
Expand Down
4 changes: 2 additions & 2 deletions firedrake/eigensolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def __init__(self, A, M=None, bcs=None, bc_shift=0.0, restrict=True):
M = inner(u, v) * dx

if restrict and bcs: # assumed u and v are in the same space here
V_res = RestrictedFunctionSpace(self.output_space, boundary_set=set(bc.sub_domain for bc in bcs))
V_res = RestrictedFunctionSpace(self.output_space, bcs)
u_res = TrialFunction(V_res)
v_res = TestFunction(V_res)
self.M = replace(M, {u: u_res, v: v_res})
self.A = replace(A, {u: u_res, v: v_res})
self.bcs = [bc.reconstruct(V=V_res) for bc in bcs]
self.bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs]
self.restricted_space = V_res
else:
self.A = A # LHS
Expand Down
27 changes: 14 additions & 13 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,22 +995,23 @@ def callable():

if len(V) == 1:
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs))
else:
if len(arguments) > 0:
raise NotImplementedError("Cannot interpolate a mixed expression with %d arguments" % len(arguments))
elif (hasattr(expr, "subfunctions") and len(expr.subfunctions) == len(V)
and all(sub_expr.ufl_shape == Vsub.value_shape for Vsub, sub_expr in zip(V, expr.subfunctions))):
for Vsub, sub_tensor, sub_expr in zip(V, tensor, expr.subfunctions):
loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs))
elif len(arguments) == 0:
# Unflatten the expression into each of the mixed components
offset = 0
for Vsub, usub in zip(V, tensor):
shape = Vsub.value_shape
rank = len(shape)
components = [expr[offset + j] for j in range(Vsub.value_size)]
if rank == 0:
Vexpr = components[0]
elif rank == 1:
Vexpr = ufl.as_vector(components)
for Vsub, sub_tensor in zip(V, tensor):
if len(Vsub.value_shape) == 0:
sub_expr = expr[offset]
else:
Vexpr = ufl.as_tensor(numpy.reshape(components, Vsub.value_shape).tolist())
loops.extend(_interpolator(Vsub, usub, Vexpr, subset, arguments, access, bcs=bcs))
components = [expr[offset + j] for j in range(Vsub.value_size)]
sub_expr = ufl.as_tensor(numpy.reshape(components, Vsub.value_shape))
loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs))
offset += Vsub.value_size
else:
raise NotImplementedError("Cannot interpolate a mixed expression with %d arguments" % len(arguments))

if bcs and len(arguments) == 0:
loops.extend(partial(bc.apply, f) for bc in bcs)
Expand Down
18 changes: 5 additions & 13 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@
"NonlinearVariationalSolver"]


def get_sub(V, indices):
for i in indices:
V = V.sub(i)
return V


def check_pde_args(F, J, Jp):
if not isinstance(F, (ufl.BaseForm, slate.slate.TensorBase)):
raise TypeError("Provided residual is a '%s', not a BaseForm or Slate Tensor" % type(F).__name__)
Expand Down Expand Up @@ -94,19 +88,17 @@ def __init__(self, F, u, bcs=None, J=None,
self.restrict = restrict

if restrict and bcs:
V_res = RestrictedFunctionSpace(V, boundary_set=bcs)
bcs = [bc.reconstruct(V=get_sub(V_res, bc._indices)) for bc in bcs]
V_res = RestrictedFunctionSpace(V, bcs)
bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs]
self.u_restrict = Function(V_res).interpolate(u)
v_res, u_res = TestFunction(V_res), TrialFunction(V_res)
F_arg, = F.arguments()
replace_dict = {F_arg: v_res}
replace_dict[self.u] = self.u_restrict
self.F = replace(F, replace_dict)
self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict})
v_arg, u_arg = self.J.arguments()
self.J = replace(self.J, {v_arg: v_res, u_arg: u_res})
self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict})
if self.Jp:
v_arg, u_arg = self.Jp.arguments()
self.Jp = replace(self.Jp, {v_arg: v_res, u_arg: u_res})
self.Jp = replace(self.Jp, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict})
self.restricted_space = V_res
else:
self.u_restrict = u
Expand Down

0 comments on commit d6ef63a

Please sign in to comment.