From 8244bd40dd390d72b0623986e662ad497c1bc361 Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Tue, 24 Sep 2024 10:28:11 -0700 Subject: [PATCH] Fix doctests PiperOrigin-RevId: 678305003 --- optax/tree_utils/_state_utils.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/optax/tree_utils/_state_utils.py b/optax/tree_utils/_state_utils.py index 08f01839..4e7d6a8f 100644 --- a/optax/tree_utils/_state_utils.py +++ b/optax/tree_utils/_state_utils.py @@ -199,10 +199,8 @@ def tree_get_all_with_path( ... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path], ... sep="\n", ... ) - ("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1., - dtype=float32)) - ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", - WrappedScheduleState(count=Array(0, dtype=int32))) + ("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1., dtype=float32)) + ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32))) Usage with a filtering operation @@ -221,8 +219,7 @@ def tree_get_all_with_path( ... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path], ... sep="\n", ... ) - ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", - WrappedScheduleState(count=Array(0, dtype=int32))) + ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32))) .. seealso:: :func:`optax.tree_utils.tree_get`, :func:`optax.tree_utils.tree_set` @@ -335,8 +332,7 @@ def tree_get( >>> state = opt.init(params) >>> noise_state = optax.tree_utils.tree_get(state, 'AddNoiseState') >>> print(noise_state) - AddNoiseState(count=Array(0, dtype=int32), rng_key=Array([0, 0], - dtype=uint32)) + AddNoiseState(count=Array(0, dtype=int32), rng_key=Array([0, 0], dtype=uint32)) Differentiating between two values by the name of their named tuples. @@ -418,12 +414,10 @@ def tree_set( >>> opt = optax.adam(learning_rate=1.) >>> state = opt.init(params) >>> print(state) - (ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], - dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState()) + (ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState()) >>> new_state = optax.tree_utils.tree_set(state, count=2.) >>> print(new_state) - (ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), - nu=Array([0., 0., 0.], dtype=float32)), EmptyState()) + (ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState()) Usage with a filtering operation @@ -435,19 +429,13 @@ def tree_set( ... ) >>> state = opt.init(params) >>> print(state) - InjectStatefulHyperparamsState(count=Array(0, dtype=int32), - hyperparams={'learning_rate': Array(1., dtype=float32)}, - hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, - dtype=int32))}, inner_state=(EmptyState(), EmptyState())) + InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState())) >>> filtering = lambda path, value: isinstance(value, jnp.ndarray) >>> new_state = optax.tree_utils.tree_set( ... state, filtering, learning_rate=jnp.asarray(0.1) ... ) >>> print(new_state) - InjectStatefulHyperparamsState(count=Array(0, dtype=int32), - hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, - hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, - dtype=int32))}, inner_state=(EmptyState(), EmptyState())) + InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState())) .. seealso:: :func:`optax.tree_utils.tree_get_all_with_path`, :func:`optax.tree_utils.tree_get`