Skip to content

Commit

Permalink
Fix double precision tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673920430
  • Loading branch information
vroulet authored and OptaxDev committed Sep 12, 2024
1 parent c0e4228 commit 9965925
Showing 1 changed file with 49 additions and 13 deletions.
62 changes: 49 additions & 13 deletions optax/_src/float64_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@

from absl.testing import absltest
from absl.testing import parameterized

import chex
import jax
from jax import config
import jax.numpy as jnp

import numpy as np
from optax._src import alias
from optax._src import base
from optax._src import clipping
from optax._src import numerics
from optax._src import transform
from optax._src import update

Expand All @@ -40,10 +39,16 @@
('scale_by_stddev', transform.scale_by_stddev, {}),
('adam', transform.scale_by_adam, {}),
('scale', transform.scale, dict(step_size=3.0)),
('add_decayed_weights', transform.add_decayed_weights,
dict(weight_decay=0.1)),
('scale_by_schedule', transform.scale_by_schedule,
dict(step_size_fn=lambda x: x * 0.1)),
(
'add_decayed_weights',
transform.add_decayed_weights,
dict(weight_decay=0.1),
),
(
'scale_by_schedule',
transform.scale_by_schedule,
dict(step_size_fn=lambda x: x * 0.1),
),
('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}),
('add_noise', transform.add_noise, dict(eta=1.0, gamma=0.1, seed=42)),
('apply_every_k', transform.apply_every, {}),
Expand All @@ -69,25 +74,56 @@ def _assert_dtype_equals(self, tree1, tree2):
@chex.all_variants
@parameterized.named_parameters(ALL_MODULES)
def test_mixed_dtype_input_outputs(self, transform_constr, transform_kwargs):
if not jax.config.jax_enable_x64:
raise self.skipTest('jax_enable_x64 is not set')
initial_params = (
jnp.array([1., 2.], dtype=jnp.float32),
jnp.array([3., 4.], dtype=jnp.float64))
jnp.array([1.0, 2.0], dtype=jnp.float32),
jnp.array([3.0, 4.0], dtype=jnp.float64),
)
updates = (
jnp.array([10., 21.], dtype=jnp.float32),
jnp.array([33., 42.], dtype=jnp.float64))
jnp.array([10.0, 21.0], dtype=jnp.float32),
jnp.array([33.0, 42.0], dtype=jnp.float64),
)
scaler = transform_constr(**transform_kwargs)
init_fn = self.variant(scaler.init)
update_fn = self.variant(scaler.update)

initial_state = init_fn(initial_params)
updates, new_state = update_fn(
updates, initial_state, params=initial_params)
updates, initial_state, params=initial_params
)
new_params = update.apply_updates(initial_params, updates)

self._assert_dtype_equals(initial_state, new_state)
self._assert_dtype_equals(initial_params, new_params)

@chex.all_variants
@parameterized.product(str_dtype=['float64', 'int64'])
def test_safe_increment(self, str_dtype):
"""Tests that safe_increment works for all dtypes."""
if not jax.config.jax_enable_x64:
raise self.skipTest('jax_enable_x64 is not set')
dtype = jnp.dtype(str_dtype)
inc_fn = self.variant(numerics.safe_increment)

with self.subTest('Increments correctly'):
x = jnp.asarray(3, dtype=dtype)
incremented = inc_fn(x)
expected = jnp.asarray(4, dtype=dtype)
np.testing.assert_array_equal(incremented, expected)

with self.subTest('Avoids overflow'):
if jnp.issubdtype(dtype, jnp.integer):
max_val = jnp.iinfo(dtype).max
elif jnp.issubdtype(dtype, jnp.floating):
max_val = jnp.finfo(dtype).max
else:
raise ValueError(f'Unsupported dtype: {dtype}')
x = jnp.asarray(max_val, dtype=dtype)
incremented = inc_fn(x)
np.testing.assert_array_equal(incremented, x)


if __name__ == '__main__':
config.update('jax_enable_x64', True)
jax.config.update('jax_enable_x64', True)
absltest.main()

0 comments on commit 9965925

Please sign in to comment.