Skip to content

Commit

Permalink
Improving robustness of log and exp, with proper special values o…
Browse files Browse the repository at this point in the history
…utput.

Making sure that `exp` of `0` is `1` and `log` of `0` is `-inf`.
Using a custom `logsumexp` in MNIST example until an additional scale propagation
bug is solved.

NOTE: additional robustness means MNIST training converges when initialization scale > 1.
  • Loading branch information
balancap committed Jan 15, 2024
1 parent ebfc951 commit e371883
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 8 deletions.
22 changes: 18 additions & 4 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,24 @@
import jax.numpy as jnp
import numpy as np
import numpy.random as npr
from jax import grad, jit
from jax.scipy.special import logsumexp
from jax import grad, jit, lax

import jax_scaled_arithmetics as jsa

# from jax.scipy.special import logsumexp


def logsumexp(a, axis=None, keepdims=False):
dims = (axis,)
amax = jnp.max(a, axis=dims, keepdims=keepdims)
# FIXME: not proper scale propagation, introducing NaNs
# amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
amax = lax.stop_gradient(amax)
out = lax.sub(a, amax)
out = lax.exp(out)
out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax)
return out


def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
Expand All @@ -46,7 +59,8 @@ def predict(params, inputs):
logits = jnp.dot(activations, final_w) + final_b
# Dynamic rescaling of the gradient, as logits gradient not properly scaled.
logits = jsa.ops.dynamic_rescale_l2_grad(logits)
return logits - logsumexp(logits, axis=1, keepdims=True)
logits = logits - logsumexp(logits, axis=1, keepdims=True)
return logits


def loss(params, batch):
Expand Down Expand Up @@ -88,7 +102,7 @@ def data_stream():
batches = data_stream()
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params, scale=scale_dtype(1))
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
Expand Down
15 changes: 13 additions & 2 deletions jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,23 @@ def scaled_op_default_translation(

@core.register_scaled_lax_op
def scaled_exp(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.exp_p, [val])
assert isinstance(val, ScaledArray)
# Estimate in FP32, to avoid NaN when "descaling" the array.
# Otherwise: issues for representing properly 0 and +-Inf.
arr = val.to_array(dtype=np.float32).astype(val.dtype)
scale = np.array(1, dtype=val.scale.dtype)
return ScaledArray(lax.exp(arr), scale)


@core.register_scaled_lax_op
def scaled_log(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.log_p, [val])
assert isinstance(val, ScaledArray)
# Log of data & scale components.
log_data = lax.log(val.data)
log_scale = lax.log(val.scale).astype(val.dtype)
data = log_data + log_scale
scale = np.array(1, dtype=val.scale.dtype)
return ScaledArray(data, scale)


@core.register_scaled_lax_op
Expand Down
19 changes: 18 additions & 1 deletion tests/lax/test_scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,24 @@ def test__scaled_unary_op__proper_result_and_scaling(self, prim, dtype, expected
assert out.dtype == val.dtype
assert out.scale.dtype == val.scale.dtype
npt.assert_almost_equal(out.scale, expected_scale)
npt.assert_array_almost_equal(out, expected_output)
# FIXME: higher precision for `log`?
npt.assert_array_almost_equal(out, expected_output, decimal=3)

def test__scaled_exp__large_scale_zero_values(self):
scaled_op, _ = find_registered_scaled_op(lax.exp_p)
# Scaled array, with values < 0 and scale overflowing in float16.
val = scaled_array(np.array([0, -1, -2, -32768], np.float16), np.float32(32768 * 16))
out = scaled_op(val)
# Zero value should not be a NaN!
npt.assert_array_almost_equal(out, [1, 0, 0, 0], decimal=2)

def test__scaled_log__zero_large_values_large_scale(self):
scaled_op, _ = find_registered_scaled_op(lax.log_p)
# 0 + large values => proper log values, without NaN/overflow.
val = scaled_array(np.array([0, 1], np.float16), np.float32(32768 * 16))
out = scaled_op(val)
# No NaN value + not overflowing!
npt.assert_array_almost_equal(out, lax.log(val.to_array(np.float32)), decimal=2)


class ScaledTranslationBinaryOpsTests(chex.TestCase):
Expand Down
11 changes: 10 additions & 1 deletion tests/lax/test_scipy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,25 @@
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
from jax import lax

from jax_scaled_arithmetics.core import autoscale, scaled_array


class ScaledTranslationPrimitivesTests(chex.TestCase):
class ScaledScipyHighLevelMethodsTests(chex.TestCase):
def setUp(self):
super().setUp()
# Use random state for reproducibility!
self.rs = np.random.RandomState(42)

def test__lax_full_like__zero_scale(self):
def fn(a):
return lax.full_like(a, 0)

a = scaled_array(np.random.rand(3, 5).astype(np.float32), np.float32(1))
autoscale(fn)(a)
# FIMXE/TODO: what should be the expected result?

@parameterized.parameters(
{"dtype": np.float32},
{"dtype": np.float16},
Expand Down

0 comments on commit e371883

Please sign in to comment.