Skip to content

Commit

Permalink
Merge pull request #1092 from carlosgmartin:safe_increment_unsigned
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683647626
  • Loading branch information
OptaxDev committed Oct 8, 2024
2 parents 06ce57a + 921f30e commit d9212fb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
9 changes: 2 additions & 7 deletions optax/_src/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,15 @@ def safe_increment(count: chex.Numeric) -> chex.Numeric:
count_dtype = jnp.asarray(count).dtype
if jnp.issubdtype(count_dtype, jnp.integer):
max_value = jnp.iinfo(count_dtype).max
if jnp.issubdtype(count_dtype, jnp.unsignedinteger):
# The comparison count < max_value appears to convert its arguments into
# signed integers so we get overflow errors with unsigned integers.
raise ValueError(
f'Unsigned integers like {count_dtype} cannot be incremented safely.'
)
elif jnp.issubdtype(count_dtype, jnp.floating):
max_value = jnp.finfo(count_dtype).max
else:
raise ValueError(
f'Cannot safely increment count with dtype {count_dtype},'
' valid dtypes are subdtypes of "jnp.integer" or "jnp.floating".'
)
one = jnp.array(1, dtype=count_dtype)
max_value = jnp.array(max_value, count_dtype)
one = jnp.array(1, count_dtype)
return jnp.where(count < max_value, count + one, max_value)


Expand Down
11 changes: 6 additions & 5 deletions optax/_src/numerics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ class NumericsTest(chex.TestCase):
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
))
def test_safe_increment(self, dtype):
"""Tests that safe_increment works for all dtypes."""
if dtype in ["float64", "int64"]:
if dtype in ["float64", "int64", "uint64"]:
jax.config.update("jax_enable_x64", True)
dtype = jnp.dtype(dtype)
inc_fn = self.variant(numerics.safe_increment)
Expand All @@ -77,14 +81,11 @@ def test_safe_increment(self, dtype):
base = jnp.asarray(max_val, dtype=dtype)
incremented = inc_fn(base)
np.testing.assert_array_equal(incremented, base)
if dtype in ["float64", "int64"]:
if dtype in ["float64", "int64", "uint64"]:
jax.config.update("jax_enable_x64", False)

@parameterized.product(
str_dtype=[
"uint8",
"uint16",
"uint32",
"bool",
"complex64",
]
Expand Down

0 comments on commit d9212fb

Please sign in to comment.