Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameter management utilities #163

Draft
wants to merge 5 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions resolve/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ add_subdirectory(utilities)
# C++ files
set(ReSolve_SRC
LinSolver.cpp
LinSolverDirect.cpp
LinSolverIterative.cpp
GramSchmidt.cpp
LinSolverIterativeFGMRES.cpp
LinSolverDirectCpuILU0.cpp
Expand All @@ -36,23 +38,25 @@ set(ReSolve_CUDASDK_SRC

# C++ code that links to ROCm libraries
set(ReSolve_ROCM_SRC
LinSolverDirectRocSolverRf.cpp
LinSolverDirectRocSparseILU0.cpp
LinSolverDirectRocSolverRf.cpp
LinSolverDirectRocSparseILU0.cpp
)

# Header files to be installed
set(ReSolve_HEADER_INSTALL
Common.hpp
cusolver_defs.hpp
LinSolver.hpp
LinSolverDirect.hpp
LinSolverIterative.hpp
LinSolverIterativeFGMRES.hpp
LinSolverDirectCpuILU0.hpp
SystemSolver.hpp
GramSchmidt.hpp
MemoryUtils.hpp)

set(ReSolve_KLU_HEADER_INSTALL
LinSolverDirectKLU.hpp
LinSolverDirectKLU.hpp
)

set(ReSolve_LUSOL_HEADER_INSTALL
Expand Down
6 changes: 1 addition & 5 deletions resolve/Common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ namespace ReSolve {
constexpr double EPSMAC = 1.0e-16;


// TODO: let cmake manage these. combined with the todo above relating to cstdint
// this is related to resolve/lusol/lusol_precision.f90. whatever is here should
// have an equivalent there

// NOTE: i'd love to make this std::float64_t but we're not on c++23
/// @todo Provide CMake option to se these types at config time
using real_type = double;
using index_type = std::int32_t;

Expand Down
176 changes: 8 additions & 168 deletions resolve/LinSolver.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
/**
* @file LinSolverIterative.cpp
* @author Kasia Swirydowicz (kasia.swirydowicz@pnnl.gov)
* @author Slaven Peles (peless@ornl.gov)
* @brief Implementation of linear solver base class.
*
*/

#include <resolve/matrix/Sparse.hpp>
#include <resolve/utilities/logger/Logger.hpp>

Expand All @@ -23,174 +31,6 @@ namespace ReSolve
return 1.0;
}

//
// Direct solver methods implementations
//

LinSolverDirect::LinSolverDirect()
{
L_ = nullptr;
U_ = nullptr;
P_ = nullptr;
Q_ = nullptr;
}

LinSolverDirect::~LinSolverDirect()
{
}

int LinSolverDirect::setup(matrix::Sparse* A,
matrix::Sparse* /* L */,
matrix::Sparse* /* U */,
index_type* /* P */,
index_type* /* Q */,
vector_type* /* rhs */)
{
if (A == nullptr) {
return 1;
}
this->A_ = A;
return 0;
}

int LinSolverDirect::analyze()
{
return 1;
} //the same as symbolic factorization

int LinSolverDirect::factorize()
{
return 1;
}

int LinSolverDirect::refactorize()
{
return 1;
}

matrix::Sparse* LinSolverDirect::getLFactor()
{
return nullptr;
}

matrix::Sparse* LinSolverDirect::getUFactor()
{
return nullptr;
}

index_type* LinSolverDirect::getPOrdering()
{
return nullptr;
}

index_type* LinSolverDirect::getQOrdering()
{
return nullptr;
}

void LinSolverDirect::setPivotThreshold(real_type tol)
{
pivot_threshold_tol_ = tol;
}

void LinSolverDirect::setOrdering(int ordering)
{
ordering_ = ordering;
}

void LinSolverDirect::setHaltIfSingular(bool is_halt)
{
halt_if_singular_ = is_halt;
}

real_type LinSolverDirect::getMatrixConditionNumber()
{
out::error() << "Solver does not implement returning system matrix condition number.\n";
return -1.0;
}

//
// Iterative solver methods implementations
//

LinSolverIterative::LinSolverIterative()
{
}

LinSolverIterative::~LinSolverIterative()
{
}

int LinSolverIterative::setup(matrix::Sparse* A)
{
if (A == nullptr) {
return 1;
}
this->A_ = A;
return 0;
}

real_type LinSolverIterative::getFinalResidualNorm() const
{
return final_residual_norm_;
}

real_type LinSolverIterative::getInitResidualNorm() const
{
return initial_residual_norm_;
}

index_type LinSolverIterative::getNumIter() const
{
return total_iters_;
}


real_type LinSolverIterative::getTol()
{
return tol_;
}

index_type LinSolverIterative::getMaxit()
{
return maxit_;
}

index_type LinSolverIterative::getRestart()
{
return restart_;
}

index_type LinSolverIterative::getConvCond()
{
return conv_cond_;
}

bool LinSolverIterative::getFlexible()
{
return flexible_;
}

int LinSolverIterative::setOrthogonalization(GramSchmidt* /* gs */)
{
out::error() << "Solver does not implement setting orthogonalization.\n";
return 1;
}

void LinSolverIterative::setTol(real_type new_tol)
{
this->tol_ = new_tol;
}

void LinSolverIterative::setMaxit(index_type new_maxit)
{
this->maxit_ = new_maxit;
}

void LinSolverIterative::setConvCond(index_type new_conv_cond)
{
this->conv_cond_ = new_conv_cond;
}
}


Expand Down
132 changes: 48 additions & 84 deletions resolve/LinSolver.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
/**
* @file LinSolverIterative.hpp
* @author Kasia Swirydowicz (kasia.swirydowicz@pnnl.gov)
* @author Slaven Peles (peless@ornl.gov)
* @brief Declaration of linear solver base class.
*
*/
#pragma once

#include <map>
#include <string>

#include "Common.hpp"

namespace ReSolve
Expand All @@ -21,9 +31,13 @@ namespace ReSolve

// Forward declaration of MatrixHandler class
class MatrixHandler;

class GramSchmidt;

class SolverParameters;

/**
* @brief Base class for all linear solvers.
*
*/
class LinSolver
{
protected:
Expand All @@ -34,94 +48,44 @@ namespace ReSolve
virtual ~LinSolver();

real_type evaluateResidual();

virtual int setCliParam(const std::string /* id */, const std::string /* value */)
{
return 1;
}
pelesh marked this conversation as resolved.
Show resolved Hide resolved

virtual int getCliParam(const std::string /* id */, std::string& /* value */)
{
return 1;
}

virtual int getCliParam(const std::string /* id */, index_type& /* value */)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these functions be pure virtual? This is possible if and only if they are always overridden in derived classes; otherwise a default implementation like this is required.

{
return 1;
}

virtual int getCliParam(const std::string /* id */, real_type& /* value */)
{
return 1;
}

virtual int getCliParam(const std::string /* id */, bool& /* value */)
{
return 1;
}

virtual int printCliParam(const std::string /* id */)
{
return 1;
}

protected:
matrix::Sparse* A_{nullptr};
real_type* rhs_{nullptr};
real_type* sol_{nullptr};

MatrixHandler* matrix_handler_{nullptr};
VectorHandler* vector_handler_{nullptr};
};

class LinSolverDirect : public LinSolver
{
public:
LinSolverDirect();
virtual ~LinSolverDirect();
virtual int setup(matrix::Sparse* A = nullptr,
matrix::Sparse* L = nullptr,
matrix::Sparse* U = nullptr,
index_type* P = nullptr,
index_type* Q = nullptr,
vector_type* rhs = nullptr);

virtual int analyze(); //the same as symbolic factorization
virtual int factorize();
virtual int refactorize();
virtual int solve(vector_type* rhs, vector_type* x) = 0;
virtual int solve(vector_type* x) = 0;

virtual matrix::Sparse* getLFactor();
virtual matrix::Sparse* getUFactor();
virtual index_type* getPOrdering();
virtual index_type* getQOrdering();

virtual void setPivotThreshold(real_type tol);
virtual void setOrdering(int ordering);
virtual void setHaltIfSingular(bool is_halt);

virtual real_type getMatrixConditionNumber();

protected:
matrix::Sparse* L_{nullptr};
matrix::Sparse* U_{nullptr};
index_type* P_{nullptr};
index_type* Q_{nullptr};

int ordering_{1}; // 0 = AMD, 1 = COLAMD, 2 = user provided P, Q
real_type pivot_threshold_tol_{0.1};
bool halt_if_singular_{false};
std::map<std::string, int> params_list_;
};

class LinSolverIterative : public LinSolver
{
public:
LinSolverIterative();
virtual ~LinSolverIterative();
virtual int setup(matrix::Sparse* A);
virtual int resetMatrix(matrix::Sparse* A) = 0;
virtual int setupPreconditioner(std::string type, LinSolverDirect* LU_solver) = 0;

virtual int solve(vector_type* rhs, vector_type* init_guess) = 0;

virtual real_type getFinalResidualNorm() const;
virtual real_type getInitResidualNorm() const;
virtual index_type getNumIter() const;

virtual int setOrthogonalization(GramSchmidt* gs);

real_type getTol();
index_type getMaxit();
index_type getRestart();
index_type getConvCond();
bool getFlexible();

void setTol(real_type new_tol);
void setMaxit(index_type new_maxit);
virtual int setRestart(index_type new_restart) = 0;
void setConvCond(index_type new_conv_cond);
virtual int setFlexible(bool new_flexible) = 0;

protected:
real_type initial_residual_norm_;
real_type final_residual_norm_;
index_type total_iters_;

real_type tol_{1e-14};
index_type maxit_{100};
index_type restart_{10};
index_type conv_cond_{0};
bool flexible_{true}; // if can be run as "normal" GMRES if needed, set flexible_ to false. Default is true of course.
};
}
} // namespace ReSolve
Loading