Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving robustness of log and exp, with proper special values o… #84

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,24 @@
import jax.numpy as jnp
import numpy as np
import numpy.random as npr
from jax import grad, jit
from jax.scipy.special import logsumexp
from jax import grad, jit, lax

import jax_scaled_arithmetics as jsa

# from jax.scipy.special import logsumexp


def logsumexp(a, axis=None, keepdims=False):
dims = (axis,)
amax = jnp.max(a, axis=dims, keepdims=keepdims)
# FIXME: not proper scale propagation, introducing NaNs
# amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
amax = lax.stop_gradient(amax)
out = lax.sub(a, amax)
out = lax.exp(out)
out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax)
return out


def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
Expand All @@ -46,7 +59,8 @@ def predict(params, inputs):
logits = jnp.dot(activations, final_w) + final_b
# Dynamic rescaling of the gradient, as logits gradient not properly scaled.
logits = jsa.ops.dynamic_rescale_l2_grad(logits)
return logits - logsumexp(logits, axis=1, keepdims=True)
logits = logits - logsumexp(logits, axis=1, keepdims=True)
return logits


def loss(params, batch):
Expand Down Expand Up @@ -88,7 +102,7 @@ def data_stream():
batches = data_stream()
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params, scale=scale_dtype(1))
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
Expand Down
15 changes: 13 additions & 2 deletions jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,23 @@ def scaled_op_default_translation(

@core.register_scaled_lax_op
def scaled_exp(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.exp_p, [val])
assert isinstance(val, ScaledArray)
# Estimate in FP32, to avoid NaN when "descaling" the array.
# Otherwise: issues for representing properly 0 and +-Inf.
arr = val.to_array(dtype=np.float32).astype(val.dtype)
scale = np.array(1, dtype=val.scale.dtype)
return ScaledArray(lax.exp(arr), scale)


@core.register_scaled_lax_op
def scaled_log(val: ScaledArray) -> ScaledArray:
return scaled_op_default_translation(lax.log_p, [val])
assert isinstance(val, ScaledArray)
# Log of data & scale components.
log_data = lax.log(val.data)
log_scale = lax.log(val.scale).astype(val.dtype)
data = log_data + log_scale
scale = np.array(1, dtype=val.scale.dtype)
return ScaledArray(data, scale)


@core.register_scaled_lax_op
Expand Down
19 changes: 18 additions & 1 deletion tests/lax/test_scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,24 @@ def test__scaled_unary_op__proper_result_and_scaling(self, prim, dtype, expected
assert out.dtype == val.dtype
assert out.scale.dtype == val.scale.dtype
npt.assert_almost_equal(out.scale, expected_scale)
npt.assert_array_almost_equal(out, expected_output)
# FIXME: higher precision for `log`?
npt.assert_array_almost_equal(out, expected_output, decimal=3)

def test__scaled_exp__large_scale_zero_values(self):
scaled_op, _ = find_registered_scaled_op(lax.exp_p)
# Scaled array, with values < 0 and scale overflowing in float16.
val = scaled_array(np.array([0, -1, -2, -32768], np.float16), np.float32(32768 * 16))
out = scaled_op(val)
# Zero value should not be a NaN!
npt.assert_array_almost_equal(out, [1, 0, 0, 0], decimal=2)

def test__scaled_log__zero_large_values_large_scale(self):
scaled_op, _ = find_registered_scaled_op(lax.log_p)
# 0 + large values => proper log values, without NaN/overflow.
val = scaled_array(np.array([0, 1], np.float16), np.float32(32768 * 16))
out = scaled_op(val)
# No NaN value + not overflowing!
npt.assert_array_almost_equal(out, lax.log(val.to_array(np.float32)), decimal=2)


class ScaledTranslationBinaryOpsTests(chex.TestCase):
Expand Down
11 changes: 10 additions & 1 deletion tests/lax/test_scipy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,25 @@
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
from jax import lax

from jax_scaled_arithmetics.core import autoscale, scaled_array


class ScaledTranslationPrimitivesTests(chex.TestCase):
class ScaledScipyHighLevelMethodsTests(chex.TestCase):
def setUp(self):
super().setUp()
# Use random state for reproducibility!
self.rs = np.random.RandomState(42)

def test__lax_full_like__zero_scale(self):
def fn(a):
return lax.full_like(a, 0)

a = scaled_array(np.random.rand(3, 5).astype(np.float32), np.float32(1))
autoscale(fn)(a)
# FIMXE/TODO: what should be the expected result?

@parameterized.parameters(
{"dtype": np.float32},
{"dtype": np.float16},
Expand Down