diff --git a/non-linear-solver-spec.json b/non-linear-solver-spec.json index 600ec74..ba418a4 100644 --- a/non-linear-solver-spec.json +++ b/non-linear-solver-spec.json @@ -16,6 +16,7 @@ "LBFGSB", "Newton", "ADAM", + "StochasticADAM", "StochasticGradientDescent", "box_constraints", "advanced" @@ -31,6 +32,7 @@ "DenseNewton", "GradientDescent", "ADAM", + "StochasticADAM", "StochasticGradientDescent", "L-BFGS", "BFGS", @@ -204,6 +206,49 @@ "type": "float", "doc": "Parameter epsilon for ADAM." }, + { + "pointer": "/StochasticADAM", + "default": null, + "type": "object", + "optional": [ + "alpha", + "beta_1", + "beta_2", + "epsilon", + "erase_component_probability" + ], + "doc": "Options for ADAM." + }, + { + "pointer": "/StochasticADAM/alpha", + "default": 0.001, + "type": "float", + "doc": "Parameter alpha for ADAM." + }, + { + "pointer": "/StochasticADAM/beta_1", + "default": 0.9, + "type": "float", + "doc": "Parameter beta_1 for ADAM." + }, + { + "pointer": "/StochasticADAM/beta_2", + "default": 0.999, + "type": "float", + "doc": "Parameter beta_2 for ADAM." + }, + { + "pointer": "/StochasticADAM/epsilon", + "default": 1e-8, + "type": "float", + "doc": "Parameter epsilon for ADAM." + }, + { + "pointer": "/StochasticADAM/erase_component_probability", + "default": 0.3, + "type": "float", + "doc": "Probability of erasing a component on the gradient for ADAM." + }, { "pointer": "/StochasticGradientDescent", "default": null, diff --git a/src/polysolve/nonlinear/Solver.cpp b/src/polysolve/nonlinear/Solver.cpp index ca2b00d..aa2a632 100644 --- a/src/polysolve/nonlinear/Solver.cpp +++ b/src/polysolve/nonlinear/Solver.cpp @@ -89,7 +89,14 @@ namespace polysolve::nonlinear else if (solver_name == "ADAM" || solver_name == "adam") { solver->add_strategy(std::make_unique( - solver_params, characteristic_length, logger)); + solver_params, false, characteristic_length, logger)); + } + else if (solver_name == "StochasticADAM" || solver_name == "stochastic_adam") + { + solver->add_strategy(std::make_unique( + solver_params, true, characteristic_length, logger)); + } + else if (solver_name == "StochasticGradientDescent" || solver_name == "stochastic_gradient_descent") { solver->add_strategy(std::make_unique( @@ -115,6 +122,7 @@ namespace polysolve::nonlinear "DenseNewton", "Newton", "ADAM", + "StochasticADAM", "GradientDescent", "StochasticGradientDescent", "L-BFGS"}; diff --git a/src/polysolve/nonlinear/descent_strategies/ADAM.cpp b/src/polysolve/nonlinear/descent_strategies/ADAM.cpp index 156640f..3233fd0 100644 --- a/src/polysolve/nonlinear/descent_strategies/ADAM.cpp +++ b/src/polysolve/nonlinear/descent_strategies/ADAM.cpp @@ -6,22 +6,26 @@ namespace polysolve::nonlinear { ADAM::ADAM(const json &solver_params, + const bool is_stochastic, const double characteristic_length, spdlog::logger &logger) - : Superclass(solver_params, characteristic_length, logger) + : Superclass(solver_params, characteristic_length, logger), is_stochastic_(is_stochastic) { - alpha = solver_params["ADAM"]["alpha"]; - beta_1 = solver_params["ADAM"]["beta_1"]; - beta_2 = solver_params["ADAM"]["beta_2"]; - epsilon = solver_params["ADAM"]["epsilon"]; + std::string param_name = is_stochastic ? "StochasticADAM" : "ADAM"; + alpha_ = solver_params[param_name]["alpha"]; + beta_1_ = solver_params[param_name]["beta_1"]; + beta_2_ = solver_params[param_name]["beta_2"]; + epsilon_ = solver_params[param_name]["epsilon"]; + if (is_stochastic) + erase_component_probability_ = solver_params["StochasticADAM"]["erase_component_probability"]; } void ADAM::reset(const int ndof) { Superclass::reset(ndof); - m_prev = Eigen::VectorXd::Zero(ndof); - v_prev = Eigen::VectorXd::Zero(ndof); - t = 0; + m_prev_ = Eigen::VectorXd::Zero(ndof); + v_prev_ = Eigen::VectorXd::Zero(ndof); + t_ = 0; } bool ADAM::compute_update_direction( @@ -30,24 +34,33 @@ namespace polysolve::nonlinear const TVector &grad, TVector &direction) { - if (m_prev.size() == 0) - m_prev = Eigen::VectorXd::Zero(x.size()); - if (v_prev.size() == 0) - v_prev = Eigen::VectorXd::Zero(x.size()); + if (m_prev_.size() == 0) + m_prev_ = Eigen::VectorXd::Zero(x.size()); + if (v_prev_.size() == 0) + v_prev_ = Eigen::VectorXd::Zero(x.size()); - TVector m = (beta_1 * m_prev) + ((1 - beta_1) * grad); - TVector v = beta_2 * v_prev; + TVector grad_modified = grad; + + if (is_stochastic_) + { + Eigen::VectorXd mask = (Eigen::VectorXd::Random(direction.size()).array() + 1.) / 2.; + for (int i = 0; i < direction.size(); ++i) + grad_modified(i) *= (mask(i) < erase_component_probability_) ? 0. : 1.; + } + + TVector m = (beta_1_ * m_prev_) + ((1 - beta_1_) * grad_modified); + TVector v = beta_2_ * v_prev_; for (int i = 0; i < v.size(); ++i) - v(i) += (1 - beta_2) * grad(i) * grad(i); + v(i) += (1 - beta_2_) * grad_modified(i) * grad_modified(i); - m = m.array() / (1 - pow(beta_1, t)); - v = v.array() / (1 - pow(beta_2, t)); + m = m.array() / (1 - pow(beta_1_, t_)); + v = v.array() / (1 - pow(beta_2_, t_)); - direction = -alpha * m; + direction = -alpha_ * m; for (int i = 0; i < v.size(); ++i) - direction(i) /= sqrt(v(i) + epsilon); + direction(i) /= sqrt(v(i) + epsilon_); - ++t; + ++t_; return true; } diff --git a/src/polysolve/nonlinear/descent_strategies/ADAM.hpp b/src/polysolve/nonlinear/descent_strategies/ADAM.hpp index a86cb1c..b047aea 100644 --- a/src/polysolve/nonlinear/descent_strategies/ADAM.hpp +++ b/src/polysolve/nonlinear/descent_strategies/ADAM.hpp @@ -13,10 +13,11 @@ namespace polysolve::nonlinear using Superclass = DescentStrategy; ADAM(const json &solver_params, + const bool is_stochastic, const double characteristic_length, spdlog::logger &logger); - std::string name() const override { return "ADAM"; } + std::string name() const override { return is_stochastic_ ? "StochasticADAM" : "ADAM"; } void reset(const int ndof) override; @@ -26,14 +27,19 @@ namespace polysolve::nonlinear const TVector &grad, TVector &direction) override; + bool is_direction_descent() override { return false; } + private: - TVector m_prev; - TVector v_prev; + TVector m_prev_; + TVector v_prev_; + + double beta_1_, beta_2_; + double alpha_; - double beta_1, beta_2; - double alpha; + int t_ = 0; + double epsilon_; - int t = 0; - double epsilon; + bool is_stochastic_; + double erase_component_probability_ = 0; }; } // namespace polysolve::nonlinear