From 1cde7043a965eef4c4f3817006159d16dd9fae3e Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Mon, 23 Sep 2024 13:22:48 -0400 Subject: [PATCH] Replace deprecated typing.Hashable, typing.Callable, and typing.Mapping with equivalents from collections.abc. --- docs/ext/coverage_check.py | 3 ++- examples/cifar10_resnet.ipynb | 3 ++- examples/meta_learning.ipynb | 3 ++- optax/_src/alias.py | 3 ++- optax/_src/alias_test.py | 3 ++- optax/_src/base.py | 3 ++- optax/_src/factorized.py | 3 ++- optax/_src/linear_algebra.py | 3 ++- optax/_src/linesearch.py | 3 ++- optax/_src/linesearch_test.py | 3 ++- optax/_src/utils.py | 3 ++- optax/_src/wrappers.py | 2 +- optax/contrib/_acprop.py | 3 ++- optax/contrib/_cocob.py | 3 ++- optax/contrib/_dog.py | 3 ++- optax/contrib/_sam.py | 3 ++- optax/losses/_ranking.py | 3 ++- optax/monte_carlo/control_variates.py | 3 ++- optax/monte_carlo/stochastic_gradient_estimators.py | 3 ++- optax/schedules/_inject.py | 3 ++- optax/transforms/_accumulation.py | 3 ++- optax/transforms/_adding.py | 3 ++- optax/transforms/_combining.py | 3 ++- optax/transforms/_masking.py | 3 ++- optax/tree_utils/_random.py | 3 ++- optax/tree_utils/_random_test.py | 2 +- optax/tree_utils/_state_utils.py | 3 ++- 27 files changed, 52 insertions(+), 27 deletions(-) diff --git a/docs/ext/coverage_check.py b/docs/ext/coverage_check.py index 4046d212..0626b64a 100644 --- a/docs/ext/coverage_check.py +++ b/docs/ext/coverage_check.py @@ -16,7 +16,8 @@ import inspect import types -from typing import Any, Mapping, Sequence, Tuple +from collections.abc import Mapping +from typing import Any, Sequence, Tuple import optax from sphinx import application diff --git a/examples/cifar10_resnet.ipynb b/examples/cifar10_resnet.ipynb index 7b5706d4..a9f9374e 100644 --- a/examples/cifar10_resnet.ipynb +++ b/examples/cifar10_resnet.ipynb @@ -44,7 +44,8 @@ ], "source": [ "import functools\n", - "from typing import Any, Callable, Sequence, Tuple, Optional, Dict\n", + "from collections.abc import Callable\n", + "from typing import Any, Sequence, Tuple, Optional, Dict\n", "\n", "from flax import linen as nn\n", "\n", diff --git a/examples/meta_learning.ipynb b/examples/meta_learning.ipynb index e30e4292..8df9648c 100644 --- a/examples/meta_learning.ipynb +++ b/examples/meta_learning.ipynb @@ -45,7 +45,8 @@ }, "outputs": [], "source": [ - "from typing import Callable, Iterator, Tuple\n", + "from collections.abc import Callable\n", + "from typing import Iterator, Tuple\n", "import chex\n", "import jax\n", "import jax.numpy as jnp\n", diff --git a/optax/_src/alias.py b/optax/_src/alias.py index c7660298..90abc698 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -15,7 +15,8 @@ """Aliases for popular optimizers.""" import functools -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import jax.numpy as jnp diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index a2444c68..d00c34f5 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -14,7 +14,8 @@ # ============================================================================== """Tests for `alias.py`.""" -from typing import Any, Callable, Union +from collections.abc import Callable +from typing import Any, Union from absl.testing import absltest from absl.testing import parameterized diff --git a/optax/_src/base.py b/optax/_src/base.py index 1ed155f9..c56d3aa1 100644 --- a/optax/_src/base.py +++ b/optax/_src/base.py @@ -14,7 +14,8 @@ # ============================================================================== """Base interfaces and datatypes.""" -from typing import Any, Callable, NamedTuple, Optional, Protocol, runtime_checkable, Sequence, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Protocol, runtime_checkable, Sequence, Union import chex import jax diff --git a/optax/_src/factorized.py b/optax/_src/factorized.py index 7379c9ac..22ce7eef 100644 --- a/optax/_src/factorized.py +++ b/optax/_src/factorized.py @@ -15,7 +15,8 @@ """Factorized optimizers.""" import dataclasses -from typing import NamedTuple, Optional, Callable +from collections.abc import Callable +from typing import NamedTuple, Optional import chex import jax diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index a84bd3aa..cc2412cc 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -15,7 +15,8 @@ """Linear algebra utilities used in optimisation.""" import functools -from typing import Callable, Optional, Union +from collections.abc import Callable +from typing import Optional, Union import chex import jax diff --git a/optax/_src/linesearch.py b/optax/_src/linesearch.py index aef22cf9..94cf1078 100644 --- a/optax/_src/linesearch.py +++ b/optax/_src/linesearch.py @@ -15,7 +15,8 @@ """Line-searches.""" import functools -from typing import Any, Callable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Union import chex import jax diff --git a/optax/_src/linesearch_test.py b/optax/_src/linesearch_test.py index 640dfdc0..b63f4ac0 100644 --- a/optax/_src/linesearch_test.py +++ b/optax/_src/linesearch_test.py @@ -19,7 +19,8 @@ import io import itertools import math -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional from absl.testing import absltest from absl.testing import parameterized diff --git a/optax/_src/utils.py b/optax/_src/utils.py index 86e81dea..ef7f37a5 100644 --- a/optax/_src/utils.py +++ b/optax/_src/utils.py @@ -17,7 +17,8 @@ import functools import inspect import operator -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Callable +from typing import Any, Optional, Sequence, Union import chex from etils import epy diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index 262c80a6..f1e57237 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -15,7 +15,7 @@ """Transformation wrappers.""" import functools -from typing import Callable +from collections.abc import Callable import chex import jax.numpy as jnp diff --git a/optax/contrib/_acprop.py b/optax/contrib/_acprop.py index a5b6788b..2fe62664 100644 --- a/optax/contrib/_acprop.py +++ b/optax/contrib/_acprop.py @@ -18,7 +18,8 @@ Asynchronous Update for Adaptive Gradient Methods" by Zhuang et al. (https://arxiv.org/abs/2110.05454). """ -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import jax import jax.numpy as jnp diff --git a/optax/contrib/_cocob.py b/optax/contrib/_cocob.py index edf13a02..ad5cdbab 100644 --- a/optax/contrib/_cocob.py +++ b/optax/contrib/_cocob.py @@ -18,7 +18,8 @@ Networks without Learning Rates Through Coin Betting" by Francesco Orabona and Tatiana Tommasi. """ -from typing import Any, Callable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Union import jax import jax.numpy as jnp diff --git a/optax/contrib/_dog.py b/optax/contrib/_dog.py index c8249ab8..e46be349 100644 --- a/optax/contrib/_dog.py +++ b/optax/contrib/_dog.py @@ -22,7 +22,8 @@ Gradient Descent Method`_, 2023. """ -from typing import Any, Callable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Union import chex import jax diff --git a/optax/contrib/_sam.py b/optax/contrib/_sam.py index 41ae7382..5bfdce96 100644 --- a/optax/contrib/_sam.py +++ b/optax/contrib/_sam.py @@ -47,7 +47,8 @@ """ # pytype: skip-file -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import chex import jax import jax.numpy as jnp diff --git a/optax/losses/_ranking.py b/optax/losses/_ranking.py index 7138c93e..c451296b 100644 --- a/optax/losses/_ranking.py +++ b/optax/losses/_ranking.py @@ -47,7 +47,8 @@ [-0.755, 0.09, 0.665] """ -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import chex import jax diff --git a/optax/monte_carlo/control_variates.py b/optax/monte_carlo/control_variates.py index 95d5261f..58f6a724 100644 --- a/optax/monte_carlo/control_variates.py +++ b/optax/monte_carlo/control_variates.py @@ -53,7 +53,8 @@ For examples, see `control_delta_method` and `moving_avg_baseline`. """ -from typing import Any, Callable, Sequence +from collections.abc import Callable +from typing import Any, Sequence import chex import jax diff --git a/optax/monte_carlo/stochastic_gradient_estimators.py b/optax/monte_carlo/stochastic_gradient_estimators.py index fd2ba605..b45d3bc4 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators.py +++ b/optax/monte_carlo/stochastic_gradient_estimators.py @@ -29,7 +29,8 @@ """ import math -from typing import Any, Callable, Sequence +from collections.abc import Callable +from typing import Any, Sequence import chex import jax diff --git a/optax/schedules/_inject.py b/optax/schedules/_inject.py index 52521b4e..5ec61e38 100644 --- a/optax/schedules/_inject.py +++ b/optax/schedules/_inject.py @@ -16,7 +16,8 @@ import functools import inspect -from typing import Callable, Iterable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Iterable, NamedTuple, Optional, Union import warnings import chex diff --git a/optax/transforms/_accumulation.py b/optax/transforms/_accumulation.py index 25ae2123..1d902af2 100644 --- a/optax/transforms/_accumulation.py +++ b/optax/transforms/_accumulation.py @@ -14,7 +14,8 @@ # ============================================================================== """Gradient transformations for accumulating gradients across updates.""" -from typing import Any, Callable, NamedTuple, Optional, Protocol, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Protocol, Union import chex import jax diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py index 57806859..655ec807 100644 --- a/optax/transforms/_adding.py +++ b/optax/transforms/_adding.py @@ -14,7 +14,8 @@ # ============================================================================== """Additive components in gradient transformations.""" -from typing import Any, Callable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Union import chex import jax diff --git a/optax/transforms/_combining.py b/optax/transforms/_combining.py index 5a74a65a..a4fb7c08 100644 --- a/optax/transforms/_combining.py +++ b/optax/transforms/_combining.py @@ -14,7 +14,8 @@ # ============================================================================== """Flexibly compose gradient transformations.""" -from typing import Callable, NamedTuple, Union, Mapping, Hashable +from collections.abc import Callable, Hashable, Mapping +from typing import NamedTuple, Union import jax diff --git a/optax/transforms/_masking.py b/optax/transforms/_masking.py index 82f34e7f..754d62ed 100644 --- a/optax/transforms/_masking.py +++ b/optax/transforms/_masking.py @@ -14,7 +14,8 @@ # ============================================================================== """Wrappers that mask out part of the parameters when applying a transform.""" -from typing import Any, Callable, NamedTuple, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Union import jax diff --git a/optax/tree_utils/_random.py b/optax/tree_utils/_random.py index b4f87501..db541af0 100644 --- a/optax/tree_utils/_random.py +++ b/optax/tree_utils/_random.py @@ -14,7 +14,8 @@ # ============================================================================== """Utilities to generate random pytrees.""" -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import chex import jax diff --git a/optax/tree_utils/_random_test.py b/optax/tree_utils/_random_test.py index 9dfb0a07..55569e80 100644 --- a/optax/tree_utils/_random_test.py +++ b/optax/tree_utils/_random_test.py @@ -14,7 +14,7 @@ # ============================================================================== """Tests for optax.tree_utils._random.""" -from typing import Callable +from collections.abc import Callable from absl.testing import absltest from absl.testing import parameterized diff --git a/optax/tree_utils/_state_utils.py b/optax/tree_utils/_state_utils.py index bdc07cf1..8c68057f 100644 --- a/optax/tree_utils/_state_utils.py +++ b/optax/tree_utils/_state_utils.py @@ -17,7 +17,8 @@ import dataclasses import functools import typing -from typing import Any, Callable, Optional, Protocol, Tuple, Union, cast +from collections.abc import Callable +from typing import Any, Optional, Protocol, Tuple, Union, cast import jax from optax._src import base