diff --git a/docs/ShardingFormalism.md b/docs/proposals/ShardingFormalism.md similarity index 50% rename from docs/ShardingFormalism.md rename to docs/proposals/ShardingFormalism.md index a4973b913bb..416ea4670c8 100644 --- a/docs/ShardingFormalism.md +++ b/docs/proposals/ShardingFormalism.md @@ -86,32 +86,125 @@ _Add, And, BitShift, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, Equal, Great **Constraints on input sharding** * For any non-broadcast axis, the sharding spec of the two (or more) inputs must be identical -* Any broadcast axis of size 1 (in the unsharded original tensor) must be replicated across all devices that participate in the parallel computation (that is, all devices identified in the node's sharding spec). +* Any broadcast axis of size 1 (in the unsharded original tensor) must be replicated across all devices +that participate in the parallel computation (that is, all devices identified in the node's sharding spec). +* The case where there are two or more broadcast axes is more involved. Some conditions must be satisfied +to ensure that the natural output (without extra communication ops) has a proper (complete) sharding. +The constraint is that the sharding specs of the multiple broadcast axes must be *composable*, +which is illustrated down below. **Inference of output sharding** -* The sharding spec for any axes of the output is the same as the sharding spec for the axes of the -corresponding input axes in the case of non-broadcast. In the case of broadcast, the output axes -derives the sharding spec from the corresponding input axes with a size other than 1, if any. -In the special case where all corresponding input axes have a size of 1, the output axis inherits +* The sharding spec for any axis of the output is the same as the sharding spec for the corresponding +input axes in the case of non-broadcast. +* In the case of a single broadcast axis, the output axis derives the sharding spec from the corresponding +input axes with a size other than 1, if any. +* In the special case where all corresponding input axes have a size of 1, the output axis inherits the same sharding (that is, replicated across all devices of the node op). - -_Note_: The above can be generalized, but the generalization is hard to describe in words. -TODO: either add example figures or code to describe more complex scenarios. +* In the case of two or more broadcast axes, the output axis derives the sharding spec from the corresponding +input axes with a size other than 1, if any. However, the device assignment is inferred by composing the +sharding specs of all broadcast axes (where each output shard resides in the intersection of the sets of +devices that contain the corresponding input shards used to compute that output shard). See below for +an illustration of this. + +**Composing Sharding Specs on Different Axes** + +Consider the example of an `Add (Input1, Input2)` op. Consider the case where `Input1` has shape `[M, 1]` and +`Input2` has shape `[1, N]`. The output has shape `[M, N]`, as a result of broadcasting. + +The figure below shows how we can use sharding for both the `M` and `N` axes: + +![Composing sharding specs on different axes](images/composing_broadcast_axes.png) + +Note that in this example, both the `M` and `N` axes are split into two shards each. +This means that the output itself has 4 shards, as shown in the figure. +In this example, we want each output-shard to be on one device, as described by +the sharding spec +``` +{ + device = [0, 1, 2, 3] + sharded_dim =[ + { + axis = 0 + simple_sharding = + [ + { + num_shards = 2 + } + ] + } + { + axis = 1 + simple_sharding = + [ + { + num_shards = 2 + } + ] + } + ] +} +``` +To produce this output, however, we need to ensure that the input-shards are +each available in two devices each, as shown in the figure above. In particular, +the first shard of `Input1` is needed by both devices 0 and 1, as it is used +to compute the first two output shards. Likewise, the first shard of `Input2` +is needed by both devices 0 and 2. + +Thus, the sharding spec for `Input1` is as below: + +``` +{ + device = [-1, -2] // keys into device_map + device_map = {-1: [0, 1], -2: [2, 3]} + sharded_dim =[ + { + axis = 0 + simple_sharding = + [ + { + num_shards = 2 + } + ] + } + ] +} +``` +The sharding spec for `Input2` is analogous, as explained and shown in figure above. + +This leads to the following constraint for input-sharding and inference rule +for output-sharding in the presence of two broadcast axes: +* The (inferred) devices for `output-shard[i,j]` is the intersection of the set of devices +for `input-1-shard[i]` and `input-2-shard[j]`. If this set is empty, then the input +sharding specs are not compatible (for broadcast composition). + +This rule is extended to the case of more than two broadcast axes accordingly. ### Reduction ops **Constraints on input sharding** * No constraints on input sharding. -* Sharding along non-reduction axes is straightforward, since parallel iteration over the non-reduction -axes is possible. -* Sharding along reduction axes can be supported, but it requires an implicit collective-reduce operation. +* Sharding along non-reduction axes is straightforward. It indicates +parallelization of the iteration over the non-reduction axes. +* Sharding along reduction axes is permitted. It indicates parallelization of the reduction +loop, but this involves performing the reduction in two steps. In the first step, the +reduction is done locally on the shard, and in the second step the reduction is done +across the different shards. This can be typically mapped to a collective-reduce operation. **Inference of output sharding** * Non-reduction axes inherit the sharding of the corresponding axes of the input. -* Two natural possibilities exist for the reduction axes, if they are sharded. The result can be -broadcast to all devices containing some shard along the reduction axes, or just to the devices -containing a distinguished shard (say, the first one). As a default, we assume a broadcast (the -first option). +* Since the size of the reduction axis is one after the reduction, it can't be used +for any meaningful sharding. The axis may even be omitted from the output shape, +depending on the value of the attribute `keep_dims`. If the axis is retained, it +is treated as having no sharding. + +In the case where the inputs are only sharded along one or more reduction axes, +there will be no sharded axis in the inferred output sharding specification. +However, there is still a choice as to whether the computed output is replicated +on all the devices that participate in this operation, or whether it is stored +only in some distinguished node. Collective-reduce operations typically +support both variations. The default inferred output specification is to +broadcast the computed result to all devices that participate in the particular +reduction (the first option). ### MatMul-like ops @@ -119,13 +212,25 @@ List of operations: MatMul, Gemm, quantized variations of these ops, special cas The constraints for these ops follow analogous cases above. Consider the simple case of matrix multiplication of two matrices of dimensions `[M, K]` and `[K, N]` producing an output matrix of dimension `[M, N]`. +This operation is essentially a broadcast-reduction operation, where the first +input is interpreted to have the shape `[M, K, 1]` and the second input is interpreted to have +the shape `[1, K, N]`, and we perform a broadcast element-wise multiplication, followed +by a reduce-sum along the `K` axis. The constraints and inference for the operation follows +from the corresponding rules for broadcast and reduction described above. + Axis 0 of the first input (with value `M`) is conceptually broadcast to the second input. Hence, its constraints and handling are similar to the treatment of broadcast axes for n-ary -elementwise ops. +elementwise ops. Specifically, since only the first input has this axis, the partitioning of +this axis is not constrained by the partitioning of the second input. Furthermore, the output +matrix will inherit the partitioning for the corresponding axis from the partitioning of axis +0 of the first input. + Axis 1 of the second input (with value `N`) is also handled similarly. -The axes with size value `K` represent reduction axes. The corresponding two axes must have -compatible sharding. +The two axes with size value (the _reduction_ axes) are both required to +have the same sharding (similar to non-broadcast axes in a binary operation above). + +The output device assignment follows the rules described above for broadcast axes. ### Pooling and Convolution ops diff --git a/docs/proposals/images/composing_broadcast_axes.png b/docs/proposals/images/composing_broadcast_axes.png new file mode 100644 index 00000000000..d1470567314 Binary files /dev/null and b/docs/proposals/images/composing_broadcast_axes.png differ