diff --git a/docs/proposals/ShardingFormalism.md b/docs/proposals/ShardingFormalism.md index 290fd4b7d1a..27b5dd1ae56 100644 --- a/docs/proposals/ShardingFormalism.md +++ b/docs/proposals/ShardingFormalism.md @@ -94,11 +94,17 @@ The constraint is that the sharding specs of the multiple broadcast axes must be which is illustrated down below. **Inference of output sharding** -* The sharding spec for any axes of the output is the same as the sharding spec for the axes of the -corresponding input axes in the case of non-broadcast. In the case of broadcast, the output axes -derives the sharding spec from the corresponding input axes with a size other than 1, if any. -In the special case where all corresponding input axes have a size of 1, the output axis inherits +* The sharding spec for any axis of the output is the same as the sharding spec for the corresponding +input axes in the case of non-broadcast. +* In the case of a single broadcast axis, the output axis derives the sharding spec from the corresponding +input axes with a size other than 1, if any. +* In the special case where all corresponding input axes have a size of 1, the output axis inherits the same sharding (that is, replicated across all devices of the node op). +* In the case of two or more broadcast axes, the output axis derives the sharding spec from the corresponding +input axes with a size other than 1, if any. However, the device assignment is inferred by composing the +sharding specs of all broadcast axes (where each output shard resides in the intersection of the sets of +devices that contain the corresponding input shards used to compute that output shard). See below for +an illustration of this. **Composing Sharding Specs on Different Axes** @@ -177,16 +183,28 @@ This rule is extended to the case of more than two broadcast axes accordingly. **Constraints on input sharding** * No constraints on input sharding. -* Sharding along non-reduction axes is straightforward, since parallel iteration over the non-reduction -axes is possible. -* Sharding along reduction axes can be supported, but it requires an implicit collective-reduce operation. +* Sharding along non-reduction axes is straightforward. It indicates +parallelization of the iteration over the non-reduction axes, and is straightforward. +* Sharding along reduction axes is permitted. It indicates parallelization of the reduction +loop, but this involves performing the reduction in two steps. In the first step, the +reduction is done locally on the shard, and in the second step the reduction is done +across the different shards. This can be typically mapped to a collective-reduce operation. **Inference of output sharding** * Non-reduction axes inherit the sharding of the corresponding axes of the input. -* Two natural possibilities exist for the reduction axes, if they are sharded. The result can be -broadcast to all devices containing some shard along the reduction axes, or just to the devices -containing a distinguished shard (say, the first one). As a default, we assume a broadcast (the -first option). +* Since the size of the reduction axis is one after the reduction, it can't be used +for any meaningful sharding. The axis may even be omitted from the output shape, +depending on the value of the attribute `keep_dims`. If the axis is retained, it +is treated as having no sharding. + +In the case where the inputs are only sharded along one or more reduction axes, +there will be no sharded axis in the inferred output sharding specification. +However, there is still a choice as to whether the computed output is replicated +on all the devices that participate in this operation, or whether it is stored +only in some distinguished node. Collective-reduce operations typically +support both variations. The default inferred output specification is to +broadcast the computed result to all devices that participate in the particular +reduction (the first option). ### MatMul-like ops @@ -194,6 +212,11 @@ List of operations: MatMul, Gemm, quantized variations of these ops, special cas The constraints for these ops follow analogous cases above. Consider the simple case of matrix multiplication of two matrices of dimensions `[M, K]` and `[K, N]` producing an output matrix of dimension `[M, N]`. +This operation is essentially a broadcast-reduction operation, where the first +input is interpreted to have the shape `[M, K, 1]` and the second input is interpreted to have +the shape `[1, K, N]`, and we perform a broadcast element-wise multiplication, followed +by a reduce-sum along the `K` axis. The constraints and inference for the operation follows +from the corresponding rules for broadcast and reduction described above. Axis 0 of the first input (with value `M`) is conceptually broadcast to the second input. Hence, its constraints and handling are similar to the treatment of broadcast axes for n-ary @@ -204,8 +227,10 @@ matrix will inherit the partitioning for the corresponding axis from the partiti Axis 1 of the second input (with value `N`) is also handled similarly. -The axes with size value `K` represent _reduction_ axes. The corresponding two axes must have -a reduction-compatible sharding. This means that the two axes must have the same sharding. +The two axes with size value (the _reduction_ axes) are both required to +have the same sharding (similar to non-broadcast axes in a binary operation above). + +The output device assignment follows the rules described above for broadcast axes. ### Pooling and Convolution ops diff --git a/docs/proposals/images/composing_broadcast_axes.png b/docs/proposals/images/composing_broadcast_axes.png index 45d3fd2ffb2..d1470567314 100644 Binary files a/docs/proposals/images/composing_broadcast_axes.png and b/docs/proposals/images/composing_broadcast_axes.png differ