Skip to content

Commit

Permalink
Add double precision tests for safe_increment and fix warnings on flo…
Browse files Browse the repository at this point in the history
…at64_test.py

PiperOrigin-RevId: 674934804
  • Loading branch information
vroulet authored and OptaxDev committed Sep 15, 2024
1 parent 469c878 commit ee63e45
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 35 deletions.
33 changes: 20 additions & 13 deletions optax/_src/float64_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@

from absl.testing import absltest
from absl.testing import parameterized

import chex
import jax
from jax import config
import jax.numpy as jnp

from optax._src import alias
from optax._src import base
from optax._src import clipping
Expand All @@ -40,10 +37,16 @@
('scale_by_stddev', transform.scale_by_stddev, {}),
('adam', transform.scale_by_adam, {}),
('scale', transform.scale, dict(step_size=3.0)),
('add_decayed_weights', transform.add_decayed_weights,
dict(weight_decay=0.1)),
('scale_by_schedule', transform.scale_by_schedule,
dict(step_size_fn=lambda x: x * 0.1)),
(
'add_decayed_weights',
transform.add_decayed_weights,
dict(weight_decay=0.1),
),
(
'scale_by_schedule',
transform.scale_by_schedule,
dict(step_size_fn=lambda x: x * 0.1),
),
('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}),
('add_noise', transform.add_noise, dict(eta=1.0, gamma=0.1, seed=42)),
('apply_every_k', transform.apply_every, {}),
Expand All @@ -69,25 +72,29 @@ def _assert_dtype_equals(self, tree1, tree2):
@chex.all_variants
@parameterized.named_parameters(ALL_MODULES)
def test_mixed_dtype_input_outputs(self, transform_constr, transform_kwargs):
jax.config.update('jax_enable_x64', True)
initial_params = (
jnp.array([1., 2.], dtype=jnp.float32),
jnp.array([3., 4.], dtype=jnp.float64))
jnp.array([1.0, 2.0], dtype=jnp.float32),
jnp.array([3.0, 4.0], dtype=jnp.float64),
)
updates = (
jnp.array([10., 21.], dtype=jnp.float32),
jnp.array([33., 42.], dtype=jnp.float64))
jnp.array([10.0, 21.0], dtype=jnp.float32),
jnp.array([33.0, 42.0], dtype=jnp.float64),
)
scaler = transform_constr(**transform_kwargs)
init_fn = self.variant(scaler.init)
update_fn = self.variant(scaler.update)

initial_state = init_fn(initial_params)
updates, new_state = update_fn(
updates, initial_state, params=initial_params)
updates, initial_state, params=initial_params
)
new_params = update.apply_updates(initial_params, updates)

self._assert_dtype_equals(initial_state, new_state)
self._assert_dtype_equals(initial_params, new_params)
jax.config.update('jax_enable_x64', False)


if __name__ == '__main__':
config.update('jax_enable_x64', True)
absltest.main()
1 change: 1 addition & 0 deletions optax/_src/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def safe_increment(count: chex.Numeric) -> chex.Numeric:
counter stays at ``max_val``.
Examples:
>>> import jax.numpy as jnp
>>> import optax
>>> optax.safe_increment(jnp.asarray(1, dtype=jnp.int32))
Array(2, dtype=int32)
Expand Down
28 changes: 16 additions & 12 deletions optax/_src/numerics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,21 @@ def _invalid_ord_axis_inputs(ord_axis_keepdims):
class NumericsTest(chex.TestCase):

@chex.all_variants
@parameterized.product(
str_dtype=[
"bfloat16",
"float16",
"float32",
"int8",
"int16",
"int32",
]
)
def test_safe_increment(self, str_dtype):
@parameterized.parameters(*(
"bfloat16",
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
))
def test_safe_increment(self, dtype):
"""Tests that safe_increment works for all dtypes."""
dtype = jnp.dtype(str_dtype)
if dtype in ["float64", "int64"]:
jax.config.update("jax_enable_x64", True)
dtype = jnp.dtype(dtype)
inc_fn = self.variant(numerics.safe_increment)

with self.subTest("Increments correctly"):
Expand All @@ -75,6 +77,8 @@ def test_safe_increment(self, str_dtype):
base = jnp.asarray(max_val, dtype=dtype)
incremented = inc_fn(base)
np.testing.assert_array_equal(incremented, base)
if dtype in ["float64", "int64"]:
jax.config.update("jax_enable_x64", False)

@parameterized.product(
str_dtype=[
Expand Down
2 changes: 1 addition & 1 deletion optax/contrib/_acprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def update_fn(updates, state, params=None):
prediction_error = jtu.tree_map(lambda g, m: g - m, updates, state.mu)
nu = otu.tree_update_moment_per_elem_norm(prediction_error, state.nu, b2, 2)
nu = jtu.tree_map(lambda v: v + eps_root, nu)
count_inc = numerics.safe_int32_increment(state.count)
count_inc = numerics.safe_increment(state.count)

# On initial step, avoid division by zero and force nu_hat to be 1.
initial = state.count == 0
Expand Down
4 changes: 2 additions & 2 deletions optax/schedules/_inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def update_fn(updates, state, params=None, **extra_args):
).update(updates, state.inner_state, params, **extra_args)

return updates, InjectStatefulHyperparamsState(
count=numerics.safe_int32_increment(state.count),
count=numerics.safe_increment(state.count),
hyperparams=hparams,
hyperparams_states=hyperparams_states,
inner_state=inner_state,
Expand Down Expand Up @@ -270,7 +270,7 @@ def update(
**extra_args,
) -> WrappedScheduleState:
del extra_args
new_count = numerics.safe_int32_increment(state.count)
new_count = numerics.safe_increment(state.count)
return WrappedScheduleState(count=new_count)

def __call__(
Expand Down
4 changes: 2 additions & 2 deletions optax/transforms/_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,9 @@ def _do_update(updates, state, params):

emit = state.mini_step == (k_steps - 1)
new_state = MultiStepsState(
mini_step=numerics.safe_int32_increment(state.mini_step) % k_steps,
mini_step=numerics.safe_increment(state.mini_step) % k_steps,
gradient_step=emit
* numerics.safe_int32_increment(state.gradient_step)
* numerics.safe_increment(state.gradient_step)
+ (1 - emit) * state.gradient_step,
inner_opt_state=jtu.tree_map(
lambda st, nst: jnp.where(emit, nst, st),
Expand Down
2 changes: 1 addition & 1 deletion optax/transforms/_adding.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def init_fn(params):

def update_fn(updates, state, params=None): # pylint: disable=missing-docstring
del params
count_inc = numerics.safe_int32_increment(state.count)
count_inc = numerics.safe_increment(state.count)
standard_deviation = jnp.sqrt(eta / count_inc**gamma)

rng_key, sample_key = jax.random.split(state.rng_key)
Expand Down
8 changes: 4 additions & 4 deletions optax/transforms/_conditionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def reject_update(_):
should_transform_fn(state.step, **condition_kwargs),
do_update, reject_update, operand=None)
return updates, ConditionallyTransformState(
new_inner_state, numerics.safe_int32_increment(state.step))
new_inner_state, numerics.safe_increment(state.step))

return base.GradientTransformationExtraArgs(init_fn, update_fn)

Expand Down Expand Up @@ -165,7 +165,7 @@ def reject_update(_):
do_update, reject_update, operand=None)

return updates, ConditionallyMaskState(
step=numerics.safe_int32_increment(state.step),
step=numerics.safe_increment(state.step),
inner_state=new_inner_state,
)

Expand Down Expand Up @@ -230,7 +230,7 @@ def update(updates, state, params=None, **extra_args):
jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
notfinite_count = jnp.where(
isfinite, jnp.zeros([], jnp.int32),
numerics.safe_int32_increment(state.notfinite_count))
numerics.safe_increment(state.notfinite_count))

def do_update(_):
return inner.update(updates, inner_state, params, **extra_args)
Expand All @@ -247,7 +247,7 @@ def reject_update(_):
last_finite=isfinite,
total_notfinite=jnp.where(
isfinite, state.total_notfinite,
numerics.safe_int32_increment(state.total_notfinite)),
numerics.safe_increment(state.total_notfinite)),
inner_state=new_inner_state)

return base.GradientTransformationExtraArgs(init=init, update=update)

0 comments on commit ee63e45

Please sign in to comment.