diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 5a8608f494c3..0e759a493a61 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -282,6 +282,35 @@ "On TPUs, programs are executed in a combination of parallel and sequential\n", "(depending on the architecture) so there are slightly different considerations.\n", "\n", + "To call the above kernel on TPU, run:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "796f928c", + "metadata": {}, + "outputs": [], + "source": [ + "from jax.experimental.pallas import tpu as pltpu\n", + "\n", + "def iota(size: int):\n", + " return pl.pallas_call(iota_kernel,\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),\n", + " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", + " grid=(size,))()\n", + "iota(8)" + ] + }, + { + "cell_type": "markdown", + "id": "68f97b4e", + "metadata": {}, + "source": [ + "TPUs distinguish between vector and scalar memory spaces and in this case the\n", + "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", + "a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n", + "\n", "You can read more details at {ref}`pallas_grid`." ] }, diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index b8f9254f21d9..a8b13ea38eaf 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -188,6 +188,23 @@ operations like matrix multiplications really quickly. On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations. +To call the above kernel on TPU, run: + +```{code-cell} ipython3 +from jax.experimental.pallas import tpu as pltpu + +def iota(size: int): + return pl.pallas_call(iota_kernel, + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), + grid=(size,))() +iota(8) +``` + +TPUs distinguish between vector and scalar memory spaces and in this case the +output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is +a scalar. For more details read {ref}`pallas_tpu_pipelining`. + You can read more details at {ref}`pallas_grid`. +++ diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 2a3aa9d114de..b5f2c652b5a5 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -1,5 +1,13 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "7704d3bb", + "metadata": {}, + "source": [ + "(pallas_tpu_pipelining)=" + ] + }, { "cell_type": "markdown", "metadata": { diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 507eab658a39..19150b3832fa 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -11,6 +11,8 @@ kernelspec: name: python3 --- +(pallas_tpu_pipelining)= + +++ {"id": "teoJ_fUwlu0l"} # Pipelining