diff --git a/README.md b/README.md index 2775411..dff4424 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# JAX Scalify: end-to-end scaled Arithmetics +# JAX Scalify: end-to-end scaled arithmetic **JAX Scalify** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of deep neural networks in low precision (BF16, FP16, FP8). diff --git a/examples/autoscale-quickstart.ipynb b/examples/autoscale-quickstart.ipynb deleted file mode 100644 index f273d37..0000000 --- a/examples/autoscale-quickstart.ipynb +++ /dev/null @@ -1,269 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "7c85dead-5274-487c-91ff-7137fbaca393", - "metadata": {}, - "source": [ - "# JAX Scaled Arithmetics / AutoScale quickstart\n", - "\n", - "**JAX Scaled Arithmetics** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of\n", - "deep neural networks in low precision (BF16, FP16, FP8) with full scale propagation." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "30940677-4296-40fa-b418-351fcfb62098", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import jax\n", - "import jax_scalify as jsa" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e0e729aa-7a81-4001-8a34-9a00ec822948", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "f374e654-97e4-43ef-902a-a890d36a52b9", - "metadata": {}, - "outputs": [], - "source": [ - "# `scalify` interpreter is tracing the graph, adding scale propagation where necessary.\n", - "@jsa.scalify\n", - "def fn(a, b):\n", - " return a + b" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "8c59245d-27e5-41a7-bfef-f40849a7b550", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "INPUTS: [1. 2.] [3. 6.]\n", - "OUTPUT: [4. 8.] \n" - ] - } - ], - "source": [ - "# Let's start with standard JAX inputs\n", - "a = np.array([1, 2], np.float16)\n", - "b = np.array([3, 6], np.float16)\n", - "out = fn(a, b)\n", - "\n", - "print(\"INPUTS:\", a, b)\n", - "# No scaled arithmetics => \"normal\" JAX mode.\n", - "print(\"OUTPUT:\", out, type(out))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e60cf138-d92d-4ab9-89d4-bacc9e28c39f", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "e7efaa2e-00a1-40e8-9bbb-685edc975636", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SCALED inputs: ScaledArray(data=array([1., 2.], dtype=float16), scale=1.0) ScaledArray(data=array([1.5, 3. ], dtype=float16), scale=2.0)\n", - "UNSCALED inputs: [1. 2.] [3. 6.]\n" - ] - } - ], - "source": [ - "# Let's create input scaled arrays.\n", - "# NOTE: scale dtype does not have to match core data dtype.\n", - "sa = jsa.as_scaled_array(a, scale=np.float32(1))\n", - "sb = jsa.as_scaled_array(b, scale=np.float32(2))\n", - "\n", - "print(\"SCALED inputs:\", sa, sb)\n", - "# `as_scaled_array` does not change the value of tensor:\n", - "print(\"UNSCALED inputs:\", np.asarray(sa), np.asarray(sb))" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "1f457243-a0b8-4e4d-b45d-7444d0566b37", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SCALED OUTPUT: ScaledArray(data=DeviceArray([2., 4.], dtype=float16), scale=DeviceArray(2., dtype=float32))\n", - "No scale rounding: ScaledArray(data=DeviceArray([1.789, 3.578], dtype=float16), scale=DeviceArray(2.236068, dtype=float32))\n" - ] - } - ], - "source": [ - "# Running `fn` on scaled arrays triggers `scalify` graph transformation\n", - "sout = fn(sa, sb)\n", - "# NOTE: by default, scale propagation is using power-of-2.\n", - "print(\"SCALED OUTPUT:\", sout)\n", - "\n", - "# To choose a different scale rounding:\n", - "with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.NONE):\n", - " print(\"No scale rounding:\", fn(sa, sb))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c2429c10-00d9-44f8-b0b6-a1fdf13ed264", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "307ee27d-6ed2-4ab6-a152-83947dbf47fd", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "RESCALED OUTPUT: ScaledArray(data=DeviceArray([0.5, 1. ], dtype=float16), scale=DeviceArray(8., dtype=float32))\n" - ] - }, - { - "data": { - "text/plain": [ - "functools.partial(, )" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# JAX Scaled Arithmetics offers basic dynamic rescaling methods. e.g.: max, l1, l2\n", - "sout_rescaled = jsa.ops.dynamic_rescale_max(sout)\n", - "print(\"RESCALED OUTPUT:\", sout_rescaled)\n", - "\n", - "# Equivalent methods are available to dynamically rescale gradients:\n", - "jsa.ops.dynamic_rescale_l1_grad" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "32930d15-d7ff-41d1-85be-eee958bb0741", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# NOTE: in normal JAX mode, these rescale operations are no-ops:\n", - "jsa.ops.dynamic_rescale_max(a) is a" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ea5942e7-0279-4dc4-a720-b8c7323ab6a1", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "9920f44a-26e2-4e20-89c3-4e2b2548239f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ScaledArray(data=DeviceArray([16., 20.], dtype=float32), scale=1.0)\n" - ] - } - ], - "source": [ - "import ml_dtypes\n", - "# Minimal FP8 simulated support is provided using jax.lax.reduce_precision and ml_dtypes.\n", - "# Similarly to `dynamic_rescale`, `cast_ml_dtype(_grad)` are available to cast in forward and backward passes\n", - "sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(1))\n", - "\n", - "@jsa.scalify\n", - "def cast_fn(v):\n", - " return jsa.ops.cast_ml_dtype(v, ml_dtypes.float8_e4m3fn)\n", - "\n", - "sc_fp8 = cast_fn(sc)\n", - "print(sc_fp8)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1bd7c1d5-4ea2-4ded-a066-818d9146b8a6", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/scalify-quickstart.ipynb b/examples/scalify-quickstart.ipynb new file mode 100644 index 0000000..32781ba --- /dev/null +++ b/examples/scalify-quickstart.ipynb @@ -0,0 +1,672 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7c85dead-5274-487c-91ff-7137fbaca393", + "metadata": {}, + "source": [ + "# JAX Scalify: Quickstart on end-to-end scaled arithmetic\n", + "\n", + "**JAX Scalify** is a library implemeting general scaled arithmetic in JAX, allowing end-to-end scale propagation in computational graphs and easy training/inference of deep neural networks in low precision (mainly FP16 & FP8).\n", + "\n", + "JAX Scalify supports converting any computational graph into a scaled computational graph, i.e. a function with `ScaledArray` inputs/outputs.\n", + "\n", + "```python\n", + "@dataclass\n", + "class ScaledArray:\n", + " # Main `data` component, with \"low precision\"\n", + " data: Array\n", + " # Scale, usually scalar, represented in E8M0 or FP32.\n", + " scale: Array\n", + "```\n", + "It fully decouples scale propagation from model definition, allowing easy experimentation and debugging with low precision formats such as FP16 and FP8." + ] + }, + { + "cell_type": "markdown", + "id": "39019611", + "metadata": {}, + "source": [ + "## Scaled array representation\n", + "\n", + "In Scalify, every tensor is a `ScaledArray`. This systematic approach simplifies the use of FP16 and FP8 in LLM training, decoupling the scale and numerical stability questions from the high-level model definition. \n", + "\n", + "Below is presented the basics of `ScaledArray` construction, and how it helps representing very large or very small tensors." + ] + }, + { + "cell_type": "code", + "execution_count": 226, + "id": "30940677-4296-40fa-b418-351fcfb62098", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import numpy.testing as npt\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax_scalify as jsa" + ] + }, + { + "cell_type": "code", + "execution_count": 227, + "id": "e0e729aa-7a81-4001-8a34-9a00ec822948", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Normal `a`: [1. 2.]\n", + "Scaled `a`: ScaledArray(data=array([1., 2.], dtype=float16), scale=1.0) ~ [1. 2.]\n" + ] + } + ], + "source": [ + "# Let's start at the beginning: convert an array to a ScaledArray.\n", + "a = np.array([1, 2], np.float16)\n", + "# Analogue of `np.asarray`, with in addition passing of the scale to use.\n", + "# NOTE: scale dtype does not have to match core data dtype. Usually using np.float32\n", + "sa = jsa.as_scaled_array(a, scale=np.float32(1))\n", + "assert isinstance(sa, jsa.ScaledArray)\n", + "\n", + "# `a` and `sa` represent the same formal tensor.\n", + "print(\"Normal `a`:\", a)\n", + "print(\"Scaled `a`:\", sa, \" ~ \", np.asarray(sa))" + ] + }, + { + "cell_type": "code", + "execution_count": 228, + "id": "5f624725", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Normal `a`: [1. 2.]\n", + "Scaled `a`: ScaledArray(data=array([2., 4.], dtype=float16), scale=0.5) ~ [1. 2.]\n" + ] + } + ], + "source": [ + "# Scalify preserves the semantics of arrays and computational graphs.\n", + "# Passing a different scale does not change the \"value\" of a represented tensor.\n", + "sa = jsa.as_scaled_array(a, scale=np.float32(0.5))\n", + "# `a` and `sa` still represent the same formal tensor.\n", + "print(\"Normal `a`:\", a)\n", + "print(\"Scaled `a`:\", sa, \" ~ \", np.asarray(sa))" + ] + }, + { + "cell_type": "code", + "execution_count": 229, + "id": "c49c5c55", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<< Scaled Arrays with large values >>\n", + "Normal `a` FP32: [131072. 262144.]\n", + "Normal `a` FP16: [inf inf]\n", + "Scaled `a` FP16: ScaledArray(data=array([1., 2.], dtype=float16), scale=131072.0) ~ [131072. 262144.]\n", + "\n", + "<< Scaled Arrays with small values >>\n", + "Normal `a` FP32: [2.9802322e-08 5.9604645e-08]\n", + "Normal `a` FP16: [0.e+00 6.e-08]\n", + "Scaled `a` FP16: ScaledArray(data=array([0.001953, 0.003906], dtype=float16), scale=1.5258789e-05) ~ [2.9802322e-08 5.9604645e-08]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_392367/4076835521.py:5: RuntimeWarning: overflow encountered in cast\n", + " a_fp16 = a.astype(np.float16)\n" + ] + } + ], + "source": [ + "# Why using Scaled Arrays? => representing very \"small\" or \"large\" tensor.\n", + "# Large FP32 tensor.\n", + "a = np.array([2, 4], np.float32) * 256**2\n", + "# Overflowing to Inf in FP16\n", + "a_fp16 = a.astype(np.float16)\n", + "# \"Properly\" represented with a large scale.\n", + "sa_fp16 = jsa.as_scaled_array(a, scale=np.float32(256**2 * 2)).astype(np.float16)\n", + "\n", + "print(\"<< Scaled Arrays with large values >>\")\n", + "print(\"Normal `a` FP32:\", a)\n", + "print(\"Normal `a` FP16:\", a_fp16)\n", + "# FP16 scaled representation does not overflow. \n", + "print(\"Scaled `a` FP16:\", sa_fp16, \" ~ \", np.asarray(sa_fp16, dtype=np.float32))\n", + "\n", + "a = np.array([2, 4], np.float32) * (256*32)**-2\n", + "a_fp16 = a.astype(np.float16)\n", + "sa_fp16 = jsa.as_scaled_array(a, scale=np.float32(256**-2)).astype(np.float16)\n", + "\n", + "print(\"\\n<< Scaled Arrays with small values >>\")\n", + "print(\"Normal `a` FP32:\", a)\n", + "# Underflowing + loss of precision in sub-normals representation.\n", + "print(\"Normal `a` FP16:\", a_fp16)\n", + "# FP16 scaled representation does not underflow.\n", + "# NOTE: scale factor does not need to be \"perfect\" to keep accurate representation.\n", + "print(\"Scaled `a` FP16:\", sa_fp16, \" ~ \", np.asarray(sa_fp16, dtype=np.float32))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a018d505", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "e91afff9", + "metadata": {}, + "source": [ + "### Scaled array and FP8 formats\n", + "\n", + "How does it work with FP8? The same `:)`\n", + "Last generation GPUs supports two FP8 formats define by the OCP FP8 specification (https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1):\n", + "* `float8_e4m3fn`: 4 exponent bits, 3 mantissa bits, no infinity;\n", + "* `float8_e5m2fnuz`: 5 exponent bits, 2 mantissa bits, with infinity;\n", + "\n", + "**Note:** there is still on-going IEEE standardization work on FP8 formats (see https://github.com/P3109/Public/blob/main/Shared%20Reports/P3109%20WG%20Interim%20Report.pdf). " + ] + }, + { + "cell_type": "code", + "execution_count": 230, + "id": "aa737550", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FP8-E4M3: Machine parameters for float8_e4m3fn\n", + "---------------------------------------------------------------\n", + "precision = 1 resolution = 1.00e-01\n", + "machep = -3 eps = 1.25e-01\n", + "negep = -4 epsneg = 6.25e-02\n", + "minexp = -6 tiny = 1.56e-02\n", + "maxexp = 9 max = 4.48e+02\n", + "nexp = 4 min = -max\n", + "smallest_normal = 1.56e-02 smallest_subnormal = 1.95e-03\n", + "---------------------------------------------------------------\n", + "\n", + "FP8-E5M2: Machine parameters for float8_e5m2fnuz\n", + "---------------------------------------------------------------\n", + "precision = 1 resolution = 1.00e-01\n", + "machep = -2 eps = 2.50e-01\n", + "negep = -3 epsneg = 1.25e-01\n", + "minexp = -15 tiny = 3.05e-05\n", + "maxexp = 16 max = 5.73e+04\n", + "nexp = 5 min = -max\n", + "smallest_normal = 3.05e-05 smallest_subnormal = 7.63e-06\n", + "---------------------------------------------------------------\n", + "\n" + ] + } + ], + "source": [ + "# FP8 formats properties\n", + "print(\"FP8-E4M3:\", jnp.finfo(jnp.float8_e4m3fn))\n", + "print(\"FP8-E5M2:\", jnp.finfo(jnp.float8_e5m2fnuz))" + ] + }, + { + "cell_type": "code", + "execution_count": 231, + "id": "70e85309", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Normal `a` FP32: [400. 448. 512.]\n", + "Normal `a` FP8-E4M3: [384 448 nan]\n", + "Scaled `a` FP8-E4M3: ScaledArray(data=Array([3, 3.5, 4], dtype=float8_e4m3fn), scale=128.0) ~ [384. 448. 512.]\n" + ] + } + ], + "source": [ + "\n", + "a = jnp.array([400, 448, 512], np.float32)\n", + "# Overflowing to NaN as no Inf available.\n", + "a_fp8_e4m3 = a.astype(jnp.float8_e4m3fn)\n", + "# Scaled representation, without overflowing.\n", + "as_fp8_e4m3 = jsa.as_scaled_array(a, scale=np.float32(128)).astype(jnp.float8_e4m3fn)\n", + "\n", + "print(\"Normal `a` FP32:\", a)\n", + "# NOTE: the loss of precision due to 3-bit mantissa.\n", + "print(\"Normal `a` FP8-E4M3:\", a_fp8_e4m3)\n", + "print(\"Scaled `a` FP8-E4M3:\", as_fp8_e4m3, \" ~ \", np.asarray(as_fp8_e4m3, dtype=np.float32))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab192562", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 252, + "id": "8b93d946", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scaled input in FP32: ScaledArray(data=array([8.5, 9.5]), scale=2.0)\n", + "Pseudo-cast to ML dtypes: ScaledArray(data=Array([ 8., 10.], dtype=float32), scale=2.0)\n" + ] + } + ], + "source": [ + "import ml_dtypes\n", + "# Minimal FP8 simulated support is provided using jax.lax.reduce_precision and ml_dtypes.\n", + "# Similarly to `dynamic_rescale`, `cast_ml_dtype(_grad)` are available to cast in forward and backward passes\n", + "sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(2))\n", + "\n", + "@jsa.scalify\n", + "def cast_fn(v):\n", + " return jsa.ops.cast_ml_dtype(v, ml_dtypes.float8_e4m3fn)\n", + "\n", + "sc_fp8 = cast_fn(sc)\n", + "print(\"Scaled input in FP32:\", sc)\n", + "# NOTE: still using FP32 (or FP16) as underlying storage.\n", + "print(\"Pseudo-cast to ML dtypes:\", sc_fp8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ed59571", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "8442121f", + "metadata": {}, + "source": [ + "## `scalify` transform: end-to-end scale propagation\n", + "\n", + "The `scalify` transform performs end-to-end scale propagation, with application of \"unit-scaling\" type rules. `scalify` for now only supports a subset of [LAX operators](../docs/operators.md), and will raise an error if unsupported operations are used in the computational graph." + ] + }, + { + "cell_type": "code", + "execution_count": 232, + "id": "f374e654-97e4-43ef-902a-a890d36a52b9", + "metadata": {}, + "outputs": [], + "source": [ + "# `scalify` transform is tracing the graph, adding scale propagation where necessary.\n", + "@jsa.scalify\n", + "def fn(a, b):\n", + " return a + b" + ] + }, + { + "cell_type": "code", + "execution_count": 233, + "id": "8c59245d-27e5-41a7-bfef-f40849a7b550", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUTS: [1. 2.] [3. 6.]\n", + "OUTPUT: [4. 8.] float16 \n" + ] + } + ], + "source": [ + "# Let's start with standard JAX inputs\n", + "a = np.array([1, 2], np.float16)\n", + "b = np.array([3, 6], np.float16)\n", + "# The function `fn` is unchanged with unscaled inputs. \n", + "out = fn(a, b)\n", + "\n", + "print(\"INPUTS:\", a, b)\n", + "# \"Unscaled\" inputs => \"normal\" JAX mode with unscaled outputs\n", + "print(\"OUTPUT:\", out, out.dtype, type(out))" + ] + }, + { + "cell_type": "code", + "execution_count": 234, + "id": "e7efaa2e-00a1-40e8-9bbb-685edc975636", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scaled inputs:\n", + "\tScaledArray(data=array([0.5, 1. ], dtype=float16), scale=2.0)\n", + "\tScaledArray(data=array([0.75, 1.5 ], dtype=float16), scale=4.0)\n", + "Equivalent input arrays: [1. 2.] [3. 6.]\n" + ] + } + ], + "source": [ + "# Let's create input scaled arrays.\n", + "sa = jsa.as_scaled_array(a, scale=np.float32(2))\n", + "sb = jsa.as_scaled_array(b, scale=np.float32(4))\n", + "\n", + "print(f\"Scaled inputs:\\n\\t{sa}\\n\\t{sb}\")\n", + "# `as_scaled_array` does not change the semantic: same underlying tensor represented.\n", + "print(\"Equivalent input arrays:\", np.asarray(sa), np.asarray(sb))" + ] + }, + { + "cell_type": "code", + "execution_count": 235, + "id": "1f457243-a0b8-4e4d-b45d-7444d0566b37", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scaled output: ScaledArray(data=Array([1., 2.], dtype=float16), scale=Array(4., dtype=float32))\n", + "Equivalent unscaled output: [4. 8.]\n", + "\n", + "Scaled output without scale rounding: ScaledArray(data=Array([0.8945, 1.789 ], dtype=float16), scale=Array(4.472136, dtype=float32))\n" + ] + } + ], + "source": [ + "# Running `fn` on scaled arrays triggers `scalify` graph transformation & scale propagtion\n", + "sout = fn(sa, sb)\n", + "# NOTE: by default, scale propagation is using power-of-2.\n", + "assert isinstance(sout, jsa.ScaledArray)\n", + "print(\"Scaled output:\", sout)\n", + "print(\"Equivalent unscaled output:\", np.asarray(sout))\n", + "\n", + "# To choose a different scale rounding:\n", + "with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.NONE):\n", + " print(\"\\nScaled output without scale rounding:\", fn(sa, sb))" + ] + }, + { + "cell_type": "markdown", + "id": "02a53217", + "metadata": {}, + "source": [ + "### Why using unit-scaling rules in `scalify`?\n", + "\n", + "We present in this section how unit-scaling rules implemented in `scalify` are propagating optimally scaling in operations. We show a simple example of FP16 matrix multiplication where `scalify` avoids overflowing (compared to normal model)." + ] + }, + { + "cell_type": "code", + "execution_count": 236, + "id": "c2429c10-00d9-44f8-b0b6-a1fdf13ed264", + "metadata": {}, + "outputs": [], + "source": [ + "# `scalify` scale propagation is using `unit-scaling` static scale propagation rules.\n", + "@jsa.scalify\n", + "def matmul_fn(a, b):\n", + " return a @ b" + ] + }, + { + "cell_type": "code", + "execution_count": 237, + "id": "384be44f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUTS std: 128.61899 62.915386\n", + "OUTPUT std: 251969.02\n" + ] + } + ], + "source": [ + "# Large reduction axis K.\n", + "M, N, K = 16, 8, 1024\n", + "ascale = 128\n", + "bscale = 64\n", + "# IID Gaussian inputs.\n", + "a = np.random.randn(M, K).astype(np.float32) * ascale\n", + "b = np.random.randn(K, N).astype(np.float32) * bscale\n", + "\n", + "# The function `fn` is unchanged with unscaled inputs. \n", + "out = matmul_fn(a, b)\n", + "\n", + "print(\"INPUTS std:\", np.std(a), np.std(b))\n", + "# Large matmul output standard deviation.\n", + "print(\"OUTPUT std:\", np.std(out))" + ] + }, + { + "cell_type": "code", + "execution_count": 238, + "id": "4b2759ba", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUTS data std: 1.0048358 0.9830529\n", + "INPUTS scales: 128.0 64.0\n", + "OUTPUT data std and scale: 0.9611855 262144.0\n" + ] + } + ], + "source": [ + "# Let's create equivalent input scaled arrays.\n", + "sa = jsa.as_scaled_array(a, scale=np.float32(ascale))\n", + "sb = jsa.as_scaled_array(b, scale=np.float32(bscale))\n", + "\n", + "# Scale propagation in `matmul`\n", + "sout = matmul_fn(sa, sb)\n", + "\n", + "print(\"INPUTS data std:\", np.std(sa.data), np.std(sb.data))\n", + "print(\"INPUTS scales:\", sa.scale, sb.scale)\n", + "# Large scale is getting incorporated in `scale`, with main `data` std ~ 1.\n", + "print(\"OUTPUT data std and scale:\", np.std(sout.data), sout.scale)" + ] + }, + { + "cell_type": "code", + "execution_count": 239, + "id": "1d4df895", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Are INPUTS finite? True True\n", + "How many OUTPUT values finite? (vs nb entries) 28 128\n" + ] + } + ], + "source": [ + "# How about the same matmul in FP16\n", + "a_fp16 = a.astype(np.float16)\n", + "b_fp16 = b.astype(np.float16)\n", + "out_fp16 = matmul_fn(a_fp16, b_fp16)\n", + "\n", + "# Finite inputs, but overflowing output.\n", + "print(\"Are INPUTS finite?\", np.all(np.isfinite(a_fp16)), np.all(np.isfinite(b_fp16)))\n", + "print(\"How many OUTPUT values finite? (vs nb entries)\", np.sum(np.isfinite(out_fp16)), out_fp16.size)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 240, + "id": "ab671a43", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUTS data std: 1.005 0.983\n", + "INPUTS scales: 128.0 64.0\n", + "OUTPUT data std and scale: 0.961 262144.0\n", + "Scalify FP16 matmul relative error (mean/max) 0.00057976914 0.057415348\n" + ] + } + ], + "source": [ + "# Let's create equivalent input scaled arrays.\n", + "sa_fp16 = sa.astype(np.float16)\n", + "sb_fp16 = sb.astype(np.float16)\n", + "\n", + "# Scale propagation in `matmul` FP16 => not overflowing.\n", + "sout_fp16 = matmul_fn(sa_fp16, sb_fp16)\n", + "\n", + "print(\"INPUTS data std:\", np.std(sa_fp16.data), np.std(sb_fp16.data))\n", + "print(\"INPUTS scales:\", sa_fp16.scale, sb_fp16.scale)\n", + "# Large scale is getting incorporated in `scale`, with main `data` std ~ 1.\n", + "print(\"OUTPUT data std and scale:\", np.std(sout_fp16.data), sout_fp16.scale)\n", + "\n", + "# Relative error vs FP32 matmul\n", + "rel_error = np.abs(np.asarray(sout_fp16, dtype=np.float32) - out) / out\n", + "print(\"Scalify FP16 matmul relative error (mean/max)\", np.mean(rel_error), np.max(rel_error))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ebcc3c8e", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "d7d8f873", + "metadata": {}, + "source": [ + "### `scalify` dynamic rescaling\n", + "\n", + "As well known, the neural-network activations, weights and gradients do not follow perfect a Gaussian assumption. As a consequence, we provide in `scalify` a way to dynamically rescale tensor at any point in the computational graph. " + ] + }, + { + "cell_type": "code", + "execution_count": 241, + "id": "b45af678", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INPUT a: 65.71072\n", + "Static scaled INPUT a: 16.42768 4.0\n" + ] + } + ], + "source": [ + "a = np.random.randn(1024).astype(np.float32) * 64\n", + "sa_in = jsa.as_scaled_array(a, scale=np.float32(4))\n", + "\n", + "print(\"INPUT a:\", np.std(a))\n", + "print(\"Static scaled INPUT a:\", np.std(sa_in.data), sa_in.scale)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 242, + "id": "307ee27d-6ed2-4ab6-a152-83947dbf47fd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dynamic (re)scaled INPUT a: 1.02673 64.0\n" + ] + } + ], + "source": [ + "# Dynamic rescaling of scaled array, using L2 norm (rounded to power-of-two).\n", + "sa_out = jsa.ops.dynamic_rescale_l2(sa_in)\n", + "print(\"Dynamic (re)scaled INPUT a:\", np.std(sa_out.data), sa_out.scale)\n", + "\n", + "# `dynamic_rescale` operations do not change the semantic of the tensor.\n", + "npt.assert_array_almost_equal(np.asarray(sa_out), a)" + ] + }, + { + "cell_type": "code", + "execution_count": 245, + "id": "32930d15-d7ff-41d1-85be-eee958bb0741", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 245, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# NOTE: in normal JAX mode, these rescale operations are no-ops:\n", + "jsa.ops.dynamic_rescale_max(a) is a" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1bd7c1d5-4ea2-4ded-a066-818d9146b8a6", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}