diff --git a/docs/ext/coverage_check.py b/docs/ext/coverage_check.py index 4046d2123..0f47e8bea 100644 --- a/docs/ext/coverage_check.py +++ b/docs/ext/coverage_check.py @@ -14,9 +14,10 @@ # ============================================================================== """Asserts all public symbols are covered in the docs.""" +from collections.abc import Mapping import inspect import types -from typing import Any, Mapping, Sequence, Tuple +from typing import Any, Sequence, Tuple import optax from sphinx import application diff --git a/examples/cifar10_resnet.ipynb b/examples/cifar10_resnet.ipynb index 7b5706d4d..a9f9374e1 100644 --- a/examples/cifar10_resnet.ipynb +++ b/examples/cifar10_resnet.ipynb @@ -44,7 +44,8 @@ ], "source": [ "import functools\n", - "from typing import Any, Callable, Sequence, Tuple, Optional, Dict\n", + "from collections.abc import Callable\n", + "from typing import Any, Sequence, Tuple, Optional, Dict\n", "\n", "from flax import linen as nn\n", "\n", diff --git a/examples/meta_learning.ipynb b/examples/meta_learning.ipynb index e30e42920..8df9648c2 100644 --- a/examples/meta_learning.ipynb +++ b/examples/meta_learning.ipynb @@ -45,7 +45,8 @@ }, "outputs": [], "source": [ - "from typing import Callable, Iterator, Tuple\n", + "from collections.abc import Callable\n", + "from typing import Iterator, Tuple\n", "import chex\n", "import jax\n", "import jax.numpy as jnp\n", diff --git a/optax/_src/alias.py b/optax/_src/alias.py index c7660298f..59c7ec63b 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -14,11 +14,11 @@ # ============================================================================== """Aliases for popular optimizers.""" +from collections.abc import Callable import functools -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import jax.numpy as jnp - from optax._src import base from optax._src import clipping from optax._src import combine diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index d0e2e35b5..ef0f5ea9e 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -12,20 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for `alias.py`.""" +"""Tests for methods defined in `alias.py`.""" -from typing import Any, Callable, Union +from collections.abc import Callable +from typing import Any, Union from absl.testing import absltest from absl.testing import parameterized - import chex import jax from jax import flatten_util import jax.numpy as jnp import jax.random as jrd import numpy as np - from optax._src import alias from optax._src import base from optax._src import linesearch as _linesearch @@ -36,8 +35,6 @@ from optax.schedules import _inject from optax.transforms import _accumulation import optax.tree_utils as otu - - import scipy.optimize as scipy_optimize from sklearn import datasets from sklearn import linear_model @@ -81,7 +78,7 @@ dict(opt_name='rprop', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)), dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='polyak_sgd', opt_kwargs=dict(max_learning_rate=1.)) + dict(opt_name='polyak_sgd', opt_kwargs=dict(max_learning_rate=1.0)), ) @@ -111,8 +108,9 @@ def _setup_rosenbrock(dtype): final_params = jnp.array([a, a**2], dtype=dtype) def objective(params): - return (numerics.abs_sq(a - params[0]) + - b * numerics.abs_sq(params[1] - params[0]**2)) + return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq( + params[1] - params[0] ** 2 + ) return initial_params, final_params, objective @@ -179,7 +177,8 @@ def step(params, state): @chex.all_variants @parameterized.product(_OPTIMIZERS_UNDER_TEST) def test_optimizers_can_be_wrapped_in_inject_hyperparams( - self, opt_name, opt_kwargs): + self, opt_name, opt_kwargs + ): """Checks that optimizers can be wrapped in inject_hyperparams.""" # See also https://github.com/google-deepmind/optax/issues/412. opt_factory = getattr(alias, opt_name) @@ -189,7 +188,8 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( # argument to be specified in order to be jittable. See issue # https://github.com/google-deepmind/optax/issues/412. opt_inject = _inject.inject_hyperparams( - opt_factory, static_args=('min_dim_size_to_factor',))(**opt_kwargs) + opt_factory, static_args=('min_dim_size_to_factor',) + )(**opt_kwargs) else: opt_inject = _inject.inject_hyperparams(opt_factory)(**opt_kwargs) @@ -198,7 +198,7 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( state = self.variant(opt.init)(params) if opt_name == 'polyak_sgd': - update_kwargs = {'value': jnp.array(0.)} + update_kwargs = {'value': jnp.array(0.0)} else: update_kwargs = {} updates, new_state = self.variant(opt.update)( @@ -207,13 +207,15 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( state_inject = self.variant(opt_inject.init)(params) updates_inject, new_state_inject = self.variant(opt_inject.update)( - grads, state_inject, params, **update_kwargs) + grads, state_inject, params, **update_kwargs + ) with self.subTest('Equality of updates.'): chex.assert_trees_all_close(updates_inject, updates, rtol=1e-4) with self.subTest('Equality of new optimizer states.'): chex.assert_trees_all_close( - new_state_inject.inner_state, new_state, rtol=1e-4) + new_state_inject.inner_state, new_state, rtol=1e-4 + ) @parameterized.product( params_dtype=('bfloat16', 'float32', 'complex64', None), @@ -233,8 +235,7 @@ def test_explicit_dtype(self, params_dtype, state_dtype, opt_name): params_dtype = jax.dtypes.canonicalize_dtype(params_dtype) params = jnp.array([0.0, 0.0], dtype=params_dtype) state_has_lower_dtype = ( - jnp.promote_types(params_dtype, jnp.dtype(state_dtype)) - == params_dtype + jnp.promote_types(params_dtype, jnp.dtype(state_dtype)) == params_dtype ) if state_dtype is None or state_has_lower_dtype: state = opt.init(params) @@ -266,9 +267,7 @@ def test_explicit_dtype(self, params_dtype, state_dtype, opt_name): @chex.variants( with_jit=True, without_jit=True, with_device=True, with_pmap=True ) - @parameterized.product( - _OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32') - ) + @parameterized.product(_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32')) def test_preserve_dtype(self, opt_name, opt_kwargs, dtype): """Test that the optimizers return updates of same dtype as params.""" # When debugging this test, note that operations like @@ -318,6 +317,7 @@ def test_gradient_accumulation(self, opt_name, opt_kwargs, dtype): updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) chex.assert_trees_all_equal(updates, jnp.zeros_like(grads)) + ########################## # ALGORITHM SPECIFIC TESTS ########################## @@ -336,6 +336,7 @@ def _run_lbfgs_solver( ) -> tuple[chex.ArrayTree, base.OptState]: """Run LBFGS solver by iterative calls to grad transform and apply_updates.""" value_and_grad_fun = jax.value_and_grad(fun) + def stopping_criterion(carry): _, _, count, grad = carry return (otu.tree_l2_norm(grad) >= tol) & (count < maxiter) @@ -347,7 +348,7 @@ def step(carry): grad, state, params, value=value, grad=grad, value_fn=fun ) params = update.apply_updates(params, updates) - return params, state, count+1, grad + return params, state, count + 1, grad init_state = opt.init(init_params) init_grad = jax.grad(fun)(init_params) @@ -492,7 +493,7 @@ def _plain_lbfgs( dws = dws[1:] # Pop left. dus = dus[1:] - grad_norm = jnp.sqrt(jnp.sum(g ** 2)) + grad_norm = jnp.sqrt(jnp.sum(g**2)) if grad_norm <= tol: break @@ -729,13 +730,14 @@ def fun(x): sol_arr, _ = _run_lbfgs_solver(opt, fun, init_array, maxiter=3) sol_tree, _ = _run_lbfgs_solver(opt, fun, init_tree, maxiter=3) sol_tree = jnp.stack((sol_tree[0], sol_tree[1])) - chex.assert_trees_all_close(sol_arr, sol_tree, rtol=5*1e-5, atol=5*1e-5) + chex.assert_trees_all_close(sol_arr, sol_tree, rtol=5 * 1e-5, atol=5 * 1e-5) @parameterized.product(scale_init_precond=[True, False]) def test_multiclass_logreg(self, scale_init_precond): data = datasets.make_classification( n_samples=10, n_features=5, n_classes=3, n_informative=3, random_state=0 ) + def fun(params): inputs, labels = data weights, bias = params @@ -868,8 +870,10 @@ def test_steep_objective(self): tol = 1e-5 n = 2 mat = jnp.eye(n) * 1e4 + def fun(x): return jnp.mean((mat @ x) ** 2) + opt = alias.lbfgs() sol, _ = _run_lbfgs_solver(opt, fun, init_params=jnp.ones(n), tol=tol) chex.assert_trees_all_close(sol, jnp.zeros(n), atol=tol, rtol=tol) diff --git a/optax/_src/base.py b/optax/_src/base.py index 1ed155f9e..c56d3aa1f 100644 --- a/optax/_src/base.py +++ b/optax/_src/base.py @@ -14,7 +14,8 @@ # ============================================================================== """Base interfaces and datatypes.""" -from typing import Any, Callable, NamedTuple, Optional, Protocol, runtime_checkable, Sequence, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Protocol, runtime_checkable, Sequence, Union import chex import jax diff --git a/optax/_src/factorized.py b/optax/_src/factorized.py index 2f040e66d..bd542b628 100644 --- a/optax/_src/factorized.py +++ b/optax/_src/factorized.py @@ -14,30 +14,29 @@ # ============================================================================== """Factorized optimizers.""" +from collections.abc import Callable import dataclasses -from typing import NamedTuple, Optional, Callable +from typing import NamedTuple, Optional import chex import jax import jax.numpy as jnp import numpy as np - from optax._src import base from optax._src import numerics + # pylint:disable=no-value-for-parameter def _decay_rate_pow(i: int, exponent: float = 0.8) -> chex.Array: """Second-order moment decay schedule.""" t = jnp.array(i + 1, jnp.float32) - return 1.0 - t**(-exponent) + return 1.0 - t ** (-exponent) def _factored_dims( - shape: base.Shape, - factored: bool, - min_dim_size_to_factor: int + shape: base.Shape, factored: bool, min_dim_size_to_factor: int ) -> Optional[tuple[int, int]]: """Whether to use a factored second moment estimator. @@ -47,8 +46,8 @@ def _factored_dims( Args: shape: an input shape factored: whether to use factored second-moment estimator for 2d vars. - min_dim_size_to_factor: only factor accumulator if two array dimensions - have at least this size. + min_dim_size_to_factor: only factor accumulator if two array dimensions have + at least this size. Returns: None or a tuple of ints @@ -64,6 +63,7 @@ def _factored_dims( @dataclasses.dataclass class _UpdateResult: """Opaque containter that is not traversed by jax.tree.map.""" + update: chex.Array # the update to apply to params v_row: chex.Array # used for factored params. v_col: chex.Array # used for factored params. @@ -72,6 +72,7 @@ class _UpdateResult: class FactoredState(NamedTuple): """Overall state of the gradient transformation.""" + count: chex.Array # number of update steps. v_row: chex.ArrayTree # Tree of factored params. v_col: chex.ArrayTree # Tree of factored params. @@ -84,7 +85,8 @@ def scale_by_factored_rms( step_offset: int = 0, min_dim_size_to_factor: int = 128, epsilon: float = 1e-30, - decay_rate_fn: Callable[[int, float], chex.Array] = _decay_rate_pow): + decay_rate_fn: Callable[[int, float], chex.Array] = _decay_rate_pow, +): """Scaling by a factored estimate of the gradient rms (as in Adafactor). This is a so-called "1+epsilon" scaling algorithms, that is extremely memory @@ -120,7 +122,8 @@ def _to_state(count: chex.Array, result_tree): count=count, v_row=jax.tree.map(lambda o: o.v_row, result_tree), v_col=jax.tree.map(lambda o: o.v_col, result_tree), - v=jax.tree.map(lambda o: o.v, result_tree)) + v=jax.tree.map(lambda o: o.v, result_tree), + ) def init_fn(params): """Initialise the optimiser's state.""" @@ -136,16 +139,17 @@ def _init(param): update=jnp.zeros((1,), dtype=dtype), v_row=jnp.zeros(vr_shape, dtype=dtype), v_col=jnp.zeros(vc_shape, dtype=dtype), - v=jnp.zeros((1,), dtype=dtype)) + v=jnp.zeros((1,), dtype=dtype), + ) else: return _UpdateResult( update=jnp.zeros((1,), dtype=dtype), v_row=jnp.zeros((1,), dtype=dtype), v_col=jnp.zeros((1,), dtype=dtype), - v=jnp.zeros(param.shape, dtype=dtype)) + v=jnp.zeros(param.shape, dtype=dtype), + ) - return _to_state( - jnp.zeros([], jnp.int32), jax.tree.map(_init, params)) + return _to_state(jnp.zeros([], jnp.int32), jax.tree.map(_init, params)) def update_fn(grads, state, params): """Apply gradient transformation.""" @@ -165,34 +169,40 @@ def _update(grad, v_row, v_col, v, param, step): if factored_dims is not None: d1, d0 = factored_dims grad_sqr = numerics.abs_sq(grad) + epsilon - new_v_row = ( - decay_rate_t * v_row + - (1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d0)) - new_v_col = ( - decay_rate_t * v_col + - (1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d1)) + new_v_row = decay_rate_t * v_row + (1.0 - decay_rate_t) * jnp.mean( + grad_sqr, axis=d0 + ) + new_v_col = decay_rate_t * v_col + (1.0 - decay_rate_t) * jnp.mean( + grad_sqr, axis=d1 + ) new_v_row = new_v_row.astype(dtype) new_v_col = new_v_col.astype(dtype) - reduced_d1 = d1-1 if d1 > d0 else d1 + reduced_d1 = d1 - 1 if d1 > d0 else d1 row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True) row_factor = (new_v_row / row_col_mean) ** -0.5 col_factor = (new_v_col) ** -0.5 update = ( - grad * - jnp.expand_dims(row_factor, axis=d0) * - jnp.expand_dims(col_factor, axis=d1)) + grad + * jnp.expand_dims(row_factor, axis=d0) + * jnp.expand_dims(col_factor, axis=d1) + ) else: grad_sqr = numerics.abs_sq(grad) + epsilon - new_v = decay_rate_t * v + (1. - decay_rate_t) * grad_sqr + new_v = decay_rate_t * v + (1.0 - decay_rate_t) * grad_sqr new_v = new_v.astype(dtype) - update = grad * (new_v)**-0.5 + update = grad * (new_v) ** -0.5 return _UpdateResult(update, new_v_row, new_v_col, new_v) # Transform grad and compute new per-parameter stats. output = jax.tree.map( lambda *args: _update(*args, state.count), - grads, state.v_row, state.v_col, state.v, params) + grads, + state.v_row, + state.v_col, + state.v, + params, + ) # Unpack updates / stats and return. updates = jax.tree.map(lambda o: o.update, output) diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index a84bd3aa4..9bf482cbf 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -14,8 +14,9 @@ # ============================================================================== """Linear algebra utilities used in optimisation.""" +from collections.abc import Callable import functools -from typing import Callable, Optional, Union +from typing import Optional, Union import chex import jax @@ -33,8 +34,9 @@ def _normalize_tree(x): def global_norm(updates: base.PyTree) -> chex.Array: """Compute the global norm across a nested structure of tensors.""" - return jnp.sqrt(sum( - jnp.sum(numerics.abs_sq(x)) for x in jax.tree.leaves(updates))) + return jnp.sqrt( + sum(jnp.sum(numerics.abs_sq(x)) for x in jax.tree.leaves(updates)) + ) def _power_iteration_cond_fun(error_tolerance, num_iters, loop_vars): @@ -69,12 +71,12 @@ def power_iteration( Args: matrix: a square matrix, either as an array or a callable implementing a matrix-vector product. - v0: initial vector approximating the dominiant eigenvector. If ``matrix`` - is an array of size (n, n), v0 must be a vector of size (n,). If instead + v0: initial vector approximating the dominiant eigenvector. If ``matrix`` is + an array of size (n, n), v0 must be a vector of size (n,). If instead ``matrix`` is a callable, then v0 must be a tree with the same structure as the input of this callable. If this argument is None and ``matrix`` is - an array, then a random vector sampled from a uniform distribution in - [-1, 1] is used as initial vector. + an array, then a random vector sampled from a uniform distribution in [-1, + 1] is used as initial vector. num_iters: Number of power iterations. error_tolerance: Iterative exit condition. The procedure stops when the relative error of the estimate of the dominant eigenvalue is below this @@ -83,8 +85,8 @@ def power_iteration( lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest). - key: random key for the initialization of ``v0`` when not given - explicitly. When this argument is None, `jax.random.PRNGKey(0)` is used. + key: random key for the initialization of ``v0`` when not given explicitly. + When this argument is None, `jax.random.PRNGKey(0)` is used. Returns: A pair (eigenvalue, eigenvector), where eigenvalue is the dominant @@ -129,19 +131,21 @@ def _body_fun(loop_vars): return eigvec, z, eig, iter_num + 1 init_vars = (v0, mvp(v0), jnp.asarray(0.0), jnp.asarray(0)) - _, unormalized_eigenvector, dominant_eigenvalue, _ = ( - jax.lax.while_loop(cond_fun, _body_fun, init_vars) + _, unormalized_eigenvector, dominant_eigenvalue, _ = jax.lax.while_loop( + cond_fun, _body_fun, init_vars ) normalized_eigenvector = _normalize_tree(unormalized_eigenvector) return dominant_eigenvalue, normalized_eigenvector -def matrix_inverse_pth_root(matrix: chex.Array, - p: int, - num_iters: int = 100, - ridge_epsilon: float = 1e-6, - error_tolerance: float = 1e-6, - precision: lax.Precision = lax.Precision.HIGHEST): +def matrix_inverse_pth_root( + matrix: chex.Array, + p: int, + num_iters: int = 100, + ridge_epsilon: float = 1e-6, + error_tolerance: float = 1e-6, + precision: lax.Precision = lax.Precision.HIGHEST, +): """Computes `matrix^(-1/p)`, where `p` is a positive integer. This function uses the Coupled newton iterations algorithm for @@ -159,10 +163,10 @@ def matrix_inverse_pth_root(matrix: chex.Array, num_iters: Maximum number of iterations. ridge_epsilon: Ridge epsilon added to make the matrix positive definite. error_tolerance: Error indicator, useful for early termination. - precision: precision XLA related flag, the available options are: - a) lax.Precision.DEFAULT (better step time, but not precise); - b) lax.Precision.HIGH (increased precision, slower); - c) lax.Precision.HIGHEST (best possible precision, slowest). + precision: precision XLA related flag, the available options are: a) + lax.Precision.DEFAULT (better step time, but not precise); b) + lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST + (best possible precision, slowest). Returns: matrix^(-1/p) @@ -174,8 +178,8 @@ def matrix_inverse_pth_root(matrix: chex.Array, alpha = jnp.asarray(-1.0 / p, jnp.float32) identity = jnp.eye(matrix_size, dtype=jnp.float32) max_ev, _ = power_iteration( - matrix=matrix, num_iters=100, - error_tolerance=1e-6, precision=precision) + matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision + ) ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16) def _unrolled_mat_pow_1(mat_m): @@ -189,14 +193,12 @@ def _unrolled_mat_pow_2(mat_m): def _unrolled_mat_pow_4(mat_m): """Computes mat_m^4.""" mat_pow_2 = _unrolled_mat_pow_2(mat_m) - return jnp.matmul( - mat_pow_2, mat_pow_2, precision=precision) + return jnp.matmul(mat_pow_2, mat_pow_2, precision=precision) def _unrolled_mat_pow_8(mat_m): """Computes mat_m^4.""" mat_pow_4 = _unrolled_mat_pow_4(mat_m) - return jnp.matmul( - mat_pow_4, mat_pow_4, precision=precision) + return jnp.matmul(mat_pow_4, mat_pow_4, precision=precision) def mat_power(mat_m, p): """Computes mat_m^p, for p == 1, 2, 4 or 8. @@ -211,18 +213,19 @@ def mat_power(mat_m, p): # We unrolled the loop for performance reasons. exponent = jnp.round(jnp.log2(p)) return lax.switch( - jnp.asarray(exponent, jnp.int32), [ + jnp.asarray(exponent, jnp.int32), + [ _unrolled_mat_pow_1, _unrolled_mat_pow_2, _unrolled_mat_pow_4, _unrolled_mat_pow_8, - ], (mat_m)) + ], + (mat_m), + ) def _iter_condition(state): - (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, - run_step) = state - error_above_threshold = jnp.logical_and( - error > error_tolerance, run_step) + (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state + error_above_threshold = jnp.logical_and(error > error_tolerance, run_step) return jnp.logical_and(i < num_iters, error_above_threshold) def _iter_body(state): @@ -233,11 +236,17 @@ def _iter_body(state): new_error = jnp.max(jnp.abs(new_mat_m - identity)) # sometimes error increases after an iteration before decreasing and # converging. 1.2 factor is used to bound the maximal allowed increase. - return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, - new_error < error * 1.2) + return ( + i + 1, + new_mat_m, + new_mat_h, + mat_h, + new_error, + new_error < error * 1.2, + ) if matrix_size == 1: - resultant_mat_h = (matrix + ridge_epsilon)**alpha + resultant_mat_h = (matrix + ridge_epsilon) ** alpha error = 0 else: damped_matrix = matrix + ridge_epsilon * identity @@ -247,9 +256,11 @@ def _iter_body(state): new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) new_mat_h_0 = identity * jnp.power(z, 1.0 / p) init_state = tuple( - [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True]) + [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True] + ) _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop( - _iter_condition, _iter_body, init_state) + _iter_condition, _iter_body, init_state + ) error = jnp.max(jnp.abs(mat_m - identity)) is_converged = jnp.asarray(convergence, old_mat_h.dtype) resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h diff --git a/optax/_src/linesearch.py b/optax/_src/linesearch.py index aef22cf94..38f88b5c8 100644 --- a/optax/_src/linesearch.py +++ b/optax/_src/linesearch.py @@ -14,8 +14,9 @@ # ============================================================================== """Line-searches.""" +from collections.abc import Callable import functools -from typing import Any, Callable, NamedTuple, Optional, Union +from typing import Any, NamedTuple, Optional, Union import chex import jax diff --git a/optax/_src/linesearch_test.py b/optax/_src/linesearch_test.py index 640dfdc0b..f9d040424 100644 --- a/optax/_src/linesearch_test.py +++ b/optax/_src/linesearch_test.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for `linesearch.py`.""" +"""Tests for methods defined in `linesearch.py`.""" +from collections.abc import Callable import contextlib import functools import io import itertools import math -from typing import Callable, Optional +from typing import Optional from absl.testing import absltest from absl.testing import parameterized @@ -375,7 +376,7 @@ def _check_linesearch_conditions( default_opt_args = dict( slope_rtol=1e-4, curv_rtol=0.9, - tol=0., + tol=0.0, ) opt_args = default_opt_args | opt_args slope_rtol, curv_rtol, tol = ( @@ -384,16 +385,16 @@ def _check_linesearch_conditions( opt_args['tol'], ) with self.subTest('Check decrease conditions'): - sufficient_decrease_error = ( - value_final - (value_init + slope_rtol * final_lr * slope_init + tol) + sufficient_decrease_error = value_final - ( + value_init + slope_rtol * final_lr * slope_init + tol ) self.assertTrue( (sufficient_decrease_error <= 0) or potentially_failed, f'Sufficent decrease error: {sufficient_decrease_error}', ) with self.subTest('Check curvature conditions'): - small_curvature_error = ( - jnp.abs(slope_final) - (curv_rtol * jnp.abs(slope_init) + tol) + small_curvature_error = jnp.abs(slope_final) - ( + curv_rtol * jnp.abs(slope_init) + tol ) self.assertTrue( @@ -428,9 +429,14 @@ def test_linesearch_with_jax_variants(self): @parameterized.product( problem_name=[ - 'polynomial', 'exponential', 'sinusoidal', - 'rosenbrock', 'himmelblau', 'matyas', 'eggholder' - ], + 'polynomial', + 'exponential', + 'sinusoidal', + 'rosenbrock', + 'himmelblau', + 'matyas', + 'eggholder', + ], seed=[0, 1], ) def test_linesearch(self, problem_name: str, seed: int): @@ -438,7 +444,7 @@ def test_linesearch(self, problem_name: str, seed: int): # Fixed tolerances, we check the behavior in standard conditions slope_rtol = 1e-4 curv_rtol = 0.9 - tol = 0. + tol = 0.0 key = jrd.PRNGKey(seed) params_key, precond_key = jrd.split(key, 2) @@ -450,14 +456,14 @@ def test_linesearch(self, problem_name: str, seed: int): # Mimics a preconditioning by a diagonal matrix with non-negative entries # (non-negativity ensures that we keep a descent direction) - init_updates = -precond_vec*jax.grad(fn)(init_params) + init_updates = -precond_vec * jax.grad(fn)(init_params) opt_args = dict( max_linesearch_steps=30, slope_rtol=slope_rtol, curv_rtol=curv_rtol, tol=tol, - max_learning_rate=None + max_learning_rate=None, ) opt = _linesearch.scale_by_zoom_linesearch(**opt_args) @@ -526,6 +532,7 @@ def test_failure_too_small_max_stepsize(self): if jax.default_backend() in ['tpu', 'gpu']: return else: + def fn(x): return jnp.dot(x, x) @@ -561,6 +568,7 @@ def test_failure_not_enough_iter(self): if jax.default_backend() in ['tpu', 'gpu']: return else: + def fn(x): return jnp.dot(x, x) @@ -611,6 +619,7 @@ def test_failure_flat_fun(self): if jax.default_backend() in ['tpu', 'gpu']: return else: + def fun_flat(x): return jnp.exp(-1 / x**2) @@ -634,6 +643,7 @@ def test_failure_inf_value(self): if jax.default_backend() in ['tpu', 'gpu']: return else: + def fun_inf(x): return jnp.log(x) @@ -663,13 +673,12 @@ def fn(x): u = -1.95 * w opt = _linesearch.scale_by_zoom_linesearch(max_linesearch_steps=20) - _, final_state = _run_linesearch( - opt, fn, w, u, stepsize_guess=1.0 - ) + _, final_state = _run_linesearch(opt, fn, w, u, stepsize_guess=1.0) decrease_error = otu.tree_get(final_state, 'decrease_error') curvature_error = otu.tree_get(final_state, 'curvature_error') success = (decrease_error <= 0.0) and (curvature_error <= 0.0) self.assertTrue(success) + if __name__ == '__main__': absltest.main() diff --git a/optax/_src/utils.py b/optax/_src/utils.py index 86e81dea2..5a3768367 100644 --- a/optax/_src/utils.py +++ b/optax/_src/utils.py @@ -14,10 +14,11 @@ # ============================================================================== """Utility functions for testing.""" +from collections.abc import Callable import functools import inspect import operator -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Union import chex from etils import epy @@ -28,6 +29,7 @@ from optax._src import linear_algebra from optax._src import numerics + with epy.lazy_imports(): import jax.scipy.stats.norm as multivariate_normal # pylint: disable=g-import-not-at-top,ungrouped-imports diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index 262c80a68..aaaa157ae 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -14,12 +14,11 @@ # ============================================================================== """Transformation wrappers.""" +from collections.abc import Callable import functools -from typing import Callable import chex import jax.numpy as jnp - from optax._src import base from optax.transforms import _accumulation from optax.transforms import _conditionality diff --git a/optax/contrib/_acprop.py b/optax/contrib/_acprop.py index a5b6788b1..2fe626644 100644 --- a/optax/contrib/_acprop.py +++ b/optax/contrib/_acprop.py @@ -18,7 +18,8 @@ Asynchronous Update for Adaptive Gradient Methods" by Zhuang et al. (https://arxiv.org/abs/2110.05454). """ -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import jax import jax.numpy as jnp diff --git a/optax/contrib/_cocob.py b/optax/contrib/_cocob.py index 7fcb77a27..8f05c1312 100644 --- a/optax/contrib/_cocob.py +++ b/optax/contrib/_cocob.py @@ -18,7 +18,8 @@ Networks without Learning Rates Through Coin Betting" by Francesco Orabona and Tatiana Tommasi. """ -from typing import Any, Callable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Union import jax import jax.numpy as jnp diff --git a/optax/contrib/_dog.py b/optax/contrib/_dog.py index b3967141f..79966fbf4 100644 --- a/optax/contrib/_dog.py +++ b/optax/contrib/_dog.py @@ -22,7 +22,8 @@ Gradient Descent Method`_, 2023. """ -from typing import Any, Callable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Union import chex import jax diff --git a/optax/contrib/_sam.py b/optax/contrib/_sam.py index 41ae73828..5bfdce96d 100644 --- a/optax/contrib/_sam.py +++ b/optax/contrib/_sam.py @@ -47,7 +47,8 @@ """ # pytype: skip-file -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import chex import jax import jax.numpy as jnp diff --git a/optax/losses/_ranking.py b/optax/losses/_ranking.py index 7138c93e5..c451296b9 100644 --- a/optax/losses/_ranking.py +++ b/optax/losses/_ranking.py @@ -47,7 +47,8 @@ [-0.755, 0.09, 0.665] """ -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import chex import jax diff --git a/optax/monte_carlo/control_variates.py b/optax/monte_carlo/control_variates.py index 95d5261f2..58f6a7247 100644 --- a/optax/monte_carlo/control_variates.py +++ b/optax/monte_carlo/control_variates.py @@ -53,7 +53,8 @@ For examples, see `control_delta_method` and `moving_avg_baseline`. """ -from typing import Any, Callable, Sequence +from collections.abc import Callable +from typing import Any, Sequence import chex import jax diff --git a/optax/monte_carlo/stochastic_gradient_estimators.py b/optax/monte_carlo/stochastic_gradient_estimators.py index fd2ba605d..ae76a2af5 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators.py +++ b/optax/monte_carlo/stochastic_gradient_estimators.py @@ -28,8 +28,9 @@ Monte Carlo Gradient Estimation in Machine Learning. JMLR, 2020. """ +from collections.abc import Callable import math -from typing import Any, Callable, Sequence +from typing import Any, Sequence import chex import jax diff --git a/optax/schedules/_inject.py b/optax/schedules/_inject.py index 52521b4ed..43bd2feef 100644 --- a/optax/schedules/_inject.py +++ b/optax/schedules/_inject.py @@ -14,9 +14,10 @@ # ============================================================================== """Utilities to inject dynamically changing hyper-parameters.""" +from collections.abc import Callable import functools import inspect -from typing import Callable, Iterable, NamedTuple, Optional, Union +from typing import Iterable, NamedTuple, Optional, Union import warnings import chex diff --git a/optax/transforms/_accumulation.py b/optax/transforms/_accumulation.py index b73b35dfe..3f182bf4e 100644 --- a/optax/transforms/_accumulation.py +++ b/optax/transforms/_accumulation.py @@ -14,7 +14,8 @@ # ============================================================================== """Gradient transformations for accumulating gradients across updates.""" -from typing import Any, Callable, NamedTuple, Optional, Protocol, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Protocol, Union import chex import jax diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py index 578068599..655ec807d 100644 --- a/optax/transforms/_adding.py +++ b/optax/transforms/_adding.py @@ -14,7 +14,8 @@ # ============================================================================== """Additive components in gradient transformations.""" -from typing import Any, Callable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Union import chex import jax diff --git a/optax/transforms/_combining.py b/optax/transforms/_combining.py index 5a74a65a5..a4fb7c084 100644 --- a/optax/transforms/_combining.py +++ b/optax/transforms/_combining.py @@ -14,7 +14,8 @@ # ============================================================================== """Flexibly compose gradient transformations.""" -from typing import Callable, NamedTuple, Union, Mapping, Hashable +from collections.abc import Callable, Hashable, Mapping +from typing import NamedTuple, Union import jax diff --git a/optax/transforms/_masking.py b/optax/transforms/_masking.py index 82f34e7f6..754d62ed5 100644 --- a/optax/transforms/_masking.py +++ b/optax/transforms/_masking.py @@ -14,7 +14,8 @@ # ============================================================================== """Wrappers that mask out part of the parameters when applying a transform.""" -from typing import Any, Callable, NamedTuple, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Union import jax diff --git a/optax/tree_utils/_random.py b/optax/tree_utils/_random.py index b4f875017..db541af03 100644 --- a/optax/tree_utils/_random.py +++ b/optax/tree_utils/_random.py @@ -14,7 +14,8 @@ # ============================================================================== """Utilities to generate random pytrees.""" -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import chex import jax diff --git a/optax/tree_utils/_random_test.py b/optax/tree_utils/_random_test.py index 9dfb0a073..8995c5329 100644 --- a/optax/tree_utils/_random_test.py +++ b/optax/tree_utils/_random_test.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for optax.tree_utils._random.""" +"""Tests for methods defined in optax.tree_utils._random.""" -from typing import Callable +from collections.abc import Callable from absl.testing import absltest from absl.testing import parameterized diff --git a/optax/tree_utils/_state_utils.py b/optax/tree_utils/_state_utils.py index bdc07cf17..08f018391 100644 --- a/optax/tree_utils/_state_utils.py +++ b/optax/tree_utils/_state_utils.py @@ -14,10 +14,11 @@ # ============================================================================== """Tools for mapping over optimizer states.""" +from collections.abc import Callable import dataclasses import functools import typing -from typing import Any, Callable, Optional, Protocol, Tuple, Union, cast +from typing import Any, Optional, Protocol, Tuple, Union, cast import jax from optax._src import base @@ -59,6 +60,7 @@ class NamedTupleKey: .. versionadded:: 0.2.2 """ + tuple_name: str name: str @@ -197,8 +199,10 @@ def tree_get_all_with_path( ... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path], ... sep="\n", ... ) - ("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1., dtype=float32)) - ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32))) + ("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1., + dtype=float32)) + ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", + WrappedScheduleState(count=Array(0, dtype=int32))) Usage with a filtering operation @@ -217,7 +221,8 @@ def tree_get_all_with_path( ... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path], ... sep="\n", ... ) - ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32))) + ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", + WrappedScheduleState(count=Array(0, dtype=int32))) .. seealso:: :func:`optax.tree_utils.tree_get`, :func:`optax.tree_utils.tree_set` @@ -226,10 +231,10 @@ def tree_get_all_with_path( tree: tree to search in. key: keyword or field to search in tree for. filtering: optional callable to further filter values in tree that match the - key. ``filtering(path: Key_Path, value: Any) -> bool: ...`` - takes as arguments both the path to the value (as returned by - :func:`optax.tree_utils.tree_get_all_with_path`) and the - value that match the given key. + key. ``filtering(path: Key_Path, value: Any) -> bool: ...`` takes as + arguments both the path to the value (as returned by + :func:`optax.tree_utils.tree_get_all_with_path`) and the value that match + the given key. Returns: values_with_path @@ -272,7 +277,7 @@ def tree_get( Raises a ``KeyError`` if multiple values of ``key`` are found in ``tree``. Generally, you may first get all pairs ``(path_to_value, value)`` for a given - ``key`` using :func:`optax.tree_utils.tree_get_all_with_path`. You may then + ``key`` using :func:`optax.tree_utils.tree_get_all_with_path`. You may then define a filtering operation ``filtering(path: Key_Path, value: Any) -> bool: ...`` that enables you to select the specific values you wanted to fetch by looking at the type of the @@ -330,7 +335,8 @@ def tree_get( >>> state = opt.init(params) >>> noise_state = optax.tree_utils.tree_get(state, 'AddNoiseState') >>> print(noise_state) - AddNoiseState(count=Array(0, dtype=int32), rng_key=Array([0, 0], dtype=uint32)) + AddNoiseState(count=Array(0, dtype=int32), rng_key=Array([0, 0], + dtype=uint32)) Differentiating between two values by the name of their named tuples. @@ -354,10 +360,10 @@ def tree_get( key: keyword or field to search in ``tree`` for. default: default value to return if ``key`` is not found in ``tree``. filtering: optional callable to further filter values in ``tree`` that match - the ``key``. ``filtering(path: Key_Path, value: Any) -> bool: ...`` - takes as arguments both the path to the value (as returned by - :func:`optax.tree_utils.tree_get_all_with_path`) and the - value that match the given key. + the ``key``. ``filtering(path: Key_Path, value: Any) -> bool: ...`` takes + as arguments both the path to the value (as returned by + :func:`optax.tree_utils.tree_get_all_with_path`) and the value that match + the given key. Returns: value @@ -412,10 +418,12 @@ def tree_set( >>> opt = optax.adam(learning_rate=1.) >>> state = opt.init(params) >>> print(state) - (ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState()) + (ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], + dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState()) >>> new_state = optax.tree_utils.tree_set(state, count=2.) >>> print(new_state) - (ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState()) + (ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), + nu=Array([0., 0., 0.], dtype=float32)), EmptyState()) Usage with a filtering operation @@ -427,13 +435,19 @@ def tree_set( ... ) >>> state = opt.init(params) >>> print(state) - InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState())) + InjectStatefulHyperparamsState(count=Array(0, dtype=int32), + hyperparams={'learning_rate': Array(1., dtype=float32)}, + hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, + dtype=int32))}, inner_state=(EmptyState(), EmptyState())) >>> filtering = lambda path, value: isinstance(value, jnp.ndarray) >>> new_state = optax.tree_utils.tree_set( ... state, filtering, learning_rate=jnp.asarray(0.1) ... ) >>> print(new_state) - InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState())) + InjectStatefulHyperparamsState(count=Array(0, dtype=int32), + hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, + hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, + dtype=int32))}, inner_state=(EmptyState(), EmptyState())) .. seealso:: :func:`optax.tree_utils.tree_get_all_with_path`, :func:`optax.tree_utils.tree_get` @@ -441,11 +455,10 @@ def tree_set( Args: tree: pytree whose values are to be replaced. filtering: optional callable to further filter values in ``tree`` that match - the keys to replace. - ``filtering(path: Key_Path, value: Any) -> bool: ...`` - takes as arguments both the path to the value (as returned by - :func:`optax.tree_utils.tree_get_all_with_path`) and the - value that match a given key. + the keys to replace. ``filtering(path: Key_Path, value: Any) -> bool: + ...`` takes as arguments both the path to the value (as returned by + :func:`optax.tree_utils.tree_get_all_with_path`) and the value that match + a given key. **kwargs: dictionary of keys with values to replace in ``tree``. Returns: