Skip to content

Commit

Permalink
Use float32 scale in AutoScale MNIST training example.
Browse files Browse the repository at this point in the history
In some configurations, `float16` does not have enough dynamic range to represent scaling,
meaning the scale factor can quickly overflow. Using `float32` (or equivalently `bfloat16`)
should allow a bit more margin.
  • Loading branch information
balancap committed Jan 11, 2024
1 parent 7329bc7 commit 4e7d44a
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
12 changes: 7 additions & 5 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def accuracy(params, batch):
step_size = 0.001
num_epochs = 10
batch_size = 128

training_dtype = np.float16
scale_dtype = np.float32

train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
Expand All @@ -86,8 +88,8 @@ def data_stream():
batches = data_stream()
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params)
params = jax.tree_map(lambda v: v.astype(training_dtype), params)
params = jsa.as_scaled_array(params, scale=scale_dtype(1))
params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
@jsa.autoscale
Expand All @@ -100,10 +102,10 @@ def update(params, batch):
for _ in range(num_batches):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
batch = jsa.as_scaled_array(batch)
batch = jax.tree_map(lambda v: v.astype(training_dtype), batch)
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)

with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN):
with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params = update(params, batch)

epoch_time = time.time() - start_time
Expand Down
6 changes: 6 additions & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def aval(self) -> ShapedArray:
"""Abstract value of the scaled array, i.e. shape and dtype."""
return ShapedArray(self.data.shape, self.data.dtype)

def astype(self, dtype) -> "ScaledArray":
"""Convert the ScaledArray to a dtype.
NOTE: only impacting `data` field, not the `scale` tensor.
"""
return ScaledArray(self.data.astype(dtype), self.scale)


def make_scaled_scalar(val: Array) -> ScaledArray:
"""Make a scaled scalar (array), from a single value.
Expand Down
1 change: 0 additions & 1 deletion jax_scaled_arithmetics/lax/base_scaling_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def get_data_scale_abstract_eval(values: core.ShapedArray) -> core.ShapedArray:
return (values.data, values.scale)
# Use array dtype for scale by default.
scale_dtype = get_scale_dtype() or values.dtype
print(scale_dtype)
return values, core.ShapedArray((), dtype=scale_dtype)


Expand Down
10 changes: 10 additions & 0 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def test__scaled_array__numpy_array_interface(self, npapi):
assert isinstance(out, np.ndarray)
npt.assert_array_equal(out, sarr.data * sarr.scale)

@parameterized.parameters(
{"npapi": np},
{"npapi": jnp},
)
def test__scaled_array__astype(self, npapi):
arr_in = ScaledArray(data=npapi.array([1.0, 2.0], dtype=np.float16), scale=npapi.array(1, dtype=np.int32))
arr_out = arr_in.astype(np.float32)
assert arr_out.dtype == np.float32
assert arr_out.scale.dtype == arr_in.scale.dtype

@parameterized.parameters(
{"val": 0.25},
)
Expand Down

0 comments on commit 4e7d44a

Please sign in to comment.