Skip to content

Commit

Permalink
Merge pull request #56 from arvigj/stochastic_gd
Browse files Browse the repository at this point in the history
Add stochastic gradient descent
  • Loading branch information
arvigj authored Dec 8, 2023
2 parents 2c79268 + 1d16421 commit e9bb718
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 3 deletions.
17 changes: 17 additions & 0 deletions non-linear-solver-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"LBFGS",
"LBFGSB",
"Newton",
"StochasticGradientDescent",
"box_constraints",
"advanced"
],
Expand All @@ -28,6 +29,7 @@
"Newton",
"DenseNewton",
"GradientDescent",
"StochasticGradientDescent",
"L-BFGS",
"BFGS",
"L-BFGS-B",
Expand Down Expand Up @@ -164,6 +166,21 @@
"type": "bool",
"doc": "Use PSD as fallback using second order solvers (i.e., Newton's method)."
},
{
"pointer": "/StochasticGradientDescent",
"default": null,
"type": "object",
"optional": [
"erase_component_probability"
],
"doc": "Options for Stochastic Gradient Descent."
},
{
"pointer": "/StochasticGradientDescent/erase_component_probability",
"default": 0.3,
"type": "float",
"doc": "Probability of erasing a component on the gradient for StochasticGradientDescent."
},
{
"pointer": "/line_search",
"default": null,
Expand Down
8 changes: 7 additions & 1 deletion src/polysolve/nonlinear/Solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ namespace polysolve::nonlinear
solver->add_strategy(std::make_unique<LBFGS>(
solver_params, characteristic_length, logger));
}
else if (solver_name == "StochasticGradientDescent" || solver_name == "stochastic_gradient_descent")
{
solver->add_strategy(std::make_unique<GradientDescent>(
solver_params, true, characteristic_length, logger));
}
else if (solver_name == "GradientDescent" || solver_name == "gradient_descent")
{
// grad descent always there
Expand All @@ -93,7 +98,7 @@ namespace polysolve::nonlinear
throw std::runtime_error("Unrecognized solver type: " + solver_name);

solver->add_strategy(std::make_unique<GradientDescent>(
solver_params, characteristic_length, logger));
solver_params, false, characteristic_length, logger));

solver->set_strategies_iterations(solver_params);
return solver;
Expand All @@ -105,6 +110,7 @@ namespace polysolve::nonlinear
"DenseNewton",
"Newton",
"GradientDescent",
"StochasticGradientDescent",
"L-BFGS"};
}

Expand Down
12 changes: 11 additions & 1 deletion src/polysolve/nonlinear/descent_strategies/GradientDescent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ namespace polysolve::nonlinear
{

GradientDescent::GradientDescent(const json &solver_params_,
const bool is_stochastic,
const double characteristic_length,
spdlog::logger &logger)
: Superclass(solver_params_, characteristic_length, logger)
: Superclass(solver_params_, characteristic_length, logger), is_stochastic_(is_stochastic)
{
if (is_stochastic_)
erase_component_probability_ = solver_params_["StochasticGradientDescent"]["erase_component_probability"];
}

bool GradientDescent::compute_update_direction(
Expand All @@ -18,6 +21,13 @@ namespace polysolve::nonlinear
{
direction = -grad;

if (is_stochastic_)
{
Eigen::VectorXd mask = (Eigen::VectorXd::Random(direction.size()).array() + 1.) / 2.;
for (int i = 0; i < direction.size(); ++i)
direction(i) *= (mask(i) < erase_component_probability_) ? 0. : 1.;
}

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@ namespace polysolve::nonlinear
using Superclass = DescentStrategy;

GradientDescent(const json &solver_params_,
const bool is_stochastic,
const double characteristic_length,
spdlog::logger &logger);

std::string name() const override { return "GradientDescent"; }
std::string name() const override { return is_stochastic_ ? "StochasticGradientDescent" : "GradientDescent"; }

bool compute_update_direction(
Problem &objFunc,
const TVector &x,
const TVector &grad,
TVector &direction) override;

private:
bool is_stochastic_ = false;
double erase_component_probability_ = 0;
};
} // namespace polysolve::nonlinear

0 comments on commit e9bb718

Please sign in to comment.