From 3c21e0583a0fd3eb79809d184164748bc89c29a2 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 18 Jan 2024 09:20:25 +0000 Subject: [PATCH 1/2] wip --- .../mnist/mnist_classifier_from_scratch.py | 54 ++++++++++++++++++- jax_scaled_arithmetics/ops/rescaling.py | 2 +- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/experiments/mnist/mnist_classifier_from_scratch.py b/experiments/mnist/mnist_classifier_from_scratch.py index 5c5c719..a0e6597 100644 --- a/experiments/mnist/mnist_classifier_from_scratch.py +++ b/experiments/mnist/mnist_classifier_from_scratch.py @@ -28,9 +28,17 @@ from jax import grad, jit, lax import jax_scaled_arithmetics as jsa +from functools import partial # from jax.scipy.special import logsumexp +def print_mean_std(name, v): + data, scale = jsa.lax.get_data_scale(v) + # Always use np.float32, to avoid floating errors in descaling + stats. + data = jsa.asarray(data, dtype=np.float32) + m, s, min, max = np.mean(data), np.std(data), np.min(data), np.max(data) + print(f"{name}: MEAN({m:.5f}) / STD({s:.5f}) / MIN({min:.5f}) / MAX({max:.5f}) / SCALE({scale:.5f})") + def logsumexp(a, axis=None, keepdims=False): dims = (axis,) @@ -47,6 +55,27 @@ def logsumexp(a, axis=None, keepdims=False): def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] +# jax.nn.logsumexp + +def one_hot_dot(logits, mask, axis: int): + size = logits.shape[axis] + + mask = jsa.lax.rebalance(mask, np.float32(1./8.)) + + jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits) + (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) + jsa.ops.debug_callback(partial(print_mean_std, "Mask"), mask) + + r = jnp.sum(logits * mask, axis=axis) + jsa.ops.debug_callback(partial(print_mean_std, "Out"), r) + print("SIZE:", size, jsa.core.pow2_round_down(np.float32(size))) + (r,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "OutGrad"), r) + + + return r + + + def predict(params, inputs): activations = inputs @@ -58,14 +87,32 @@ def predict(params, inputs): final_w, final_b = params[-1] logits = jnp.dot(activations, final_w) + final_b # Dynamic rescaling of the gradient, as logits gradient not properly scaled. - logits = jsa.ops.dynamic_rescale_l2_grad(logits) - logits = logits - logsumexp(logits, axis=1, keepdims=True) + # logits = jsa.ops.dynamic_rescale_l2_grad(logits) + + # print("LOGITS", logits) + # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad0"), logits) + logsumlogits = logsumexp(logits, axis=1, keepdims=True) + # (logsumlogits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsLogSumGrad"), logsumlogits) + logits = logits - logsumlogits + # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad1"), logits) + + # logits = jsa.ops.dynamic_rescale_l1_grad(logits) return logits def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) + loss = one_hot_dot(preds, targets, axis=1) + # loss = jnp.sum(preds * targets, axis=1)s + # loss = jsa.ops.dynamic_rescale_l1_grad(loss) + (loss,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LossGrad2"), loss) + loss = -jnp.mean(loss) + jsa.ops.debug_callback(partial(print_mean_std, "Loss"), loss) + (loss,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LossGrad"), loss) + + + return loss return -jnp.mean(jnp.sum(preds * targets, axis=1)) @@ -111,6 +158,9 @@ def update(params, batch): grads = grad(loss)(params, batch) return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] + num_batches = 2 + num_epochs = 1 + for epoch in range(num_epochs): start_time = time.time() for _ in range(num_batches): diff --git a/jax_scaled_arithmetics/ops/rescaling.py b/jax_scaled_arithmetics/ops/rescaling.py index f2e0325..e0daf86 100644 --- a/jax_scaled_arithmetics/ops/rescaling.py +++ b/jax_scaled_arithmetics/ops/rescaling.py @@ -51,7 +51,7 @@ def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray: data_sq = jax.lax.abs(data) axes = tuple(range(data.ndim)) # Get MAX norm + pow2 rounding. - norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes) + norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes) / (64) norm = jax.lax.max(pow2_round(norm).astype(scale.dtype), eps.astype(scale.dtype)) # Rebalancing based on norm. return rebalance(arr, norm) From 556641b2defc3ad84a8c8b6184d42d36f3fbb8ad Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 18 Jan 2024 09:47:37 +0000 Subject: [PATCH 2/2] wip --- notebooks/log-softmax-analysis.ipynb | 335 +++++++++++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 notebooks/log-softmax-analysis.ipynb diff --git a/notebooks/log-softmax-analysis.ipynb b/notebooks/log-softmax-analysis.ipynb new file mode 100644 index 0000000..7f9b080 --- /dev/null +++ b/notebooks/log-softmax-analysis.ipynb @@ -0,0 +1,335 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 36, + "id": "40b36be1-307a-437b-a401-8411407993f0", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "import jax\n", + "import jax.lax as lax\n", + "import jax.nn\n", + "import jax.numpy as jnp\n", + "import jax_scaled_arithmetics as jsa" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "fddd3b72-ce99-4e96-bae4-b4a884e749e7", + "metadata": {}, + "outputs": [], + "source": [ + "B = 128\n", + "N = 10" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "9462a6fe-5ad4-459d-8be5-b24f0f8fe7af", + "metadata": {}, + "outputs": [], + "source": [ + "act = np.random.randn(B, N).astype(np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "80c21df4-2715-4404-ab5d-3ddceb34f4e3", + "metadata": {}, + "outputs": [], + "source": [ + "def logsumexp(a, axis=None, keepdims=True):\n", + " dims = (axis,)\n", + " amax = jnp.max(a, axis=dims, keepdims=keepdims)\n", + " # FIXME: not proper scale propagation, introducing NaNs\n", + " # amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))\n", + " amax = lax.stop_gradient(amax)\n", + " out = lax.sub(a, amax)\n", + " out = lax.exp(out)\n", + " out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax)\n", + " return out\n" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "84349e45-5ce7-4fd2-9449-9d30c48de291", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(128, 1)" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def fn(act):\n", + " return logsumexp(act, axis=1)\n", + "\n", + "fn(act).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "44aeddb0-948f-4da7-b55e-fa90f3548a97", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tracedwith\n" + ] + }, + { + "data": { + "text/plain": [ + "{ lambda ; a:f32[128,10]. let\n", + " b:f32[128] = reduce_max[axes=(1,)] a\n", + " c:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] b\n", + " d:f32[128,1] = stop_gradient c\n", + " e:f32[128,10] = sub a d\n", + " f:f32[128,10] = exp e\n", + " g:f32[128] = reduce_sum[axes=(1,)] f\n", + " h:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] g\n", + " i:f32[128,1] = log h\n", + " j:f32[128,1] = add i d\n", + " in (j,) }" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.make_jaxpr(fn)(act)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "15cf4d4e-e367-4bc5-836c-d9e5b13ea3c9", + "metadata": {}, + "outputs": [], + "source": [ + "out, fn_vjp = jax.vjp(fn, act)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "6f0d8366-1e44-4795-babc-2e009ed111d7", + "metadata": {}, + "outputs": [], + "source": [ + "def fn_with_grad(in_act, out_grad):\n", + " out_act, fn_vjp = jax.vjp(fn, in_act)\n", + " return out_act, fn_vjp(out_grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "097a08b4-1a61-4dd6-84f1-f4105a53d9e2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{ lambda ; a:f32[128,10] b:f32[128,1]. let\n", + " c:f32[128] = reduce_max[axes=(1,)] a\n", + " d:f32[128,1] = reshape[dimensions=None new_sizes=(128, 1)] c\n", + " e:bool[128,10] = eq a d\n", + " f:f32[128,10] = convert_element_type[new_dtype=float32 weak_type=False] e\n", + " _:f32[128] = reduce_sum[axes=(1,)] f\n", + " g:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] c\n", + " h:f32[128,1] = stop_gradient g\n", + " i:f32[128,10] = sub a h\n", + " j:f32[128,10] = exp i\n", + " k:f32[128] = reduce_sum[axes=(1,)] j\n", + " l:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] k\n", + " m:f32[128,1] = log l\n", + " n:f32[128,1] = add m h\n", + " o:f32[128,1] = div b l\n", + " p:f32[128] = reduce_sum[axes=(1,)] o\n", + " q:f32[128,10] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 10)] p\n", + " r:f32[128,10] = mul q j\n", + " in (n, r) }" + ] + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.make_jaxpr(fn_with_grad)(act, act[:, :1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07782137-deb5-4a34-805c-209d68f86880", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a709ffe-3bca-4a5b-b96c-50cebf8f4dd1", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 72, + "id": "084098be-2729-4080-b484-fc98fc2febd9", + "metadata": {}, + "outputs": [], + "source": [ + "def fn2(x, y):\n", + " return x * y\n", + "\n", + "\n", + "def fn2_with_grad(in_act, out_grad):\n", + " out_act, fn_vjp = jax.vjp(fn2, in_act, in_act)\n", + " return out_act, fn_vjp(out_grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "c922d288-e938-4de0-a46a-edb48df6d3c0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{ lambda ; a:f32[128,10] b:f32[128,10]. let\n", + " c:f32[128,10] = mul a a\n", + " d:f32[128,10] = mul a b\n", + " e:f32[128,10] = mul b a\n", + " in (c, e, d) }" + ] + }, + "execution_count": 71, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.make_jaxpr(fn2_with_grad)(act, act)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c2f08c3-05b5-4661-81a4-1abfc1a4e625", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "bd63a6ea-2ab0-4f0f-81cd-8ce4e25ef3b9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((128, 10), dtype('float32'))" + ] + }, + "execution_count": 79, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "act.shape, act.dtype" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "4e8e0182-2bf6-415b-b45e-fa12ac355db5", + "metadata": {}, + "outputs": [], + "source": [ + "def fn3(x):\n", + " return jax.grad(lambda x: jnp.mean(x))(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "34dcb443-e0e6-4ded-9ce9-b4bb6660d553", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{ lambda ; a:f32[128,10]. let\n", + " b:f32[] = reduce_sum[axes=(0, 1)] a\n", + " _:f32[] = div b 1280.0\n", + " c:f32[] = div 1.0 1280.0\n", + " d:f32[128,10] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 10)] c\n", + " in (d,) }" + ] + }, + "execution_count": 81, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.make_jaxpr(fn3)(act)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5228dc2-8058-4c9a-9089-3e44a9b1eeba", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}