Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] [Docs] Replace full urls with label-based cross references #24091

Merged
merged 1 commit into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/notebooks/shard_map.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@
"the corresponding `PartitionSpec` `spec` as roughly\n",
"`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.\n",
"\n",
"(shard_map_collectives_tutorial)=\n",
"\n",
"## Collectives tutorial\n",
"\n",
"A `shard_map` need not be a pure map: function applications can communicate\n",
Expand Down
2 changes: 2 additions & 0 deletions docs/notebooks/shard_map.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ from the shape `shape` of the corresponding argument to `shard_map`-of-`f` and
the corresponding `PartitionSpec` `spec` as roughly
`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.

(shard_map_collectives_tutorial)=

## Collectives tutorial

A `shard_map` need not be a pure map: function applications can communicate
Expand Down
8 changes: 4 additions & 4 deletions docs/pallas/grid_blockspec.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ programs write to disjoint places in HBM to avoid these parallel writes.

On TPUs, programs are executed in a combination of parallel and sequential
(depending on the architecture) so there are slightly different considerations.
See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions).

See {ref}`pallas_tpu_noteworthy_properties`.

(pallas_blockspec)=

Expand All @@ -88,8 +89,7 @@ to *which block of our inputs and outputs to be operated on*.
This is provided via {class}`jax.experimental.pallas.BlockSpec` objects.

Before we get into the details of `BlockSpec`s, you may want
to revisit the
[Pallas Quickstart BlockSpecs example](https://jax.readthedocs.io/en/latest/pallas/quickstart.html#block-specs-by-example).
to revisit {ref}`pallas_block_specs_by_example` in Pallas Quickstart.

`BlockSpec`s are provided to `pallas_call` via the
`in_specs` and `out_specs`, one for each input and output respectively.
Expand Down Expand Up @@ -239,7 +239,7 @@ The output shown below was generated on CPU using `interpret=True`
mode, which at the moment executes the invocation sequentially.
On TPUs, programs are executed in a combination of parallel and sequential,
and this function generates the output shown.
See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions).
See {ref}`pallas_tpu_noteworthy_properties`.

```python
>>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2, 10),
Expand Down
2 changes: 2 additions & 0 deletions docs/pallas/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"(pallas_block_specs_by_example)=\n",
"\n",
"### Block specs by example"
]
},
Expand Down
2 changes: 2 additions & 0 deletions docs/pallas/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ You can read more details at {ref}`pallas_grid`.

+++

(pallas_block_specs_by_example)=

### Block specs by example

+++
Expand Down
2 changes: 2 additions & 0 deletions docs/pallas/tpu/details.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ ideas described transfer to later generations as well.
* `TPU v4: An Optically Reconfigurable Supercomputer for Machine Learning with Hardware Support for Embeddings <https://dl.acm.org/doi/abs/10.1145/3579371.3589350>`_


.. _pallas_tpu_noteworthy_properties:

Noteworthy properties and restrictions
--------------------------------------

Expand Down
6 changes: 3 additions & 3 deletions docs/pallas/tpu/distributed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n",
"\n",
"Some recommended readings beforehand:\n",
" - [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html)\n",
" - [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial)"
" - [Pallas Pipelining on TPU](pallas_tpu_pipelining)\n",
" - [Collectives with `shard_map`](shard_map_collectives_tutorial)"
]
},
{
Expand Down Expand Up @@ -1703,7 +1703,7 @@
"\n",
"### Megacore\n",
"\n",
"Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n",
"Certain TPUs contain multiple cores in a [Megacore](pallas_tpu_megacore) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n",
"\n",
"### Interaction with XLA\n",
"\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/pallas/tpu/distributed.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ kernelspec:
In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.

Some recommended readings beforehand:
- [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html)
- [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial)
- [Pallas Pipelining on TPU](pallas_tpu_pipelining)
- [Collectives with `shard_map`](shard_map_collectives_tutorial)

```{code-cell} ipython3
---
Expand Down Expand Up @@ -1516,7 +1516,7 @@ print(

### Megacore

Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.
Certain TPUs contain multiple cores in a [Megacore](pallas_tpu_megacore) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.

### Interaction with XLA

Expand Down
2 changes: 2 additions & 0 deletions docs/pallas/tpu/pipelining.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,8 @@
"id": "KvPFez9N8cKJ"
},
"source": [
"(pallas_tpu_megacore)=\n",
"\n",
"## TPUs in Megacore configuration"
]
},
Expand Down
2 changes: 2 additions & 0 deletions docs/pallas/tpu/pipelining.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ dimensions.

+++ {"id": "KvPFez9N8cKJ"}

(pallas_tpu_megacore)=

## TPUs in Megacore configuration

+++ {"id": "0f4HAVzQ8n71"}
Expand Down
Loading