Skip to content

Commit

Permalink
Add double precision tests for safe_increment and fix warnings on flo…
Browse files Browse the repository at this point in the history
…at64_test.py

PiperOrigin-RevId: 673920430
  • Loading branch information
vroulet authored and OptaxDev committed Sep 14, 2024
1 parent 469c878 commit e44155c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 25 deletions.
33 changes: 20 additions & 13 deletions optax/_src/float64_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@

from absl.testing import absltest
from absl.testing import parameterized

import chex
import jax
from jax import config
import jax.numpy as jnp

from optax._src import alias
from optax._src import base
from optax._src import clipping
Expand All @@ -40,10 +37,16 @@
('scale_by_stddev', transform.scale_by_stddev, {}),
('adam', transform.scale_by_adam, {}),
('scale', transform.scale, dict(step_size=3.0)),
('add_decayed_weights', transform.add_decayed_weights,
dict(weight_decay=0.1)),
('scale_by_schedule', transform.scale_by_schedule,
dict(step_size_fn=lambda x: x * 0.1)),
(
'add_decayed_weights',
transform.add_decayed_weights,
dict(weight_decay=0.1),
),
(
'scale_by_schedule',
transform.scale_by_schedule,
dict(step_size_fn=lambda x: x * 0.1),
),
('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}),
('add_noise', transform.add_noise, dict(eta=1.0, gamma=0.1, seed=42)),
('apply_every_k', transform.apply_every, {}),
Expand All @@ -69,25 +72,29 @@ def _assert_dtype_equals(self, tree1, tree2):
@chex.all_variants
@parameterized.named_parameters(ALL_MODULES)
def test_mixed_dtype_input_outputs(self, transform_constr, transform_kwargs):
jax.config.update('jax_enable_x64', True)
initial_params = (
jnp.array([1., 2.], dtype=jnp.float32),
jnp.array([3., 4.], dtype=jnp.float64))
jnp.array([1.0, 2.0], dtype=jnp.float32),
jnp.array([3.0, 4.0], dtype=jnp.float64),
)
updates = (
jnp.array([10., 21.], dtype=jnp.float32),
jnp.array([33., 42.], dtype=jnp.float64))
jnp.array([10.0, 21.0], dtype=jnp.float32),
jnp.array([33.0, 42.0], dtype=jnp.float64),
)
scaler = transform_constr(**transform_kwargs)
init_fn = self.variant(scaler.init)
update_fn = self.variant(scaler.update)

initial_state = init_fn(initial_params)
updates, new_state = update_fn(
updates, initial_state, params=initial_params)
updates, initial_state, params=initial_params
)
new_params = update.apply_updates(initial_params, updates)

self._assert_dtype_equals(initial_state, new_state)
self._assert_dtype_equals(initial_params, new_params)
jax.config.update('jax_enable_x64', False)


if __name__ == '__main__':
config.update('jax_enable_x64', True)
absltest.main()
28 changes: 16 additions & 12 deletions optax/_src/numerics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,21 @@ def _invalid_ord_axis_inputs(ord_axis_keepdims):
class NumericsTest(chex.TestCase):

@chex.all_variants
@parameterized.product(
str_dtype=[
"bfloat16",
"float16",
"float32",
"int8",
"int16",
"int32",
]
)
def test_safe_increment(self, str_dtype):
@parameterized.parameters(*(
"bfloat16",
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
))
def test_safe_increment(self, dtype):
"""Tests that safe_increment works for all dtypes."""
dtype = jnp.dtype(str_dtype)
if dtype in ["float64", "int64"]:
jax.config.update("jax_enable_x64", True)
dtype = jnp.dtype(dtype)
inc_fn = self.variant(numerics.safe_increment)

with self.subTest("Increments correctly"):
Expand All @@ -75,6 +77,8 @@ def test_safe_increment(self, str_dtype):
base = jnp.asarray(max_val, dtype=dtype)
incremented = inc_fn(base)
np.testing.assert_array_equal(incremented, base)
if dtype in ["float64", "int64"]:
jax.config.update("jax_enable_x64", False)

@parameterized.product(
str_dtype=[
Expand Down

0 comments on commit e44155c

Please sign in to comment.