Skip to content

Commit

Permalink
Simplify JAX Scalify MNIST examples using jax_scalify.tree methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jun 28, 2024
1 parent e7e0562 commit 26b54e0
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 19 deletions.
4 changes: 2 additions & 2 deletions examples/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def data_stream():
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)
params = jsa.tree.astype(params, training_dtype)

@jit
@jsa.scalify
Expand All @@ -119,7 +119,7 @@ def update(params, batch):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)
batch = jsa.tree.astype(batch, training_dtype)

with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params = update(params, batch)
Expand Down
5 changes: 2 additions & 3 deletions examples/mnist/mnist_classifier_from_scratch_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import time

import datasets
import jax
import jax.numpy as jnp
import ml_dtypes
import numpy as np
Expand Down Expand Up @@ -133,7 +132,7 @@ def data_stream():
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)
params = jsa.tree.astype(params, training_dtype)

@jit
@jsa.scalify
Expand All @@ -147,7 +146,7 @@ def update(params, batch):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)
batch = jsa.tree.astype(batch, training_dtype)

with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params = update(params, batch)
Expand Down
19 changes: 8 additions & 11 deletions examples/mnist/mnist_classifier_mlp_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def update(model, optimizer, model_state, opt_state, batch):
key = jax.random.PRNGKey(42)
use_scalify: bool = True

# training_dtype = np.dtype(np.float16)
training_dtype = np.dtype(np.float16)
optimizer_dtype = np.dtype(np.float16)
scale_dtype = np.float32

train_images, train_labels, test_images, test_labels = datasets.mnist()
Expand All @@ -102,27 +102,24 @@ def data_stream():
model_state = model.init(key, np.zeros((batch_size, mnist_img_size), dtype=training_dtype))
# Optimizer & optimizer state.
# opt = optax.sgd(learning_rate=step_size)
opt = optax.adam(learning_rate=step_size, eps=1e-5)
opt = optax.adam(learning_rate=step_size, eps=2**-16)
opt_state = opt.init(model_state)
# Freeze model, optimizer (with step size).
update_fn = partial(update, model, opt)

if use_scalify:
# Transform parameters to `ScaledArray` and proper dtype.
# Transform parameters to `ScaledArray`.
model_state = jsa.as_scaled_array(model_state, scale=scale_dtype(1.0))
opt_state = jsa.as_scaled_array(opt_state, scale=scale_dtype(0.0001))

model_state = jax.tree_util.tree_map(
lambda v: v.astype(training_dtype), model_state, is_leaf=jsa.core.is_scaled_leaf
)
# Scalify the update function as well.
update_fn = jsa.scalify(update_fn)
else:
model_state = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), model_state)
# Convert the model state (weights) & optimizer state to proper dtype.
model_state = jsa.tree.astype(model_state, training_dtype)
opt_state = jsa.tree.astype(opt_state, optimizer_dtype, floating_only=True)

print(f"Using Scalify: {use_scalify}")
print(f"Training data format: {training_dtype.name}")
# print(f"Optimizer data format: {training_dtype.name}")
print(f"Optimizer data format: {optimizer_dtype.name}")
print("")

update_fn = jax.jit(update_fn)
Expand All @@ -134,7 +131,7 @@ def data_stream():
for _ in range(num_batches):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch)
batch = jsa.tree.astype(batch, training_dtype)
if use_scalify:
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
Expand Down
9 changes: 6 additions & 3 deletions jax_scalify/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,13 @@ def make_scaled_scalar(val: Array, scale_dtype: Optional[DTypeLike] = None) -> S


def is_scaled_leaf(val: Any) -> bool:
"""Is input a JAX PyTree (scaled) leaf, including ScaledArray.
"""Is input a normal JAX PyTree leaf (i.e. `Array`) or `ScaledArray1.
This function is useful for JAX PyTree handling where the user wants
to keep the ScaledArray datastructures (i.e. not flattened as a pair of arrays).
This function is useful for JAX PyTree handling with `jax.tree` methods where
the user wants to keep the ScaledArray data structures (i.e. not flattened as a
pair of arrays).
See `jax_scalify.tree` for PyTree `jax.tree` methods compatible with `ScaledArray`.
"""
# TODO: check Numpy scalars as well?
return np.isscalar(val) or isinstance(val, (Array, np.ndarray, ScaledArray))
Expand Down

0 comments on commit 26b54e0

Please sign in to comment.