From ab4590ce0a415d750ea92d62701fdba17165b297 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 27 Sep 2024 01:30:21 -0700 Subject: [PATCH] [Pallas TPU] Add a note in the Pallas Quickstart documentation about the instructions of running the existing example on TPU This fixes https://github.com/jax-ml/jax/issues/22817 This changes is originally proposed by @justinjfu in the comments of the above issue. This PR is related to https://github.com/jax-ml/jax/pull/23885. PiperOrigin-RevId: 679487218 --- docs/pallas/quickstart.ipynb | 29 +++++++++++++++++++++++++++++ docs/pallas/quickstart.md | 17 +++++++++++++++++ docs/pallas/tpu/pipelining.ipynb | 8 ++++++++ docs/pallas/tpu/pipelining.md | 2 ++ 4 files changed, 56 insertions(+) diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 5a8608f494c3..0e759a493a61 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -282,6 +282,35 @@ "On TPUs, programs are executed in a combination of parallel and sequential\n", "(depending on the architecture) so there are slightly different considerations.\n", "\n", + "To call the above kernel on TPU, run:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "796f928c", + "metadata": {}, + "outputs": [], + "source": [ + "from jax.experimental.pallas import tpu as pltpu\n", + "\n", + "def iota(size: int):\n", + " return pl.pallas_call(iota_kernel,\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),\n", + " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", + " grid=(size,))()\n", + "iota(8)" + ] + }, + { + "cell_type": "markdown", + "id": "68f97b4e", + "metadata": {}, + "source": [ + "TPUs distinguish between vector and scalar memory spaces and in this case the\n", + "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", + "a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n", + "\n", "You can read more details at {ref}`pallas_grid`." ] }, diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index b8f9254f21d9..a8b13ea38eaf 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -188,6 +188,23 @@ operations like matrix multiplications really quickly. On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations. +To call the above kernel on TPU, run: + +```{code-cell} ipython3 +from jax.experimental.pallas import tpu as pltpu + +def iota(size: int): + return pl.pallas_call(iota_kernel, + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), + grid=(size,))() +iota(8) +``` + +TPUs distinguish between vector and scalar memory spaces and in this case the +output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is +a scalar. For more details read {ref}`pallas_tpu_pipelining`. + You can read more details at {ref}`pallas_grid`. +++ diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 2a3aa9d114de..b5f2c652b5a5 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -1,5 +1,13 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "7704d3bb", + "metadata": {}, + "source": [ + "(pallas_tpu_pipelining)=" + ] + }, { "cell_type": "markdown", "metadata": { diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 507eab658a39..19150b3832fa 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -11,6 +11,8 @@ kernelspec: name: python3 --- +(pallas_tpu_pipelining)= + +++ {"id": "teoJ_fUwlu0l"} # Pipelining