diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 1315783c340c..aa355f471a20 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -510,6 +510,8 @@ "the corresponding `PartitionSpec` `spec` as roughly\n", "`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.\n", "\n", + "(shard_map_collectives_tutorial)=\n", + "\n", "## Collectives tutorial\n", "\n", "A `shard_map` need not be a pure map: function applications can communicate\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 96667e709ac6..d77dec652068 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -357,6 +357,8 @@ from the shape `shape` of the corresponding argument to `shard_map`-of-`f` and the corresponding `PartitionSpec` `spec` as roughly `tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`. +(shard_map_collectives_tutorial)= + ## Collectives tutorial A `shard_map` need not be a pure map: function applications can communicate diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md index 267199128283..cde200528785 100644 --- a/docs/pallas/grid_blockspec.md +++ b/docs/pallas/grid_blockspec.md @@ -75,7 +75,8 @@ programs write to disjoint places in HBM to avoid these parallel writes. On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations. -See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions). + +See {ref}`pallas_tpu_noteworthy_properties`. (pallas_blockspec)= @@ -88,8 +89,7 @@ to *which block of our inputs and outputs to be operated on*. This is provided via {class}`jax.experimental.pallas.BlockSpec` objects. Before we get into the details of `BlockSpec`s, you may want -to revisit the -[Pallas Quickstart BlockSpecs example](https://jax.readthedocs.io/en/latest/pallas/quickstart.html#block-specs-by-example). +to revisit {ref}`pallas_block_specs_by_example` in Pallas Quickstart. `BlockSpec`s are provided to `pallas_call` via the `in_specs` and `out_specs`, one for each input and output respectively. @@ -239,7 +239,7 @@ The output shown below was generated on CPU using `interpret=True` mode, which at the moment executes the invocation sequentially. On TPUs, programs are executed in a combination of parallel and sequential, and this function generates the output shown. -See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions). +See {ref}`pallas_tpu_noteworthy_properties`. ```python >>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2, 10), diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 0e759a493a61..50464ce8ffd4 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -319,6 +319,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "(pallas_block_specs_by_example)=\n", + "\n", "### Block specs by example" ] }, diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index a8b13ea38eaf..b9acd6497fb5 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -209,6 +209,8 @@ You can read more details at {ref}`pallas_grid`. +++ +(pallas_block_specs_by_example)= + ### Block specs by example +++ diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index 4a2d4daa637f..b7ce10d564f6 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -59,6 +59,8 @@ ideas described transfer to later generations as well. * `TPU v4: An Optically Reconfigurable Supercomputer for Machine Learning with Hardware Support for Embeddings `_ +.. _pallas_tpu_noteworthy_properties: + Noteworthy properties and restrictions -------------------------------------- diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index 8552e10d8552..95abf803a780 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -11,8 +11,8 @@ "In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n", "\n", "Some recommended readings beforehand:\n", - " - [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html)\n", - " - [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial)" + " - [Pallas Pipelining on TPU](pallas_tpu_pipelining)\n", + " - [Collectives with `shard_map`](shard_map_collectives_tutorial)" ] }, { @@ -1703,7 +1703,7 @@ "\n", "### Megacore\n", "\n", - "Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n", + "Certain TPUs contain multiple cores in a [Megacore](pallas_tpu_megacore) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n", "\n", "### Interaction with XLA\n", "\n", diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index fc3f929866bd..c71f75ec6040 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -20,8 +20,8 @@ kernelspec: In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. Some recommended readings beforehand: - - [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html) - - [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial) + - [Pallas Pipelining on TPU](pallas_tpu_pipelining) + - [Collectives with `shard_map`](shard_map_collectives_tutorial) ```{code-cell} ipython3 --- @@ -1516,7 +1516,7 @@ print( ### Megacore -Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core. +Certain TPUs contain multiple cores in a [Megacore](pallas_tpu_megacore) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core. ### Interaction with XLA diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index b5f2c652b5a5..9774e08dcda8 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -645,6 +645,8 @@ "id": "KvPFez9N8cKJ" }, "source": [ + "(pallas_tpu_megacore)=\n", + "\n", "## TPUs in Megacore configuration" ] }, diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 19150b3832fa..21865430178d 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -436,6 +436,8 @@ dimensions. +++ {"id": "KvPFez9N8cKJ"} +(pallas_tpu_megacore)= + ## TPUs in Megacore configuration +++ {"id": "0f4HAVzQ8n71"}