diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 39e0d96f..ca2366a5 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -282,10 +282,32 @@ def adagrad( initial_accumulator_value: float = 0.1, eps: float = 1e-7 ) -> base.GradientTransformation: - """The Adagrad optimizer. + r"""The Adagrad optimizer. - Adagrad is an algorithm for gradient based optimization that anneals the - learning rate for each parameter during the course of training. + AdaGrad is a sub-gradient algorithm for stochastic optimization that adapts + the learning rate individually for each feature based on its gradient history. + + The updated parameters adopt the form: + .. math:: + + w_{t+1}^{(i)} = w_{t}^{(i)} - \eta \frac{g_{t}^{(i)}} + {\sqrt{\sum_{\tau=1}^{t} (g_{\tau}^{(i)})^2 + \epsilon}} + + where: + - \( w_t^{(i)} \) is the parameter \( i \) at time step \( t \), + - \( \eta \) is the learning rate, + - \( g_t^{(i)} \) is the gradient of parameter \( i \) at time step \( t \), + - \( \epsilon \) is a small constant to ensure numerical stability. + + Defining \(G = \sum_{t=1}^\tau g_t g_t^\top\), the update can be written as + + .. math:: + + w_{t+1} = w_{t} - \eta \cdot \text{diag}(G + \epsilon I)^{-1/2} \cdot g_t + + where \(\text{diag} (G) = (G_{ii})_{i=1}^p\) is the vector of diagonal + entries of \(G \in \mathbb{R}^p\) and \(I\) is the identity matrix + in \(\mathbb{R}^p\). .. warning:: Adagrad's main limit is the monotonic accumulation of squared