Skip to content

Commit

Permalink
format and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
lroberts36 committed Sep 24, 2024
1 parent 83226a3 commit 3c81e68
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 51 deletions.
9 changes: 4 additions & 5 deletions example/poisson_gmg/poisson_package.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,11 @@ std::shared_ptr<StateDescriptor> Initialize(ParameterInput *pin) {
pkg.get(), bicgstab_params, eq);
pkg->AddParam<>("MGBiCGSTABsolver", bicg_solver,
parthenon::Params::Mutability::Mutable);

parthenon::solvers::CGParams cg_params(pin, "poisson/solver_params");
parthenon::solvers::CGSolver<u, rhs, PoissonEquation> cg_solver(
pkg.get(), cg_params, eq);
pkg->AddParam<>("MGCGsolver", cg_solver,
parthenon::Params::Mutability::Mutable);
parthenon::solvers::CGSolver<u, rhs, PoissonEquation> cg_solver(pkg.get(), cg_params,
eq);
pkg->AddParam<>("MGCGsolver", cg_solver, parthenon::Params::Mutability::Mutable);

using namespace parthenon::refinement_ops;
auto mD = Metadata(
Expand Down
50 changes: 27 additions & 23 deletions src/prolong_restrict/pr_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ using ProlongateSharedMinMod = ProlongateSharedGeneral<true, false>;
using ProlongateSharedLinear = ProlongateSharedGeneral<false, false>;
using ProlongatePiecewiseConstant = ProlongateSharedGeneral<false, true>;

enum class MGProlongationType {Constant, Linear, Quadratic, Kwak};
enum class MGProlongationType { Constant, Linear, Quadratic, Kwak };

template <MGProlongationType type>
struct ProlongateSharedMG {
Expand All @@ -299,15 +299,15 @@ struct ProlongateSharedMG {
}

KOKKOS_FORCEINLINE_FUNCTION
static Real QuadraticFactor(int d) {
static Real QuadraticFactor(int d) {
if (d == 0) return 1.0; // Indicates this dimension is not included
if (d == 1 || d == -1) return 30.0 / 32.0;
if (d == 3 || d == -3) return 5.0 / 32.0;
return -3.0 / 32.0;
}

KOKKOS_FORCEINLINE_FUNCTION
static Real LinearFactor(int d, bool up_bound, bool lo_bound) {
static Real LinearFactor(int d, bool up_bound, bool lo_bound) {
if (d == 0) return 1.0; // Indicates this dimension is not included
if (d == 1) return (2.0 + !up_bound) / 4.0;
if (d == -1) return (2.0 + !lo_bound) / 4.0;
Expand All @@ -317,12 +317,12 @@ struct ProlongateSharedMG {
}

KOKKOS_FORCEINLINE_FUNCTION
static Real ConstantFactor(int d) {
static Real ConstantFactor(int d) {
if (d == 0) return 1.0; // Indicates this dimension is not included
if (d == 1 || d == -1) return 1.0;
return 0.0;
}

template <int DIM, TopologicalElement el = TopologicalElement::CC,
TopologicalElement /*cel*/ = TopologicalElement::CC>
KOKKOS_FORCEINLINE_FUNCTION static void
Expand All @@ -341,10 +341,10 @@ struct ProlongateSharedMG {
const int fi = (DIM > 0) ? (i - cib.s) * 2 + ib.s : ib.s;
const int fj = (DIM > 1) ? (j - cjb.s) * 2 + jb.s : jb.s;
const int fk = (DIM > 2) ? (k - ckb.s) * 2 + kb.s : kb.s;
for (int fok = 0; fok < 1 + (DIM > 2); ++fok) {
for (int foj = 0; foj < 1 + (DIM > 1); ++foj) {
for (int foi = 0; foi < 1 + (DIM > 0); ++foi) {

for (int fok = 0; fok < 1 + (DIM > 2); ++fok) {
for (int foj = 0; foj < 1 + (DIM > 1); ++foj) {
for (int foi = 0; foi < 1 + (DIM > 0); ++foi) {
auto &f = fine(element_idx, l, m, n, fk + fok, fj + foj, fi + foi);
f = 0.0;
const bool lo_bound_x = (fi == ib.s);
Expand All @@ -356,23 +356,27 @@ struct ProlongateSharedMG {
for (int ok = -(DIM > 2); ok < 1 + (DIM > 2); ++ok) {
for (int oj = -(DIM > 1); oj < 1 + (DIM > 1); ++oj) {
for (int oi = -(DIM > 0); oi < 1 + (DIM > 0); ++oi) {
const int dx = 4 * oi - foi + 1;
const int dy = (DIM > 1) ? 4 * oj - foj + 1 : 0;
const int dz = (DIM > 2) ? 4 * ok - fok + 1 : 0;
const int dx = 4 * oi - foi + 1;
const int dy = (DIM > 1) ? 4 * oj - foj + 1 : 0;
const int dz = (DIM > 2) ? 4 * ok - fok + 1 : 0;
if constexpr (MGProlongationType::Linear == type) {
f += LinearFactor(dx, lo_bound_x, up_bound_x)
* LinearFactor(dy, lo_bound_y, up_bound_y)
* LinearFactor(dz, lo_bound_z, up_bound_z)
* coarse(element_idx, l, m, n, k + ok, j + oj, i + oi);
f += LinearFactor(dx, lo_bound_x, up_bound_x) *
LinearFactor(dy, lo_bound_y, up_bound_y) *
LinearFactor(dz, lo_bound_z, up_bound_z) *
coarse(element_idx, l, m, n, k + ok, j + oj, i + oi);
} else if constexpr (MGProlongationType::Kwak == type) {
const Real fac = ((dx <= 1) + (dy <= 1 && DIM > 1) + (dz <=1 && DIM > 2)) / (2.0 * DIM);
f += fac * coarse(element_idx, l, m, n, k + ok, j + oj, i + oi);
} else if constexpr(MGProlongationType::Quadratic == type) {
f += QuadraticFactor(dx) * QuadraticFactor(dy) * QuadraticFactor(dz) * coarse(element_idx, l, m, n, k + ok, j + oj, i + oi);
} else {
f += ConstantFactor(dx) * ConstantFactor(dy) * ConstantFactor(dz) * coarse(element_idx, l, m, n, k + ok, j + oj, i + oi);
const Real fac =
((dx <= 1) + (dy <= 1 && DIM > 1) + (dz <= 1 && DIM > 2)) /
(2.0 * DIM);
f += fac * coarse(element_idx, l, m, n, k + ok, j + oj, i + oi);
} else if constexpr (MGProlongationType::Quadratic == type) {
f += QuadraticFactor(dx) * QuadraticFactor(dy) * QuadraticFactor(dz) *
coarse(element_idx, l, m, n, k + ok, j + oj, i + oi);
} else {
f += ConstantFactor(dx) * ConstantFactor(dy) * ConstantFactor(dz) *
coarse(element_idx, l, m, n, k + ok, j + oj, i + oi);
}
}
}
}
}
}
Expand Down
35 changes: 17 additions & 18 deletions src/solvers/cg_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#ifndef SOLVERS_CG_SOLVER_HPP_
#define SOLVERS_CG_SOLVER_HPP_

#include <cstdio>
#include <limits>
#include <memory>
#include <string>
#include <utility>
Expand Down Expand Up @@ -84,9 +86,8 @@ class CGSolver {
return names;
}

CGSolver(StateDescriptor *pkg, CGParams params_in,
equations eq_in = equations(), std::vector<int> shape = {},
const std::string &container = "base")
CGSolver(StateDescriptor *pkg, CGParams params_in, equations eq_in = equations(),
std::vector<int> shape = {}, const std::string &container = "base")
: preconditioner(pkg, params_in.mg_params, eq_in, shape, container),
params_(params_in), iter_counter(0), eqs_(eq_in), container_(container) {
using namespace refinement_ops;
Expand Down Expand Up @@ -124,8 +125,7 @@ class CGSolver {
get_rhs2 = DotProduct<rhs, rhs>(dependence, tl, &rhs2, md);
auto initialize = tl.AddTask(
TaskQualifier::once_per_region | TaskQualifier::local_sync,
zero_u | zero_v | zero_x | zero_p | copy_r | get_rhs2,
"zero factors",
zero_u | zero_v | zero_x | zero_p | copy_r | get_rhs2, "zero factors",
[](CGSolver *solver) {
solver->iter_counter = -1;
solver->ru.val = std::numeric_limits<Real>::max();
Expand All @@ -136,8 +136,7 @@ class CGSolver {
if (params_.print_per_step && Globals::my_rank == 0) {
initialize = tl.AddTask(
TaskQualifier::once_per_region, initialize, "print to screen",
[&](CGSolver *solver, std::shared_ptr<Real> res_tol,
bool relative_residual) {
[&](CGSolver *solver, std::shared_ptr<Real> res_tol, bool relative_residual) {
Real tol =
relative_residual
? *res_tol * std::sqrt(solver->rhs2.val / pmesh->GetTotalCells())
Expand Down Expand Up @@ -175,8 +174,8 @@ class CGSolver {
}

// 2. beta <- r dot u / r dot u {old}
auto get_ru = DotProduct<r, u>(precon, itl, &ru, md);
auto get_ru = DotProduct<r, u>(precon, itl, &ru, md);

// 3. p <- u + beta p
auto correct_p = itl.AddTask(
get_ru, "p <- u + beta p",
Expand All @@ -185,26 +184,26 @@ class CGSolver {
return AddFieldsAndStore<u, p, p>(md, 1.0, beta);
},
this, md);

// 4. v <- A p
auto copy_u = itl.AddTask(correct_p, TF(CopyData<p, u>), md);
auto comm =
AddBoundaryExchangeTasks<BoundaryType::any>(copy_u, itl, md_comm, multilevel);
auto get_v = eqs_.template Ax<u, v>(itl, comm, md);

// 5. alpha <- r dot u / p dot v (calculate denominator)
// 5. alpha <- r dot u / p dot v (calculate denominator)
auto get_pAp = DotProduct<p, v>(get_v, itl, &pAp, md);

// 6. x <- x + alpha p
// 6. x <- x + alpha p
auto correct_x = itl.AddTask(
get_pAp, "x <- x + alpha p",
[](CGSolver *solver, std::shared_ptr<MeshData<Real>> &md) {
Real alpha = solver->ru.val / solver->pAp.val;
return AddFieldsAndStore<x, p, x>(md, 1.0, alpha);
},
this, md);
// 6. r <- r - alpha A p
this, md);

// 6. r <- r - alpha A p
auto correct_r = itl.AddTask(
get_pAp, "r <- r - alpha A p",
[](CGSolver *solver, std::shared_ptr<MeshData<Real>> &md) {
Expand All @@ -225,11 +224,11 @@ class CGSolver {
return TaskStatus::complete;
},
this, pmesh);

auto check = itl.AddTask(
TaskQualifier::completion, get_res | correct_x, "completion",
[](CGSolver *solver, Mesh *pmesh, int max_iter,
std::shared_ptr<Real> res_tol, bool relative_residual) {
[](CGSolver *solver, Mesh *pmesh, int max_iter, std::shared_ptr<Real> res_tol,
bool relative_residual) {
Real rms_res = std::sqrt(solver->residual.val / pmesh->GetTotalCells());
solver->final_residual = rms_res;
solver->final_iteration = solver->iter_counter;
Expand Down
14 changes: 9 additions & 5 deletions src/solvers/mg_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,19 @@ class MGSolver {
Metadata({Metadata::Cell, Metadata::Independent, Metadata::GMGRestrict,
Metadata::GMGProlongate, Metadata::OneCopy},
shape);

if (params_.prolongation == "Linear") {
mres_err.RegisterRefinementOps<ProlongateSharedMG<MGProlongationType::Linear>, RestrictAverage>();
mres_err.RegisterRefinementOps<ProlongateSharedMG<MGProlongationType::Linear>,
RestrictAverage>();
} else if (params_.prolongation == "Kwak") {
mres_err.RegisterRefinementOps<ProlongateSharedMG<MGProlongationType::Kwak>, RestrictAverage>();
mres_err.RegisterRefinementOps<ProlongateSharedMG<MGProlongationType::Kwak>,
RestrictAverage>();
} else if (params_.prolongation == "Quadratic") {
mres_err.RegisterRefinementOps<ProlongateSharedMG<MGProlongationType::Quadratic>, RestrictAverage>();
mres_err.RegisterRefinementOps<ProlongateSharedMG<MGProlongationType::Quadratic>,
RestrictAverage>();
} else if (params_.prolongation == "Constant") {
mres_err.RegisterRefinementOps<ProlongateSharedMG<MGProlongationType::Constant>, RestrictAverage>();
mres_err.RegisterRefinementOps<ProlongateSharedMG<MGProlongationType::Constant>,
RestrictAverage>();
} else if (params_.prolongation == "OldLinear") {
mres_err.RegisterRefinementOps<ProlongateSharedLinear, RestrictAverage>();
} else {
Expand Down

0 comments on commit 3c81e68

Please sign in to comment.