From 2e6df020f59848faf3657348352f9b50a1ad5b27 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 16 Sep 2024 14:25:20 -0700 Subject: [PATCH] [Pallas] Add block-sparse kernel tutorial --- docs/_static/pallas/sparse/block_coo.svg | 1 + docs/_static/pallas/sparse/prefetch_map.svg | 1 + docs/_static/pallas/sparse/sparse_matmul.svg | 1 + docs/conf.py | 2 + docs/pallas/tpu/index.rst | 2 + docs/pallas/tpu/sparse.ipynb | 728 +++++++++++++++++++ docs/pallas/tpu/sparse.md | 571 +++++++++++++++ 7 files changed, 1306 insertions(+) create mode 100644 docs/_static/pallas/sparse/block_coo.svg create mode 100644 docs/_static/pallas/sparse/prefetch_map.svg create mode 100644 docs/_static/pallas/sparse/sparse_matmul.svg create mode 100644 docs/pallas/tpu/sparse.ipynb create mode 100644 docs/pallas/tpu/sparse.md diff --git a/docs/_static/pallas/sparse/block_coo.svg b/docs/_static/pallas/sparse/block_coo.svg new file mode 100644 index 000000000000..474dfcb64d7a --- /dev/null +++ b/docs/_static/pallas/sparse/block_coo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/sparse/prefetch_map.svg b/docs/_static/pallas/sparse/prefetch_map.svg new file mode 100644 index 000000000000..08fdd2c1cf39 --- /dev/null +++ b/docs/_static/pallas/sparse/prefetch_map.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/sparse/sparse_matmul.svg b/docs/_static/pallas/sparse/sparse_matmul.svg new file mode 100644 index 000000000000..06a24317cfe1 --- /dev/null +++ b/docs/_static/pallas/sparse/sparse_matmul.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 1a7bf32842f0..ed6fcfd0dc8b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -134,6 +134,7 @@ def _do_not_evaluate_in_jax( 'pallas/quickstart.md', 'pallas/tpu/pipelining.md', 'pallas/tpu/distributed.md', + 'pallas/tpu/sparse.md', 'pallas/tpu/matmul.md', 'jep/9407-type-promotion.md', 'autodidax.md', @@ -224,6 +225,7 @@ def _do_not_evaluate_in_jax( 'pallas/quickstart.*', 'pallas/tpu/pipelining.*', 'pallas/tpu/distributed.*', + 'pallas/tpu/sparse.*', 'pallas/tpu/matmul.*', 'sharded-computation.*', 'distributed_data_loading.*' diff --git a/docs/pallas/tpu/index.rst b/docs/pallas/tpu/index.rst index eba986c2cfe8..20abad5f610e 100644 --- a/docs/pallas/tpu/index.rst +++ b/docs/pallas/tpu/index.rst @@ -9,4 +9,6 @@ TPU specific documentation. details pipelining matmul + sparse distributed + diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb new file mode 100644 index 000000000000..9bf847df8489 --- /dev/null +++ b/docs/pallas/tpu/sparse.ipynb @@ -0,0 +1,728 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ZHuzXqQ-9JUQ" + }, + "source": [ + "# Scalar Prefetch and Block-Sparse Computation\n", + "\n", + "In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 56, + "status": "ok", + "timestamp": 1726001133029, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "ibeIs_6QFMAM", + "outputId": "d72edb91-4529-4650-c9e9-b96788608635" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on TPU v5 lite\n" + ] + } + ], + "source": [ + "import functools\n", + "import timeit\n", + "import numpy as np\n", + "import jax\n", + "from jax import numpy as jnp\n", + "from jax import lax\n", + "from jax.experimental import checkify\n", + "from jax.experimental import pallas as pl\n", + "from jax.experimental.pallas import tpu as pltpu\n", + "\n", + "assert \"TPU\" in jax.devices()[0].device_kind, \"Please run this notebook with TPU devices.\"\n", + "print(\"Running on\", jax.devices()[0].device_kind)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FIDGpPTEIcOa" + }, + "source": [ + "## Dynamic Block Indexing with Scalar Prefetch\n", + "\n", + "We will be exploiting the \"scalar prefetch\" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM (\"scalar\") that is loaded before the invocation of the kernel (\"prefetch\"). This metadata is then passed into the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.\n", + "\n", + "To enable scalar prefetch, use the `PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`:\n", + "\n", + "```python\n", + "class PrefetchScalarGridSpec:\n", + " def __init__(self,\n", + " num_scalar_prefetch: int,\n", + " grid: tuple[int, ...],\n", + " in_specs: PyTree[BlockSpec],\n", + " out_specs: PyTree[BlockSpec],\n", + " scratch_shapes: tuple[MemorySpace, ...]):\n", + " ...\n", + "```\n", + "\n", + "There are two additional agruments for `pltpu.PrefetchScalarGridSpec` compared to `pl.GridSpec`:\n", + "- `num_scalar_prefetch`, which indicates the number of scalar prefetch values.\n", + "- `scratch_shapes`, which allocates a list of temporary \"scratchpad\" Refs for storing intermediate calculations (generally useful but unrelated to this tutorial).\n", + "\n", + "We are primarily interested in `num_scalar_prefetch`. When this is set to a non-zero value, it changes the call signature of the kernel and index maps to expect additional prefetch values. The prefetch `Ref`s passed in to the `index_map` and kernel are all allocated in SMEM and are not partitioned into blocks as they do not have a BlockSpec defined. Moreover, the order of arguments to both `index_map` and kernel are always fixed and described below:\n", + "\n", + "- Each `BlockSpec`'s `index_map` now expects the prefetch `Ref`s to come after the grid indices:\n", + "```python\n", + "def index_map(*grid_indices, *prefetch_refs):\n", + " ...\n", + "```\n", + "\n", + "- The user-defined kernel expects prefetch `Ref`s to come before the input `Ref`s. Additionally, the scratch refs come after the output `Ref`s.\n", + "```python\n", + "def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):\n", + " ...\n", + "```\n", + "\n", + "- When calling a new kernel using `pallas_call`, the function returned by `pallas_call` also expects the scalar prefetch arguments to come before the inputs, e.g.\n", + "```python\n", + "kernel = pl.pallas_call(...)\n", + "result = kernel(*prefetch_args, *input_args)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pA8RmHEA2HN3" + }, + "source": [ + "## Example: Block Dynamic Slice with Scalar Prefetch\n", + "\n", + "Let's begin with a basic example that demonstrates how to use the scalar prefetch feature. We will implement a block-aligned dynamic slice kernel which simply extracts a block out of larger array based on user-specified indices:\n", + "\n", + "1. Outside of the kernel, we compute the block index to extract as: `block_idx = (start[0] // size[0], start[1] // size[1])`\n", + "\n", + "2. We pass `block_idx` as a scalar prefetch argument into `pallas_call`.\n", + "\n", + "3. In our index map, we use the block index to select the corresponding block by returning `(block_idx[0], block_idx[1])`.\n", + "\n", + "Of course, this kernel is limited in that our slice sizes must fit inside of a kernel block (limited by VMEM size) and we can only start on size-aligned indices. A more advanced kernel would decouple the kernel block size with the slice size and allow non-aligned start indices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 143, + "status": "ok", + "timestamp": 1726003877561, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "FWeTBlEYlCGD", + "outputId": "4b04a441-c97c-4d0d-d167-c60d4d31fd2e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error |result - lax.dynamic_slice| = 0\n" + ] + } + ], + "source": [ + "def dynamic_slice_kernel(indices, x_ref, o_ref):\n", + " del indices\n", + " o_ref[...] = x_ref[...]\n", + "\n", + "@checkify.checkify\n", + "@functools.partial(jax.jit, static_argnums=(2,))\n", + "def block_dynamic_slice(x, starts, sizes):\n", + " grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=1,\n", + " grid=(1, 1),\n", + " in_specs=[pl.BlockSpec(\n", + " sizes,\n", + " lambda i, j, block_idx: (block_idx[0], block_idx[1]))],\n", + " out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),\n", + " )\n", + "\n", + " kernel = pl.pallas_call(\n", + " dynamic_slice_kernel,\n", + " grid_spec=grid_spec,\n", + " out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),\n", + " )\n", + " # Checkify inserts a runtime assert that starts are divisible by block size.\n", + " checkify.check(starts[0] % sizes[0] == 0, \"Starts must be divisible by size.\")\n", + " checkify.check(starts[1] % sizes[1] == 0, \"Starts must be divisible by size.\")\n", + " block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])\n", + " return kernel(block_idx, x)\n", + "\n", + "shape = (512, 512)\n", + "x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)\n", + "err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))\n", + "err.throw()\n", + "ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))\n", + "diff = jnp.max(jnp.abs(result - ref))\n", + "print(\"Error |result - lax.dynamic_slice| =\", diff)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K2dod4lkoifa" + }, + "source": [ + "## Sparse Kernels: Representing Sparse Data\n", + "\n", + "Before we dive into implementing sparse kernels, let's first review how sparse matrices are represented. While there are several popular formats for storing sparse matrices, we will be following a blocked variant of the coordinate-list format (COO) in which we will store a matrix as a list of `(block_index, block_data)` pairs. All blocks that are not explicitly stored in the list are assumed to be zero, meaning we can save a significant amount of memory if there are many zero blocks in the matrix.\n", + "\n", + "The following figure demonstrates how we convert a 4x4 dense matrix (left) into a block-COO format (right) with a block size of 2x2. Note that in the sparse format, we can avoid explicitly storing the upper-right block which consists of all zero elements.\n", + "\n", + "![block_coo](../../_static/pallas/sparse/block_coo.svg)\n", + "\n", + "We will use the following helper function to sample a block-sparse matrix. It returns a dense matrix used for checking our results, as well as a list of block data and indices for each axis." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1gLiSvgIYUEx" + }, + "outputs": [], + "source": [ + "def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):\n", + " \"\"\"Returns a sampled matrix and its block-sparse representation.\n", + "\n", + " Args:\n", + " key: RNG Key.\n", + " M: Major array dimension.\n", + " N: Minor array dimension.\n", + " blk_M: Block size along M dimension.\n", + " blk_N: Block size along N dimension.\n", + " p: Probability that a block will be non-zero.\n", + " dtype: dtype of the sampled matrix.\n", + "\n", + " Returns:\n", + " dense_mat: A (M, N) dense sampled array.\n", + " block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing\n", + " the non-zero blocks of the matrix.\n", + " indices_i: A (num_blocks,) array of block indices for the first axis.\n", + " indices_j: A (num_blocks,) array of block indices for the second axis.\n", + " \"\"\"\n", + " mask_key, blocks_key = jax.random.split(key)\n", + " num_blocks = (M // blk_M, N // blk_N)\n", + " # We first sample a block mask, denoting which blocks are nonzero.\n", + " block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)\n", + " num_blocks = jnp.sum(block_mask)\n", + " indices = jnp.where(block_mask)\n", + " # For each non-zero block, we sample a block of random values.\n", + " block_data = jax.random.uniform(blocks_key,\n", + " shape=(num_blocks, blk_M, blk_N),\n", + " dtype=dtype)\n", + " # For checking purposes, create the dense version of the sparse matrix.\n", + " dense_mat = jnp.zeros((M, N), dtype=dtype)\n", + " for blk in range(num_blocks):\n", + " idx_i = indices[0][blk]\n", + " idx_j = indices[1][blk]\n", + " slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)\n", + " slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)\n", + " dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])\n", + " return dense_mat, block_data, indices[0], indices[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eFyoZSTOH9Fk" + }, + "source": [ + "## Example: Sparse @ Dense Matrix Multiplication\n", + "\n", + "In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output.\n", + "\n", + "We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram:\n", + "\n", + "![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg)\n", + "\n", + "It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 673, + "status": "ok", + "timestamp": 1725919879291, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "WfyV2WWhjsyA", + "outputId": "fa4d4fff-bc6b-4dc9-ac14-63276ca14131" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean |result - ref|: 0\n" + ] + } + ], + "source": [ + "M = N = K = 16384\n", + "blk_M = blk_N = blk_K = 512\n", + "\n", + "\n", + "def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.\n", + " x_ref, y_ref, _, o_ref, # Kernel inputs.\n", + " accum_scratch,\n", + " ):\n", + " \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n", + " del idxs_k_ref\n", + " blk_idx = pl.program_id(1)\n", + " is_start = blk_idx == 0\n", + " changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n", + " @pl.when(is_start | changed_blocks)\n", + " def _():\n", + " accum_scratch[...] = jnp.zeros_like(accum_scratch)\n", + " accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)\n", + "\n", + " next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])\n", + " is_end = blk_idx == (num_blocks - 1)\n", + " @pl.when(is_end | next_block_change)\n", + " def _():\n", + " o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n", + "\n", + "\n", + "def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", + " del j, blk_idxs_i, blk_idxs_k\n", + " return (blk_idx, 0, 0)\n", + "def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", + " del blk_idxs_i\n", + " return (blk_idxs_k[blk_idx], j)\n", + "def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", + " del blk_idxs_k\n", + " return (blk_idxs_i[blk_idx], j)\n", + "\n", + "(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(\n", + " jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)\n", + "num_blocks = X_blocks.shape[0]\n", + "Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)\n", + "zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)\n", + "out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)\n", + "\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=2,\n", + " # Note that while num_blocks is static here, Pallas does support\n", + " # dynamic grid sizes.\n", + " grid=(M // blk_M, num_blocks),\n", + " in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n", + " pl.BlockSpec((blk_K, blk_N), y_map),\n", + " # Placeholder for a zeros-array used by input_output_aliases.\n", + " pl.BlockSpec((blk_M, blk_N), o_map),\n", + " ],\n", + " out_specs=pl.BlockSpec((blk_M, blk_N), o_map),\n", + " scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]\n", + ")\n", + "kernel = pl.pallas_call(\n", + " dsd_kernel,\n", + " grid_spec=grid_spec,\n", + " out_shape=out_shape,\n", + " # We use input-output aliases to zero-out o_ref for blocks that we never\n", + " # visit. By passing in an array of zeros we avoid having o_ref start with\n", + " # uninitialized values.\n", + " input_output_aliases={4: 0}, # Map zeros to o_ref.\n", + ")\n", + "args = (indices_i, indices_k, X_blocks, Y, zeros)\n", + "result = kernel(*args)\n", + "\n", + "ref = X_dense @ Y\n", + "diff = jnp.abs(ref - result)\n", + "print('mean |result - ref|:', jnp.mean(diff))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2KDgPKF2tUjq" + }, + "source": [ + "We can do a quick benchmark to compare the performance of our sparse kernel compared to a dense matmul in JAX. On a TPU v5 lite chip, this kernel achieves a roughly ~6x speed increase compared to the theoretical 10x from the sparsity factor.\n", + "\n", + "There are a few main tips for performance here, mainly centered around reducing the communication overhead between HBM/VMEM:\n", + "- Using `dtype=jnp.bfloat16` is critical for performance since it reduces memory bandwidth by half.\n", + "- Using larger block sizes also helps, since matrix multiply is an $O(N^3)$ compute and $O(N^2)$ memory operation. As $N$ grows larger, the kernel becomes compute-bound. However, a counter-argument to this in practice is that smaller block sizes also enables data to be more sparse, so this is a parameter that should be selected carefully." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 6576, + "status": "ok", + "timestamp": 1725919886762, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "CkzjqnekpZbx", + "outputId": "1ae9031e-705a-4d05-f8b9-d09623918300" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sparse Kernel: 8.136 ms (avg over 100 trials)\n", + "Reference: 46.953 ms (avg over 100 trials)\n" + ] + } + ], + "source": [ + "# Benchmark Sparse Pallas kernel vs reference JAX implementation\n", + "\n", + "def benchmark(f, ntrials: int = 100):\n", + " def run(*args, **kwargs):\n", + " # Compile function first\n", + " jax.block_until_ready(f(*args, **kwargs))\n", + " # Time function\n", + " result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),\n", + " number=ntrials)\n", + " time = result / ntrials\n", + " return time\n", + " return run\n", + "\n", + "\n", + "N_trials = 100\n", + "\n", + "PallasImpl = lambda *args: kernel(*args)\n", + "time = benchmark(PallasImpl, N_trials)(indices_i, indices_k, X_blocks, Y, zeros)\n", + "print(\"Sparse Kernel: %.3f ms (avg over %d trials)\" % (time * 1000, N_trials))\n", + "\n", + "RefImpl = jax.jit(lambda x, y: x @ y)\n", + "time = benchmark(RefImpl, N_trials)(X_dense, Y)\n", + "print(\"Reference: %.3f ms (avg over %d trials)\" % (time * 1000, N_trials))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Q1KKd5vTCwnB" + }, + "source": [ + "## Sparse Access Patterns on Dense Data\n", + "\n", + "In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`).\n", + "\n", + "A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out.\n", + "\n", + "The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format.\n", + "\n", + "![prefetch_map](../../_static/pallas/sparse/prefetch_map.svg)\n", + "\n", + "*Left: A sparse access pattern, where the color blue denotes blocks with non-zero masks that we need to compute. Right: The prefetch map, where each element of the array contains the index of the next non-zero block data.*\n", + "\n", + "Once the prefetch map has been constructed, we can pass the map as a scalar prefetch argument and query it in the `index_map` function of the BlockSpec.\n", + "\n", + "```python\n", + "def mask_index_map(prefetch_map, i, j, ...):\n", + " next_nonzero_block = prefetch_map[i, j]\n", + " return (next_nonzero_block, 0, 0)\n", + "```\n", + "\n", + "We can construct similar index maps for the other inputs to the kernel. For dense inputs you will most likely need to construct prefetch maps which point to the next non-zero block index in the grid. Our next example will provide an example of using these prefetch maps." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ii7rzL5YIA8-" + }, + "source": [ + "## Example: Dense @ Dense Matrix Multiplication with a Block-Sparse Output Mask" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ecjiqWfA2RlV" + }, + "source": [ + "In our next example we will cover dense matrix multiplication fused with a sparse output mask using a prefetch mask to improve pipelining performance. We will use the mask to selectively skip computing output blocks that are zeroed-out, therefore saving on computation costs.\n", + "\n", + "As we will be working with a sparse mask, we will begin by implementing a function that converts an `N x M` mask stored in dense format into a block-sparse format. We additionally need to compute prefetch maps to help the pipeline emitter know which block to fetch next. In total, our `sparsify_mask` function computes:\n", + "- A `block_mask` of shape `(num_N_blocks, num_M_blocks)` indicating if a block is all-zeros (value `0`) or contains non-zero elements (value `1`). If the `block_mask` has a value of 0 we can skip computing the block in the kernel.\n", + "- A `prefetch_mask` array of shape `(num_N_blocks, num_M_blocks)` consisting of indices into `mask_data` for the next non-zero block.\n", + "- A `prefetch_i` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `i` index of the mask.\n", + "- A `prefetch_j` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `j` index of the mask.\n", + "- A `mask_data` array of shape `(num_blocks, blk_N, blk_M)` containing data for non-zero blocks of the mask." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "19zGcliL2SJy" + }, + "outputs": [], + "source": [ + "def sparsify_mask(mask: jax.Array,\n", + " block_shape: tuple[int, int]):\n", + " \"\"\"Preprocesses a mask into a sparse reprentation.\n", + "\n", + " Args:\n", + " mask: A boolean array of shape [M, N]\n", + " block_shape: The size of a single block.\n", + "\n", + " Returns:\n", + " block_mask: A block_shape array of booleans indicating whether a block\n", + " is all-zeros (0) or contains non-zero elements (1).\n", + " prefetch_mask: A block_shape array of integers indicating the index of the\n", + " next non-zero block.\n", + " mask_data: A (num_blocks, block_shape) array containing\n", + " the data for non-zero blocks of the mask.\n", + " \"\"\"\n", + " M, N = mask.shape\n", + " bm, bn = block_shape\n", + "\n", + " block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)\n", + " mask_types_finder = []\n", + " mask_data = []\n", + " mask_type_idxs = []\n", + "\n", + " next_mask_type_idx = 0\n", + " prefetch_mask = jnp.zeros_like(block_mask)\n", + " next_i = (M // bm) - 1\n", + " next_j = (N // bn) - 1\n", + " prefetch_i = jnp.zeros_like(block_mask)\n", + " prefetch_j = jnp.zeros_like(block_mask)\n", + " for i in range(M // bm, -1, -1):\n", + " for j in range(N // bn, -1, -1):\n", + " mask_block = mask[i * bm :(i + 1) * bm,\n", + " j * bn :(j + 1) * bn]\n", + " is_nonzero = jnp.any(mask_block)\n", + " if is_nonzero:\n", + " try:\n", + " type_index = mask_types_finder.index(str(mask_block))\n", + " except ValueError:\n", + " type_index = len(mask_types_finder)\n", + " mask_types_finder.append(str(mask_block))\n", + " mask_data.append(mask_block)\n", + " next_mask_type_idx = type_index\n", + " next_i = i\n", + " next_j = j\n", + " else:\n", + " type_index = -1\n", + " mask_type_idxs.append(type_index)\n", + " block_mask = block_mask.at[i, j].set(is_nonzero)\n", + " prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)\n", + " prefetch_i = prefetch_i.at[i, j].set(next_i)\n", + " prefetch_j = prefetch_j.at[i, j].set(next_j)\n", + " return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "w4b7ckKq67Xw" + }, + "source": [ + "In terms of the structure of the kernel, we use the same grid pattern as the standard matrix multiplication kernel we covered in previous tutorials with a 3 loops over the `N`, `M`, and `K` dimensions. Within the kernel itself, we first check the `block_mask` to see if the mask for the current output block was all zeros. If the mask is all zeros, we can skip computation and move onto the next block; otherwise we need to compute the matrix multiplication and then mask the result." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 5374, + "status": "ok", + "timestamp": 1725919713252, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "4YQ9OmbTCSjT", + "outputId": "2d752609-34f2-4059-e8ba-4d80afe8cb26" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean |result - ref|: 1.0252e-05\n" + ] + } + ], + "source": [ + "M = N = K = 16384\n", + "blk_M = blk_N = 512\n", + "blk_K = 1024\n", + "\n", + "def sparse_mask_matmul(\n", + " block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.\n", + " x_ref, y_ref, mask_ref, o_ref, # Kernel inputs.\n", + " accum_scratch\n", + " ):\n", + " del prefetch_mask, prefetch_i, prefetch_j\n", + " i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)\n", + " should_compute = block_mask_ref[i, j] != 0\n", + " @pl.when(k == 0)\n", + " def _():\n", + " o_ref[...] = jnp.zeros_like(o_ref)\n", + " accum_scratch[...] = jnp.zeros_like(accum_scratch[...])\n", + "\n", + " # We only compute the output for blocks with non-zero masks.\n", + " # Otherwise we skip the computation entirely.\n", + " @pl.when(should_compute)\n", + " def _():\n", + " result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)\n", + " accum_scratch[...] += result\n", + " @pl.when(k == pl.num_programs(2) - 1)\n", + " def _():\n", + " o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)\n", + "\n", + "X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)\n", + "Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)\n", + "mask = jnp.ones((M, N), dtype=jnp.int32)\n", + "mask = jnp.tril(mask)\n", + "block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))\n", + "\n", + "def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):\n", + " del prefetch_mask, prefetch_j\n", + " # Zero-out the k index if the mask is zero, to avoid constantly fetching\n", + " # new blocks in the inner loop for blocks we are skipping.\n", + " k_fetch = (block_mask[i, j] != 0) * k\n", + " return (prefetch_i[i, j], k_fetch)\n", + "\n", + "def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):\n", + " del prefetch_mask, prefetch_i\n", + " k_fetch = (block_mask[i, j] != 0) * k\n", + " return (k_fetch, prefetch_j[i, j])\n", + "\n", + "def mask_map(i, j, k, block_mask, prefetch_mask, *_):\n", + " del k, block_mask\n", + " return (prefetch_mask[i, j], 0, 0)\n", + "\n", + "def o_map(i, j, k, *_):\n", + " del k\n", + " return (i, j)\n", + "\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=4,\n", + " grid=(M // blk_M, N // blk_N, K // blk_K),\n", + " in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),\n", + " pl.BlockSpec((blk_K, blk_N), y_map),\n", + " pl.BlockSpec((1, blk_M, blk_N), mask_map)],\n", + " out_specs=pl.BlockSpec((blk_M, blk_N), o_map),\n", + " scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]\n", + ")\n", + "kernel = pl.pallas_call(\n", + " sparse_mask_matmul,\n", + " grid_spec=grid_spec,\n", + " out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),\n", + ")\n", + "args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)\n", + "result = kernel(*args)\n", + "\n", + "ref = mask * (X @ Y)\n", + "diff = jnp.abs(ref - result)\n", + "print('mean |result - ref|:', jnp.mean(diff))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uutNGgjZGGhB" + }, + "source": [ + "Now let's compare performance versus a naive dense implementation. On TPU v5 lite, we achieve around a ~1.8x speed increase with the sparse kernel, compared to a theoretical best-case of 2x from using a lower triangular mask and only visiting half of the possible outputs.\n", + "\n", + "We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are:\n", + "- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation.\n", + "- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 8877, + "status": "ok", + "timestamp": 1725917397452, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "MAT9JjGNvsx8", + "outputId": "a32d56fb-a71b-4007-c6a5-e5270dcaa6cf" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sparse Kernel: 28.648 ms (avg over 100 trials)\n", + "Reference: 49.988 ms (avg over 100 trials)\n" + ] + } + ], + "source": [ + "N_trials = 100\n", + "\n", + "PallasImpl = lambda *args: kernel(*args)\n", + "time = benchmark(PallasImpl, N_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)\n", + "print(\"Sparse Kernel: %.3f ms (avg over %d trials)\" % (time * 1000, N_trials))\n", + "\n", + "RefImpl = jax.jit(lambda mask, x, y: mask * (x @ y))\n", + "time = benchmark(RefImpl, N_trials)(mask, X, Y)\n", + "print(\"Reference: %.3f ms (avg over %d trials)\" % (time * 1000, N_trials))" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst", + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md new file mode 100644 index 000000000000..465ae92e4239 --- /dev/null +++ b/docs/pallas/tpu/sparse.md @@ -0,0 +1,571 @@ +--- +jupytext: + formats: ipynb,md:myst + main_language: python + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.1 +kernelspec: + display_name: Python 3 + name: python3 +--- + ++++ {"id": "ZHuzXqQ-9JUQ"} + +# Scalar Prefetch and Block-Sparse Computation + +In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory. + +```{code-cell} +--- +executionInfo: + elapsed: 56 + status: ok + timestamp: 1726001133029 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: ibeIs_6QFMAM +outputId: d72edb91-4529-4650-c9e9-b96788608635 +--- +import functools +import timeit +import numpy as np +import jax +from jax import numpy as jnp +from jax import lax +from jax.experimental import checkify +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + +assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices." +print("Running on", jax.devices()[0].device_kind) +``` + ++++ {"id": "FIDGpPTEIcOa"} + +## Dynamic Block Indexing with Scalar Prefetch + +We will be exploiting the "scalar prefetch" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM ("scalar") that is loaded before the invocation of the kernel ("prefetch"). This metadata is then passed into the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature. + +To enable scalar prefetch, use the `PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`: + +```python +class PrefetchScalarGridSpec: + def __init__(self, + num_scalar_prefetch: int, + grid: tuple[int, ...], + in_specs: PyTree[BlockSpec], + out_specs: PyTree[BlockSpec], + scratch_shapes: tuple[MemorySpace, ...]): + ... +``` + +There are two additional agruments for `pltpu.PrefetchScalarGridSpec` compared to `pl.GridSpec`: +- `num_scalar_prefetch`, which indicates the number of scalar prefetch values. +- `scratch_shapes`, which allocates a list of temporary "scratchpad" Refs for storing intermediate calculations (generally useful but unrelated to this tutorial). + +We are primarily interested in `num_scalar_prefetch`. When this is set to a non-zero value, it changes the call signature of the kernel and index maps to expect additional prefetch values. The prefetch `Ref`s passed in to the `index_map` and kernel are all allocated in SMEM and are not partitioned into blocks as they do not have a BlockSpec defined. Moreover, the order of arguments to both `index_map` and kernel are always fixed and described below: + +- Each `BlockSpec`'s `index_map` now expects the prefetch `Ref`s to come after the grid indices: +```python +def index_map(*grid_indices, *prefetch_refs): + ... +``` + +- The user-defined kernel expects prefetch `Ref`s to come before the input `Ref`s. Additionally, the scratch refs come after the output `Ref`s. +```python +def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs): + ... +``` + +- When calling a new kernel using `pallas_call`, the function returned by `pallas_call` also expects the scalar prefetch arguments to come before the inputs, e.g. +```python +kernel = pl.pallas_call(...) +result = kernel(*prefetch_args, *input_args) +``` + ++++ {"id": "pA8RmHEA2HN3"} + +## Example: Block Dynamic Slice with Scalar Prefetch + +Let's begin with a basic example that demonstrates how to use the scalar prefetch feature. We will implement a block-aligned dynamic slice kernel which simply extracts a block out of larger array based on user-specified indices: + +1. Outside of the kernel, we compute the block index to extract as: `block_idx = (start[0] // size[0], start[1] // size[1])` + +2. We pass `block_idx` as a scalar prefetch argument into `pallas_call`. + +3. In our index map, we use the block index to select the corresponding block by returning `(block_idx[0], block_idx[1])`. + +Of course, this kernel is limited in that our slice sizes must fit inside of a kernel block (limited by VMEM size) and we can only start on size-aligned indices. A more advanced kernel would decouple the kernel block size with the slice size and allow non-aligned start indices. + +```{code-cell} +--- +executionInfo: + elapsed: 143 + status: ok + timestamp: 1726003877561 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: FWeTBlEYlCGD +outputId: 4b04a441-c97c-4d0d-d167-c60d4d31fd2e +--- +def dynamic_slice_kernel(indices, x_ref, o_ref): + del indices + o_ref[...] = x_ref[...] + +@checkify.checkify +@functools.partial(jax.jit, static_argnums=(2,)) +def block_dynamic_slice(x, starts, sizes): + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(1, 1), + in_specs=[pl.BlockSpec( + sizes, + lambda i, j, block_idx: (block_idx[0], block_idx[1]))], + out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)), + ) + + kernel = pl.pallas_call( + dynamic_slice_kernel, + grid_spec=grid_spec, + out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype), + ) + # Checkify inserts a runtime assert that starts are divisible by block size. + checkify.check(starts[0] % sizes[0] == 0, "Starts must be divisible by size.") + checkify.check(starts[1] % sizes[1] == 0, "Starts must be divisible by size.") + block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]]) + return kernel(block_idx, x) + +shape = (512, 512) +x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape) +err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128)) +err.throw() +ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128)) +diff = jnp.max(jnp.abs(result - ref)) +print("Error |result - lax.dynamic_slice| =", diff) +``` + ++++ {"id": "K2dod4lkoifa"} + +## Sparse Kernels: Representing Sparse Data + +Before we dive into implementing sparse kernels, let's first review how sparse matrices are represented. While there are several popular formats for storing sparse matrices, we will be following a blocked variant of the coordinate-list format (COO) in which we will store a matrix as a list of `(block_index, block_data)` pairs. All blocks that are not explicitly stored in the list are assumed to be zero, meaning we can save a significant amount of memory if there are many zero blocks in the matrix. + +The following figure demonstrates how we convert a 4x4 dense matrix (left) into a block-COO format (right) with a block size of 2x2. Note that in the sparse format, we can avoid explicitly storing the upper-right block which consists of all zero elements. + +![block_coo](../../_static/pallas/sparse/block_coo.svg) + +We will use the following helper function to sample a block-sparse matrix. It returns a dense matrix used for checking our results, as well as a list of block data and indices for each axis. + +```{code-cell} +:id: 1gLiSvgIYUEx + +def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32): + """Returns a sampled matrix and its block-sparse representation. + + Args: + key: RNG Key. + M: Major array dimension. + N: Minor array dimension. + blk_M: Block size along M dimension. + blk_N: Block size along N dimension. + p: Probability that a block will be non-zero. + dtype: dtype of the sampled matrix. + + Returns: + dense_mat: A (M, N) dense sampled array. + block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing + the non-zero blocks of the matrix. + indices_i: A (num_blocks,) array of block indices for the first axis. + indices_j: A (num_blocks,) array of block indices for the second axis. + """ + mask_key, blocks_key = jax.random.split(key) + num_blocks = (M // blk_M, N // blk_N) + # We first sample a block mask, denoting which blocks are nonzero. + block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks) + num_blocks = jnp.sum(block_mask) + indices = jnp.where(block_mask) + # For each non-zero block, we sample a block of random values. + block_data = jax.random.uniform(blocks_key, + shape=(num_blocks, blk_M, blk_N), + dtype=dtype) + # For checking purposes, create the dense version of the sparse matrix. + dense_mat = jnp.zeros((M, N), dtype=dtype) + for blk in range(num_blocks): + idx_i = indices[0][blk] + idx_j = indices[1][blk] + slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M) + slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N) + dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk]) + return dense_mat, block_data, indices[0], indices[1] +``` + ++++ {"id": "eFyoZSTOH9Fk"} + +## Example: Sparse @ Dense Matrix Multiplication + +In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output. + +We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram: + +![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg) + +It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct. + +```{code-cell} +--- +executionInfo: + elapsed: 673 + status: ok + timestamp: 1725919879291 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: WfyV2WWhjsyA +outputId: fa4d4fff-bc6b-4dc9-ac14-63276ca14131 +--- +M = N = K = 16384 +blk_M = blk_N = blk_K = 512 + + +def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs. + x_ref, y_ref, _, o_ref, # Kernel inputs. + accum_scratch, + ): + """A DSD (Dense = Sparse @ Dense) matmul kernel.""" + del idxs_k_ref + blk_idx = pl.program_id(1) + is_start = blk_idx == 0 + changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)]) + @pl.when(is_start | changed_blocks) + def _(): + accum_scratch[...] = jnp.zeros_like(accum_scratch) + accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32) + + next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)]) + is_end = blk_idx == (num_blocks - 1) + @pl.when(is_end | next_block_change) + def _(): + o_ref[...] = accum_scratch[...].astype(o_ref.dtype) + + +def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k): + del j, blk_idxs_i, blk_idxs_k + return (blk_idx, 0, 0) +def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k): + del blk_idxs_i + return (blk_idxs_k[blk_idx], j) +def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k): + del blk_idxs_k + return (blk_idxs_i[blk_idx], j) + +(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat( + jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16) +num_blocks = X_blocks.shape[0] +Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16) +zeros = jnp.zeros((M, N), dtype=jnp.bfloat16) +out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16) + +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=2, + # Note that while num_blocks is static here, Pallas does support + # dynamic grid sizes. + grid=(M // blk_M, num_blocks), + in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map), + pl.BlockSpec((blk_K, blk_N), y_map), + # Placeholder for a zeros-array used by input_output_aliases. + pl.BlockSpec((blk_M, blk_N), o_map), + ], + out_specs=pl.BlockSpec((blk_M, blk_N), o_map), + scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)] +) +kernel = pl.pallas_call( + dsd_kernel, + grid_spec=grid_spec, + out_shape=out_shape, + # We use input-output aliases to zero-out o_ref for blocks that we never + # visit. By passing in an array of zeros we avoid having o_ref start with + # uninitialized values. + input_output_aliases={4: 0}, # Map zeros to o_ref. +) +args = (indices_i, indices_k, X_blocks, Y, zeros) +result = kernel(*args) + +ref = X_dense @ Y +diff = jnp.abs(ref - result) +print('mean |result - ref|:', jnp.mean(diff)) +``` + ++++ {"id": "2KDgPKF2tUjq"} + +We can do a quick benchmark to compare the performance of our sparse kernel compared to a dense matmul in JAX. On a TPU v5 lite chip, this kernel achieves a roughly ~6x speed increase compared to the theoretical 10x from the sparsity factor. + +There are a few main tips for performance here, mainly centered around reducing the communication overhead between HBM/VMEM: +- Using `dtype=jnp.bfloat16` is critical for performance since it reduces memory bandwidth by half. +- Using larger block sizes also helps, since matrix multiply is an $O(N^3)$ compute and $O(N^2)$ memory operation. As $N$ grows larger, the kernel becomes compute-bound. However, a counter-argument to this in practice is that smaller block sizes also enables data to be more sparse, so this is a parameter that should be selected carefully. + +```{code-cell} +--- +executionInfo: + elapsed: 6576 + status: ok + timestamp: 1725919886762 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: CkzjqnekpZbx +outputId: 1ae9031e-705a-4d05-f8b9-d09623918300 +--- +# Benchmark Sparse Pallas kernel vs reference JAX implementation + +def benchmark(f, ntrials: int = 100): + def run(*args, **kwargs): + # Compile function first + jax.block_until_ready(f(*args, **kwargs)) + # Time function + result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)), + number=ntrials) + time = result / ntrials + return time + return run + + +N_trials = 100 + +PallasImpl = lambda *args: kernel(*args) +time = benchmark(PallasImpl, N_trials)(indices_i, indices_k, X_blocks, Y, zeros) +print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, N_trials)) + +RefImpl = jax.jit(lambda x, y: x @ y) +time = benchmark(RefImpl, N_trials)(X_dense, Y) +print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, N_trials)) +``` + ++++ {"id": "Q1KKd5vTCwnB"} + +## Sparse Access Patterns on Dense Data + +In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`). + +A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. + +The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format. + +![prefetch_map](../../_static/pallas/sparse/prefetch_map.svg) + +*Left: A sparse access pattern, where the color blue denotes blocks with non-zero masks that we need to compute. Right: The prefetch map, where each element of the array contains the index of the next non-zero block data.* + +Once the prefetch map has been constructed, we can pass the map as a scalar prefetch argument and query it in the `index_map` function of the BlockSpec. + +```python +def mask_index_map(prefetch_map, i, j, ...): + next_nonzero_block = prefetch_map[i, j] + return (next_nonzero_block, 0, 0) +``` + +We can construct similar index maps for the other inputs to the kernel. For dense inputs you will most likely need to construct prefetch maps which point to the next non-zero block index in the grid. Our next example will provide an example of using these prefetch maps. + ++++ {"id": "ii7rzL5YIA8-"} + +## Example: Dense @ Dense Matrix Multiplication with a Block-Sparse Output Mask + ++++ {"id": "ecjiqWfA2RlV"} + +In our next example we will cover dense matrix multiplication fused with a sparse output mask using a prefetch mask to improve pipelining performance. We will use the mask to selectively skip computing output blocks that are zeroed-out, therefore saving on computation costs. + +As we will be working with a sparse mask, we will begin by implementing a function that converts an `N x M` mask stored in dense format into a block-sparse format. We additionally need to compute prefetch maps to help the pipeline emitter know which block to fetch next. In total, our `sparsify_mask` function computes: +- A `block_mask` of shape `(num_N_blocks, num_M_blocks)` indicating if a block is all-zeros (value `0`) or contains non-zero elements (value `1`). If the `block_mask` has a value of 0 we can skip computing the block in the kernel. +- A `prefetch_mask` array of shape `(num_N_blocks, num_M_blocks)` consisting of indices into `mask_data` for the next non-zero block. +- A `prefetch_i` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `i` index of the mask. +- A `prefetch_j` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `j` index of the mask. +- A `mask_data` array of shape `(num_blocks, blk_N, blk_M)` containing data for non-zero blocks of the mask. + +```{code-cell} +:id: 19zGcliL2SJy + +def sparsify_mask(mask: jax.Array, + block_shape: tuple[int, int]): + """Preprocesses a mask into a sparse reprentation. + + Args: + mask: A boolean array of shape [M, N] + block_shape: The size of a single block. + + Returns: + block_mask: A block_shape array of booleans indicating whether a block + is all-zeros (0) or contains non-zero elements (1). + prefetch_mask: A block_shape array of integers indicating the index of the + next non-zero block. + mask_data: A (num_blocks, block_shape) array containing + the data for non-zero blocks of the mask. + """ + M, N = mask.shape + bm, bn = block_shape + + block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype) + mask_types_finder = [] + mask_data = [] + mask_type_idxs = [] + + next_mask_type_idx = 0 + prefetch_mask = jnp.zeros_like(block_mask) + next_i = (M // bm) - 1 + next_j = (N // bn) - 1 + prefetch_i = jnp.zeros_like(block_mask) + prefetch_j = jnp.zeros_like(block_mask) + for i in range(M // bm, -1, -1): + for j in range(N // bn, -1, -1): + mask_block = mask[i * bm :(i + 1) * bm, + j * bn :(j + 1) * bn] + is_nonzero = jnp.any(mask_block) + if is_nonzero: + try: + type_index = mask_types_finder.index(str(mask_block)) + except ValueError: + type_index = len(mask_types_finder) + mask_types_finder.append(str(mask_block)) + mask_data.append(mask_block) + next_mask_type_idx = type_index + next_i = i + next_j = j + else: + type_index = -1 + mask_type_idxs.append(type_index) + block_mask = block_mask.at[i, j].set(is_nonzero) + prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx) + prefetch_i = prefetch_i.at[i, j].set(next_i) + prefetch_j = prefetch_j.at[i, j].set(next_j) + return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data) +``` + ++++ {"id": "w4b7ckKq67Xw"} + +In terms of the structure of the kernel, we use the same grid pattern as the standard matrix multiplication kernel we covered in previous tutorials with a 3 loops over the `N`, `M`, and `K` dimensions. Within the kernel itself, we first check the `block_mask` to see if the mask for the current output block was all zeros. If the mask is all zeros, we can skip computation and move onto the next block; otherwise we need to compute the matrix multiplication and then mask the result. + +```{code-cell} +--- +executionInfo: + elapsed: 5374 + status: ok + timestamp: 1725919713252 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: 4YQ9OmbTCSjT +outputId: 2d752609-34f2-4059-e8ba-4d80afe8cb26 +--- +M = N = K = 16384 +blk_M = blk_N = 512 +blk_K = 1024 + +def sparse_mask_matmul( + block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs. + x_ref, y_ref, mask_ref, o_ref, # Kernel inputs. + accum_scratch + ): + del prefetch_mask, prefetch_i, prefetch_j + i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2) + should_compute = block_mask_ref[i, j] != 0 + @pl.when(k == 0) + def _(): + o_ref[...] = jnp.zeros_like(o_ref) + accum_scratch[...] = jnp.zeros_like(accum_scratch[...]) + + # We only compute the output for blocks with non-zero masks. + # Otherwise we skip the computation entirely. + @pl.when(should_compute) + def _(): + result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32) + accum_scratch[...] += result + @pl.when(k == pl.num_programs(2) - 1) + def _(): + o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype) + +X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16) +Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16) +mask = jnp.ones((M, N), dtype=jnp.int32) +mask = jnp.tril(mask) +block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N)) + +def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j): + del prefetch_mask, prefetch_j + # Zero-out the k index if the mask is zero, to avoid constantly fetching + # new blocks in the inner loop for blocks we are skipping. + k_fetch = (block_mask[i, j] != 0) * k + return (prefetch_i[i, j], k_fetch) + +def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j): + del prefetch_mask, prefetch_i + k_fetch = (block_mask[i, j] != 0) * k + return (k_fetch, prefetch_j[i, j]) + +def mask_map(i, j, k, block_mask, prefetch_mask, *_): + del k, block_mask + return (prefetch_mask[i, j], 0, 0) + +def o_map(i, j, k, *_): + del k + return (i, j) + +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=4, + grid=(M // blk_M, N // blk_N, K // blk_K), + in_specs=[pl.BlockSpec((blk_M, blk_K), x_map), + pl.BlockSpec((blk_K, blk_N), y_map), + pl.BlockSpec((1, blk_M, blk_N), mask_map)], + out_specs=pl.BlockSpec((blk_M, blk_N), o_map), + scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)] +) +kernel = pl.pallas_call( + sparse_mask_matmul, + grid_spec=grid_spec, + out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16), +) +args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data) +result = kernel(*args) + +ref = mask * (X @ Y) +diff = jnp.abs(ref - result) +print('mean |result - ref|:', jnp.mean(diff)) +``` + ++++ {"id": "uutNGgjZGGhB"} + +Now let's compare performance versus a naive dense implementation. On TPU v5 lite, we achieve around a ~1.8x speed increase with the sparse kernel, compared to a theoretical best-case of 2x from using a lower triangular mask and only visiting half of the possible outputs. + +We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are: +- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation. +- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger. + +```{code-cell} +--- +executionInfo: + elapsed: 8877 + status: ok + timestamp: 1725917397452 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: MAT9JjGNvsx8 +outputId: a32d56fb-a71b-4007-c6a5-e5270dcaa6cf +--- +N_trials = 100 + +PallasImpl = lambda *args: kernel(*args) +time = benchmark(PallasImpl, N_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data) +print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, N_trials)) + +RefImpl = jax.jit(lambda mask, x, y: mask * (x @ y)) +time = benchmark(RefImpl, N_trials)(mask, X, Y) +print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, N_trials)) +```