From 69d428efb406876dbfab034e3319f29584b40a88 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 11 Jan 2024 13:30:51 +0000 Subject: [PATCH] Use `float32` scale in AutoScale MNIST training example. (#80) In some configurations, `float16` does not have enough dynamic range to represent scaling, meaning the scale factor can quickly overflow. Using `float32` (or equivalently `bfloat16`) should allow a bit more margin. --- experiments/mnist/mnist_classifier_from_scratch.py | 12 +++++++----- jax_scaled_arithmetics/core/datatype.py | 6 ++++++ .../lax/base_scaling_primitives.py | 1 - tests/core/test_datatype.py | 10 ++++++++++ 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/experiments/mnist/mnist_classifier_from_scratch.py b/experiments/mnist/mnist_classifier_from_scratch.py index d123da0..7e78cc6 100644 --- a/experiments/mnist/mnist_classifier_from_scratch.py +++ b/experiments/mnist/mnist_classifier_from_scratch.py @@ -68,7 +68,9 @@ def accuracy(params, batch): step_size = 0.001 num_epochs = 10 batch_size = 128 + training_dtype = np.float16 + scale_dtype = np.float32 train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] @@ -86,8 +88,8 @@ def data_stream(): batches = data_stream() params = init_random_params(param_scale, layer_sizes) # Transform parameters to `ScaledArray` and proper dtype. - params = jsa.as_scaled_array(params) - params = jax.tree_map(lambda v: v.astype(training_dtype), params) + params = jsa.as_scaled_array(params, scale=scale_dtype(1)) + params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) @jit @jsa.autoscale @@ -100,10 +102,10 @@ def update(params, batch): for _ in range(num_batches): batch = next(batches) # Scaled micro-batch + training dtype cast. - batch = jsa.as_scaled_array(batch) - batch = jax.tree_map(lambda v: v.astype(training_dtype), batch) + batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) + batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) - with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN): + with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): params = update(params, batch) epoch_time = time.time() - start_time diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 0cad268..bc08040 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -94,6 +94,12 @@ def aval(self) -> ShapedArray: """Abstract value of the scaled array, i.e. shape and dtype.""" return ShapedArray(self.data.shape, self.data.dtype) + def astype(self, dtype) -> "ScaledArray": + """Convert the ScaledArray to a dtype. + NOTE: only impacting `data` field, not the `scale` tensor. + """ + return ScaledArray(self.data.astype(dtype), self.scale) + def make_scaled_scalar(val: Array) -> ScaledArray: """Make a scaled scalar (array), from a single value. diff --git a/jax_scaled_arithmetics/lax/base_scaling_primitives.py b/jax_scaled_arithmetics/lax/base_scaling_primitives.py index 81878f8..defa590 100644 --- a/jax_scaled_arithmetics/lax/base_scaling_primitives.py +++ b/jax_scaled_arithmetics/lax/base_scaling_primitives.py @@ -189,7 +189,6 @@ def get_data_scale_abstract_eval(values: core.ShapedArray) -> core.ShapedArray: return (values.data, values.scale) # Use array dtype for scale by default. scale_dtype = get_scale_dtype() or values.dtype - print(scale_dtype) return values, core.ShapedArray((), dtype=scale_dtype) diff --git a/tests/core/test_datatype.py b/tests/core/test_datatype.py index fee2f03..3aef8c4 100644 --- a/tests/core/test_datatype.py +++ b/tests/core/test_datatype.py @@ -103,6 +103,16 @@ def test__scaled_array__numpy_array_interface(self, npapi): assert isinstance(out, np.ndarray) npt.assert_array_equal(out, sarr.data * sarr.scale) + @parameterized.parameters( + {"npapi": np}, + {"npapi": jnp}, + ) + def test__scaled_array__astype(self, npapi): + arr_in = ScaledArray(data=npapi.array([1.0, 2.0], dtype=np.float16), scale=npapi.array(1, dtype=np.int32)) + arr_out = arr_in.astype(np.float32) + assert arr_out.dtype == np.float32 + assert arr_out.scale.dtype == arr_in.scale.dtype + @parameterized.parameters( {"val": 0.25}, )