Skip to content

Commit

Permalink
[Pallas TPU] Add a note in the Pallas Quickstart documentation about …
Browse files Browse the repository at this point in the history
…the instructions of running the existing example on TPU

This fixes #22817

This changes is originally proposed by @justinjfu in the comments of the above issue.

This PR is related to #23885.

PiperOrigin-RevId: 679487218
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Sep 27, 2024
1 parent 5a1549c commit ab4590c
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 0 deletions.
29 changes: 29 additions & 0 deletions docs/pallas/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,35 @@
"On TPUs, programs are executed in a combination of parallel and sequential\n",
"(depending on the architecture) so there are slightly different considerations.\n",
"\n",
"To call the above kernel on TPU, run:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "796f928c",
"metadata": {},
"outputs": [],
"source": [
"from jax.experimental.pallas import tpu as pltpu\n",
"\n",
"def iota(size: int):\n",
" return pl.pallas_call(iota_kernel,\n",
" out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),\n",
" out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n",
" grid=(size,))()\n",
"iota(8)"
]
},
{
"cell_type": "markdown",
"id": "68f97b4e",
"metadata": {},
"source": [
"TPUs distinguish between vector and scalar memory spaces and in this case the\n",
"output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n",
"a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n",
"\n",
"You can read more details at {ref}`pallas_grid`."
]
},
Expand Down
17 changes: 17 additions & 0 deletions docs/pallas/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,23 @@ operations like matrix multiplications really quickly.
On TPUs, programs are executed in a combination of parallel and sequential
(depending on the architecture) so there are slightly different considerations.

To call the above kernel on TPU, run:

```{code-cell} ipython3
from jax.experimental.pallas import tpu as pltpu
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid=(size,))()
iota(8)
```

TPUs distinguish between vector and scalar memory spaces and in this case the
output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is
a scalar. For more details read {ref}`pallas_tpu_pipelining`.

You can read more details at {ref}`pallas_grid`.

+++
Expand Down
8 changes: 8 additions & 0 deletions docs/pallas/tpu/pipelining.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7704d3bb",
"metadata": {},
"source": [
"(pallas_tpu_pipelining)="
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down
2 changes: 2 additions & 0 deletions docs/pallas/tpu/pipelining.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ kernelspec:
name: python3
---

(pallas_tpu_pipelining)=

+++ {"id": "teoJ_fUwlu0l"}

# Pipelining
Expand Down

0 comments on commit ab4590c

Please sign in to comment.