Skip to content

Commit

Permalink
[Pallas TPU] Add a note in the Pallas Quickstart documentation about …
Browse files Browse the repository at this point in the history
…the instructions of running the existing example on TPU and some explanations about the difference in memory handling between TPUs and GPUs

This fixes #22817

This changes is originally proposed by @justinjfu in the comments of the above issue.

This PR is related to #23885.

PiperOrigin-RevId: 678420788
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Sep 26, 2024
1 parent 9f4e8d0 commit d7bd3a3
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions docs/pallas/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,24 @@ operations like matrix multiplications really quickly.
On TPUs, programs are executed in a combination of parallel and sequential
(depending on the architecture) so there are slightly different considerations.

The TPU version of the above kernel is:

```{code-cell} ipython3
from jax.experimental.pallas import tpu as pltpu
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid=(size,))()
iota(8)
```

TPUs distinguish between vector and scalar memory spaces and in this case the
output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is
a scalar. For more details read
https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html.

You can read more details at {ref}`pallas_grid`.

+++
Expand Down

0 comments on commit d7bd3a3

Please sign in to comment.