-
Notifications
You must be signed in to change notification settings - Fork 193
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Supports stochastic and deterministic cases. Deterministic case uses a priori minimal number of calls to function/gradient evaluations. However there is still two function call sites for readability. Could totally make it a single call site but the readability will suffer. PiperOrigin-RevId: 605625086
- Loading branch information
Showing
2 changed files
with
456 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Linesearches.""" | ||
|
||
import functools | ||
from typing import Any, Callable, NamedTuple, Optional, Union | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from optax._src import base | ||
import optax.tree_utils as optax_tu | ||
|
||
|
||
class BacktrackingLinesearchState(NamedTuple): | ||
tuned_learning_rate: Union[float, jax.Array] | ||
base_opt_state: base.OptState | ||
value_ref: Union[float, jax.Array] | ||
grad_ref: base.Updates | ||
|
||
|
||
def backtracking_linesearch( | ||
base_opt: base.GradientTransformation, | ||
coef: float = 1e-4, | ||
decrease_factor: float = 0.8, | ||
increase_factor: float = 1.5, | ||
max_learning_rate: float = 1.0, | ||
max_backtracking_steps: int = 15, | ||
atol: float = 0.0, | ||
rtol: float = 0.0, | ||
recycle_computations: bool = False, | ||
): | ||
r"""Backtracking linesearch, a.k.a. Armijo linesearch. | ||
Selects learning rate such that | ||
.. math:: | ||
\begin{align*} | ||
f(w + lr \cdot u) \leq (1-\delta)f(w) | ||
+ c \langle u, \nabla f(w) \rangle + \epsilon | ||
\end{align*} | ||
where f is the function to optimize, lr is the learning_rate to find, | ||
u is the updatte direction computed by the base_opt, c is a coefficient (coef) | ||
measuring the decrease of the function in terms of the slope (scalar | ||
product between grads and updates) in the given direction, delta is a | ||
relative tolerance (rtol), epsilon is an | ||
absolute tolerance (atol). | ||
We start by a given guess of a learning rate and decrease it by some factor | ||
until the criterion above is met. | ||
.. warning:: | ||
The base optimizer needs to return a descent direction, i.e., such that | ||
<u, g> < 0 for u the updates and g the gradient. This is the case for, e.g., | ||
a simple sgd but not for e.g., adam. If the updates are not a descent | ||
direction the linesearch is doomed to fail. | ||
References: | ||
Vaswani et al, <Painless Stochastic Gradient: Interpolation, | ||
Line-Search, and Convergence Rates>_https://arxiv.org/abs/1905.09997, 2019 | ||
Nocedal & Wright, Numerical Optimization, 1999 | ||
Args: | ||
base_opt: base optimizer providing updates along which the linesearch is | ||
performed. | ||
coef: sufficient decrease must be coef * lr * <grads, updates>, see formula | ||
above. | ||
decrease_factor: decreasing factor to reduce learning rate. | ||
increase_factor: increasing factor to increase learning rate guess. | ||
max_learning_rate: maximal learning rate (learning rate clipped to this). | ||
max_backtracking_steps: maximal number of iterations for the search. | ||
atol: absolute tolerance at which the condition needs to be satisfied. | ||
rtol: relative tolerance at which the condition needs to be satisfied. | ||
recycle_computations: whether to recycle last computed value and grads to | ||
initialize the next search. If the function does not change from iteration | ||
to iteration (no varying data for example), this can save up to half the | ||
computations. | ||
Returns: | ||
(init_fn, update_fn): initalization and update functions | ||
""" | ||
|
||
def init_fn(params: base.Params): | ||
return BacktrackingLinesearchState( | ||
tuned_learning_rate=jnp.array(1.0), | ||
base_opt_state=base_opt.init(params), | ||
value_ref=jnp.inf, | ||
grad_ref=jnp.zeros_like(params), | ||
) | ||
|
||
def check_condition(learning_rate, slope, value, next_value): | ||
violation = next_value - (1 - rtol) * value - learning_rate * coef * slope | ||
violation = jnp.where(jnp.isnan(violation), jnp.inf, violation) | ||
return violation <= atol | ||
|
||
def update_fn( | ||
grad: base.Updates, | ||
state: BacktrackingLinesearchState, | ||
params: base.Params, | ||
*, | ||
value_fn: Callable[..., Union[jax.Array, float]], | ||
value: Optional[Union[float, jax.Array]] = None, | ||
fn_kwargs: Optional[dict[str, Any]] = None, | ||
**extra_kwargs, | ||
): | ||
"""Compute scaled updates guaranteeing decrease of current objective. | ||
Args: | ||
grad: gradient of the function at the current params. | ||
state: current state. | ||
params: current parameters. | ||
value_fn: function returning uniquely value of the function to search. | ||
value: value of the function at the current params. | ||
fn_kwargs: additional keyword arguments for the function to minimize. | ||
**extra_kwargs: additional keyword arguments that may be used by other | ||
transforms when the linesearch is combined with other transforms through | ||
optax. | ||
Returns: | ||
updates: updates for the params (next_params = params + updates). | ||
state: updated state. | ||
""" | ||
del extra_kwargs | ||
fn_kwargs = fn_kwargs or {} | ||
|
||
if recycle_computations: | ||
value, grad = jax.lax.cond( | ||
jnp.isinf(state.value_ref), | ||
lambda p, f_kw: jax.value_and_grad(value_fn)(p, **f_kw), | ||
lambda *_: (state.value_ref, state.grad_ref), | ||
params, | ||
fn_kwargs, | ||
) | ||
updates, base_opt_state = base_opt.update( | ||
grad, state.base_opt_state, params | ||
) | ||
|
||
slope = optax_tu.tree_vdot(updates, grad) | ||
|
||
def cond_fn(carry): | ||
accept, iter_num = carry[-2], carry[-1] | ||
return ~accept & (iter_num <= max_backtracking_steps) | ||
|
||
def body_fn(carry): | ||
learning_rate, _, next_grad, _, iter_num = carry | ||
learning_rate = jnp.where( | ||
iter_num > 0, decrease_factor * learning_rate, learning_rate | ||
) | ||
next_params = optax_tu.tree_add_scalar_mul(params, learning_rate, updates) | ||
if recycle_computations: | ||
# We evaluate value_fn and get its jvp operator so that we can | ||
# compute the gradient by transposing the jvp. | ||
value_fn_ = functools.partial(value_fn, **fn_kwargs) | ||
next_value, jvp_value_fn = jax.linearize(value_fn_, next_params) | ||
accept = check_condition(learning_rate, slope, value, next_value) | ||
# If the step has been accepted, we get the gradient for the next | ||
# run of linesearch. | ||
next_grad = jax.lax.cond( | ||
accept, | ||
lambda p: jax.linear_transpose(jvp_value_fn, p)(1.0)[0], | ||
lambda *_: next_grad, | ||
next_params, | ||
) | ||
else: | ||
next_value = value_fn(next_params, **fn_kwargs) | ||
accept = check_condition(learning_rate, slope, value, next_value) | ||
return learning_rate, next_value, next_grad, accept, iter_num + 1 | ||
|
||
# Guess candidate learning rate | ||
learning_rate = jnp.minimum( | ||
increase_factor * state.tuned_learning_rate, max_learning_rate | ||
) | ||
|
||
next_value, next_grad = value, jnp.zeros_like(grad) | ||
init_carry = (learning_rate, next_value, next_grad, False, 0) | ||
|
||
learning_rate, next_value, next_grad, *_ = jax.lax.while_loop( | ||
cond_fn, body_fn, init_carry | ||
) | ||
|
||
new_updates = optax_tu.tree_scalar_mul(learning_rate, updates) | ||
return new_updates, BacktrackingLinesearchState( | ||
tuned_learning_rate=learning_rate, | ||
base_opt_state=base_opt_state, | ||
value_ref=next_value, | ||
grad_ref=next_grad, | ||
) | ||
|
||
return base.GradientTransformationExtraArgs(init_fn, update_fn) |
Oops, something went wrong.