Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix preconditioner when using EquationBC #3842

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ def extract_form(self, form_type):
# DirichletBC is directly used in assembly.
return self

def _as_nonlinear_variational_problem_arg(self):
return self


class EquationBC(object):
r'''Construct and store EquationBCSplit objects (for `F`, `J`, and `Jp`).
Expand Down Expand Up @@ -549,12 +552,15 @@ def extract_form(self, form_type):
return getattr(self, f"_{form_type}")

@PETSc.Log.EventDecorator()
def reconstruct(self, V, subu, u, field):
def reconstruct(self, V, subu, u, field, is_linear):
_F = self._F.reconstruct(field=field, V=V, subu=subu, u=u)
_J = self._J.reconstruct(field=field, V=V, subu=subu, u=u)
_Jp = self._Jp.reconstruct(field=field, V=V, subu=subu, u=u)
if all([_F is not None, _J is not None, _Jp is not None]):
return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=self.is_linear)
return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=is_linear)

def _as_nonlinear_variational_problem_arg(self):
return self


class EquationBCSplit(BCBase):
Expand Down Expand Up @@ -645,6 +651,20 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
ebc.add(bc_temp)
return ebc

def _as_nonlinear_variational_problem_arg(self):
# NonlinearVariationalProblem expects EquationBC, not EquationBCSplit.
# -- This method is required when NonlinearVariationalProblem is constructed inside PC.
if len(self.f.arguments()) != 2:
raise NotImplementedError(f"Not expecting a form of rank {len(self.f.arguments())} (!= 2)")
J = self.f
Vcol = J.arguments()[-1].function_space()
u = firedrake.Function(Vcol)
F = ufl_expr.action(J, u)
Vrow = self._function_space
sub_domain = self.sub_domain
bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in self.bcs)
return EquationBC(F == 0, u, sub_domain, bcs=bcs, J=J, V=Vrow)


@PETSc.Log.EventDecorator()
def homogenize(bc):
Expand Down
6 changes: 4 additions & 2 deletions firedrake/preconditioners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def form(self, obj, *args):
if P.getType() == "python":
ctx = P.getPythonContext()
a = ctx.a
bcs = tuple(ctx.row_bcs)
bcs = tuple(ctx.bcs)
else:
ctx = get_appctx(pc.getDM())
a = ctx.Jp or ctx.J
bcs = tuple(ctx._problem.bcs)
bcs = ctx.bcs_Jp
if len(args):
a = a(*args)
return a, bcs
Expand All @@ -121,6 +121,8 @@ def new_snes_ctx(pc, op, bcs, mat_type, fcp=None, options_prefix=None):
old_appctx = get_appctx(dm).appctx
u = Function(op.arguments()[-1].function_space())
F = action(op, u)
if bcs:
bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in bcs)
nprob = NonlinearVariationalProblem(F, u,
bcs=bcs,
J=op,
Expand Down
2 changes: 1 addition & 1 deletion firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def split(self, fields):
if isinstance(bc, DirichletBC):
bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg, sub_domain=bc.sub_domain)
elif isinstance(bc, EquationBC):
bc_temp = bc.reconstruct(field, V, subu, u)
bc_temp = bc.reconstruct(V, subu, u, field, False)
if bc_temp is not None:
bcs.append(bc_temp)
new_problem = NLVP(F, subu, bcs=bcs, J=J, Jp=Jp,
Expand Down
43 changes: 43 additions & 0 deletions tests/equation_bcs/test_equation_bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,46 @@ def test_EquationBC_mixedpoisson_matfree_fieldsplit():
err.append(nonlinear_poisson_mixed(solver_parameters, mesh_num, porder))

assert abs(math.log2(err[0][0]) - math.log2(err[1][0]) - (porder+1)) < 0.05


def test_equation_bcs_pc():
mesh = UnitSquareMesh(2**6, 2**6)
CG = FunctionSpace(mesh, "CG", 3)
R = FunctionSpace(mesh, "R", 0)
V = CG * R
f = Function(V)
u, l = split(f)
v, w = split(TestFunction(V))
x, y = SpatialCoordinate(mesh)
exact = cos(2 * pi * x) * cos(2 * pi * y)
g = Function(CG).interpolate(8 * pi**2 * exact)
F = inner(grad(u), grad(v)) * dx + inner(l, w) * dx - inner(g, v) * dx
bc = EquationBC(inner((u - exact), v) * ds == 0, f, (1, 2, 3, 4), V=V.sub(0))
params = {
"mat_type": "matfree",
"ksp_type": "fgmres",
"pc_type": "fieldsplit",
"pc_fieldsplit_type": "schur",
"pc_fieldsplit_schur_fact_type": "full",
"pc_fieldsplit_0_fields": "0",
"pc_fieldsplit_1_fields": "1",
"fieldsplit_0": {
"ksp_type": "preonly",
"pc_type": "python",
"pc_python_type": "firedrake.AssembledPC",
"assembled": {
"ksp_type": "cg",
"pc_type": "asm",
},
},
"fieldsplit_1": {
"ksp_type": "gmres",
"max_it": 1,
"convergence_test": "skip",
}
}
problem = NonlinearVariationalProblem(F, f, bcs=[bc])
solver = NonlinearVariationalSolver(problem, solver_parameters=params)
solver.solve()
error = assemble(inner(u - exact, u - exact) * dx)**0.5
assert error < 1.e-7
Loading