Skip to content

Commit

Permalink
[Pallas] [Docs] Replace full urls with label-based cross references
Browse files Browse the repository at this point in the history
This PR uses the same method to add cross references as the previous PR #23889.

---

The content below is for future references.

#### Useful commands

Build documentation:

```sh
sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto
```

Create a label in *.md:

```md
(pallas_block_specs_by_example)=
```

Create a label in *.rst:

```rst
.. _pallas_tpu_noteworthy_properties:
```

Reference a label in *.md

```md
{ref}`pallas_block_specs_by_example`
```

PiperOrigin-RevId: 681675621
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Oct 3, 2024
1 parent 5a2e5a5 commit ee6295b
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 8 deletions.
8 changes: 4 additions & 4 deletions docs/pallas/grid_blockspec.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ programs write to disjoint places in HBM to avoid these parallel writes.

On TPUs, programs are executed in a combination of parallel and sequential
(depending on the architecture) so there are slightly different considerations.
See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions).

See {ref}`pallas_tpu_noteworthy_properties`.

(pallas_blockspec)=

Expand All @@ -88,8 +89,7 @@ to *which block of our inputs and outputs to be operated on*.
This is provided via {class}`jax.experimental.pallas.BlockSpec` objects.

Before we get into the details of `BlockSpec`s, you may want
to revisit the
[Pallas Quickstart BlockSpecs example](https://jax.readthedocs.io/en/latest/pallas/quickstart.html#block-specs-by-example).
to revisit {ref}`pallas_block_specs_by_example` in Pallas Quickstart.

`BlockSpec`s are provided to `pallas_call` via the
`in_specs` and `out_specs`, one for each input and output respectively.
Expand Down Expand Up @@ -239,7 +239,7 @@ The output shown below was generated on CPU using `interpret=True`
mode, which at the moment executes the invocation sequentially.
On TPUs, programs are executed in a combination of parallel and sequential,
and this function generates the output shown.
See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions).
See {ref}`pallas_tpu_noteworthy_properties`.

```python
>>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2, 10),
Expand Down
2 changes: 2 additions & 0 deletions docs/pallas/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"(pallas_block_specs_by_example)=\n",
"\n",
"### Block specs by example"
]
},
Expand Down
2 changes: 2 additions & 0 deletions docs/pallas/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ You can read more details at {ref}`pallas_grid`.

+++

(pallas_block_specs_by_example)=

### Block specs by example

+++
Expand Down
2 changes: 2 additions & 0 deletions docs/pallas/tpu/details.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ ideas described transfer to later generations as well.
* `TPU v4: An Optically Reconfigurable Supercomputer for Machine Learning with Hardware Support for Embeddings <https://dl.acm.org/doi/abs/10.1145/3579371.3589350>`_


.. _pallas_tpu_noteworthy_properties:

Noteworthy properties and restrictions
--------------------------------------

Expand Down
4 changes: 2 additions & 2 deletions docs/pallas/tpu/distributed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n",
"\n",
"Some recommended readings beforehand:\n",
" - [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html)\n",
" - [Pallas Pipelining on TPU](pallas_tpu_pipelining)\n",
" - [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial)"
]
},
Expand Down Expand Up @@ -1703,7 +1703,7 @@
"\n",
"### Megacore\n",
"\n",
"Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n",
"Certain TPUs contain multiple cores in a [Megacore](pallas_tpu_megacore) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n",
"\n",
"### Interaction with XLA\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/pallas/tpu/distributed.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ kernelspec:
In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.

Some recommended readings beforehand:
- [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html)
- [Pallas Pipelining on TPU](pallas_tpu_pipelining)
- [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial)

```{code-cell} ipython3
Expand Down Expand Up @@ -1516,7 +1516,7 @@ print(

### Megacore

Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.
Certain TPUs contain multiple cores in a [Megacore](pallas_tpu_megacore) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.

### Interaction with XLA

Expand Down
2 changes: 2 additions & 0 deletions docs/pallas/tpu/pipelining.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,8 @@
"id": "KvPFez9N8cKJ"
},
"source": [
"(pallas_tpu_megacore)=\n",
"\n",
"## TPUs in Megacore configuration"
]
},
Expand Down
2 changes: 2 additions & 0 deletions docs/pallas/tpu/pipelining.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ dimensions.

+++ {"id": "KvPFez9N8cKJ"}

(pallas_tpu_megacore)=

## TPUs in Megacore configuration

+++ {"id": "0f4HAVzQ8n71"}
Expand Down

0 comments on commit ee6295b

Please sign in to comment.