Skip to content

Commit

Permalink
Merge pull request #1013 from jungtaekkim:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 652837562
  • Loading branch information
OptaxDev committed Jul 16, 2024
2 parents 460a64c + c31ef5c commit ed6062d
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions optax/contrib/_dog.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def dowg(
r"""Distance over weighted Gradients optimizer.
Examples:
>>> import optax
>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
Expand All @@ -306,13 +307,15 @@ def dowg(
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... value, grad = jax.value_and_grad(f)(params)
... params, opt_state = solver.update(grad, opt_state, params, value=value)
... updates, opt_state = solver.update(
... grad, opt_state, params, value=value)
... params = optax.apply_updates(params, updates)
... print('Objective function: ', f(params))
Objective function: 9.973327e-05
Objective function: 7.0334883
Objective function: 14.074293
Objective function: 49.897446
Objective function: 42.62062
Objective function: 13.925367
Objective function: 13.872763
Objective function: 13.775433
Objective function: 13.596172
Objective function: 13.268837
References:
Khaled et al., `DoWG Unleashed: An Efficient Universal Parameter-Free
Expand Down

0 comments on commit ed6062d

Please sign in to comment.