Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft : Measure mean std in mnist training #72

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 54 additions & 14 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

The primary aim here is simplicity and minimal dependencies.
"""


import time
from functools import partial

import datasets
import jax
Expand All @@ -31,27 +30,60 @@
import jax_scaled_arithmetics as jsa


def print_mean_std(name, v):
data, scale = jsa.lax.get_data_scale(v)
# Always use np.float32, to avoid floating errors in descaling + stats.
v = jsa.asarray(data, dtype=np.float32)
m, s = np.mean(v), np.std(v)
# print(data)
print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})")


def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]


def predict(params, inputs):
activations = inputs
for w, b in params[:-1]:
jsa.ops.debug_callback(partial(print_mean_std, "W"), w)
jsa.ops.debug_callback(partial(print_mean_std, "B"), b)
(w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w)

# Matmul + relu
outputs = jnp.dot(activations, w) + b
activations = jnp.maximum(outputs, 0)
jsa.ops.debug_callback(partial(print_mean_std, "Act"), activations)
# activations = jsa.ops.dynamic_rescale_l2_grad(activations)

final_w, final_b = params[-1]
logits = jnp.dot(activations, final_w) + final_b
# Dynamic rescaling of the gradient, as logits gradient not properly scaled.
logits = jnp.dot(activations, final_w)
jsa.ops.debug_callback(partial(print_mean_std, "Logits0"), logits)
logits = logits + final_b

jsa.ops.debug_callback(partial(print_mean_std, "Logits1"), logits)
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)

logits = jsa.ops.dynamic_rescale_l2_grad(logits)
return logits - logsumexp(logits, axis=1, keepdims=True)
# logits = logits.astype(np.float32)
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)

logits = logits - logsumexp(logits, axis=1, keepdims=True)
jsa.ops.debug_callback(partial(print_mean_std, "Logits2"), logits)
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
return logits


def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
# jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds)
loss = jnp.sum(preds * targets, axis=1)
# loss = jsa.ops.dynamic_rescale_l2(loss)
# jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss)
loss = -jnp.mean(loss)
# jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss)
return loss
return -jnp.mean(jnp.sum(preds * targets, axis=1))


Expand All @@ -64,7 +96,7 @@ def accuracy(params, batch):

if __name__ == "__main__":
layer_sizes = [784, 1024, 1024, 10]
param_scale = 1.0
param_scale = 2.0
step_size = 0.001
num_epochs = 10
batch_size = 128
Expand All @@ -88,18 +120,26 @@ def data_stream():
batches = data_stream()
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params, scale=scale_dtype(1))
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
# @jit
@jsa.autoscale
def update(params, batch):
grads = grad(loss)(params, batch)
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
return [
(jsa.ops.dynamic_rescale_l1(w - step_size * dw), jsa.ops.dynamic_rescale_l1(b - step_size * db))
for (w, b), (dw, db) in zip(params, grads)
]

num_batches = 1
num_epochs = 1
for epoch in range(num_epochs):
# print("EPOCH:", epoch)
start_time = time.time()
for _ in range(num_batches):
# print("BATCH...")
batch = next(batches)
# Scaled micro-batch + training dtype cast.
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
Expand All @@ -111,9 +151,9 @@ def update(params, batch):
epoch_time = time.time() - start_time

# Evaluation in float32, for consistency.
raw_params = jsa.asarray(params, dtype=np.float32)
train_acc = accuracy(raw_params, (train_images, train_labels))
test_acc = accuracy(raw_params, (test_images, test_labels))
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
print(f"Training set accuracy {train_acc:0.5f}")
print(f"Test set accuracy {test_acc:0.5f}")
# raw_params = jsa.asarray(params, dtype=np.float32)
# train_acc = accuracy(raw_params, (train_images, train_labels))
# test_acc = accuracy(raw_params, (test_images, test_labels))
# print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
# print(f"Training set accuracy {train_acc:0.5f}")
# print(f"Test set accuracy {test_acc:0.5f}")
6 changes: 4 additions & 2 deletions jax_scaled_arithmetics/ops/rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from functools import partial

import jax

# import jax.numpy as jnp
import numpy as np

from jax_scaled_arithmetics.core import ScaledArray, pow2_round
Expand Down Expand Up @@ -48,7 +50,7 @@ def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray:
data_sq = jax.lax.abs(data)
axes = tuple(range(data.ndim))
# Get MAX norm + pow2 rounding.
norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes)
norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes) + np.float32(1e-3)
norm = pow2_round(norm.astype(scale.dtype))
# Rebalancing based on norm.
return rebalance(arr, norm)
Expand All @@ -63,7 +65,7 @@ def dynamic_rescale_l1_base(arr: ScaledArray) -> ScaledArray:
data_sq = jax.lax.abs(data.astype(np.float32))
axes = tuple(range(data.ndim))
# Get L1 norm + pow2 rounding.
norm = jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size
norm = jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size + np.float32(1e-3)
norm = pow2_round(norm.astype(scale.dtype))
# Rebalancing based on norm.
return rebalance(arr, norm)
Expand Down
Loading