diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 5a8608f494c3..0ad7bb6c21c8 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -282,6 +282,36 @@ "On TPUs, programs are executed in a combination of parallel and sequential\n", "(depending on the architecture) so there are slightly different considerations.\n", "\n", + "The TPU version of the above kernel is:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "796f928c", + "metadata": {}, + "outputs": [], + "source": [ + "from jax.experimental.pallas import tpu as pltpu\n", + "\n", + "def iota(size: int):\n", + " return pl.pallas_call(iota_kernel,\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),\n", + " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", + " grid=(size,))()\n", + "iota(8)" + ] + }, + { + "cell_type": "markdown", + "id": "68f97b4e", + "metadata": {}, + "source": [ + "TPUs distinguish between vector and scalar memory spaces and in this case the\n", + "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", + "a scalar. For more details read\n", + "https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html.\n", + "\n", "You can read more details at {ref}`pallas_grid`." ] }, diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index b8f9254f21d9..fa88133422aa 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -188,6 +188,24 @@ operations like matrix multiplications really quickly. On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations. +The TPU version of the above kernel is: + +```{code-cell} ipython3 +from jax.experimental.pallas import tpu as pltpu + +def iota(size: int): + return pl.pallas_call(iota_kernel, + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), + grid=(size,))() +iota(8) +``` + +TPUs distinguish between vector and scalar memory spaces and in this case the +output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is +a scalar. For more details read +https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html. + You can read more details at {ref}`pallas_grid`. +++