Skip to content

Commit

Permalink
Fix doctests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678305003
  • Loading branch information
vroulet authored and OptaxDev committed Sep 25, 2024
1 parent fea4745 commit 8244bd4
Showing 1 changed file with 8 additions and 20 deletions.
28 changes: 8 additions & 20 deletions optax/tree_utils/_state_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,8 @@ def tree_get_all_with_path(
... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path],
... sep="\n",
... )
("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1.,
dtype=float32))
("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']",
WrappedScheduleState(count=Array(0, dtype=int32)))
("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1., dtype=float32))
("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))
Usage with a filtering operation
Expand All @@ -221,8 +219,7 @@ def tree_get_all_with_path(
... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path],
... sep="\n",
... )
("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']",
WrappedScheduleState(count=Array(0, dtype=int32)))
("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))
.. seealso:: :func:`optax.tree_utils.tree_get`,
:func:`optax.tree_utils.tree_set`
Expand Down Expand Up @@ -335,8 +332,7 @@ def tree_get(
>>> state = opt.init(params)
>>> noise_state = optax.tree_utils.tree_get(state, 'AddNoiseState')
>>> print(noise_state)
AddNoiseState(count=Array(0, dtype=int32), rng_key=Array([0, 0],
dtype=uint32))
AddNoiseState(count=Array(0, dtype=int32), rng_key=Array([0, 0], dtype=uint32))
Differentiating between two values by the name of their named tuples.
Expand Down Expand Up @@ -418,12 +414,10 @@ def tree_set(
>>> opt = optax.adam(learning_rate=1.)
>>> state = opt.init(params)
>>> print(state)
(ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.],
dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())
(ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())
>>> new_state = optax.tree_utils.tree_set(state, count=2.)
>>> print(new_state)
(ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32),
nu=Array([0., 0., 0.], dtype=float32)), EmptyState())
(ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())
Usage with a filtering operation
Expand All @@ -435,19 +429,13 @@ def tree_set(
... )
>>> state = opt.init(params)
>>> print(state)
InjectStatefulHyperparamsState(count=Array(0, dtype=int32),
hyperparams={'learning_rate': Array(1., dtype=float32)},
hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0,
dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
>>> filtering = lambda path, value: isinstance(value, jnp.ndarray)
>>> new_state = optax.tree_utils.tree_set(
... state, filtering, learning_rate=jnp.asarray(0.1)
... )
>>> print(new_state)
InjectStatefulHyperparamsState(count=Array(0, dtype=int32),
hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)},
hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0,
dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
.. seealso:: :func:`optax.tree_utils.tree_get_all_with_path`,
:func:`optax.tree_utils.tree_get`
Expand Down

0 comments on commit 8244bd4

Please sign in to comment.