Skip to content

Commit

Permalink
Merge pull request #22520 from sharadmv:pallas-matmul-docs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653777947
  • Loading branch information
jax authors committed Jul 18, 2024
2 parents 6cc4298 + c504fa6 commit 984eca2
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 37 deletions.
65 changes: 36 additions & 29 deletions docs/pallas/tpu/matmul.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"id": "-z6pOJwvn-_j"
},
"source": [
"# Pallas TPU - Your First Matmul\n",
"# Matrix Multiplication\n",
"\n",
"In this guide, we'll write a matrix multiplication routine using Pallas. We'll also go over how to think about matmul performance on TPU and how to template a matmul kernel to fuse in operations."
]
Expand Down Expand Up @@ -232,6 +232,9 @@
"\n",
"Let's think about how to analyze matrix multiplication performance.\n",
"\n",
"> Quick note: in this guide, we'll be discussing floating point operations, but want to make the distinction between FLOPs vs FLOP/s.\n",
" When we say \"FLOPs\" we mean \"floating point operations\", as in a number of operations. When we say \"FLOP/s\", we refer to \"floating point operations *per second*\", as in a *rate* of performing floating point operations.\n",
"\n",
"The number of FLOPs in a `(m, k) x (k, n)` matrix multiplication are (approximately) `2 * m * k * n`. (Technically it is `n * m * (2k - 1)` but for large enough `k` our approximation is sufficient.)\n",
"\n",
"The minimum amount of memory usage for a matrix multiply (assuming float32) is the total size of the inputs (copying into VMEM) plus the size of the output (copying into HBM). Thus the minimum bandwidth usage is `(m * k + k * n + m * n) * 4 bytes/float32`. Memory usage can be greater if we re-read the inputs multiple times, which is often the case.\n",
Expand Down Expand Up @@ -273,9 +276,9 @@
"id": "agCtb2GMQazl"
},
"source": [
"Now that we can calculate the amount of FLOPs and (minimum) memory bandwidth usage of a matrix multiplication, let's now see how many FLOP/s we can execute and how much memory bandwidth we actually have on a particular TPU.\n",
"Now that we can calculate the total number of FLOPs and (minimum) memory bandwidth usage of a matrix multiplication, let's now see how many FLOP/s we can execute and how much memory bandwidth we actually have on a particular TPU.\n",
"\n",
"This notebook was run on TPU v5e so we will utilize numbers for that specific chip (if you are running this notebook, your numbers may differ). TPU v5es have [197 TFLOPs of bf16 compute and 819 GB/s of memory bandwidth](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v5e). We can compute the minimum arithmetic intensity (ratio of FLOPs to memory bandwidth) needed to be compute bound in this analytical model on this chip by taking their ratio (about 240 FLOPs/byte on TPU v5e)."
"This notebook was run on TPU v5e so we will utilize numbers for that specific chip (if you are running this notebook, your numbers may differ). TPU v5es have [197 TFLOP/s of bf16 compute and 819 GB/s of memory bandwidth](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v5e). We can compute the minimum arithmetic intensity (ratio of FLOP/s to memory bandwidth) needed to be compute bound in this analytical model on this chip by taking their ratio (about 240 FLOPs/byte on TPU v5e)."
]
},
{
Expand Down Expand Up @@ -451,6 +454,8 @@
" grid=(m // bm, n // bn, k // bk),\n",
" ),\n",
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
" compiler_params=dict(mosaic=dict(\n",
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
" )(x, y)"
]
},
Expand All @@ -477,9 +482,9 @@
"source": [
"## Performance of pipelined kernels\n",
"\n",
"Our above analysis about FLOPS vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks.\n",
"Our above analysis about FLOPs vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks.\n",
"\n",
"This means that we actually care about the FLOPS vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPS vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.\n",
"This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.\n",
"\n",
"The intuition should therefore be: to be compute bound, make the blocks as big as possible! There are two main constraints:\n",
"\n",
Expand Down Expand Up @@ -507,33 +512,33 @@
"----- 1024 x 1024 x 1024 -----\n",
"Matmul time: 0.00029766598949208854\n",
"Matmul FLOP/s: 7214407167121.377\n",
"FLOPs utilization: 3.6621%\n",
"FLOP/s utilization: 3.6621%\n",
"\n",
"----- 4096 x 4096 x 4096 -----\n",
"Matmul time: 0.011771515250438824\n",
"Matmul FLOP/s: 11675553278230.387\n",
"FLOPs utilization: 5.9267%\n",
"FLOP/s utilization: 5.9267%\n",
"\n",
"----- 8192 x 8192 x 8192 -----\n",
"Matmul time: 0.09183577066054567\n",
"Matmul FLOP/s: 11972585626140.668\n",
"FLOPs utilization: 6.0775%\n",
"FLOP/s utilization: 6.0775%\n",
"\n",
"================bm=512, bk=1024, bn=1024===================\n",
"----- 1024 x 1024 x 1024 -----\n",
"Matmul time: 0.00012708659982308746\n",
"Matmul FLOP/s: 16897797651282.135\n",
"FLOPs utilization: 8.5776%\n",
"FLOP/s utilization: 8.5776%\n",
"\n",
"----- 4096 x 4096 x 4096 -----\n",
"Matmul time: 0.00088908776990138\n",
"Matmul FLOP/s: 154584235803001.88\n",
"FLOPs utilization: 78.4692%\n",
"FLOP/s utilization: 78.4692%\n",
"\n",
"----- 8192 x 8192 x 8192 -----\n",
"Matmul time: 0.006099433819763363\n",
"Matmul FLOP/s: 180264539343531.62\n",
"FLOPs utilization: 91.5048%\n",
"FLOP/s utilization: 91.5048%\n",
"\n"
]
}
Expand Down Expand Up @@ -562,7 +567,7 @@
" print(\"Matmul time: \", time)\n",
" mm_flops = matmul_flops(m, k, n) / time\n",
" print(\"Matmul FLOP/s: \", mm_flops)\n",
" print(f\"FLOPs utilization: {mm_flops / v5e_bf16_flops * 100:.4f}%\")\n",
" print(f\"FLOP/s utilization: {mm_flops / v5e_bf16_flops * 100:.4f}%\")\n",
" print()\n",
"\n",
"print(\"================bm=128, bk=128, bn=128===================\")\n",
Expand Down Expand Up @@ -606,17 +611,17 @@
"----- 1024 x 1024 x 1024 -----\n",
"Matmul time: 0.00011943008983507753\n",
"Matmul FLOP/s: 17981093801113.996\n",
"FLOPs utilization: 9.1275%\n",
"FLOP/s utilization: 9.1275%\n",
"\n",
"----- 4096 x 4096 x 4096 -----\n",
"Matmul time: 0.0008272899803705514\n",
"Matmul FLOP/s: 166131533963991.34\n",
"FLOPs utilization: 84.3307%\n",
"FLOP/s utilization: 84.3307%\n",
"\n",
"----- 8192 x 8192 x 8192 -----\n",
"Matmul time: 0.006047147869830951\n",
"Matmul FLOP/s: 181823175395037.44\n",
"FLOPs utilization: 92.2960%\n",
"FLOP/s utilization: 92.2960%\n",
"\n"
]
}
Expand Down Expand Up @@ -722,6 +727,8 @@
" grid=(m // bm, n // bn, k // bk),\n",
" ),\n",
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
" compiler_params=dict(mosaic=dict(\n",
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
" )(x, y)"
]
},
Expand Down Expand Up @@ -755,33 +762,33 @@
"----- 1024 x 1024 x 1024 -----\n",
"Matmul time: 0.0003029372810851783\n",
"Matmul FLOP/s: 7088872126624.065\n",
"FLOPs utilization: 3.5984%\n",
"FLOP/s utilization: 3.5984%\n",
"\n",
"----- 4096 x 4096 x 4096 -----\n",
"Matmul time: 0.012017967159627005\n",
"Matmul FLOP/s: 11436123235026.848\n",
"FLOPs utilization: 5.8051%\n",
"FLOP/s utilization: 5.8051%\n",
"\n",
"----- 8192 x 8192 x 8192 -----\n",
"Matmul time: 0.09500920018996112\n",
"Matmul FLOP/s: 11572685861765.383\n",
"FLOPs utilization: 5.8745%\n",
"FLOP/s utilization: 5.8745%\n",
"\n",
"================bm=512, bk=1024, bn=1024===================\n",
"----- 1024 x 1024 x 1024 -----\n",
"Matmul time: 0.00012131539988331496\n",
"Matmul FLOP/s: 17701657415839.363\n",
"FLOPs utilization: 8.9856%\n",
"FLOP/s utilization: 8.9856%\n",
"\n",
"----- 4096 x 4096 x 4096 -----\n",
"Matmul time: 0.0008790623804088682\n",
"Matmul FLOP/s: 156347213275211.03\n",
"FLOPs utilization: 79.3641%\n",
"FLOP/s utilization: 79.3641%\n",
"\n",
"----- 8192 x 8192 x 8192 -----\n",
"Matmul time: 0.006107717020204291\n",
"Matmul FLOP/s: 180020067095253.78\n",
"FLOPs utilization: 91.3807%\n",
"FLOP/s utilization: 91.3807%\n",
"\n"
]
}
Expand All @@ -804,7 +811,7 @@
" print(\"Matmul time: \", time)\n",
" mm_flops = matmul_flops(m, k, n) / time\n",
" print(\"Matmul FLOP/s: \", mm_flops)\n",
" print(f\"FLOPs utilization: {mm_flops / v5e_bf16_flops * 100:.4f}%\")\n",
" print(f\"FLOP/s utilization: {mm_flops / v5e_bf16_flops * 100:.4f}%\")\n",
" print()\n",
"\n",
"print(\"================bm=128, bk=128, bn=128===================\")\n",
Expand Down Expand Up @@ -929,33 +936,33 @@
"----- 1024 x 1024 x 1024 -----\n",
"Matmul time: 0.00030103540048003196\n",
"Matmul FLOP/s: 7133658182976.541\n",
"FLOPs utilization: 3.6211%\n",
"FLOP/s utilization: 3.6211%\n",
"\n",
"----- 4096 x 4096 x 4096 -----\n",
"Matmul time: 0.011807117109419778\n",
"Matmul FLOP/s: 11640348122095.826\n",
"FLOPs utilization: 5.9088%\n",
"FLOP/s utilization: 5.9088%\n",
"\n",
"----- 8192 x 8192 x 8192 -----\n",
"Matmul time: 0.09181861146935262\n",
"Matmul FLOP/s: 11974823079773.941\n",
"FLOPs utilization: 6.0786%\n",
"FLOP/s utilization: 6.0786%\n",
"\n",
"================bm=512, bk=1024, bn=1024===================\n",
"----- 1024 x 1024 x 1024 -----\n",
"Matmul time: 0.00012622540001757442\n",
"Matmul FLOP/s: 17013086492108.6\n",
"FLOPs utilization: 8.6361%\n",
"FLOP/s utilization: 8.6361%\n",
"\n",
"----- 4096 x 4096 x 4096 -----\n",
"Matmul time: 0.000896632740041241\n",
"Matmul FLOP/s: 153283442968721.44\n",
"FLOPs utilization: 77.8089%\n",
"FLOP/s utilization: 77.8089%\n",
"\n",
"----- 8192 x 8192 x 8192 -----\n",
"Matmul time: 0.006130605939542875\n",
"Matmul FLOP/s: 179347953304919.88\n",
"FLOPs utilization: 91.0396%\n",
"FLOP/s utilization: 91.0396%\n",
"\n"
]
}
Expand All @@ -979,7 +986,7 @@
" print(\"Matmul time: \", time)\n",
" mm_flops = matmul_flops(m, k, n) / time\n",
" print(\"Matmul FLOP/s: \", mm_flops)\n",
" print(f\"FLOPs utilization: {mm_flops / v5e_bf16_flops * 100:.4f}%\")\n",
" print(f\"FLOP/s utilization: {mm_flops / v5e_bf16_flops * 100:.4f}%\")\n",
" print()\n",
"\n",
"\n",
Expand Down
23 changes: 15 additions & 8 deletions docs/pallas/tpu/matmul.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ kernelspec:

+++ {"id": "-z6pOJwvn-_j"}

# Pallas TPU - Your First Matmul
# Matrix Multiplication

In this guide, we'll write a matrix multiplication routine using Pallas. We'll also go over how to think about matmul performance on TPU and how to template a matmul kernel to fuse in operations.

Expand Down Expand Up @@ -180,6 +180,9 @@ np.testing.assert_array_equal(x @ y, matmul(x, y))

Let's think about how to analyze matrix multiplication performance.

> Quick note: in this guide, we'll be discussing floating point operations, but want to make the distinction between FLOPs vs FLOP/s.
When we say "FLOPs" we mean "floating point operations", as in a number of operations. When we say "FLOP/s", we refer to "floating point operations *per second*", as in a *rate* of performing floating point operations.

The number of FLOPs in a `(m, k) x (k, n)` matrix multiplication are (approximately) `2 * m * k * n`. (Technically it is `n * m * (2k - 1)` but for large enough `k` our approximation is sufficient.)

The minimum amount of memory usage for a matrix multiply (assuming float32) is the total size of the inputs (copying into VMEM) plus the size of the output (copying into HBM). Thus the minimum bandwidth usage is `(m * k + k * n + m * n) * 4 bytes/float32`. Memory usage can be greater if we re-read the inputs multiple times, which is often the case.
Expand All @@ -202,9 +205,9 @@ print(matmul_membw(1024, 1024, 1024, jnp.float32))

+++ {"id": "agCtb2GMQazl"}

Now that we can calculate the amount of FLOPs and (minimum) memory bandwidth usage of a matrix multiplication, let's now see how many FLOP/s we can execute and how much memory bandwidth we actually have on a particular TPU.
Now that we can calculate the total number of FLOPs and (minimum) memory bandwidth usage of a matrix multiplication, let's now see how many FLOP/s we can execute and how much memory bandwidth we actually have on a particular TPU.

This notebook was run on TPU v5e so we will utilize numbers for that specific chip (if you are running this notebook, your numbers may differ). TPU v5es have [197 TFLOPs of bf16 compute and 819 GB/s of memory bandwidth](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v5e). We can compute the minimum arithmetic intensity (ratio of FLOPs to memory bandwidth) needed to be compute bound in this analytical model on this chip by taking their ratio (about 240 FLOPs/byte on TPU v5e).
This notebook was run on TPU v5e so we will utilize numbers for that specific chip (if you are running this notebook, your numbers may differ). TPU v5es have [197 TFLOP/s of bf16 compute and 819 GB/s of memory bandwidth](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v5e). We can compute the minimum arithmetic intensity (ratio of FLOP/s to memory bandwidth) needed to be compute bound in this analytical model on this chip by taking their ratio (about 240 FLOPs/byte on TPU v5e).

```{code-cell} ipython3
:id: WUydNX2-K6Oy
Expand Down Expand Up @@ -308,6 +311,8 @@ def matmul(
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
compiler_params=dict(mosaic=dict(
dimension_semantics=("parallel", "parallel", "arbitrary"))),
)(x, y)
```

Expand All @@ -325,9 +330,9 @@ np.testing.assert_array_equal(x @ y, matmul(x, y))

## Performance of pipelined kernels

Our above analysis about FLOPS vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks.
Our above analysis about FLOPs vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks.

This means that we actually care about the FLOPS vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPS vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.
This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.

The intuition should therefore be: to be compute bound, make the blocks as big as possible! There are two main constraints:

Expand Down Expand Up @@ -365,7 +370,7 @@ def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
print("Matmul time: ", time)
mm_flops = matmul_flops(m, k, n) / time
print("Matmul FLOP/s: ", mm_flops)
print(f"FLOPs utilization: {mm_flops / v5e_bf16_flops * 100:.4f}%")
print(f"FLOP/s utilization: {mm_flops / v5e_bf16_flops * 100:.4f}%")
print()
print("================bm=128, bk=128, bn=128===================")
Expand Down Expand Up @@ -472,6 +477,8 @@ def matmul(
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
compiler_params=dict(mosaic=dict(
dimension_semantics=("parallel", "parallel", "arbitrary"))),
)(x, y)
```

Expand Down Expand Up @@ -505,7 +512,7 @@ def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
print("Matmul time: ", time)
mm_flops = matmul_flops(m, k, n) / time
print("Matmul FLOP/s: ", mm_flops)
print(f"FLOPs utilization: {mm_flops / v5e_bf16_flops * 100:.4f}%")
print(f"FLOP/s utilization: {mm_flops / v5e_bf16_flops * 100:.4f}%")
print()
print("================bm=128, bk=128, bn=128===================")
Expand Down Expand Up @@ -621,7 +628,7 @@ def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
print("Matmul time: ", time)
mm_flops = matmul_flops(m, k, n) / time
print("Matmul FLOP/s: ", mm_flops)
print(f"FLOPs utilization: {mm_flops / v5e_bf16_flops * 100:.4f}%")
print(f"FLOP/s utilization: {mm_flops / v5e_bf16_flops * 100:.4f}%")
print()
Expand Down

0 comments on commit 984eca2

Please sign in to comment.