diff --git a/optax/_src/linesearch.py b/optax/_src/linesearch.py new file mode 100644 index 000000000..6e076a2c8 --- /dev/null +++ b/optax/_src/linesearch.py @@ -0,0 +1,201 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linesearches.""" + +import functools +from typing import Any, Callable, NamedTuple, Optional, Union + +import jax +import jax.numpy as jnp +from optax._src import base +import optax.tree_utils as optax_tu + + +class BacktrackingLinesearchState(NamedTuple): + tuned_learning_rate: Union[float, jax.Array] + base_opt_state: base.OptState + value_ref: Union[float, jax.Array] + grad_ref: base.Updates + + +def backtracking_linesearch( + base_opt: base.GradientTransformation, + coef: float = 1e-4, + decrease_factor: float = 0.8, + increase_factor: float = 1.5, + max_learning_rate: float = 1.0, + max_backtracking_steps: int = 15, + atol: float = 0.0, + rtol: float = 0.0, + recycle_computations: bool = False, +): + r"""Backtracking linesearch, a.k.a. Armijo linesearch. + + Selects learning rate such that + + .. math:: + \begin{align*} + f(w + lr \cdot u) \leq (1-\delta)f(w) + + c \langle u, \nabla f(w) \rangle + \epsilon + \end{align*} + + where f is the function to optimize, lr is the learning_rate to find, + u is the updatte direction computed by the base_opt, c is a coefficient (coef) + measuring the decrease of the function in terms of the slope (scalar + product between grads and updates) in the given direction, delta is a + relative tolerance (rtol), epsilon is an + absolute tolerance (atol). + + We start by a given guess of a learning rate and decrease it by some factor + until the criterion above is met. + + .. warning:: + The base optimizer needs to return a descent direction, i.e., such that + < 0 for u the updates and g the gradient. This is the case for, e.g., + a simple sgd but not for e.g., adam. If the updates are not a descent + direction the linesearch is doomed to fail. + + References: + Vaswani et al, _https://arxiv.org/abs/1905.09997, 2019 + Nocedal & Wright, Numerical Optimization, 1999 + + Args: + base_opt: base optimizer providing updates along which the linesearch is + performed. + coef: sufficient decrease must be coef * lr * , see formula + above. + decrease_factor: decreasing factor to reduce learning rate. + increase_factor: increasing factor to increase learning rate guess. + max_learning_rate: maximal learning rate (learning rate clipped to this). + max_backtracking_steps: maximal number of iterations for the search. + atol: absolute tolerance at which the condition needs to be satisfied. + rtol: relative tolerance at which the condition needs to be satisfied. + recycle_computations: whether to recycle last computed value and grads to + initialize the next search. If the function does not change from iteration + to iteration (no varying data for example), this can save up to half the + computations. + + Returns: + (init_fn, update_fn): initalization and update functions + """ + + def init_fn(params: base.Params): + return BacktrackingLinesearchState( + tuned_learning_rate=jnp.array(1.0), + base_opt_state=base_opt.init(params), + value_ref=jnp.inf, + grad_ref=jnp.zeros_like(params), + ) + + def check_condition(learning_rate, slope, value, next_value): + violation = next_value - (1 - rtol) * value - learning_rate * coef * slope + violation = jnp.where(jnp.isnan(violation), jnp.inf, violation) + return violation <= atol + + def update_fn( + grad: base.Updates, + state: BacktrackingLinesearchState, + params: base.Params, + *, + value_fn: Callable[..., Union[jax.Array, float]], + value: Optional[Union[float, jax.Array]] = None, + fn_kwargs: Optional[dict[str, Any]] = None, + **extra_kwargs, + ): + """Compute scaled updates guaranteeing decrease of current objective. + + Args: + grad: gradient of the function at the current params. + state: current state. + params: current parameters. + value_fn: function returning uniquely value of the function to search. + value: value of the function at the current params. + fn_kwargs: additional keyword arguments for the function to minimize. + **extra_kwargs: additional keyword arguments that may be used by other + transforms when the linesearch is combined with other transforms through + optax. + + Returns: + updates: updates for the params (next_params = params + updates). + state: updated state. + """ + del extra_kwargs + fn_kwargs = fn_kwargs or {} + + if recycle_computations: + value, grad = jax.lax.cond( + jnp.isinf(state.value_ref), + lambda p, f_kw: jax.value_and_grad(value_fn)(p, **f_kw), + lambda *_: (state.value_ref, state.grad_ref), + params, + fn_kwargs, + ) + updates, base_opt_state = base_opt.update( + grad, state.base_opt_state, params + ) + + slope = optax_tu.tree_vdot(updates, grad) + + def cond_fn(carry): + accept, iter_num = carry[-2], carry[-1] + return ~accept & (iter_num <= max_backtracking_steps) + + def body_fn(carry): + learning_rate, _, next_grad, _, iter_num = carry + learning_rate = jnp.where( + iter_num > 0, decrease_factor * learning_rate, learning_rate + ) + next_params = optax_tu.tree_add_scalar_mul(params, learning_rate, updates) + if recycle_computations: + # We evaluate value_fn and get its jvp operator so that we can + # compute the gradient by transposing the jvp. + value_fn_ = functools.partial(value_fn, **fn_kwargs) + next_value, jvp_value_fn = jax.linearize(value_fn_, next_params) + accept = check_condition(learning_rate, slope, value, next_value) + # If the step has been accepted, we get the gradient for the next + # run of linesearch. + next_grad = jax.lax.cond( + accept, + lambda p: jax.linear_transpose(jvp_value_fn, p)(1.0)[0], + lambda *_: next_grad, + next_params, + ) + else: + next_value = value_fn(next_params, **fn_kwargs) + accept = check_condition(learning_rate, slope, value, next_value) + return learning_rate, next_value, next_grad, accept, iter_num + 1 + + # Guess candidate learning rate + learning_rate = jnp.minimum( + increase_factor * state.tuned_learning_rate, max_learning_rate + ) + + next_value, next_grad = value, jnp.zeros_like(grad) + init_carry = (learning_rate, next_value, next_grad, False, 0) + + learning_rate, next_value, next_grad, *_ = jax.lax.while_loop( + cond_fn, body_fn, init_carry + ) + + new_updates = optax_tu.tree_scalar_mul(learning_rate, updates) + return new_updates, BacktrackingLinesearchState( + tuned_learning_rate=learning_rate, + base_opt_state=base_opt_state, + value_ref=next_value, + grad_ref=next_grad, + ) + + return base.GradientTransformationExtraArgs(init_fn, update_fn) diff --git a/optax/_src/linesearch_test.py b/optax/_src/linesearch_test.py new file mode 100644 index 000000000..4e9916b6f --- /dev/null +++ b/optax/_src/linesearch_test.py @@ -0,0 +1,255 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `linesearch.py`.""" + +import itertools +import math + +from absl.testing import absltest +from absl.testing import parameterized +import chex +import jax +import jax.numpy as jnp +import jax.random as jrd +import numpy as np +from optax._src import alias +from optax._src import base +from optax._src import linesearch +from optax._src import update +import optax.tree_utils as optax_tu + + +# pytype disable=attribute-error +class BacktrackingLinesearchTest(chex.TestCase): + + def get_fun(self, name): + """Common ill-behaved functions.""" + + def rosenbrock(x): + return jnp.sum(100.0 * jnp.diff(x) ** 2 + (1.0 - x[:-1]) ** 2) + + def himmelblau(x): + return (x[0] ** 2 + x[1] - 11.0) ** 2 + (x[0] + x[1] ** 2 - 7.0) ** 2 + + def matyas(x): + return 0.26 * (x[0] ** 2 + x[1] ** 2) - 0.48 * x[0] * x[1] + + def eggholder(x): + return -(x[1] + 47) * jnp.sin( + jnp.sqrt(jnp.abs(x[0] / 2.0 + x[1] + 47.0)) + ) - x[0] * jnp.sin(jnp.sqrt(jnp.abs(x[0] - (x[1] + 47.0)))) + + funs = dict( + rosenbrock=rosenbrock, + himmelblau=himmelblau, + matyas=matyas, + eggholder=eggholder, + ) + return funs[name] + + def check_decrease_conditions( + self, fun, init_params, descent_dir, final_params, final_state, opt_args + ): + """Check decrease conditions.""" + init_value, init_grad = jax.value_and_grad(fun)(init_params) + final_value = fun(final_params) + final_lr = final_state[0] + + slope = optax_tu.tree_vdot(descent_dir, init_grad) + coef, atol, rtol = opt_args['coef'], opt_args['atol'], opt_args['rtol'] + sufficient_decrease = ( + final_value <= (1 - rtol) * init_value + coef * final_lr * slope + atol + ) + self.assertTrue(sufficient_decrease) + + @chex.all_variants + @parameterized.product( + name_fun_and_init_params=[ + ('rosenbrock', np.zeros(2)), + ('himmelblau', np.ones(2)), + ('matyas', np.ones(2) * 6.0), + ('eggholder', np.ones(2) * 100.0), + ], + increase_factor=[1.0, 1.5, None], + coef=[1e-4, 0.0], + atol=[1e-4, 0.0], + rtol=[1e-4, 0.0], + ) + def test_linesearch_one_step( + self, + name_fun_and_init_params, + increase_factor, + coef, + atol, + rtol, + ): + # Choosing increase_factor=inf amounts to guess always with the maximal + # stepsize + increase_factor = increase_factor or math.inf + name_fun, init_params = name_fun_and_init_params + fn = self.get_fun(name_fun) + + # The following descent direction amounts to armijo_gd. + # The test is kept as is for illustration purposes of alternative ways to + # make a linesearch. + descent_dir = -jax.grad(fn)(init_params) + base_init_fn = lambda params: base.EmptyState() + base_update_fn = lambda updates, state, params: (descent_dir, state) + base_opt = base.GradientTransformation(base_init_fn, base_update_fn) + + opt_args = dict( + coef=coef, + decrease_factor=0.8, + increase_factor=increase_factor, + max_learning_rate=1.0, + max_backtracking_steps=30, + atol=atol, + rtol=rtol, + ) + + # With recycling gradients (the value and grad are computed inside the loop) + solver = linesearch.backtracking_linesearch( + base_opt, recycle_computations=True, **opt_args + ) + state = solver.init(init_params) + update_fn = self.variant(solver.update) + final_updates, final_state1 = update_fn( + jnp.zeros_like(init_params), state, init_params, value_fn=fn + ) + final_params1 = update.apply_updates(init_params, final_updates) + self.check_decrease_conditions( + fn, init_params, descent_dir, final_params1, final_state1, opt_args + ) + + # Without recycling gradients + value, grad = jax.value_and_grad(fn)(init_params) + solver = linesearch.backtracking_linesearch( + base_opt, recycle_computations=False, **opt_args + ) + state = solver.init(init_params) + update_fn = self.variant(solver.update) + final_updates, final_state2 = update_fn( + grad, state, init_params, value_fn=fn, value=value + ) + final_params2 = update.apply_updates(init_params, final_updates) + self.check_decrease_conditions( + fn, init_params, descent_dir, final_params2, final_state2, opt_args + ) + chex.assert_trees_all_close(final_params1, final_params2) + + @chex.all_variants + @parameterized.product( + name_fun_and_init_params=[ + ('rosenbrock', np.zeros(2)), + ('himmelblau', np.ones(2)), + ('matyas', np.ones(2) * 6.0), + ('eggholder', np.ones(2) * 100.0), + ], + ) + def test_sgd_with_linesearch( + self, + name_fun_and_init_params, + ): + name_fun, init_params = name_fun_and_init_params + fn = self.get_fun(name_fun) + base_opt = alias.sgd(learning_rate=1.0) + max_iter = 20 + + # With recycling + solver = linesearch.backtracking_linesearch( + base_opt=base_opt, + recycle_computations=True, + ) + state = solver.init(init_params) + update_fn = self.variant(solver.update) + params = init_params + for _ in range(max_iter): + updates, state = update_fn( + jnp.zeros_like(params), + state, + params, + value_fn=fn, + ) + params = update.apply_updates(params, updates) + final_params1 = params + init_value, final_value = fn(init_params), fn(final_params1) + self.assertLessEqual(final_value, init_value) + + # Without recycling + solver = linesearch.backtracking_linesearch( + base_opt=base_opt, + recycle_computations=False, + ) + state = solver.init(init_params) + update_fn = self.variant(solver.update) + params = init_params + for _ in range(max_iter): + updates, state = update_fn( + jnp.zeros_like(params), + state, + params, + value_fn=fn, + ) + params = update.apply_updates(params, updates) + final_params2 = params + init_value, final_value = fn(init_params), fn(final_params2) + self.assertLessEqual(final_value, init_value) + chex.assert_trees_all_close(final_params1, final_params2) + + @chex.all_variants + def test_armijo_sgd(self): + def fn(params, x, y): + return jnp.sum((x.dot(params) - y) ** 2) + + key = jrd.PRNGKey(0) + x_key, y_key, params_key = jrd.split(key, 3) + d, m, n = 2, 16, 8 + xs = jrd.normal(x_key, (n, m, d)) + true_pred = jrd.normal(params_key, (d,)) + ys = jnp.stack([x.dot(true_pred) for x in xs]) + ys = ys + jrd.normal(y_key, (n, m)) + xs_iter = itertools.cycle(iter(xs)) + ys_iter = itertools.cycle(iter(ys)) + init_params = jnp.zeros((d,)) + + solver = linesearch.backtracking_linesearch( + base_opt=alias.sgd(learning_rate=1.0) + ) + num_passes = 8 + init_value = fn(init_params, xs, ys) + + state = solver.init(init_params) + update_fn = self.variant(solver.update) + + params = init_params + for _ in range(num_passes): + x, y = next(xs_iter), next(ys_iter) + value, grad = jax.value_and_grad(fn)(params, **{'x': x, 'y': y}) + updates, state = update_fn( + grad, + state, + params, + value=value, + value_fn=fn, + fn_kwargs={'x': x, 'y': y}, + ) + params = update.apply_updates(params, updates) + final_params1 = params + final_value = fn(final_params1, xs, ys) + self.assertLessEqual(final_value.item(), init_value.item()) + + +if __name__ == '__main__': + absltest.main()