Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jan 9, 2024
1 parent 091d7cd commit dbb1c1d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
17 changes: 13 additions & 4 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
The primary aim here is simplicity and minimal dependencies.
"""
import time
from functools import partial

import datasets
import jax
Expand All @@ -28,8 +29,6 @@

import jax_scaled_arithmetics as jsa

# from functools import partial


def print_mean_std(name, v):
data, scale = jsa.lax.get_data_scale(v)
Expand Down Expand Up @@ -58,19 +57,29 @@ def predict(params, inputs):
final_w, final_b = params[-1]
logits = jnp.dot(activations, final_w) + final_b

# jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)

logits = jsa.ops.dynamic_rescale_l2_grad(logits)
# logits = logits.astype(np.float32)
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)

return logits - logsumexp(logits, axis=1, keepdims=True)
logits = logits - logsumexp(logits, axis=1, keepdims=True)
jsa.ops.debug_callback(partial(print_mean_std, "Logits2"), logits)
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
return logits


def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds)
loss = jnp.sum(preds * targets, axis=1)
# loss = jsa.ops.dynamic_rescale_l2(loss)
jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss)
loss = -jnp.mean(loss)
jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss)
return loss
return -jnp.mean(jnp.sum(preds * targets, axis=1))


Expand Down
3 changes: 2 additions & 1 deletion jax_scaled_arithmetics/ops/rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from functools import partial

import jax
import jax.numpy as jnp

# import jax.numpy as jnp
import numpy as np

from jax_scaled_arithmetics.core import ScaledArray, pow2_round
Expand Down

0 comments on commit dbb1c1d

Please sign in to comment.