From f8800aa9eac370f746c34f7abc5f97adc7997ddf Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Thu, 7 Nov 2024 11:11:07 -0800 Subject: [PATCH] Revisiting linesearches and LBFGS. For backtracking linesearch: - Add debugging option for backtracking_linesearch. - Add info entry in BacktrackingLinesearchState to potentially help debugging by looking at outputs (could be useful for example in vmap setting, and mimics the setup for the zoom linesearch). - Adding mechanism to prevent the linesearch to make a step if that would end up getting NaNs or infinite values in the function. For zoom_linesearch: - Simplifies a bit the debugging information for the zoom linesearch and added prints of some relevant values for debugging. - Added a note in the zoom linesearch that using curv_tol=inf, would let this method make an efficient alternative to the backtracking linesearch using polynomial interpolation strategies. - Most importantly, added an option to define the initial guess for the linesearch. Looking up Nocedal and Wright, this initial guess should always be one for Newton or quasi-Newton methods. Could be refined for other methods (for now, for such other methods like gradient descent, we may simply keep the previous learning rate). This largely improved the performance in the public notebook. For lbfgs: - Use clipped gradient step for the very first step (when scale_init_precond=True). The scale of the preconditioner for the very first iteration is not detailed anywhere in the literature I've seen. But using such clipped gradient step ensures to capture approximately the right scale. This made for example one of the tests pass without any further modifications of the default hyperparameters of the objective. - Revised the notebook in view fo these changes. Added some tips and an example of benchmark. PiperOrigin-RevId: 694183028 --- examples/lbfgs.ipynb | 168 +++++++++++++++++++++++---- optax/_src/alias.py | 22 +++- optax/_src/alias_test.py | 39 +++---- optax/_src/linesearch.py | 243 +++++++++++++++++++++++++++------------ optax/_src/transform.py | 18 ++- 5 files changed, 369 insertions(+), 121 deletions(-) diff --git a/examples/lbfgs.ipynb b/examples/lbfgs.ipynb index 0fbe5819..eac62488 100644 --- a/examples/lbfgs.ipynb +++ b/examples/lbfgs.ipynb @@ -11,7 +11,7 @@ "L-BFGS is a classical optimization method that uses past gradients and parameters information to iteratively refine a solution to a minimization problem. In this notebook, we illustrate\n", "1. how to use L-BFGS as a simple gradient transformation,\n", "2. how to wrap L-BFGS in a solver, and how linesearches are incorporated,\n", - "3. how to debug the solver if needed,\n" + "3. how to debug the solver if needed.\n" ] }, { @@ -146,7 +146,7 @@ "\n", "where $c_1$ is some constant set to $10^{-4}$ by default. Consider for example the update direction to be $u_k = -g_k$, i.e., moving along the negative gradient direction. In that case the criterion above reduces to $f(w_k - \\eta_k g_k) \\leq f(w_k) - c_1 \\eta_k ||g_k||_2^2$. The criterion amounts then to choosing the stepsize such that it decreases the objective by an amount proportional to the squared gradient norm.\n", "\n", - "As long as the update direction is a *descent direction*, that is, $\\langle u_k, g_k\\rangle < 0$ the above criterion is guaranteed to be satisfied by some sufficiently small stepsize.\n", + "As long as the update direction is a *descent direction*, that is, $\\langle u_k, g_k\\rangle \u003c 0$ the above criterion is guaranteed to be satisfied by some sufficiently small stepsize.\n", "A simple linesearch technique to ensure a sufficient decrease is then to decrease a candidate stepsize by a constant factor up until the criterion is satisfied. This amounts to the backtracking linesearch implemented in {py:func}`optax.scale_by_backtracking_linesearch` and briefly reviewed below.\n", "\n", "#### Small curvature (Strong wolfe criterion)\n", @@ -286,7 +286,7 @@ }, "outputs": [], "source": [ - "def run_lbfgs(init_params, fun, opt, max_iter, tol):\n", + "def run_opt(init_params, fun, opt, max_iter, tol):\n", " value_and_grad_fun = optax.value_and_grad_from_state(fun)\n", "\n", " def step(carry):\n", @@ -303,7 +303,7 @@ " iter_num = otu.tree_get(state, 'count')\n", " grad = otu.tree_get(state, 'grad')\n", " err = otu.tree_l2_norm(grad)\n", - " return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))\n", + " return (iter_num == 0) | ((iter_num \u003c max_iter) \u0026 (err \u003e= tol))\n", "\n", " init_carry = (init_params, opt.init(init_params))\n", " final_params, final_state = jax.lax.while_loop(\n", @@ -338,7 +338,7 @@ " f'Initial value: {fun(init_params):.2e} '\n", " f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'\n", ")\n", - "final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)\n", + "final_params, _ = run_opt(init_params, fun, opt, max_iter=100, tol=1e-3)\n", "print(\n", " f'Final value: {fun(final_params):.2e}, '\n", " f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'\n", @@ -395,7 +395,7 @@ " f'Initial value: {fun(init_params):.2e} '\n", " f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'\n", ")\n", - "final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)\n", + "final_params, _ = run_opt(init_params, fun, opt, max_iter=100, tol=1e-3)\n", "print(\n", " f'Final value: {fun(final_params):.2e}, '\n", " f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'\n", @@ -408,9 +408,8 @@ "id": "KZIu7UDveO6D" }, "source": [ - "## Debugging solver\n", - "\n", - "In some cases, L-BFGS with a linesearch as a solver will fail. Most of the times, the culprit goes down to the linesearch. To debug the solver in such cases, we provide a `verbose` option to the `optax.scale_by_zoom_linesearch`. We show below how to proceed." + "## Debugging\n", + "\n" ] }, { @@ -419,7 +418,11 @@ "id": "LV8CslWpoDDq" }, "source": [ - "First we try to minimize the [Zakharov function](https://www.sfu.ca/~ssurjano/zakharov.html) without any changes. You'll observe that the final value is larger than the initial value which points out that the solver failed, and probably because the linesearch did not find a stepsize that ensured a sufficient decrease." + "### Accessing debug information\n", + "\n", + "In some cases, L-BFGS with a linesearch as a solver will fail. Most of the times, the culprit goes down to the linesearch. To debug the solver in such cases, we provide a `verbose` option to the `optax.scale_by_zoom_linesearch`. We show below how to proceed.\n", + "\n", + "To demonstrate such bug, we try to minimize the [Zakharov function](https://www.sfu.ca/~ssurjano/zakharov.html) and set the `scale_init_precond` option to `False` (by choosing the default option `scale_init_precond=True`, the algorithm would actually run fine, we just want to showcase the possibility to use debugging in the linesearch here). You'll observe that the final value is is the same as the initial value which points out that the solver failed." ] }, { @@ -436,14 +439,14 @@ " sum2 = (0.5 * ii * w).sum()\n", " return sum1 + sum2**2 + sum2**4\n", "\n", - "opt = optax.chain(print_info(), optax.lbfgs())\n", + "opt = optax.lbfgs(scale_init_precond=False)\n", "\n", "init_params = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])\n", "print(\n", " f'Initial value: {fun(init_params)} '\n", " f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params))}'\n", ")\n", - "final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)\n", + "final_params, _ = run_opt(init_params, fun, opt, max_iter=50, tol=1e-3)\n", "print(\n", " f'Final value: {fun(final_params)}, '\n", " f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params))}'\n", @@ -456,7 +459,11 @@ "id": "uwcbY5UXohZB" }, "source": [ - "We can change the linesearch used in lbfgs as part of its arguments. Here we keep the default number of linesearch steps (15) and set the verbose option to `True`." + "The default implementation of the linesearch in the code is\n", + "```\n", + "scale_by_zoom_linesearch(max_linesearch_steps=20, initial_guess_strategy='one')\n", + "```\n", + "To debug we can set the verbose option of the linesearch to `True`." ] }, { @@ -467,9 +474,9 @@ }, "outputs": [], "source": [ - "opt = optax.chain(print_info(), optax.lbfgs(\n", + "opt = optax.chain(print_info(), optax.lbfgs(scale_init_precond=False,\n", " linesearch=optax.scale_by_zoom_linesearch(\n", - " max_linesearch_steps=15, verbose=True\n", + " max_linesearch_steps=20, verbose=True, initial_guess_strategy='one'\n", " )\n", "))\n", "\n", @@ -478,7 +485,7 @@ " f'Initial value: {fun(init_params):.2e} '\n", " f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'\n", ")\n", - "final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)\n", + "final_params, _ = run_opt(init_params, fun, opt, max_iter=100, tol=1e-3)\n", "print(\n", " f'Final value: {fun(final_params):.2e}, '\n", " f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'\n", @@ -491,7 +498,7 @@ "id": "nCgpjzCbo7p9" }, "source": [ - "As expected, the linesearch failed at the very first step taking a stepsize that did not ensure a sufficient decrease. Multiple information is displayed. For example, the slope (derivative along the update direction) at the first step si extremely large which explains the difficulties to find an appropriate stepsize. As pointed out in the log above, the first thing to try is to use a larger number of linesearch steps." + "As expected, the linesearch failed at the very first step taking a stepsize that did not ensure a sufficient decrease. Multiple information is displayed. For example, the slope (derivative along the update direction) at the first step is extremely large which explains the difficulties to find an appropriate stepsize. As pointed out in the log above, the first thing to try is to use a larger number of linesearch steps." ] }, { @@ -502,9 +509,9 @@ }, "outputs": [], "source": [ - "opt = optax.chain(print_info(), optax.lbfgs(\n", + "opt = optax.chain(print_info(), optax.lbfgs(scale_init_precond=False,\n", " linesearch=optax.scale_by_zoom_linesearch(\n", - " max_linesearch_steps=50, verbose=True\n", + " max_linesearch_steps=50, verbose=True, initial_guess_strategy='one'\n", " )\n", "))\n", "\n", @@ -513,7 +520,7 @@ " f'Initial value: {fun(init_params):.2e} '\n", " f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'\n", ")\n", - "final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)\n", + "final_params, _ = run_opt(init_params, fun, opt, max_iter=100, tol=1e-3)\n", "print(\n", " f'Final value: {fun(final_params):.2e}, '\n", " f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'\n", @@ -526,9 +533,128 @@ "id": "na-7s1Q2o1Rc" }, "source": [ - "By simply taking a maximum of 50 steps of the linesearch instead of 15, we ensured that the first stepsize taken provided a sufficient decrease and the solver worked well.\n", + "By simply taking a maximum of 50 steps of the linesearch instead of 20, we ensured that the first stepsize taken provided a sufficient decrease and the solver worked well.\n", "Additional debugging information can be found in the source code accessible from the docs of {py:func}`optax.scale_by_zoom_linesearch`." ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "74ZbgzcKoJ0J" + }, + "source": [ + "### Tips\n", + "\n", + "- **LBFGS**\n", + " - Selecting a higher `memory_size` in lbfgs may improve performance at a memory and computational cost. No real gains may be perceived after some value.\n", + " - `scale_init_precond=True` is standard. It captures a similar scale as other well-known optimization methods like Barzilai Borwein.\n", + "\n", + "- **Zoom linesearch**\n", + " - Remember there are two conditions to be met (sufficient decrease and small curvature). If the algorithm takes too many linesearch steps, you may try\n", + " setting `curv_rtol = jnp.inf`, effectively ignoring the small curvature condition. The resulting algorithm will essentially perform a backtracking linesearch using interpolation mechanisms to search for a valid stepsize (so that would be a potentially faster algorithm than the current implementation of `scale_by_backtracking_linesearch`).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "T-oGa3P2sCbH" + }, + "source": [ + "## Contributing and benchmarking\n", + "\n", + "Numerous other linesearch could be implemented, as well as other solvers for medium scale problems without stochasticity. Contributions are welcome.\n", + "\n", + "If you want to contribute a new solver for medium scale problems like LBFGS, benchmarks would be highly appreciated. We provide below an example of benchmark (which could also be used if you want to test some hyperparameters of the algorithm). We take here the classical Rosenbroke function, but it could be better to expand such benchmarks to e.g. the set of test functions given by [Andrei, 2008](https://camo.ici.ro/journal/vol10/v10a10.pdf)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MagDCuGjsB5x" + }, + "outputs": [], + "source": [ + "import time\n", + "num_fun_calls = 0\n", + "\n", + "def register_call():\n", + " global num_fun_calls\n", + " num_fun_calls += 1\n", + "\n", + "def test_hparams(lbfgs_hparams, linesearch_hparams, dimension=512):\n", + " global num_fun_calls\n", + " num_fun_calls = 0\n", + "\n", + " def fun(x):\n", + " jax.debug.callback(register_call)\n", + " return jnp.sum((x[1:] - x[:-1] ** 2) ** 2 + (1.0 - x[:-1]) ** 2)\n", + "\n", + " opt = optax.chain(optax.lbfgs(**lbfgs_hparams,\n", + " linesearch=optax.scale_by_zoom_linesearch(**linesearch_hparams)\n", + " )\n", + " )\n", + "\n", + " init_params = jnp.arange(dimension, dtype=jnp.float32)\n", + "\n", + " tic = time.time()\n", + " final_params, _ = run_opt(\n", + " init_params, fun, opt, max_iter=500, tol=5*1e-5\n", + " )\n", + " final_params = jax.block_until_ready(final_params)\n", + " time_run = time.time() - tic\n", + "\n", + " final_value = fun(final_params)\n", + " final_grad_norm = otu.tree_l2_norm(jax.grad(fun)(final_params))\n", + " return final_value, final_grad_norm, num_fun_calls, time_run\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7CXMxWsztGf5" + }, + "outputs": [], + "source": [ + "import copy\n", + "import matplotlib.pyplot as plt\n", + "\n", + "default_lbfgs_hparams = {'memory_size': 15, 'scale_init_precond': True}\n", + "default_linesearch_hparams = {\n", + " 'max_linesearch_steps': 15,\n", + " 'initial_guess_strategy': 'one'\n", + "}\n", + "\n", + "memory_sizes = [int(2**i) for i in range(7)]\n", + "times = []\n", + "calls = []\n", + "values = []\n", + "grad_norms = []\n", + "for m in memory_sizes:\n", + " lbfgs_hparams = copy.deepcopy(default_lbfgs_hparams)\n", + " lbfgs_hparams['memory_size'] = m\n", + " v, g, n, t = test_hparams(lbfgs_hparams, default_linesearch_hparams, dimension=1024)\n", + " values.append(v)\n", + " grad_norms.append(g)\n", + " calls.append(n)\n", + " times.append(t)\n", + "\n", + "fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n", + "axs[0].plot(memory_sizes, values)\n", + "axs[0].set_ylabel('Final values')\n", + "axs[0].set_yscale('log')\n", + "axs[1].plot(memory_sizes, grad_norms)\n", + "axs[1].set_ylabel('Final gradient norms')\n", + "axs[1].set_yscale('log')\n", + "axs[2].plot(memory_sizes, calls)\n", + "axs[2].set_ylabel('Number of function calls')\n", + "axs[3].plot(memory_sizes, times)\n", + "axs[3].set_ylabel('Run times')\n", + "for i in range(4):\n", + " axs[i].set_xlabel('Memory size')\n", + "plt.tight_layout()" + ] } ], "metadata": { diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 118910c5..a0dbf4db 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -2393,7 +2393,9 @@ def lbfgs( scale_init_precond: bool = True, linesearch: Optional[ base.GradientTransformationExtraArgs - ] = _linesearch.scale_by_zoom_linesearch(max_linesearch_steps=15), + ] = _linesearch.scale_by_zoom_linesearch( + max_linesearch_steps=20, initial_guess_strategy='one' + ), ) -> base.GradientTransformationExtraArgs: r"""L-BFGS optimizer. @@ -2453,7 +2455,7 @@ def lbfgs( memory_size: number of past updates to keep in memory to approximate the Hessian inverse. scale_init_precond: whether to use a scaled identity as the initial - preconditioner, see formula above. + preconditioner, see formula of :math:`\gamma_k` above. linesearch: an instance of :class:`optax.GradientTransformationExtraArgs` such as :func:`optax.scale_by_zoom_linesearch` that computes a learning rate, a.k.a. stepsize, to satisfy some criterion such as a @@ -2480,9 +2482,9 @@ def lbfgs( ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: ', f(params)) - Objective function: 0.0 - Objective function: 0.0 - Objective function: 0.0 + Objective function: 7.5166864 + Objective function: 7.460699e-14 + Objective function: 2.6505726e-28 Objective function: 0.0 Objective function: 0.0 @@ -2504,6 +2506,16 @@ def lbfgs( zoom linesearch). See example above for best use in a non-stochastic setting, where we can recycle gradients computed by the linesearch using :func:`optax.value_and_grad_from_state`. + + .. note:: + For the very first step of the algorithm, the scaling of the identity + matrix (:math:`\gamma_k` in the formula above) is not defined. + We consider scaling by a capped reciprocal of the gradient norm. This avoids + wasting many iterations of linesearch for the first iteration by taking + into account the order of magnitude of the gradients. In other words we + constraint the trust-region of the first step to be a Euclidean ball of + radius 1 at the first iteration. The choice of :math:`\gamma_0` is not + detailed in the references above, so this is a heuristic choice. """ if learning_rate is None: base_scaling = transform.scale(-1.0) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 0969386b..b03471b9 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -27,7 +27,6 @@ import numpy as np from optax._src import alias from optax._src import base -from optax._src import linesearch as _linesearch from optax._src import numerics from optax._src import transform from optax._src import update @@ -317,7 +316,7 @@ def test_gradient_accumulation(self, opt_name, opt_kwargs, dtype): # LBFGS -def _run_lbfgs_solver( +def _run_opt( opt: base.GradientTransformationExtraArgs, fun: Callable[[chex.ArrayTree], jnp.ndarray], init_params: chex.ArrayTree, @@ -407,7 +406,7 @@ def _plain_preconditioning( m = len(dws) if m == 0: - return updates + return identity_scale * updates dws = jnp.array(dws) dus = jnp.array(dus) @@ -465,9 +464,12 @@ def _plain_lbfgs( dus = [] for it in range(maxiter): - if scale_init_precond and it > 0: - identity_scale = jnp.vdot(dus[-1], dws[-1]) - identity_scale /= jnp.sum(dus[-1] ** 2) + if scale_init_precond: + if it == 0: + identity_scale = jnp.minimum(1.0, 1.0 / jnp.sqrt(jnp.sum(g**2))) + else: + identity_scale = jnp.vdot(dus[-1], dws[-1]) + identity_scale /= jnp.sum(dus[-1] ** 2) else: identity_scale = 1.0 @@ -688,7 +690,7 @@ def test_against_plain_implementation( scale_init_precond=scale_init_precond, linesearch=None, ) - lbfgs_sol, _ = _run_lbfgs_solver( + lbfgs_sol, _ = _run_opt( opt, fun, init_params, maxiter=maxiter, tol=tol ) expected_lbfgs_sol = _plain_lbfgs( @@ -719,8 +721,8 @@ def fun(x): init_tree = (init_array[0], init_array[1]) opt = alias.lbfgs() - sol_arr, _ = _run_lbfgs_solver(opt, fun, init_array, maxiter=3) - sol_tree, _ = _run_lbfgs_solver(opt, fun, init_tree, maxiter=3) + sol_arr, _ = _run_opt(opt, fun, init_array, maxiter=3) + sol_tree, _ = _run_opt(opt, fun, init_tree, maxiter=3) sol_tree = jnp.stack((sol_tree[0], sol_tree[1])) chex.assert_trees_all_close(sol_arr, sol_tree, rtol=5 * 1e-5, atol=5 * 1e-5) @@ -744,7 +746,7 @@ def fun(params): init_params = (weights_init, biases_init) opt = alias.lbfgs(scale_init_precond=scale_init_precond) - sol, _ = _run_lbfgs_solver(opt, fun, init_params, tol=1e-3) + sol, _ = _run_opt(opt, fun, init_params, tol=1e-3) # Check optimality conditions. self.assertLessEqual(otu.tree_l2_norm(jax.grad(fun)(sol)), 1e-2) @@ -766,7 +768,7 @@ def fun(weights): init_params = jnp.zeros(inputs.shape[1]) opt = alias.lbfgs(scale_init_precond=scale_init_precond) - sol, _ = _run_lbfgs_solver(opt, fun, init_params, tol=1e-6) + sol, _ = _run_opt(opt, fun, init_params, tol=1e-6) # Check optimality conditions. self.assertLessEqual(otu.tree_l2_norm(jax.grad(fun)(sol)), 1e-2) @@ -803,15 +805,8 @@ def test_against_scipy(self, problem_name: str): init_params = problem['init'] jnp_fun, np_fun = problem['fun'], problem['numpy_fun'] - if problem_name == 'zakharov': - opt = alias.lbfgs( - linesearch=_linesearch.scale_by_zoom_linesearch( - max_linesearch_steps=30 - ) - ) - else: - opt = alias.lbfgs() - optax_sol, _ = _run_lbfgs_solver( + opt = alias.lbfgs() + optax_sol, _ = _run_opt( opt, jnp_fun, init_params, maxiter=500, tol=tol ) scipy_sol = scipy_optimize.minimize(np_fun, init_params, method='BFGS').x @@ -845,7 +840,7 @@ def test_minimize_bad_initialization(self): jnp_fun, np_fun = problem['fun'], problem['numpy_fun'] minimum = problem['minimum'] opt = alias.lbfgs() - optax_sol, _ = _run_lbfgs_solver(opt, jnp_fun, init_params, tol=tol) + optax_sol, _ = _run_opt(opt, jnp_fun, init_params, tol=tol) scipy_sol = scipy_optimize.minimize( fun=np_fun, jac=jax.grad(np_fun), @@ -867,7 +862,7 @@ def fun(x): return jnp.mean((mat @ x) ** 2) opt = alias.lbfgs() - sol, _ = _run_lbfgs_solver(opt, fun, init_params=jnp.ones(n), tol=tol) + sol, _ = _run_opt(opt, fun, init_params=jnp.ones(n), tol=tol) chex.assert_trees_all_close(sol, jnp.zeros(n), atol=tol, rtol=tol) diff --git a/optax/_src/linesearch.py b/optax/_src/linesearch.py index c077ed89..02e2ef4f 100644 --- a/optax/_src/linesearch.py +++ b/optax/_src/linesearch.py @@ -27,6 +27,20 @@ import optax.tree_utils as otu +class BacktrackingLinesearchInfo(NamedTuple): + """Information about the backtracking linesearch step, for debugging. + + Attributes: + num_linesearch_steps: number of linesearch steps. + decrease_error: error of the decrease criterion at the end of the + linesearch. A positive value indicates that + the linesearch failed to find a stepsize that ensures a sufficient + decrease. A null value indicates it succeeded in finding such a stepsize. + """ + num_linesearch_steps: int + decrease_error: Union[float, chex.Numeric] + + class ScaleByBacktrackingLinesearchState(NamedTuple): """State for :func:`optax.scale_by_backtracking_linesearch`. @@ -39,20 +53,22 @@ class ScaleByBacktrackingLinesearchState(NamedTuple): line-search if the line-search is instantiated with store_grad = True. Otherwise it is None. Can be reused using :func:`optax.value_and_grad_from_state`. + info: information about the backtracking linesearch step, for debugging. """ learning_rate: Union[float, jax.Array] value: Union[float, jax.Array] - grad: Optional[base.Updates] = None + grad: Optional[base.Updates] + info: BacktrackingLinesearchInfo -class BacktrackingSearchState(NamedTuple): +class BacktrackingLineSearchState(NamedTuple): """State during the inner loop of a backtracking line-search.""" learning_rate: Union[float, jax.Array] new_value: Union[float, jax.Array] new_grad: base.Updates - accepted: bool + decrease_error: chex.Numeric iter_num: Union[int, jax.Array] @@ -65,6 +81,7 @@ def scale_by_backtracking_linesearch( atol: float = 0.0, rtol: float = 0.0, store_grad: bool = False, + verbose: bool = False, ) -> base.GradientTransformationExtraArgs: r"""Backtracking line-search ensuring sufficient decrease (Armijo criterion). @@ -108,6 +125,7 @@ def scale_by_backtracking_linesearch( that, we can directly reuse the value and the gradient computed at the end of the linesearch for the next iteration using :func:`optax.value_and_grad_from_state`. See the example above. + verbose: whether to print debugging information. Returns: A :class:`GradientTransformationExtraArgs`, where the ``update`` function @@ -237,14 +255,25 @@ def init_fn(params: base.Params) -> ScaleByBacktrackingLinesearchState: learning_rate=jnp.array(1.0), value=jnp.array(jnp.inf), grad=grad, + info=BacktrackingLinesearchInfo( + num_linesearch_steps=0, + decrease_error=jnp.array(jnp.inf), + ), ) - def _check_criterion(learning_rate, slope, value, new_value): - violation = ( - new_value - (1 + rtol) * value - learning_rate * slope_rtol * slope + def _compute_decrease_error( + stepsize: chex.Numeric, + slope: chex.Numeric, + value: chex.Numeric, + new_value: chex.Numeric, + ) -> chex.Numeric: + decrease_error = ( + new_value - (1.0 + rtol) * value - stepsize * slope_rtol * slope + ) + decrease_error = jnp.where( + jnp.isnan(decrease_error), jnp.inf, decrease_error ) - violation = jnp.where(jnp.isnan(violation), jnp.inf, violation) - return violation <= atol + return jnp.maximum(decrease_error, 0.0) def update_fn( updates: base.Updates, @@ -293,16 +322,16 @@ def update_fn( slope = otu.tree_vdot(updates, grad) def cond_fn( - search_state: BacktrackingSearchState, - ) -> Union[int, jax._src.basearray.Array]: + search_state: BacktrackingLineSearchState, + ): """Whether to stop the line-search inner loop.""" - accepted = search_state.accepted + decrease_error = search_state.decrease_error iter_num = search_state.iter_num - return (~accepted) & (iter_num <= max_backtracking_steps) + return (~(decrease_error <= atol)) & (iter_num <= max_backtracking_steps) def body_fn( - search_state: BacktrackingSearchState, - ) -> BacktrackingSearchState: + search_state: BacktrackingLineSearchState, + ) -> BacktrackingLineSearchState: """Line-search inner loop step.""" learning_rate = search_state.learning_rate new_grad = search_state.new_grad @@ -320,11 +349,13 @@ def body_fn( # compute the gradient by transposing the jvp. new_value, jvp_value_fn = jax.linearize(value_fn_, new_params) - accepted = _check_criterion(learning_rate, slope, value, new_value) + decrease_error = _compute_decrease_error( + learning_rate, slope, value, new_value + ) # If the line-search ends, we get the gradient for the new round of # line-search. new_grad = jax.lax.cond( - accepted | (iter_num == max_backtracking_steps), + (decrease_error <= atol) | (iter_num == max_backtracking_steps), lambda p: jax.linear_transpose(jvp_value_fn, p)(1.0)[0], lambda *_: new_grad, new_params, @@ -332,12 +363,14 @@ def body_fn( else: # Here we just compute the value and leave the gradient as is new_value = value_fn_(new_params) - accepted = _check_criterion(learning_rate, slope, value, new_value) - search_state = BacktrackingSearchState( + decrease_error = _compute_decrease_error( + learning_rate, slope, value, new_value + ) + search_state = BacktrackingLineSearchState( learning_rate=learning_rate, new_value=new_value, new_grad=new_grad, - accepted=accepted, + decrease_error=decrease_error, iter_num=iter_num + 1, ) return search_state @@ -347,12 +380,11 @@ def body_fn( learning_rate = jnp.minimum( increase_factor * state.learning_rate, max_learning_rate ) - - search_state = BacktrackingSearchState( + search_state = BacktrackingLineSearchState( learning_rate=learning_rate, new_value=value, new_grad=otu.tree_zeros_like(params), - accepted=False, + decrease_error=jnp.array(jnp.inf), iter_num=0, ) search_state = jax.lax.while_loop(cond_fn, body_fn, search_state) @@ -361,15 +393,43 @@ def body_fn( # otu.tree_zeros_like(params)) new_grad = search_state.new_grad if store_grad else None new_value = search_state.new_value - new_learning_rate = search_state.learning_rate + # If the decrease error is infinite, we avoid making any step (which would + # result in nan or infinite values): we set the learning rate to 0. + new_learning_rate = jnp.where( + jnp.isinf(search_state.decrease_error), 0., search_state.learning_rate + ) + if verbose: + # We print information only if the linesearch failed. + _cond_print( + search_state.decrease_error > atol, + "INFO: optax.scale_by_backtracking_linesearch:\n" + "Backtracking linesearch failed to find a stepsize ensuring sufficent" + " decrease.\n" + "Value at current params: {value},\n" + "Slope along update direction: {slope}\n" + "Stepsize: {stepsize}\n" + "Decrease Error: {decrease_error}", + stepsize=search_state.learning_rate, + decrease_error=search_state.decrease_error, + value=value, + slope=slope, + ) + _cond_print( + jnp.isinf(search_state.decrease_error), + "Using a stepsize of 0 to avoid infinite or nan values.", + ) # At the end, we just scale the updates with the learning rate found. new_updates = otu.tree_scalar_mul(new_learning_rate, updates) - + info = BacktrackingLinesearchInfo( + num_linesearch_steps=search_state.iter_num, + decrease_error=search_state.decrease_error, + ) new_state = ScaleByBacktrackingLinesearchState( learning_rate=new_learning_rate, value=new_value, grad=new_grad, + info=info ) return new_updates, new_state @@ -380,7 +440,7 @@ def _cond_print(condition, message, **kwargs): """Prints message if condition is true.""" jax.lax.cond( condition, - lambda _: jax.debug.print(message, **kwargs), + lambda _: jax.debug.print(message, **kwargs, ordered=True), lambda _: None, None, ) @@ -641,7 +701,7 @@ def _value_and_slope_on_line( slope_step = otu.tree_vdot(grad_step, updates) return step, value_step, grad_step, slope_step - def _decrease_error( + def _compute_decrease_error( stepsize: chex.Numeric, value_step: chex.Numeric, slope_step: chex.Numeric, @@ -684,7 +744,7 @@ def _decrease_error( ) return decrease_error - def _curvature_error( + def _compute_curvature_error( slope_step: chex.Numeric, slope_init: chex.Numeric ) -> chex.Numeric: """Compute curvature error.""" @@ -710,11 +770,34 @@ def _try_safe_step( [state.stepsize, state.value, state.grad], ) if verbose: - too_small_int = jnp.abs(state.low - state.high) <= interval_threshold - _cond_print(too_small_int, FLAG_INTERVAL_TOO_SMALL) + jax.debug.print( + "INFO: optax.scale_by_zoom_linesearch:\n" + "Value at current params: {value_init}\n" + "Slope along update direction: {slope_init}\n" + "Stepsize reached: {stepsize}\n" + "Decrease Error: {decrease_error}\n" + "Curvature Error: {curvature_error}", + value_init=state.value_init, + slope_init=state.slope_init, + stepsize=state.stepsize, + decrease_error=state.decrease_error, + curvature_error=state.curvature_error, + ordered=True, + ) + interval_length = jnp.abs(state.low - state.high) + too_small_int = interval_length <= interval_threshold + _cond_print( + too_small_int, + FLAG_INTERVAL_TOO_SMALL + " Interval length: {interval_length}.", + interval_length=interval_length, + ) jax.lax.cond( state.safe_stepsize > 0.0, - lambda *_: jax.debug.print(FLAG_CURVATURE_COND_NOT_SATISFIED), + lambda _: jax.debug.print( + FLAG_CURVATURE_COND_NOT_SATISFIED + + " Stepsize ensuring sufficient decrease: {safe_stepsize}.", + safe_stepsize=state.safe_stepsize, + ), _failure_diagnostic, state, ) @@ -761,10 +844,10 @@ def _search_interval( value_and_grad_fn, params, new_stepsize, updates, fn_kwargs ) - decrease_error = _decrease_error( + decrease_error = _compute_decrease_error( new_stepsize, new_value_step, new_slope_step, value_init, slope_init ) - curvature_error = _curvature_error(new_slope_step, slope_init) + curvature_error = _compute_curvature_error(new_slope_step, slope_init) new_error = jnp.maximum(decrease_error, curvature_error) # If the new point satisfies at least the decrease error we keep it @@ -824,7 +907,20 @@ def _search_interval( if verbose: _cond_print( (max_stepsize_reached & ~interval_found), - FLAG_INTERVAL_NOT_FOUND + "\n" + FLAG_CURVATURE_COND_NOT_SATISFIED, + "INFO: optax.scale_by_zoom_linesearch:\n" + "Value at current params: {value_init}\n" + "Slope along update direction: {slope_init}\n" + "Stepsize reached: {stepsize}\n" + "Decrease Error: {decrease_error}\n" + "Curvature Error: {curvature_error}" + + FLAG_INTERVAL_NOT_FOUND + + "\n" + + FLAG_CURVATURE_COND_NOT_SATISFIED, + value_init=value_init, + slope_init=slope_init, + stepsize=new_stepsize, + decrease_error=decrease_error, + curvature_error=curvature_error, ) failed = (iter_num + 1 >= max_linesearch_steps) & (~done) @@ -932,10 +1028,10 @@ def _zoom_into_interval( value_and_grad_fn, params, middle, updates, fn_kwargs ) - decrease_error = _decrease_error( + decrease_error = _compute_decrease_error( middle, value_middle, slope_middle, value_init, slope_init ) - curvature_error = _curvature_error(slope_middle, slope_init) + curvature_error = _compute_curvature_error(slope_middle, slope_init) new_error = jnp.maximum(decrease_error, curvature_error) # If the new point satisfies at least the decrease error we keep it in case @@ -1038,20 +1134,8 @@ def _zoom_into_interval( def _failure_diagnostic(state: ZoomLinesearchState) -> None: """Prints failure diagnostics.""" - stepsize = state.stepsize jax.debug.print(FLAG_NO_STEPSIZE_FOUND) - jax.debug.print( - "INFO: optax.zoom_linesearch: " - "Iter: {} " - "Stepsize: {} " - "Decrease Error: {} " - "Curvature Error: {} ", - state.count, - stepsize, - state.decrease_error, - state.curvature_error, - ordered=True, - ) + stepsize = state.stepsize slope_init = state.slope_init is_descent_dir = slope_init < 0.0 @@ -1063,15 +1147,13 @@ def _failure_diagnostic(state: ZoomLinesearchState) -> None: ) _cond_print( is_descent_dir, - WARNING_PREAMBLE - + "Consider augmenting the maximal number of linesearch iterations.", + "Consider augmenting the maximal number of linesearch iterations.", ) eps = jnp.finfo(jnp.float32).eps below_eps = stepsize < eps _cond_print( below_eps & is_descent_dir, - WARNING_PREAMBLE - + "Computed stepsize (={stepsize}) " + "Computed stepsize (={stepsize}) " "is below machine precision (={eps}), " "consider passing to higher precision like x64, using " "jax.config.update('jax_enable_x64', True).", @@ -1082,8 +1164,7 @@ def _failure_diagnostic(state: ZoomLinesearchState) -> None: high_slope = abs_slope_init > 1e16 _cond_print( high_slope & is_descent_dir, - WARNING_PREAMBLE - + "Very large absolute slope at stepsize=0. " + "Very large absolute slope at stepsize=0. " "(|slope|={abs_slope_init}). " "The objective is badly conditioned. " "Consider reparameterizing objective (e.g., normalizing parameters) " @@ -1094,14 +1175,12 @@ def _failure_diagnostic(state: ZoomLinesearchState) -> None: outside_domain = jnp.isinf(state.decrease_error) _cond_print( outside_domain, - WARNING_PREAMBLE - + "Cannot even make a step without getting Inf or Nan. " + "Cannot even make a step without getting Inf or Nan. " + "The linesearch won't make a step and the optimizer is stuck.", ) _cond_print( ~outside_domain, - WARNING_PREAMBLE - + "Making an unsafe step, not decreasing enough the objective. " + "Making an unsafe step, not decreasing enough the objective. " "Convergence of the solver is compromised as it does not reduce" " values.", ) @@ -1112,9 +1191,20 @@ def init_fn( *, value: chex.Numeric, grad: base.Updates, - stepsize_guess: chex.Numeric = 1.0, + prev_stepsize: chex.Numeric = 1.0, + initial_guess_strategy: str = "one", ) -> ZoomLinesearchState: """Initializes the linesearch state.""" + + if initial_guess_strategy == "one": + stepsize_guess = jnp.asarray(1.0) + elif initial_guess_strategy == "keep": + stepsize_guess = prev_stepsize + else: + raise ValueError( + f"Unknown initial guess strategy: {initial_guess_strategy}" + ) + slope = otu.tree_vdot(updates, grad) return ZoomLinesearchState( count=jnp.asarray(0, dtype=jnp.int32), @@ -1241,6 +1331,7 @@ def scale_by_zoom_linesearch( curv_rtol: float = 0.9, approx_dec_rtol: Optional[float] = 1e-6, stepsize_precision: float = 1e-5, + initial_guess_strategy: str = "keep", verbose: bool = False, ) -> base.GradientTransformationExtraArgs: r"""Linesearch ensuring sufficient decrease and small curvature. @@ -1321,6 +1412,13 @@ def scale_by_zoom_linesearch( interval is reduced below ``stepsize_precision`` and a stepsize satisfying a sufficient decrease has been found, the algorithm selects that stepsize even if the curvature condition is not satisfied. + initial_guess_strategy: initial guess for the learning rate used to start + the linesearch. Can be either ``one`` or ``keep``. If ``one``, the initial + guess is set to 1. If ``keep``, the initial guess is set to the learning + rate of the previous step. We recommend to use ``keep`` if this linesearch + is used in combination with SGD. We recommend to use ``one`` if this + linesearch is used in combination with Newton methods or quasi-Newton + methods such as L-BFGS. verbose: whether to print additional debugging information in case the linesearch fails. @@ -1405,6 +1503,13 @@ def scale_by_zoom_linesearch( with Guaranteed Descent `_, 2006 + .. note:: + You may consider ignoring the small curvature criterion by setting + ``curv_rtol=jnp.inf``. The resulting algorithm will amount essentially to a + backtracking linesearch with interpolation strategies. This can be + sufficient in practice and avoids having the linesearch spend many + iterations trying to satisfy the small curvature criterion. + .. seealso:: :func:`optax.value_and_grad_from_state` to make this method more efficient for non-stochastic objectives. """ @@ -1476,13 +1581,13 @@ def update_fn( del remaining_kwargs value_and_grad_fn = jax.value_and_grad(value_fn) - stepsize_guess = state.learning_rate init_state = init_ls( updates, params, value=value, grad=grad, - stepsize_guess=stepsize_guess, + prev_stepsize=state.learning_rate, + initial_guess_strategy=initial_guess_strategy, ) final_state = jax.lax.while_loop( @@ -1510,27 +1615,21 @@ def update_fn( # Flags to print errors, used for debugging, tested -WARNING_PREAMBLE = "WARNING: optax.zoom_linesearch: " FLAG_INTERVAL_NOT_FOUND = ( - WARNING_PREAMBLE - + "No interval satisfying curvature condition." + "No interval satisfying curvature condition. " "Consider increasing maximal possible stepsize of the linesearch." ) FLAG_INTERVAL_TOO_SMALL = ( - WARNING_PREAMBLE - + "Length of searched interval has been reduced below threshold." + "Length of searched interval has been reduced below threshold." ) FLAG_CURVATURE_COND_NOT_SATISFIED = ( - WARNING_PREAMBLE - + "Returning stepsize with sufficient decrease " + "Returning stepsize with sufficient decrease " "but curvature condition not satisfied." ) FLAG_NO_STEPSIZE_FOUND = ( - WARNING_PREAMBLE - + "Linesearch failed, no stepsize satisfying sufficient decrease found." + "Linesearch failed, no stepsize satisfying sufficient decrease found." ) FLAG_NOT_A_DESCENT_DIRECTION = ( - WARNING_PREAMBLE - + "The linesearch failed because the provided direction " + "The linesearch failed because the provided direction " "is not a descent direction. " ) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 69b24bdc..5872594c 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1605,7 +1605,7 @@ def scale_by_lbfgs( memory_size: number of past parameters, gradients/updates to keep in memory to approximate the Hessian inverse. scale_init_precond: whether to use a scaled identity as the initial - preconditioner, see formula above. + preconditioner, see formula of :math:`\gamma_k` above. Returns: A :class:`optax.GradientTransformation` object. @@ -1618,6 +1618,16 @@ def scale_by_lbfgs( Liu et al., `On the limited memory BFGS method for large scale optimization `_ , 1989. + + .. note:: + For the very first step of the algorithm, the scaling of the identity + matrix (:math:`\gamma_k` in the formula above) is not defined. + We consider scaling by a capped reciprocal of the gradient norm. This avoids + wasting many iterations of linesearch for the first iteration by taking + into account the order of magnitude of the gradients. In other words we + constraint the trust-region of the first step to be a euclidean ball of + radius 1 at the first iteration. The choice of :math:`\gamma_0` is not + detailed in the references above, so this is a heuristic choice. """ if memory_size < 1: raise ValueError('memory_size must be >= 1') @@ -1688,6 +1698,12 @@ def update_fn( identity_scale = jnp.where( denominator > 0.0, numerator / denominator, 1.0 ) + # For the very first step of the algorithm, we consider scaling by a + # capped reciprocal of the gradient norm, see note in the docstring. + capped_inv_norm = jnp.minimum(1.0, 1.0/otu.tree_l2_norm(updates)) + identity_scale = jnp.where( + state.count > 0, identity_scale, capped_inv_norm + ) else: identity_scale = 1.0