Skip to content

Commit

Permalink
Improve explanation
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Nov 20, 2024
1 parent 9123089 commit 36689e5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
51 changes: 38 additions & 13 deletions docs/proposals/ShardingFormalism.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,17 @@ The constraint is that the sharding specs of the multiple broadcast axes must be
which is illustrated down below.

**Inference of output sharding**
* The sharding spec for any axes of the output is the same as the sharding spec for the axes of the
corresponding input axes in the case of non-broadcast. In the case of broadcast, the output axes
derives the sharding spec from the corresponding input axes with a size other than 1, if any.
In the special case where all corresponding input axes have a size of 1, the output axis inherits
* The sharding spec for any axis of the output is the same as the sharding spec for the corresponding
input axes in the case of non-broadcast.
* In the case of a single broadcast axis, the output axis derives the sharding spec from the corresponding
input axes with a size other than 1, if any.
* In the special case where all corresponding input axes have a size of 1, the output axis inherits
the same sharding (that is, replicated across all devices of the node op).
* In the case of two or more broadcast axes, the output axis derives the sharding spec from the corresponding
input axes with a size other than 1, if any. However, the device assignment is inferred by composing the
sharding specs of all broadcast axes (where each output shard resides in the intersection of the sets of
devices that contain the corresponding input shards used to compute that output shard). See below for
an illustration of this.

**Composing Sharding Specs on Different Axes**

Expand Down Expand Up @@ -177,23 +183,40 @@ This rule is extended to the case of more than two broadcast axes accordingly.

**Constraints on input sharding**
* No constraints on input sharding.
* Sharding along non-reduction axes is straightforward, since parallel iteration over the non-reduction
axes is possible.
* Sharding along reduction axes can be supported, but it requires an implicit collective-reduce operation.
* Sharding along non-reduction axes is straightforward. It indicates
parallelization of the iteration over the non-reduction axes, and is straightforward.
* Sharding along reduction axes is permitted. It indicates parallelization of the reduction
loop, but this involves performing the reduction in two steps. In the first step, the
reduction is done locally on the shard, and in the second step the reduction is done
across the different shards. This can be typically mapped to a collective-reduce operation.

**Inference of output sharding**
* Non-reduction axes inherit the sharding of the corresponding axes of the input.
* Two natural possibilities exist for the reduction axes, if they are sharded. The result can be
broadcast to all devices containing some shard along the reduction axes, or just to the devices
containing a distinguished shard (say, the first one). As a default, we assume a broadcast (the
first option).
* Since the size of the reduction axis is one after the reduction, it can't be used
for any meaningful sharding. The axis may even be omitted from the output shape,
depending on the value of the attribute `keep_dims`. If the axis is retained, it
is treated as having no sharding.

In the case where the inputs are only sharded along one or more reduction axes,
there will be no sharded axis in the inferred output sharding specification.
However, there is still a choice as to whether the computed output is replicated
on all the devices that participate in this operation, or whether it is stored
only in some distinguished node. Collective-reduce operations typically
support both variations. The default inferred output specification is to
broadcast the computed result to all devices that participate in the particular
reduction (the first option).

### MatMul-like ops

List of operations: MatMul, Gemm, quantized variations of these ops, special cases of EinSum

The constraints for these ops follow analogous cases above. Consider the simple case of matrix multiplication
of two matrices of dimensions `[M, K]` and `[K, N]` producing an output matrix of dimension `[M, N]`.
This operation is essentially a broadcast-reduction operation, where the first
input is interpreted to have the shape `[M, K, 1]` and the second input is interpreted to have
the shape `[1, K, N]`, and we perform a broadcast element-wise multiplication, followed
by a reduce-sum along the `K` axis. The constraints and inference for the operation follows
from the corresponding rules for broadcast and reduction described above.

Axis 0 of the first input (with value `M`) is conceptually broadcast to the second input.
Hence, its constraints and handling are similar to the treatment of broadcast axes for n-ary
Expand All @@ -204,8 +227,10 @@ matrix will inherit the partitioning for the corresponding axis from the partiti

Axis 1 of the second input (with value `N`) is also handled similarly.

The axes with size value `K` represent _reduction_ axes. The corresponding two axes must have
a reduction-compatible sharding. This means that the two axes must have the same sharding.
The two axes with size value (the _reduction_ axes) are both required to
have the same sharding (similar to non-broadcast axes in a binary operation above).

The output device assignment follows the rules described above for broadcast axes.

### Pooling and Convolution ops

Expand Down
Binary file modified docs/proposals/images/composing_broadcast_axes.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 36689e5

Please sign in to comment.