From 1cde7043a965eef4c4f3817006159d16dd9fae3e Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Mon, 23 Sep 2024 13:22:48 -0400 Subject: [PATCH] Replace deprecated typing.Hashable, typing.Callable, and typing.Mapping with equivalents from collections.abc. --- docs/ext/coverage_check.py | 3 ++- examples/cifar10_resnet.ipynb | 3 ++- examples/meta_learning.ipynb | 3 ++- optax/_src/alias.py | 3 ++- optax/_src/alias_test.py | 3 ++- optax/_src/base.py | 3 ++- optax/_src/factorized.py | 3 ++- optax/_src/linear_algebra.py | 3 ++- optax/_src/linesearch.py | 3 ++- optax/_src/linesearch_test.py | 3 ++- optax/_src/utils.py | 3 ++- optax/_src/wrappers.py | 2 +- optax/contrib/_acprop.py | 3 ++- optax/contrib/_cocob.py | 3 ++- optax/contrib/_dog.py | 3 ++- optax/contrib/_sam.py | 3 ++- optax/losses/_ranking.py | 3 ++- optax/monte_carlo/control_variates.py | 3 ++- optax/monte_carlo/stochastic_gradient_estimators.py | 3 ++- optax/schedules/_inject.py | 3 ++- optax/transforms/_accumulation.py | 3 ++- optax/transforms/_adding.py | 3 ++- optax/transforms/_combining.py | 3 ++- optax/transforms/_masking.py | 3 ++- optax/tree_utils/_random.py | 3 ++- optax/tree_utils/_random_test.py | 2 +- optax/tree_utils/_state_utils.py | 3 ++- 27 files changed, 52 insertions(+), 27 deletions(-) diff --git a/docs/ext/coverage_check.py b/docs/ext/coverage_check.py index 4046d2123..0626b64a8 100644 --- a/docs/ext/coverage_check.py +++ b/docs/ext/coverage_check.py @@ -16,7 +16,8 @@ import inspect import types -from typing import Any, Mapping, Sequence, Tuple +from collections.abc import Mapping +from typing import Any, Sequence, Tuple import optax from sphinx import application diff --git a/examples/cifar10_resnet.ipynb b/examples/cifar10_resnet.ipynb index 7b5706d4d..a9f9374e1 100644 --- a/examples/cifar10_resnet.ipynb +++ b/examples/cifar10_resnet.ipynb @@ -44,7 +44,8 @@ ], "source": [ "import functools\n", - "from typing import Any, Callable, Sequence, Tuple, Optional, Dict\n", + "from collections.abc import Callable\n", + "from typing import Any, Sequence, Tuple, Optional, Dict\n", "\n", "from flax import linen as nn\n", "\n", diff --git a/examples/meta_learning.ipynb b/examples/meta_learning.ipynb index e30e42920..8df9648c2 100644 --- a/examples/meta_learning.ipynb +++ b/examples/meta_learning.ipynb @@ -45,7 +45,8 @@ }, "outputs": [], "source": [ - "from typing import Callable, Iterator, Tuple\n", + "from collections.abc import Callable\n", + "from typing import Iterator, Tuple\n", "import chex\n", "import jax\n", "import jax.numpy as jnp\n", diff --git a/optax/_src/alias.py b/optax/_src/alias.py index c7660298f..90abc698c 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -15,7 +15,8 @@ """Aliases for popular optimizers.""" import functools -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import jax.numpy as jnp diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index a2444c689..d00c34f55 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -14,7 +14,8 @@ # ============================================================================== """Tests for `alias.py`.""" -from typing import Any, Callable, Union +from collections.abc import Callable +from typing import Any, Union from absl.testing import absltest from absl.testing import parameterized diff --git a/optax/_src/base.py b/optax/_src/base.py index 1ed155f9e..c56d3aa1f 100644 --- a/optax/_src/base.py +++ b/optax/_src/base.py @@ -14,7 +14,8 @@ # ============================================================================== """Base interfaces and datatypes.""" -from typing import Any, Callable, NamedTuple, Optional, Protocol, runtime_checkable, Sequence, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Protocol, runtime_checkable, Sequence, Union import chex import jax diff --git a/optax/_src/factorized.py b/optax/_src/factorized.py index 7379c9ac8..22ce7eef0 100644 --- a/optax/_src/factorized.py +++ b/optax/_src/factorized.py @@ -15,7 +15,8 @@ """Factorized optimizers.""" import dataclasses -from typing import NamedTuple, Optional, Callable +from collections.abc import Callable +from typing import NamedTuple, Optional import chex import jax diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index a84bd3aa4..cc2412cca 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -15,7 +15,8 @@ """Linear algebra utilities used in optimisation.""" import functools -from typing import Callable, Optional, Union +from collections.abc import Callable +from typing import Optional, Union import chex import jax diff --git a/optax/_src/linesearch.py b/optax/_src/linesearch.py index aef22cf94..94cf1078c 100644 --- a/optax/_src/linesearch.py +++ b/optax/_src/linesearch.py @@ -15,7 +15,8 @@ """Line-searches.""" import functools -from typing import Any, Callable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Union import chex import jax diff --git a/optax/_src/linesearch_test.py b/optax/_src/linesearch_test.py index 640dfdc0b..b63f4ac0e 100644 --- a/optax/_src/linesearch_test.py +++ b/optax/_src/linesearch_test.py @@ -19,7 +19,8 @@ import io import itertools import math -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional from absl.testing import absltest from absl.testing import parameterized diff --git a/optax/_src/utils.py b/optax/_src/utils.py index 86e81dea2..ef7f37a58 100644 --- a/optax/_src/utils.py +++ b/optax/_src/utils.py @@ -17,7 +17,8 @@ import functools import inspect import operator -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Callable +from typing import Any, Optional, Sequence, Union import chex from etils import epy diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index 262c80a68..f1e572374 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -15,7 +15,7 @@ """Transformation wrappers.""" import functools -from typing import Callable +from collections.abc import Callable import chex import jax.numpy as jnp diff --git a/optax/contrib/_acprop.py b/optax/contrib/_acprop.py index a5b6788b1..2fe626644 100644 --- a/optax/contrib/_acprop.py +++ b/optax/contrib/_acprop.py @@ -18,7 +18,8 @@ Asynchronous Update for Adaptive Gradient Methods" by Zhuang et al. (https://arxiv.org/abs/2110.05454). """ -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import jax import jax.numpy as jnp diff --git a/optax/contrib/_cocob.py b/optax/contrib/_cocob.py index edf13a021..ad5cdbab6 100644 --- a/optax/contrib/_cocob.py +++ b/optax/contrib/_cocob.py @@ -18,7 +18,8 @@ Networks without Learning Rates Through Coin Betting" by Francesco Orabona and Tatiana Tommasi. """ -from typing import Any, Callable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Union import jax import jax.numpy as jnp diff --git a/optax/contrib/_dog.py b/optax/contrib/_dog.py index c8249ab8a..e46be349f 100644 --- a/optax/contrib/_dog.py +++ b/optax/contrib/_dog.py @@ -22,7 +22,8 @@ Gradient Descent Method`_, 2023. """ -from typing import Any, Callable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Union import chex import jax diff --git a/optax/contrib/_sam.py b/optax/contrib/_sam.py index 41ae73828..5bfdce96d 100644 --- a/optax/contrib/_sam.py +++ b/optax/contrib/_sam.py @@ -47,7 +47,8 @@ """ # pytype: skip-file -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import chex import jax import jax.numpy as jnp diff --git a/optax/losses/_ranking.py b/optax/losses/_ranking.py index 7138c93e5..c451296b9 100644 --- a/optax/losses/_ranking.py +++ b/optax/losses/_ranking.py @@ -47,7 +47,8 @@ [-0.755, 0.09, 0.665] """ -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import chex import jax diff --git a/optax/monte_carlo/control_variates.py b/optax/monte_carlo/control_variates.py index 95d5261f2..58f6a7247 100644 --- a/optax/monte_carlo/control_variates.py +++ b/optax/monte_carlo/control_variates.py @@ -53,7 +53,8 @@ For examples, see `control_delta_method` and `moving_avg_baseline`. """ -from typing import Any, Callable, Sequence +from collections.abc import Callable +from typing import Any, Sequence import chex import jax diff --git a/optax/monte_carlo/stochastic_gradient_estimators.py b/optax/monte_carlo/stochastic_gradient_estimators.py index fd2ba605d..b45d3bc40 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators.py +++ b/optax/monte_carlo/stochastic_gradient_estimators.py @@ -29,7 +29,8 @@ """ import math -from typing import Any, Callable, Sequence +from collections.abc import Callable +from typing import Any, Sequence import chex import jax diff --git a/optax/schedules/_inject.py b/optax/schedules/_inject.py index 52521b4ed..5ec61e38c 100644 --- a/optax/schedules/_inject.py +++ b/optax/schedules/_inject.py @@ -16,7 +16,8 @@ import functools import inspect -from typing import Callable, Iterable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Iterable, NamedTuple, Optional, Union import warnings import chex diff --git a/optax/transforms/_accumulation.py b/optax/transforms/_accumulation.py index 25ae2123c..1d902af21 100644 --- a/optax/transforms/_accumulation.py +++ b/optax/transforms/_accumulation.py @@ -14,7 +14,8 @@ # ============================================================================== """Gradient transformations for accumulating gradients across updates.""" -from typing import Any, Callable, NamedTuple, Optional, Protocol, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Protocol, Union import chex import jax diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py index 578068599..655ec807d 100644 --- a/optax/transforms/_adding.py +++ b/optax/transforms/_adding.py @@ -14,7 +14,8 @@ # ============================================================================== """Additive components in gradient transformations.""" -from typing import Any, Callable, NamedTuple, Optional, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Union import chex import jax diff --git a/optax/transforms/_combining.py b/optax/transforms/_combining.py index 5a74a65a5..a4fb7c084 100644 --- a/optax/transforms/_combining.py +++ b/optax/transforms/_combining.py @@ -14,7 +14,8 @@ # ============================================================================== """Flexibly compose gradient transformations.""" -from typing import Callable, NamedTuple, Union, Mapping, Hashable +from collections.abc import Callable, Hashable, Mapping +from typing import NamedTuple, Union import jax diff --git a/optax/transforms/_masking.py b/optax/transforms/_masking.py index 82f34e7f6..754d62ed5 100644 --- a/optax/transforms/_masking.py +++ b/optax/transforms/_masking.py @@ -14,7 +14,8 @@ # ============================================================================== """Wrappers that mask out part of the parameters when applying a transform.""" -from typing import Any, Callable, NamedTuple, Union +from collections.abc import Callable +from typing import Any, NamedTuple, Union import jax diff --git a/optax/tree_utils/_random.py b/optax/tree_utils/_random.py index b4f875017..db541af03 100644 --- a/optax/tree_utils/_random.py +++ b/optax/tree_utils/_random.py @@ -14,7 +14,8 @@ # ============================================================================== """Utilities to generate random pytrees.""" -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import chex import jax diff --git a/optax/tree_utils/_random_test.py b/optax/tree_utils/_random_test.py index 9dfb0a073..55569e80e 100644 --- a/optax/tree_utils/_random_test.py +++ b/optax/tree_utils/_random_test.py @@ -14,7 +14,7 @@ # ============================================================================== """Tests for optax.tree_utils._random.""" -from typing import Callable +from collections.abc import Callable from absl.testing import absltest from absl.testing import parameterized diff --git a/optax/tree_utils/_state_utils.py b/optax/tree_utils/_state_utils.py index bdc07cf17..8c68057f0 100644 --- a/optax/tree_utils/_state_utils.py +++ b/optax/tree_utils/_state_utils.py @@ -17,7 +17,8 @@ import dataclasses import functools import typing -from typing import Any, Callable, Optional, Protocol, Tuple, Union, cast +from collections.abc import Callable +from typing import Any, Optional, Protocol, Tuple, Union, cast import jax from optax._src import base