Skip to content

Commit

Permalink
Backtracking linesearch.
Browse files Browse the repository at this point in the history
Supports stochastic and deterministic cases.
Deterministic case uses a priori minimal number of calls to function/gradient evaluations. However there is still two function call sites for readability.
Could totally make it a single call site but the readability will suffer.

PiperOrigin-RevId: 605625086
  • Loading branch information
vroulet authored and OptaxDev committed Feb 9, 2024
1 parent a6d2b5c commit 4d90b40
Show file tree
Hide file tree
Showing 2 changed files with 456 additions and 0 deletions.
201 changes: 201 additions & 0 deletions optax/_src/linesearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Linesearches."""

import functools
from typing import Any, Callable, NamedTuple, Optional, Union

import jax
import jax.numpy as jnp
from optax._src import base
import optax.tree_utils as optax_tu


class BacktrackingLinesearchState(NamedTuple):
tuned_learning_rate: Union[float, jax.Array]
base_opt_state: base.OptState
value_ref: Union[float, jax.Array]
grad_ref: base.Updates


def backtracking_linesearch(
base_opt: base.GradientTransformation,
coef: float = 1e-4,
decrease_factor: float = 0.8,
increase_factor: float = 1.5,
max_learning_rate: float = 1.0,
max_backtracking_steps: int = 15,
atol: float = 0.0,
rtol: float = 0.0,
recycle_computations: bool = False,
):
r"""Backtracking linesearch, a.k.a. Armijo linesearch.
Selects learning rate such that
.. math::
\begin{align*}
f(w + lr \cdot u) \leq (1-\delta)f(w)
+ c \langle u, \nabla f(w) \rangle + \epsilon
\end{align*}
where f is the function to optimize, lr is the learning_rate to find,
u is the updatte direction computed by the base_opt, c is a coefficient (coef)
measuring the decrease of the function in terms of the slope (scalar
product between grads and updates) in the given direction, delta is a
relative tolerance (rtol), epsilon is an
absolute tolerance (atol).
We start by a given guess of a learning rate and decrease it by some factor
until the criterion above is met.
.. warning::
The base optimizer needs to return a descent direction, i.e., such that
<u, g> < 0 for u the updates and g the gradient. This is the case for, e.g.,
a simple sgd but not for e.g., adam. If the updates are not a descent
direction the linesearch is doomed to fail.
References:
Vaswani et al, <Painless Stochastic Gradient: Interpolation,
Line-Search, and Convergence Rates>_https://arxiv.org/abs/1905.09997, 2019
Nocedal & Wright, Numerical Optimization, 1999
Args:
base_opt: base optimizer providing updates along which the linesearch is
performed.
coef: sufficient decrease must be coef * lr * <grads, updates>, see formula
above.
decrease_factor: decreasing factor to reduce learning rate.
increase_factor: increasing factor to increase learning rate guess.
max_learning_rate: maximal learning rate (learning rate clipped to this).
max_backtracking_steps: maximal number of iterations for the search.
atol: absolute tolerance at which the condition needs to be satisfied.
rtol: relative tolerance at which the condition needs to be satisfied.
recycle_computations: whether to recycle last computed value and grads to
initialize the next search. If the function does not change from iteration
to iteration (no varying data for example), this can save up to half the
computations.
Returns:
(init_fn, update_fn): initalization and update functions
"""

def init_fn(params: base.Params):
return BacktrackingLinesearchState(
tuned_learning_rate=jnp.array(1.0),
base_opt_state=base_opt.init(params),
value_ref=jnp.inf,
grad_ref=jnp.zeros_like(params),
)

def check_condition(learning_rate, slope, value, next_value):
violation = next_value - (1 - rtol) * value - learning_rate * coef * slope
violation = jnp.where(jnp.isnan(violation), jnp.inf, violation)
return violation <= atol

def update_fn(
grad: base.Updates,
state: BacktrackingLinesearchState,
params: base.Params,
*,
value_fn: Callable[..., Union[jax.Array, float]],
value: Optional[Union[float, jax.Array]] = None,
fn_kwargs: Optional[dict[str, Any]] = None,
**extra_kwargs,
):
"""Compute scaled updates guaranteeing decrease of current objective.
Args:
grad: gradient of the function at the current params.
state: current state.
params: current parameters.
value_fn: function returning uniquely value of the function to search.
value: value of the function at the current params.
fn_kwargs: additional keyword arguments for the function to minimize.
**extra_kwargs: additional keyword arguments that may be used by other
transforms when the linesearch is combined with other transforms through
optax.
Returns:
updates: updates for the params (next_params = params + updates).
state: updated state.
"""
del extra_kwargs
fn_kwargs = fn_kwargs or {}

if recycle_computations:
value, grad = jax.lax.cond(
jnp.isinf(state.value_ref),
lambda p, f_kw: jax.value_and_grad(value_fn)(p, **f_kw),
lambda *_: (state.value_ref, state.grad_ref),
params,
fn_kwargs,
)
updates, base_opt_state = base_opt.update(
grad, state.base_opt_state, params
)

slope = optax_tu.tree_vdot(updates, grad)

def cond_fn(carry):
accept, iter_num = carry[-2], carry[-1]
return ~accept & (iter_num <= max_backtracking_steps)

def body_fn(carry):
learning_rate, _, next_grad, _, iter_num = carry
learning_rate = jnp.where(
iter_num > 0, decrease_factor * learning_rate, learning_rate
)
next_params = optax_tu.tree_add_scalar_mul(params, learning_rate, updates)
if recycle_computations:
# We evaluate value_fn and get its jvp operator so that we can
# compute the gradient by transposing the jvp.
value_fn_ = functools.partial(value_fn, **fn_kwargs)
next_value, jvp_value_fn = jax.linearize(value_fn_, next_params)
accept = check_condition(learning_rate, slope, value, next_value)
# If the step has been accepted, we get the gradient for the next
# run of linesearch.
next_grad = jax.lax.cond(
accept,
lambda p: jax.linear_transpose(jvp_value_fn, p)(1.0)[0],
lambda *_: next_grad,
next_params,
)
else:
next_value = value_fn(next_params, **fn_kwargs)
accept = check_condition(learning_rate, slope, value, next_value)
return learning_rate, next_value, next_grad, accept, iter_num + 1

# Guess candidate learning rate
learning_rate = jnp.minimum(
increase_factor * state.tuned_learning_rate, max_learning_rate
)

next_value, next_grad = value, jnp.zeros_like(grad)
init_carry = (learning_rate, next_value, next_grad, False, 0)

learning_rate, next_value, next_grad, *_ = jax.lax.while_loop(
cond_fn, body_fn, init_carry
)

new_updates = optax_tu.tree_scalar_mul(learning_rate, updates)
return new_updates, BacktrackingLinesearchState(
tuned_learning_rate=learning_rate,
base_opt_state=base_opt_state,
value_ref=next_value,
grad_ref=next_grad,
)

return base.GradientTransformationExtraArgs(init_fn, update_fn)
Loading

0 comments on commit 4d90b40

Please sign in to comment.