Skip to content

Commit

Permalink
Fix preconditioner when using EquationBC (#3842)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Aaron Baier-Reinio <baierreinio@maths.ox.ac.uk>
  • Loading branch information
ksagiyam and ABaierReinio authored Nov 6, 2024
1 parent 35ba765 commit 38fd784
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 5 deletions.
24 changes: 22 additions & 2 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ def extract_form(self, form_type):
# DirichletBC is directly used in assembly.
return self

def _as_nonlinear_variational_problem_arg(self):
return self


class EquationBC(object):
r'''Construct and store EquationBCSplit objects (for `F`, `J`, and `Jp`).
Expand Down Expand Up @@ -549,12 +552,15 @@ def extract_form(self, form_type):
return getattr(self, f"_{form_type}")

@PETSc.Log.EventDecorator()
def reconstruct(self, V, subu, u, field):
def reconstruct(self, V, subu, u, field, is_linear):
_F = self._F.reconstruct(field=field, V=V, subu=subu, u=u)
_J = self._J.reconstruct(field=field, V=V, subu=subu, u=u)
_Jp = self._Jp.reconstruct(field=field, V=V, subu=subu, u=u)
if all([_F is not None, _J is not None, _Jp is not None]):
return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=self.is_linear)
return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=is_linear)

def _as_nonlinear_variational_problem_arg(self):
return self


class EquationBCSplit(BCBase):
Expand Down Expand Up @@ -645,6 +651,20 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
ebc.add(bc_temp)
return ebc

def _as_nonlinear_variational_problem_arg(self):
# NonlinearVariationalProblem expects EquationBC, not EquationBCSplit.
# -- This method is required when NonlinearVariationalProblem is constructed inside PC.
if len(self.f.arguments()) != 2:
raise NotImplementedError(f"Not expecting a form of rank {len(self.f.arguments())} (!= 2)")
J = self.f
Vcol = J.arguments()[-1].function_space()
u = firedrake.Function(Vcol)
F = ufl_expr.action(J, u)
Vrow = self._function_space
sub_domain = self.sub_domain
bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in self.bcs)
return EquationBC(F == 0, u, sub_domain, bcs=bcs, J=J, V=Vrow)


@PETSc.Log.EventDecorator()
def homogenize(bc):
Expand Down
6 changes: 4 additions & 2 deletions firedrake/preconditioners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def form(self, obj, *args):
if P.getType() == "python":
ctx = P.getPythonContext()
a = ctx.a
bcs = tuple(ctx.row_bcs)
bcs = tuple(ctx.bcs)
else:
ctx = get_appctx(pc.getDM())
a = ctx.Jp or ctx.J
bcs = tuple(ctx._problem.bcs)
bcs = ctx.bcs_Jp
if len(args):
a = a(*args)
return a, bcs
Expand All @@ -121,6 +121,8 @@ def new_snes_ctx(pc, op, bcs, mat_type, fcp=None, options_prefix=None):
old_appctx = get_appctx(dm).appctx
u = Function(op.arguments()[-1].function_space())
F = action(op, u)
if bcs:
bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in bcs)
nprob = NonlinearVariationalProblem(F, u,
bcs=bcs,
J=op,
Expand Down
2 changes: 1 addition & 1 deletion firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def split(self, fields):
if isinstance(bc, DirichletBC):
bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg, sub_domain=bc.sub_domain)
elif isinstance(bc, EquationBC):
bc_temp = bc.reconstruct(field, V, subu, u)
bc_temp = bc.reconstruct(V, subu, u, field, False)
if bc_temp is not None:
bcs.append(bc_temp)
new_problem = NLVP(F, subu, bcs=bcs, J=J, Jp=Jp,
Expand Down
43 changes: 43 additions & 0 deletions tests/equation_bcs/test_equation_bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,46 @@ def test_EquationBC_mixedpoisson_matfree_fieldsplit():
err.append(nonlinear_poisson_mixed(solver_parameters, mesh_num, porder))

assert abs(math.log2(err[0][0]) - math.log2(err[1][0]) - (porder+1)) < 0.05


def test_equation_bcs_pc():
mesh = UnitSquareMesh(2**6, 2**6)
CG = FunctionSpace(mesh, "CG", 3)
R = FunctionSpace(mesh, "R", 0)
V = CG * R
f = Function(V)
u, l = split(f)
v, w = split(TestFunction(V))
x, y = SpatialCoordinate(mesh)
exact = cos(2 * pi * x) * cos(2 * pi * y)
g = Function(CG).interpolate(8 * pi**2 * exact)
F = inner(grad(u), grad(v)) * dx + inner(l, w) * dx - inner(g, v) * dx
bc = EquationBC(inner((u - exact), v) * ds == 0, f, (1, 2, 3, 4), V=V.sub(0))
params = {
"mat_type": "matfree",
"ksp_type": "fgmres",
"pc_type": "fieldsplit",
"pc_fieldsplit_type": "schur",
"pc_fieldsplit_schur_fact_type": "full",
"pc_fieldsplit_0_fields": "0",
"pc_fieldsplit_1_fields": "1",
"fieldsplit_0": {
"ksp_type": "preonly",
"pc_type": "python",
"pc_python_type": "firedrake.AssembledPC",
"assembled": {
"ksp_type": "cg",
"pc_type": "asm",
},
},
"fieldsplit_1": {
"ksp_type": "gmres",
"max_it": 1,
"convergence_test": "skip",
}
}
problem = NonlinearVariationalProblem(F, f, bcs=[bc])
solver = NonlinearVariationalSolver(problem, solver_parameters=params)
solver.solve()
error = assemble(inner(u - exact, u - exact) * dx)**0.5
assert error < 1.e-7

0 comments on commit 38fd784

Please sign in to comment.