diff --git a/.gitignore b/.gitignore index f53c4ff..ab6f551 100644 --- a/.gitignore +++ b/.gitignore @@ -80,6 +80,7 @@ target/ # Jupyter Notebook .ipynb_checkpoints +.jupyter_ystore.db # IPython profile_default/ diff --git a/docs/JAX FP8 matmul tutorial.ipynb b/docs/JAX FP8 matmul tutorial.ipynb new file mode 100644 index 0000000..e631843 --- /dev/null +++ b/docs/JAX FP8 matmul tutorial.ipynb @@ -0,0 +1,751 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c29689be-e7f4-40fb-9942-8f8944364239", + "metadata": {}, + "source": [ + "# JAX FP8 (fused) matmul tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "f878aaba-ce22-42d3-89e3-ad7f22b6f75c", + "metadata": {}, + "source": [ + "## Quickstart: FP8 in deep learning\n", + "\n", + "The latest generation of machine learning hardware ([Nvidia H100](https://www.nvidia.com/en-gb/data-center/h100/), [AMD MI300X](https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html), [Graphcore C600](https://www.graphcore.ai/products/c600), ...) have integrated direct **FP8** support in the hardware, improving energy efficiency and throughput.\n", + "\n", + "As shown the low precision ML literature, two distinct formats are necessary to support to achieve similar accuracy to `bfloat16` (or `float16`) training: **`E4M3`** and **`E5M2`** `float8` formats. As presented below, the two formats differ in the trade-off between precision (i.e. mantissa bits) and dynamic range (i.e. exponent bits). In short, `E4M3` is used for storing **weights** and **activations** whereas `E5M2` for representing backward **gradients** (which require a higher dynamic range).\n", + "\n", + "![image](img/fp-formats.webp)\n", + "\n", + "Note that **different variations** of `E4M3` and `E5M2` exist in the literature, depending on whether infinities, NaN or negative zero have special encodings reserved (see below in the references). The Python library [`ml_dtypes`](https://github.com/jax-ml/ml_dtypes) implements these different 8-bit floating point representations as NumPy extensions.\n", + "\n", + "These two new FP8 formats introduced a major hardware difference compared to FP16 and BF16 support: FP8 hardware needs to support mixed input matrix multiplication (i.e. `E4M3 @ E5M2`) for the model training backward pass. \n", + "\n", + "In this tutorial notebook, we investigate how the ML stack JAX + XLA handles the specificities of **FP8 matmuls**, while still generating an optimal fused kernel call including:\n", + "* FP8 inputs scaling;\n", + "* FP8 output scaling & clamping;\n", + "* Non-linearity & bias fusing;\n", + "* Abs-max output capture;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51775bad-18ad-49b7-9371-930b3704a294", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "361686c2-535b-461b-abda-e88dc943de98", + "metadata": {}, + "source": [ + "## FP8 E4M3 and E5M2 format datatypes\n", + "\n", + "`E4M3` and `E5M2` datatype formats have been integrated in major ML frameworks (e.g. PyTorch and JAX), and can be used as any other classic NumPy dtype. [`ml_dtypes`](https://github.com/jax-ml/ml_dtypes) provides floating point information for these FP8 formats, showing in particular the small dynamic range of `E4M3` datatype (i.e. ±448) compared to `E5M2` (i.e. ±57344)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "fb62c752-f7ba-4714-8605-88e2afcff88f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Machine parameters for float8_e4m3fn\n", + "---------------------------------------------------------------\n", + "precision = 1 resolution = 1.00e-01\n", + "machep = -3 eps = 1.25e-01\n", + "negep = -4 epsneg = 6.25e-02\n", + "minexp = -6 tiny = 1.56e-02\n", + "maxexp = 9 max = 4.48e+02\n", + "nexp = 4 min = -max\n", + "smallest_normal = 1.56e-02 smallest_subnormal = 1.95e-03\n", + "---------------------------------------------------------------\n", + "\n", + "Machine parameters for float8_e5m2\n", + "---------------------------------------------------------------\n", + "precision = 1 resolution = 1.00e-01\n", + "machep = -2 eps = 2.50e-01\n", + "negep = -3 epsneg = 1.25e-01\n", + "minexp = -14 tiny = 6.10e-05\n", + "maxexp = 16 max = 5.73e+04\n", + "nexp = 5 min = -max\n", + "smallest_normal = 6.10e-05 smallest_subnormal = 1.53e-05\n", + "---------------------------------------------------------------\n", + "\n" + ] + } + ], + "source": [ + "import ml_dtypes\n", + "import numpy as np\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "# Note: using the E4M3 format without +-infinity encodings.\n", + "print(ml_dtypes.finfo(jnp.float8_e4m3fn))\n", + "# Note: E5M3 format with infinities and NaNs encodings, in line with FP16 IEEE standard.\n", + "print(ml_dtypes.finfo(jnp.float8_e5m2))" + ] + }, + { + "cell_type": "markdown", + "id": "c53b08a5-c67d-40c2-8ffa-61bf3a985733", + "metadata": {}, + "source": [ + "## FP8 matmul in JAX: the simple case\n", + "\n", + "With FP8 datatypes added in JAX, basic FP8 matrix multiplication is supported out-of-the-box. As highlighted above, it also means support for **mixed** `E4M3 @ E5M2` FP8 matmuls." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9be90f27-5520-45f6-a42d-b309572e6e91", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-09-25 14:11:11.258006: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "E4M3 @ E4M3 FP8 matmul output: ShapedArray(float8_e4m3fn[32,128])\n", + "E4M3 @ E5M2 FP8 matmul output: ShapedArray(float8_e5m2[32,128])\n" + ] + } + ], + "source": [ + "key = jax.random.PRNGKey(4352)\n", + "# Random FP8 inputs.\n", + "a = jax.random.normal(key, (32, 64), jnp.float8_e4m3fn)\n", + "b = jax.random.normal(key, (128, 64), jnp.float8_e4m3fn)\n", + "\n", + "# E4M3 matrix multiplication (NOTE: transpose to reduce on last axis on both inputs).\n", + "c = jax.lax.dot(a, b.T)\n", + "print(\"E4M3 @ E4M3 FP8 matmul output:\", c.aval)\n", + "\n", + "# E4M3/E5M2 mixed matrix multiplication (NOTE: transpose to reduce on last axis on both inputs).\n", + "c = jax.random.normal(key, (128, 64), jnp.float8_e5m2)\n", + "d = jax.lax.dot(a, c.T)\n", + "# Note: default output dtype is E5M2.\n", + "print(\"E4M3 @ E5M2 FP8 matmul output:\", d.aval)" + ] + }, + { + "cell_type": "markdown", + "id": "d08b0c24-1ac5-458b-a9ae-e269ec34862e", + "metadata": {}, + "source": [ + "### FP8 matmul compiled HLO\n", + "\n", + "Let's have a look at the compiled HLO module generated by JAX + XLA on latest generation GPUs: the XLA compiler recognizes an FP8 matrix multiplication and generates (on GPUs) a `custom_call` to the target **`__cublas$lt$matmul$f8`**, mapping to the FP8 [**`cublasLtMatmul`**](https://docs.nvidia.com/cuda/cublas/#cublasltmatmul) API (note: it will work similarly on other hardware platforms)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7edfa758-bf4e-49fa-8c5d-5dc9c0c2c346", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HloModule jit_matmul_fn, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0})->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"f27e70a56b27e0bfb1ec7095f85081ca\"}\n", + "\n", + "ENTRY %main.5 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64]) -> f8e4m3fn[32,128] {\n", + " %constant_1 = f32[] constant(1)\n", + " %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)\n", + " %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)\n", + " %cublas-gemm.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %constant_1, f32[] %constant_1, f32[] %constant_1, /*index=5*/f32[] %constant_1), custom_call_target=\"__cublas$lt$matmul$f8\"\n", + " ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.1.0), index=0\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "from jax_scalify.utils import print_hlo_module, parse_hlo_module\n", + "\n", + "def matmul_fn(a_fp8, b_fp8):\n", + " # FP8 x FP8 -> FP8 matmul\n", + " return jax.lax.dot(a_fp8, b_fp8.T)\n", + "\n", + "# AOT compilation with JAX, inspecting the (final) HLO module generated.\n", + "fn_compiled = jax.jit(matmul_fn).lower(a, b).compile()\n", + "# (Human readable) optimized Hlo module generated by XLA (ignoring GEMM backend config).\n", + "print_hlo_module(fn_compiled, backend_cfg=False)" + ] + }, + { + "cell_type": "markdown", + "id": "34929c8d-939d-4949-8f3f-ae3fd1c1c522", + "metadata": {}, + "source": [ + "One first interesting aspect of the custom call **`__cublas$lt$matmul$f8`** is that it takes **6 input arguments**: the first two are the classic matmul inputs, and the other four are FP32 scalars (set to a constant `%constant_1 = f32[] constant(1)` in this case).\n", + "\n", + "The field `backend_config` in **`__cublas$lt$matmul$f8`** provides additional metadata passed to the GEMM API." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "72d805ea-89b6-457d-9558-ff31fdd23d35", + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "force_earliest_schedule": false, + "gemm_backend_config": { + "alpha_imag": 0, + "alpha_real": 1, + "beta": 0, + "damax_output": false, + "dot_dimension_numbers": { + "lhs_batch_dimensions": [], + "lhs_contracting_dimensions": [ + "1" + ], + "rhs_batch_dimensions": [], + "rhs_contracting_dimensions": [ + "1" + ] + }, + "epilogue": "DEFAULT", + "grad_x": false, + "grad_y": false, + "lhs_stride": "2048", + "precision_config": { + "algorithm": "ALG_UNSET", + "operand_precision": [ + "DEFAULT", + "DEFAULT" + ] + }, + "rhs_stride": "8192", + "selected_algorithm": "5" + }, + "operation_queue_id": "0", + "wait_on_operation_queues": [] + }, + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": { + "application/json": { + "expanded": true, + "root": "root" + } + }, + "output_type": "execute_result" + } + ], + "source": [ + "from IPython.display import JSON\n", + "\n", + "hlo_module = parse_hlo_module(fn_compiled)\n", + "# Let's extract the `backend_config` dict from the FP8 matmul call.\n", + "backend_config = next((m.backend_config for m in hlo_module if \"__cublas$lt$matmul$f8\" in m.cmd))\n", + "JSON(backend_config, expanded=True)" + ] + }, + { + "cell_type": "markdown", + "id": "cc872a8d-1365-41cc-ab48-bb064a83c519", + "metadata": {}, + "source": [ + "A couple of fields are of interest to us for FP8 matmuls:\n", + "\n", + "* **`alpha_real`**, **`alpha_imag`** and **`beta`**: constant scaling factors which can be integrated into the matrix multiplication:\n", + "$$\n", + "D = \\alpha \\cdot (A @ B) + \\beta \\cdot C\n", + "$$\n", + "**Note:** these are different from the scalar FP32 tensors presented above! \n", + "* **`epilogue`**: enum field describing fusing of post-matmul operation such as adding bias or non-linearity (see below).\n", + "* **`damax_output`**: a new FP8 matmul feature: computation of the absolute reduce-max of the output (useful for output re-scaling)." + ] + }, + { + "cell_type": "markdown", + "id": "b8234ebd-6f63-4f71-bb6c-b0d9b7cd2ddc", + "metadata": {}, + "source": [ + "## Fused FP8 matmul in JAX: from simple to complicated!\n", + "\n", + "As presented above, the FP8 XLA custom target **`__cublas$lt$matmul$f8`** has an extended API & config allowing **fusing** multiple operations in the GEMM kernel. More specifically:\n", + "* Scaling of input & output tensors;\n", + "* Capturing absolute-maximum of the output (usually called `damax`);\n", + "* Post-matmul bias or/and non-linearity;\n", + "\n", + "We present below how to generate the proper fused matmul call directly from JAX (and checking the result in the compiled HLO!). Starting with inputs & outputs scaling, following the interface of **`__cublas$lt$matmul$f8`**. \n", + "\n", + "Let's first try with a naive implementation:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1ed9d08e-b18a-4fe7-bcba-72b95ddf6e68", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<<< JAX compilation error >>>\n", + "Input dtypes ('float8_e4m3fn', 'float32') have no available implicit dtype promotion path. To avoid unintended promotion, 8-bit floats do not support implicit promotion. If you'd like your inputs to be promoted to another type, you can do so explicitly using e.g. x.astype('float32')\n" + ] + } + ], + "source": [ + "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, c_scale):\n", + " # First try: can we just scale the input with an FP32 scalar?\n", + " a_fp8 = a_fp8 * a_scale\n", + " out = jax.lax.dot(a_fp8, b_fp8.T)\n", + " return out\n", + "\n", + "# `__cublas$lt$matmul$f8` expecting FP32 scales.\n", + "scale_aval = jax.core.ShapedArray((), jnp.float32)\n", + "try:\n", + " fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n", + "except Exception as e:\n", + " # Issue: JAX does not support implicit mixed-multiplication FP8 x FP32\n", + " print(f\"<<< JAX compilation error >>>\\n{e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "57c2a7de-2e89-40b2-b38c-9b8876472d8a", + "metadata": {}, + "source": [ + "### FP8 matmul with scaled inputs & outputs\n", + "\n", + "JAX and XLA do not allow implicit conversion between FP8 and FP32, meaning that we need to write something more explicit for the XLA compiler to pattern match and generate the fused call. More specifically, as presented in [XLA FP8 RFC](https://github.com/openxla/xla/discussions/22), one needs to adopt a dequantization/quantization type of semantics:\n", + "* Upcast inputs to `float32` and then scale;\n", + "* Scale output, clamp to `float8` range (not optional!) and then downcast to `float8`;\n", + "\n", + "As presented below, when using this pattern, the XLA compiler is able to fuse all the operations into a single call of `__cublas$lt$matmul$f8`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b9a608d7-6cf8-457b-8275-bdcacc9b06fe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"230c40ffa1e1e3ba7f06e4a65ac9e2bd\"}\n", + "\n", + "ENTRY %main.22 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n", + " %constant_1 = f32[] constant(1)\n", + " %Arg_4.5.0 = f32[] parameter(4)\n", + " %Arg_3.4.0 = f32[] parameter(3)\n", + " %Arg_2.3.0 = f32[] parameter(2)\n", + " %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)\n", + " %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)\n", + " %cublas-gemm.clone.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1, /*index=5*/f32[] %Arg_4.5.0), custom_call_target=\"__cublas$lt$matmul$f8\"\n", + " ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.clone.1.0), index=0\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max\n", + "\n", + "# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n", + "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n", + " # Dequantize x and y\n", + " a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n", + " b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n", + " \n", + " # Do the matmul (NOTE: adding transpose to reduce on last axis).\n", + " d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)\n", + " \n", + " # Rescale & clamp to -max/+max FP8 E4M3 values.\n", + " d_fp32 = d_fp32 * d_scale\n", + " # NOTE: clamping is NOT optional for proper pattern matching!\n", + " d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))\n", + " # (Re)Quantize the scaled matmul output.\n", + " return d_fp32.astype(jnp.float8_e4m3fn)\n", + "\n", + "# AOT compilation with JAX, inspecting the (final) HLO module generated.\n", + "fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n", + "# (Human readable) optimized Hlo module generated by XLA (ignoring GEMM backend config)\n", + "print_hlo_module(fn_compiled, backend_cfg=False)" + ] + }, + { + "cell_type": "markdown", + "id": "437fefcf-bfb2-42aa-a899-0a57416a6a5e", + "metadata": {}, + "source": [ + "### Adding non-linearity `relu` to the FP8 matmul\n", + "\n", + "Can we get XLA to fuse a post-matmul non-linearity `relu` function as well?" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "44f28bbb-d4c6-4170-a736-76d667d73f97", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"f1fb5db9dad54941d7d17e04fdbe9515\"}\n", + "\n", + "ENTRY %main.28 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n", + " %constant_1_0 = f32[] constant(1)\n", + " %Arg_4.5.0 = f32[] parameter(4)\n", + " %Arg_3.4.0 = f32[] parameter(3)\n", + " %Arg_2.3.0 = f32[] parameter(2)\n", + " %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)\n", + " %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)\n", + " %cublas-gemm.2.clone.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1_0, /*index=5*/f32[] %Arg_4.5.0), custom_call_target=\"__cublas$lt$matmul$f8\"\n", + " ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.2.clone.1.0), index=0\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max\n", + "\n", + "# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n", + "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n", + " # Dequantize x and y\n", + " a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n", + " b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n", + " \n", + " # Do the matmul (NOTE: adding transpose to simplify HLO).\n", + " d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)\n", + " # ReLU non-linearity. Note: applied before scaling.\n", + " d_fp32 = jax.nn.relu(d_fp32)\n", + " \n", + " # Rescale & clamp to -max/+max FP8 E4M3 values.\n", + " d_fp32 = d_fp32 * d_scale\n", + " # NOTE: clamping is NOT optional for proper pattern matching!\n", + " d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))\n", + " # (Re)Quantize the scaled matmul output.\n", + " return d_fp32.astype(jnp.float8_e4m3fn)\n", + "\n", + "# AOT compilation with JAX, inspecting the (final) HLO module generated.\n", + "fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n", + "# (Human readable) optimized Hlo module generated by XLA (ignoring GEMM backend config)\n", + "print_hlo_module(fn_compiled, backend_cfg=False)" + ] + }, + { + "cell_type": "markdown", + "id": "46c3829e-e308-4706-b14a-ecd3fc25d01e", + "metadata": {}, + "source": [ + "As shown in the `backend_config` below, the `epilogue` is changed to `RELU`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2ca21eae-8b0c-454b-b670-1ef0d5935a5c", + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "force_earliest_schedule": false, + "gemm_backend_config": { + "alpha_imag": 0, + "alpha_real": 1, + "beta": 0, + "damax_output": false, + "dot_dimension_numbers": { + "lhs_batch_dimensions": [], + "lhs_contracting_dimensions": [ + "1" + ], + "rhs_batch_dimensions": [], + "rhs_contracting_dimensions": [ + "1" + ] + }, + "epilogue": "RELU", + "grad_x": false, + "grad_y": false, + "lhs_stride": "2048", + "precision_config": { + "algorithm": "ALG_UNSET", + "operand_precision": [ + "DEFAULT", + "DEFAULT" + ] + }, + "rhs_stride": "8192", + "selected_algorithm": "2" + }, + "operation_queue_id": "0", + "wait_on_operation_queues": [] + }, + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": { + "application/json": { + "expanded": true, + "root": "root" + } + }, + "output_type": "execute_result" + } + ], + "source": [ + "hlo_module = parse_hlo_module(fn_compiled)\n", + "backend_config = next((m.backend_config for m in hlo_module if \"__cublas$lt$matmul$f8\" in m.cmd))\n", + "# the `epilogue` is set to `RELU`\n", + "JSON(backend_config, expanded=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2893288d-7f2a-42e1-8541-afeed1d63a85", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "83661e1e-b662-41c9-9546-0a389894be0e", + "metadata": {}, + "source": [ + "### Extracting the `abs-max` of the output\n", + "\n", + "Delayed rescaling is a common technique in FP8 training, using the output **abs-max scaling** in the next training iteration. The benefit of delayed rescaling is that it can also be merged directly in the FP8 matmul kernel, as shown below, with very small performance impact." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a65cf3be-c465-49ae-9e90-2ada54dba84a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->(f8e4m3fn[32,128]{1,0}, f32[])}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true}, frontend_attributes={fingerprint_before_lhs=\"352d42558cb7282e2f79d89bdc48e6d6\"}\n", + "\n", + "ENTRY %main.36 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> (f8e4m3fn[32,128], f32[]) {\n", + " %constant_1_0 = f32[] constant(1)\n", + " %Arg_4.5.0 = f32[] parameter(4)\n", + " %Arg_3.4.0 = f32[] parameter(3)\n", + " %Arg_2.3.0 = f32[] parameter(2)\n", + " %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)\n", + " %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)\n", + " %cublas-gemm.2.clone.1.0 = (f8e4m3fn[32,128]{1,0}, f32[], s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1_0, /*index=5*/f32[] %Arg_4.5.0), custom_call_target=\"__cublas$lt$matmul$f8\"\n", + " %get-tuple-element.1.0 = f32[] get-tuple-element((f8e4m3fn[32,128]{1,0}, f32[], s8[33554432]{0}) %cublas-gemm.2.clone.1.0), index=1\n", + " %get-tuple-element.4 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, f32[], s8[33554432]{0}) %cublas-gemm.2.clone.1.0), index=0\n", + " ROOT %tuple.35.0 = (f8e4m3fn[32,128]{1,0}, f32[]) tuple(f8e4m3fn[32,128]{1,0} %get-tuple-element.4, f32[] %get-tuple-element.1.0)\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max\n", + "\n", + "# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n", + "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n", + " # Dequantize x and y\n", + " a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n", + " b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n", + " \n", + " # Do the matmul (NOTE: adding transpose to simplify HLO).\n", + " d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)\n", + " # ReLU non-linearity. Note: needs to be before the scaling.\n", + " d_fp32 = jax.nn.relu(d_fp32)\n", + " # Delayed rescaling: capture the raw output scaling for latter.\n", + " out_scale = jnp.max(jnp.abs(d_fp32))\n", + "\n", + " # Rescale & clamp to -max/+max FP8 E4M3 values.\n", + " d_fp32 = d_fp32 * d_scale\n", + " # NOTE: clamping is NOT optional for proper pattern matching!\n", + " d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))\n", + " # (Re)Quantize the scaled matmul output.\n", + " return d_fp32.astype(jnp.float8_e4m3fn), out_scale\n", + "\n", + "# AOT compilation with JAX, inspecting the (final) HLO module generated.\n", + "fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n", + "# (Human readable) optimized Hlo module generated by XLA (ignoring GEMM backend config)\n", + "print_hlo_module(fn_compiled, backend_cfg=False)" + ] + }, + { + "cell_type": "markdown", + "id": "976dca48-2638-4fd9-bebc-f14dad02b00a", + "metadata": {}, + "source": [ + "As shown in the `backend_config` below, the `damax_output` is changed to `true`, meaning that the **`__cublas$lt$matmul$f8`** method is also computing the `abs-max` of the matmul output (prior to converting back to FP8)." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "20d4d088-6563-44c2-86a1-ab2c34fe4e8e", + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "force_earliest_schedule": false, + "gemm_backend_config": { + "alpha_imag": 0, + "alpha_real": 1, + "beta": 0, + "damax_output": true, + "dot_dimension_numbers": { + "lhs_batch_dimensions": [], + "lhs_contracting_dimensions": [ + "1" + ], + "rhs_batch_dimensions": [], + "rhs_contracting_dimensions": [ + "1" + ] + }, + "epilogue": "RELU", + "grad_x": false, + "grad_y": false, + "lhs_stride": "2048", + "precision_config": { + "algorithm": "ALG_UNSET", + "operand_precision": [ + "DEFAULT", + "DEFAULT" + ] + }, + "rhs_stride": "8192", + "selected_algorithm": "5" + }, + "operation_queue_id": "0", + "wait_on_operation_queues": [] + }, + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": { + "application/json": { + "expanded": true, + "root": "root" + } + }, + "output_type": "execute_result" + } + ], + "source": [ + "hlo_module = parse_hlo_module(fn_compiled)\n", + "backend_config = next((m.backend_config for m in hlo_module if \"__cublas$lt$matmul$f8\" in m.cmd))\n", + "# the `epilogue` is set to `RELU` & `damax_output` to `true`\n", + "JSON(backend_config, expanded=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e63950c6-4924-409b-bd4e-776b4082350c", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "b500ade4-df2d-4c12-ba9e-07e06c902197", + "metadata": {}, + "source": [ + "### Additional notebook improvements & clarifications\n", + "\n", + "* Fusing Linear layer `bias` add;\n", + "* Fusing `jax.nn.gelu` activation layer;\n", + "* FP8 peak flops & performance;" + ] + }, + { + "cell_type": "markdown", + "id": "7dbebeda-dddc-49ce-88ce-50dfbbe2581b", + "metadata": {}, + "source": [ + "### References\n", + "\n", + "* [8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915)\n", + "* [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433)\n", + "* [FP8-LM: Training FP8 Large Language Models](https://arxiv.org/pdf/2310.18313)\n", + "* [Training and inference of large language models\n", + "using 8-bit floating point](https://openreview.net/pdf?id=nErbvDkucY)\n", + "* [OCP 8-bit Floating Point Specification (OFP8)](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1)\n", + "* [IEEE Working Group P3109 Interim Report\n", + "on 8-bit Binary Floating-point Formats](https://github.com/P3109/Public/blob/main/Shared%20Reports/P3109%20WG%20Interim%20Report.pdf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e612f74-e5e2-4c4e-8053-adc39f533a99", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/img/fp-formats.webp b/docs/img/fp-formats.webp new file mode 100644 index 0000000..c169c17 Binary files /dev/null and b/docs/img/fp-formats.webp differ