From 38fd784a5011ded8fb741c746525222564975fda Mon Sep 17 00:00:00 2001 From: ksagiyam <46749170+ksagiyam@users.noreply.github.com> Date: Wed, 6 Nov 2024 16:55:11 +0000 Subject: [PATCH] Fix preconditioner when using EquationBC (#3842) --------- Co-authored-by: Aaron Baier-Reinio --- firedrake/bcs.py | 24 ++++++++++++-- firedrake/preconditioners/base.py | 6 ++-- firedrake/solving_utils.py | 2 +- tests/equation_bcs/test_equation_bcs.py | 43 +++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 5 deletions(-) diff --git a/firedrake/bcs.py b/firedrake/bcs.py index ae20ab0e2c..ee38b524e4 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -459,6 +459,9 @@ def extract_form(self, form_type): # DirichletBC is directly used in assembly. return self + def _as_nonlinear_variational_problem_arg(self): + return self + class EquationBC(object): r'''Construct and store EquationBCSplit objects (for `F`, `J`, and `Jp`). @@ -549,12 +552,15 @@ def extract_form(self, form_type): return getattr(self, f"_{form_type}") @PETSc.Log.EventDecorator() - def reconstruct(self, V, subu, u, field): + def reconstruct(self, V, subu, u, field, is_linear): _F = self._F.reconstruct(field=field, V=V, subu=subu, u=u) _J = self._J.reconstruct(field=field, V=V, subu=subu, u=u) _Jp = self._Jp.reconstruct(field=field, V=V, subu=subu, u=u) if all([_F is not None, _J is not None, _Jp is not None]): - return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=self.is_linear) + return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=is_linear) + + def _as_nonlinear_variational_problem_arg(self): + return self class EquationBCSplit(BCBase): @@ -645,6 +651,20 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col ebc.add(bc_temp) return ebc + def _as_nonlinear_variational_problem_arg(self): + # NonlinearVariationalProblem expects EquationBC, not EquationBCSplit. + # -- This method is required when NonlinearVariationalProblem is constructed inside PC. + if len(self.f.arguments()) != 2: + raise NotImplementedError(f"Not expecting a form of rank {len(self.f.arguments())} (!= 2)") + J = self.f + Vcol = J.arguments()[-1].function_space() + u = firedrake.Function(Vcol) + F = ufl_expr.action(J, u) + Vrow = self._function_space + sub_domain = self.sub_domain + bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in self.bcs) + return EquationBC(F == 0, u, sub_domain, bcs=bcs, J=J, V=Vrow) + @PETSc.Log.EventDecorator() def homogenize(bc): diff --git a/firedrake/preconditioners/base.py b/firedrake/preconditioners/base.py index e7b809024e..0bdfc97a37 100644 --- a/firedrake/preconditioners/base.py +++ b/firedrake/preconditioners/base.py @@ -95,11 +95,11 @@ def form(self, obj, *args): if P.getType() == "python": ctx = P.getPythonContext() a = ctx.a - bcs = tuple(ctx.row_bcs) + bcs = tuple(ctx.bcs) else: ctx = get_appctx(pc.getDM()) a = ctx.Jp or ctx.J - bcs = tuple(ctx._problem.bcs) + bcs = ctx.bcs_Jp if len(args): a = a(*args) return a, bcs @@ -121,6 +121,8 @@ def new_snes_ctx(pc, op, bcs, mat_type, fcp=None, options_prefix=None): old_appctx = get_appctx(dm).appctx u = Function(op.arguments()[-1].function_space()) F = action(op, u) + if bcs: + bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in bcs) nprob = NonlinearVariationalProblem(F, u, bcs=bcs, J=op, diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index a9dda71038..9e843016b5 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -369,7 +369,7 @@ def split(self, fields): if isinstance(bc, DirichletBC): bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg, sub_domain=bc.sub_domain) elif isinstance(bc, EquationBC): - bc_temp = bc.reconstruct(field, V, subu, u) + bc_temp = bc.reconstruct(V, subu, u, field, False) if bc_temp is not None: bcs.append(bc_temp) new_problem = NLVP(F, subu, bcs=bcs, J=J, Jp=Jp, diff --git a/tests/equation_bcs/test_equation_bcs.py b/tests/equation_bcs/test_equation_bcs.py index 0037537ae4..087b07aa36 100644 --- a/tests/equation_bcs/test_equation_bcs.py +++ b/tests/equation_bcs/test_equation_bcs.py @@ -354,3 +354,46 @@ def test_EquationBC_mixedpoisson_matfree_fieldsplit(): err.append(nonlinear_poisson_mixed(solver_parameters, mesh_num, porder)) assert abs(math.log2(err[0][0]) - math.log2(err[1][0]) - (porder+1)) < 0.05 + + +def test_equation_bcs_pc(): + mesh = UnitSquareMesh(2**6, 2**6) + CG = FunctionSpace(mesh, "CG", 3) + R = FunctionSpace(mesh, "R", 0) + V = CG * R + f = Function(V) + u, l = split(f) + v, w = split(TestFunction(V)) + x, y = SpatialCoordinate(mesh) + exact = cos(2 * pi * x) * cos(2 * pi * y) + g = Function(CG).interpolate(8 * pi**2 * exact) + F = inner(grad(u), grad(v)) * dx + inner(l, w) * dx - inner(g, v) * dx + bc = EquationBC(inner((u - exact), v) * ds == 0, f, (1, 2, 3, 4), V=V.sub(0)) + params = { + "mat_type": "matfree", + "ksp_type": "fgmres", + "pc_type": "fieldsplit", + "pc_fieldsplit_type": "schur", + "pc_fieldsplit_schur_fact_type": "full", + "pc_fieldsplit_0_fields": "0", + "pc_fieldsplit_1_fields": "1", + "fieldsplit_0": { + "ksp_type": "preonly", + "pc_type": "python", + "pc_python_type": "firedrake.AssembledPC", + "assembled": { + "ksp_type": "cg", + "pc_type": "asm", + }, + }, + "fieldsplit_1": { + "ksp_type": "gmres", + "max_it": 1, + "convergence_test": "skip", + } + } + problem = NonlinearVariationalProblem(F, f, bcs=[bc]) + solver = NonlinearVariationalSolver(problem, solver_parameters=params) + solver.solve() + error = assemble(inner(u - exact, u - exact) * dx)**0.5 + assert error < 1.e-7