Skip to content

Commit

Permalink
Merge pull request #1068 from carlosgmartin:typing_hashable_deprecated
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678250596
  • Loading branch information
OptaxDev committed Sep 24, 2024
2 parents f9807cc + 1cde704 commit 9785171
Show file tree
Hide file tree
Showing 27 changed files with 217 additions and 151 deletions.
3 changes: 2 additions & 1 deletion docs/ext/coverage_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# ==============================================================================
"""Asserts all public symbols are covered in the docs."""

from collections.abc import Mapping
import inspect
import types
from typing import Any, Mapping, Sequence, Tuple
from typing import Any, Sequence, Tuple

import optax
from sphinx import application
Expand Down
3 changes: 2 additions & 1 deletion examples/cifar10_resnet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
],
"source": [
"import functools\n",
"from typing import Any, Callable, Sequence, Tuple, Optional, Dict\n",
"from collections.abc import Callable\n",
"from typing import Any, Sequence, Tuple, Optional, Dict\n",
"\n",
"from flax import linen as nn\n",
"\n",
Expand Down
3 changes: 2 additions & 1 deletion examples/meta_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
},
"outputs": [],
"source": [
"from typing import Callable, Iterator, Tuple\n",
"from collections.abc import Callable\n",
"from typing import Iterator, Tuple\n",
"import chex\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down
4 changes: 2 additions & 2 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# ==============================================================================
"""Aliases for popular optimizers."""

from collections.abc import Callable
import functools
from typing import Any, Callable, Optional, Union
from typing import Any, Optional, Union

import jax.numpy as jnp

from optax._src import base
from optax._src import clipping
from optax._src import combine
Expand Down
48 changes: 26 additions & 22 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `alias.py`."""
"""Tests for methods defined in `alias.py`."""

from typing import Any, Callable, Union
from collections.abc import Callable
from typing import Any, Union

from absl.testing import absltest
from absl.testing import parameterized

import chex
import jax
from jax import flatten_util
import jax.numpy as jnp
import jax.random as jrd
import numpy as np

from optax._src import alias
from optax._src import base
from optax._src import linesearch as _linesearch
Expand All @@ -36,8 +35,6 @@
from optax.schedules import _inject
from optax.transforms import _accumulation
import optax.tree_utils as otu


import scipy.optimize as scipy_optimize
from sklearn import datasets
from sklearn import linear_model
Expand Down Expand Up @@ -81,7 +78,7 @@
dict(opt_name='rprop', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='polyak_sgd', opt_kwargs=dict(max_learning_rate=1.))
dict(opt_name='polyak_sgd', opt_kwargs=dict(max_learning_rate=1.0)),
)


Expand Down Expand Up @@ -111,8 +108,9 @@ def _setup_rosenbrock(dtype):
final_params = jnp.array([a, a**2], dtype=dtype)

def objective(params):
return (numerics.abs_sq(a - params[0]) +
b * numerics.abs_sq(params[1] - params[0]**2))
return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq(
params[1] - params[0] ** 2
)

return initial_params, final_params, objective

Expand Down Expand Up @@ -179,7 +177,8 @@ def step(params, state):
@chex.all_variants
@parameterized.product(_OPTIMIZERS_UNDER_TEST)
def test_optimizers_can_be_wrapped_in_inject_hyperparams(
self, opt_name, opt_kwargs):
self, opt_name, opt_kwargs
):
"""Checks that optimizers can be wrapped in inject_hyperparams."""
# See also https://github.com/google-deepmind/optax/issues/412.
opt_factory = getattr(alias, opt_name)
Expand All @@ -189,7 +188,8 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
# argument to be specified in order to be jittable. See issue
# https://github.com/google-deepmind/optax/issues/412.
opt_inject = _inject.inject_hyperparams(
opt_factory, static_args=('min_dim_size_to_factor',))(**opt_kwargs)
opt_factory, static_args=('min_dim_size_to_factor',)
)(**opt_kwargs)
else:
opt_inject = _inject.inject_hyperparams(opt_factory)(**opt_kwargs)

Expand All @@ -198,7 +198,7 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(

state = self.variant(opt.init)(params)
if opt_name == 'polyak_sgd':
update_kwargs = {'value': jnp.array(0.)}
update_kwargs = {'value': jnp.array(0.0)}
else:
update_kwargs = {}
updates, new_state = self.variant(opt.update)(
Expand All @@ -207,13 +207,15 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(

state_inject = self.variant(opt_inject.init)(params)
updates_inject, new_state_inject = self.variant(opt_inject.update)(
grads, state_inject, params, **update_kwargs)
grads, state_inject, params, **update_kwargs
)

with self.subTest('Equality of updates.'):
chex.assert_trees_all_close(updates_inject, updates, rtol=1e-4)
with self.subTest('Equality of new optimizer states.'):
chex.assert_trees_all_close(
new_state_inject.inner_state, new_state, rtol=1e-4)
new_state_inject.inner_state, new_state, rtol=1e-4
)

@parameterized.product(
params_dtype=('bfloat16', 'float32', 'complex64', None),
Expand All @@ -233,8 +235,7 @@ def test_explicit_dtype(self, params_dtype, state_dtype, opt_name):
params_dtype = jax.dtypes.canonicalize_dtype(params_dtype)
params = jnp.array([0.0, 0.0], dtype=params_dtype)
state_has_lower_dtype = (
jnp.promote_types(params_dtype, jnp.dtype(state_dtype))
== params_dtype
jnp.promote_types(params_dtype, jnp.dtype(state_dtype)) == params_dtype
)
if state_dtype is None or state_has_lower_dtype:
state = opt.init(params)
Expand Down Expand Up @@ -266,9 +267,7 @@ def test_explicit_dtype(self, params_dtype, state_dtype, opt_name):
@chex.variants(
with_jit=True, without_jit=True, with_device=True, with_pmap=True
)
@parameterized.product(
_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32')
)
@parameterized.product(_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32'))
def test_preserve_dtype(self, opt_name, opt_kwargs, dtype):
"""Test that the optimizers return updates of same dtype as params."""
# When debugging this test, note that operations like
Expand Down Expand Up @@ -318,6 +317,7 @@ def test_gradient_accumulation(self, opt_name, opt_kwargs, dtype):
updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs)
chex.assert_trees_all_equal(updates, jnp.zeros_like(grads))


##########################
# ALGORITHM SPECIFIC TESTS
##########################
Expand All @@ -336,6 +336,7 @@ def _run_lbfgs_solver(
) -> tuple[chex.ArrayTree, base.OptState]:
"""Run LBFGS solver by iterative calls to grad transform and apply_updates."""
value_and_grad_fun = jax.value_and_grad(fun)

def stopping_criterion(carry):
_, _, count, grad = carry
return (otu.tree_l2_norm(grad) >= tol) & (count < maxiter)
Expand All @@ -347,7 +348,7 @@ def step(carry):
grad, state, params, value=value, grad=grad, value_fn=fun
)
params = update.apply_updates(params, updates)
return params, state, count+1, grad
return params, state, count + 1, grad

init_state = opt.init(init_params)
init_grad = jax.grad(fun)(init_params)
Expand Down Expand Up @@ -492,7 +493,7 @@ def _plain_lbfgs(
dws = dws[1:] # Pop left.
dus = dus[1:]

grad_norm = jnp.sqrt(jnp.sum(g ** 2))
grad_norm = jnp.sqrt(jnp.sum(g**2))

if grad_norm <= tol:
break
Expand Down Expand Up @@ -729,13 +730,14 @@ def fun(x):
sol_arr, _ = _run_lbfgs_solver(opt, fun, init_array, maxiter=3)
sol_tree, _ = _run_lbfgs_solver(opt, fun, init_tree, maxiter=3)
sol_tree = jnp.stack((sol_tree[0], sol_tree[1]))
chex.assert_trees_all_close(sol_arr, sol_tree, rtol=5*1e-5, atol=5*1e-5)
chex.assert_trees_all_close(sol_arr, sol_tree, rtol=5 * 1e-5, atol=5 * 1e-5)

@parameterized.product(scale_init_precond=[True, False])
def test_multiclass_logreg(self, scale_init_precond):
data = datasets.make_classification(
n_samples=10, n_features=5, n_classes=3, n_informative=3, random_state=0
)

def fun(params):
inputs, labels = data
weights, bias = params
Expand Down Expand Up @@ -868,8 +870,10 @@ def test_steep_objective(self):
tol = 1e-5
n = 2
mat = jnp.eye(n) * 1e4

def fun(x):
return jnp.mean((mat @ x) ** 2)

opt = alias.lbfgs()
sol, _ = _run_lbfgs_solver(opt, fun, init_params=jnp.ones(n), tol=tol)
chex.assert_trees_all_close(sol, jnp.zeros(n), atol=tol, rtol=tol)
Expand Down
3 changes: 2 additions & 1 deletion optax/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# ==============================================================================
"""Base interfaces and datatypes."""

from typing import Any, Callable, NamedTuple, Optional, Protocol, runtime_checkable, Sequence, Union
from collections.abc import Callable
from typing import Any, NamedTuple, Optional, Protocol, runtime_checkable, Sequence, Union

import chex
import jax
Expand Down
64 changes: 37 additions & 27 deletions optax/_src/factorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,29 @@
# ==============================================================================
"""Factorized optimizers."""

from collections.abc import Callable
import dataclasses
from typing import NamedTuple, Optional, Callable
from typing import NamedTuple, Optional

import chex
import jax
import jax.numpy as jnp
import numpy as np

from optax._src import base
from optax._src import numerics


# pylint:disable=no-value-for-parameter


def _decay_rate_pow(i: int, exponent: float = 0.8) -> chex.Array:
"""Second-order moment decay schedule."""
t = jnp.array(i + 1, jnp.float32)
return 1.0 - t**(-exponent)
return 1.0 - t ** (-exponent)


def _factored_dims(
shape: base.Shape,
factored: bool,
min_dim_size_to_factor: int
shape: base.Shape, factored: bool, min_dim_size_to_factor: int
) -> Optional[tuple[int, int]]:
"""Whether to use a factored second moment estimator.
Expand All @@ -47,8 +46,8 @@ def _factored_dims(
Args:
shape: an input shape
factored: whether to use factored second-moment estimator for 2d vars.
min_dim_size_to_factor: only factor accumulator if two array dimensions
have at least this size.
min_dim_size_to_factor: only factor accumulator if two array dimensions have
at least this size.
Returns:
None or a tuple of ints
Expand All @@ -64,6 +63,7 @@ def _factored_dims(
@dataclasses.dataclass
class _UpdateResult:
"""Opaque containter that is not traversed by jax.tree.map."""

update: chex.Array # the update to apply to params
v_row: chex.Array # used for factored params.
v_col: chex.Array # used for factored params.
Expand All @@ -72,6 +72,7 @@ class _UpdateResult:

class FactoredState(NamedTuple):
"""Overall state of the gradient transformation."""

count: chex.Array # number of update steps.
v_row: chex.ArrayTree # Tree of factored params.
v_col: chex.ArrayTree # Tree of factored params.
Expand All @@ -84,7 +85,8 @@ def scale_by_factored_rms(
step_offset: int = 0,
min_dim_size_to_factor: int = 128,
epsilon: float = 1e-30,
decay_rate_fn: Callable[[int, float], chex.Array] = _decay_rate_pow):
decay_rate_fn: Callable[[int, float], chex.Array] = _decay_rate_pow,
):
"""Scaling by a factored estimate of the gradient rms (as in Adafactor).
This is a so-called "1+epsilon" scaling algorithms, that is extremely memory
Expand Down Expand Up @@ -120,7 +122,8 @@ def _to_state(count: chex.Array, result_tree):
count=count,
v_row=jax.tree.map(lambda o: o.v_row, result_tree),
v_col=jax.tree.map(lambda o: o.v_col, result_tree),
v=jax.tree.map(lambda o: o.v, result_tree))
v=jax.tree.map(lambda o: o.v, result_tree),
)

def init_fn(params):
"""Initialise the optimiser's state."""
Expand All @@ -136,16 +139,17 @@ def _init(param):
update=jnp.zeros((1,), dtype=dtype),
v_row=jnp.zeros(vr_shape, dtype=dtype),
v_col=jnp.zeros(vc_shape, dtype=dtype),
v=jnp.zeros((1,), dtype=dtype))
v=jnp.zeros((1,), dtype=dtype),
)
else:
return _UpdateResult(
update=jnp.zeros((1,), dtype=dtype),
v_row=jnp.zeros((1,), dtype=dtype),
v_col=jnp.zeros((1,), dtype=dtype),
v=jnp.zeros(param.shape, dtype=dtype))
v=jnp.zeros(param.shape, dtype=dtype),
)

return _to_state(
jnp.zeros([], jnp.int32), jax.tree.map(_init, params))
return _to_state(jnp.zeros([], jnp.int32), jax.tree.map(_init, params))

def update_fn(grads, state, params):
"""Apply gradient transformation."""
Expand All @@ -165,34 +169,40 @@ def _update(grad, v_row, v_col, v, param, step):
if factored_dims is not None:
d1, d0 = factored_dims
grad_sqr = numerics.abs_sq(grad) + epsilon
new_v_row = (
decay_rate_t * v_row +
(1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d0))
new_v_col = (
decay_rate_t * v_col +
(1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d1))
new_v_row = decay_rate_t * v_row + (1.0 - decay_rate_t) * jnp.mean(
grad_sqr, axis=d0
)
new_v_col = decay_rate_t * v_col + (1.0 - decay_rate_t) * jnp.mean(
grad_sqr, axis=d1
)
new_v_row = new_v_row.astype(dtype)
new_v_col = new_v_col.astype(dtype)
reduced_d1 = d1-1 if d1 > d0 else d1
reduced_d1 = d1 - 1 if d1 > d0 else d1
row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True)
row_factor = (new_v_row / row_col_mean) ** -0.5
col_factor = (new_v_col) ** -0.5
update = (
grad *
jnp.expand_dims(row_factor, axis=d0) *
jnp.expand_dims(col_factor, axis=d1))
grad
* jnp.expand_dims(row_factor, axis=d0)
* jnp.expand_dims(col_factor, axis=d1)
)
else:
grad_sqr = numerics.abs_sq(grad) + epsilon
new_v = decay_rate_t * v + (1. - decay_rate_t) * grad_sqr
new_v = decay_rate_t * v + (1.0 - decay_rate_t) * grad_sqr
new_v = new_v.astype(dtype)
update = grad * (new_v)**-0.5
update = grad * (new_v) ** -0.5

return _UpdateResult(update, new_v_row, new_v_col, new_v)

# Transform grad and compute new per-parameter stats.
output = jax.tree.map(
lambda *args: _update(*args, state.count),
grads, state.v_row, state.v_col, state.v, params)
grads,
state.v_row,
state.v_col,
state.v,
params,
)

# Unpack updates / stats and return.
updates = jax.tree.map(lambda o: o.update, output)
Expand Down
Loading

0 comments on commit 9785171

Please sign in to comment.