Skip to content

Commit

Permalink
Support FP16 MNIST classifier AutoScale training. (#67)
Browse files Browse the repository at this point in the history
Experimental results on this commit:

**FP32 Normal**

Training set accuracy 0.97117
Test set accuracy 0.94070

**FP32 AutoScale**

Training set accuracy 0.97222
Test set accuracy 0.94030

**FP16 Normal**

Training set accuracy 0.96815
Test set accuracy 0.93830

**FP16 AutoScale**

Training set accuracy 0.96772
Test set accuracy 0.93770
  • Loading branch information
balancap authored Jan 3, 2024
1 parent 9f7391b commit d52e870
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import time

import datasets
import jax
import jax.numpy as jnp
import numpy as np
import numpy.random as npr
from jax import grad, jit
from jax.scipy.special import logsumexp
Expand Down Expand Up @@ -60,10 +62,11 @@ def accuracy(params, batch):

if __name__ == "__main__":
layer_sizes = [784, 1024, 1024, 10]
param_scale = 0.1
param_scale = 1.0
step_size = 0.001
num_epochs = 10
batch_size = 128
training_dtype = np.float16

train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
Expand All @@ -80,8 +83,9 @@ def data_stream():

batches = data_stream()
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray`
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params)
params = jax.tree_map(lambda v: v.astype(training_dtype), params)

@jit
@jsa.autoscale
Expand All @@ -93,13 +97,19 @@ def update(params, batch):
start_time = time.time()
for _ in range(num_batches):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
batch = jsa.as_scaled_array(batch)
params = update(params, batch)
batch = jax.tree_map(lambda v: v.astype(training_dtype), batch)

with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN):
params = update(params, batch)

epoch_time = time.time() - start_time

raw_params = jsa.asarray(params)
# Evaluation in float32, for consistency.
raw_params = jsa.asarray(params, dtype=np.float32)
train_acc = accuracy(raw_params, (train_images, train_labels))
test_acc = accuracy(raw_params, (test_images, test_labels))
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
print(f"Training set accuracy {train_acc}")
print(f"Test set accuracy {test_acc}")
print(f"Training set accuracy {train_acc:0.5f}")
print(f"Test set accuracy {test_acc:0.5f}")

0 comments on commit d52e870

Please sign in to comment.