Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log softmax grad analysis #91

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,17 @@
from jax import grad, jit, lax

import jax_scaled_arithmetics as jsa
from functools import partial

# from jax.scipy.special import logsumexp

def print_mean_std(name, v):
data, scale = jsa.lax.get_data_scale(v)
# Always use np.float32, to avoid floating errors in descaling + stats.
data = jsa.asarray(data, dtype=np.float32)
m, s, min, max = np.mean(data), np.std(data), np.min(data), np.max(data)
print(f"{name}: MEAN({m:.5f}) / STD({s:.5f}) / MIN({min:.5f}) / MAX({max:.5f}) / SCALE({scale:.5f})")


def logsumexp(a, axis=None, keepdims=False):
dims = (axis,)
Expand All @@ -47,6 +55,27 @@ def logsumexp(a, axis=None, keepdims=False):
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]

# jax.nn.logsumexp

def one_hot_dot(logits, mask, axis: int):
size = logits.shape[axis]

mask = jsa.lax.rebalance(mask, np.float32(1./8.))

jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
jsa.ops.debug_callback(partial(print_mean_std, "Mask"), mask)

r = jnp.sum(logits * mask, axis=axis)
jsa.ops.debug_callback(partial(print_mean_std, "Out"), r)
print("SIZE:", size, jsa.core.pow2_round_down(np.float32(size)))
(r,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "OutGrad"), r)


return r




def predict(params, inputs):
activations = inputs
Expand All @@ -58,14 +87,32 @@ def predict(params, inputs):
final_w, final_b = params[-1]
logits = jnp.dot(activations, final_w) + final_b
# Dynamic rescaling of the gradient, as logits gradient not properly scaled.
logits = jsa.ops.dynamic_rescale_l2_grad(logits)
logits = logits - logsumexp(logits, axis=1, keepdims=True)
# logits = jsa.ops.dynamic_rescale_l2_grad(logits)

# print("LOGITS", logits)
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad0"), logits)
logsumlogits = logsumexp(logits, axis=1, keepdims=True)
# (logsumlogits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsLogSumGrad"), logsumlogits)
logits = logits - logsumlogits
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad1"), logits)

# logits = jsa.ops.dynamic_rescale_l1_grad(logits)
return logits


def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
loss = one_hot_dot(preds, targets, axis=1)
# loss = jnp.sum(preds * targets, axis=1)s
# loss = jsa.ops.dynamic_rescale_l1_grad(loss)
(loss,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LossGrad2"), loss)
loss = -jnp.mean(loss)
jsa.ops.debug_callback(partial(print_mean_std, "Loss"), loss)
(loss,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LossGrad"), loss)


return loss
return -jnp.mean(jnp.sum(preds * targets, axis=1))


Expand Down Expand Up @@ -111,6 +158,9 @@ def update(params, batch):
grads = grad(loss)(params, batch)
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]

num_batches = 2
num_epochs = 1

for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
Expand Down
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/ops/rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray:
data_sq = jax.lax.abs(data)
axes = tuple(range(data.ndim))
# Get MAX norm + pow2 rounding.
norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes)
norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes) / (64)
norm = jax.lax.max(pow2_round(norm).astype(scale.dtype), eps.astype(scale.dtype))
# Rebalancing based on norm.
return rebalance(arr, norm)
Expand Down
Loading
Loading