From 33afe10f78792a8ce5ae7cf50b6902e5ec1993c3 Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Wed, 18 Dec 2024 13:36:59 -0800 Subject: [PATCH] Add multi-device execution support in ONNX Signed-off-by: Kevin Chen --- docs/proposals/ONNXMultiDeviceProposal.md | 179 ++++++++++++++++++++++ onnx/onnx.in.proto | 71 +++++++++ 2 files changed, 250 insertions(+) create mode 100644 docs/proposals/ONNXMultiDeviceProposal.md diff --git a/docs/proposals/ONNXMultiDeviceProposal.md b/docs/proposals/ONNXMultiDeviceProposal.md new file mode 100644 index 00000000000..209fbac97c2 --- /dev/null +++ b/docs/proposals/ONNXMultiDeviceProposal.md @@ -0,0 +1,179 @@ + + + + +# ONNX Multi-Device Proposal + +## Background + +The recent trend in increasingly larger models has spurred an interest in distributed inference. A key performance bottleneck for inference for these large models has been the memory limits of GPUs and other accelerators as well as communication bandwidth. Thus, efficient distributed inference typically requires parallelization of the computation across multiple devices taking memory and bandwidth into account. + +Our goal is to extend ONNX so that it can serve as a representation of a parallelized model. This is driven by the current state-of-the-art techniques used for distributed inference (eg., see [GSPMD: General and Scalable Parallelization for ML Computation Graphs](https://arxiv.org/pdf/2105.04663.pdf)). In particular, two techniques of interest are tensor parallelism and pipelining. In tensor parallelism (also known as horizontal parallelism or operator parallelism), the computation of a single operator (node) in the graph is parallelized across multiple devices by sharding its inputs, In pipeline parallelism, different subgraphs are assigned to different devices. + + +## Design + +See [this commit](https://github.com/kevinch-nv/onnx/commit/07e97452096b28ba7c46fec6927d195907431e07) for the proposed additions to the ONNX spec. + +The key point of this design is that all multi-device specific annotations are at the node level, and do not affect the main computational graph. This means: + - All communication operations required for multi-device execution are implicit + - A backend may choose to ignore the annotations if the provided configurations are either not supported or not available + +### Sharding Specification + +Sharding refers to modifying a tensor into multiple parts to be sent across multiple devices. A tensor may be sharded across any of its axis. + +Modification of a tensor generally falls into two categories: splitting and duplication. + +#### Sharding as a Split + +For example, consider the following 2x2 tensor: + +`[[1, 2], [3, 4]]` + +If a sharding across axis 0 is specified over two devices, then: +- Device 0 will receive a tensor of shape 1x2 with data `[[1, 2]]` +- Device 1 will receive a tensor of shape 1x2 with data `[[3, 4]]` + +The corresponding ShardingSpecProto for the above will look like: +``` +{ + device = [0, 1] + sharded_dim =[ + { + axis = 0 + simple_sharding = + [ + { + num_shards = 2 + } + ] + } + ] +} +``` + +If a sharding across axis 1 is specified over two devices, then: +- Device 0 will receive a tensor of shape 2x1 with data `[[1], [3]]` +- Device 1 will receive a tensor of shape 2x1 with data `[[2], [4]]` + +The corresponding ShardingSpecProto for the above will look like: +``` +{ + device = [0, 1] + sharded_dim =[ + { + axis = 1 + simple_sharding = + [ + { + num_shards = 2 + } + ] + } + ] +} +``` + +If a sharding across axis 0 and axis 1 is specified over four devices, then: +- Device 0 will receive a tensor of shape 1x1 with data `[[1]]` +- Device 1 will receive a tensor of shape 1x1 with data `[[2]]` +- Device 2 will receive a tensor of shape 1x1 with data `[[3]]` +- Device 3 will receive a tensor of shape 1x1 with data `[[4]]` + +The corresponding ShardingSpecProto for the above will look like: +``` +{ + device = [0, 1, 2, 3] + sharded_dim =[ + { + axis = 0 + simple_sharding = + [ + { + num_shards = 2 + } + ] + } + { + axis = 1 + simple_sharding = + [ + { + num_shards = 2 + } + ] + } + ] +} +``` + +A key observation in the above example shows how indexing is performed when multiple sharding axes are provided. In general, the splitting is done as: + +``` +split_tensors = [] +for a in range(num_shards_a): + a_index = a * input.shape[axis0] / num_shards_a + for b in range(num_shards_b): + b_index = b * input.shape[axis1] / num_shards_b + split = input[a_index : a_index + num_shards_a, b_index : b_index + num_shards_b] + split_tensors.append(split) +``` + +Note that the above examples assume that the num_shards are evenly divisible into the axis that's being sharded. While this is not a hard restriction, it is up to the backend on how to handle non-evenly divisble cases. + + +#### Sharding as a Broadcast + +There may be cases where data in a tensor must be duplicated across multiple devices to ensure that operations stay functionaly correct. + +For example consider replicating the same 2x2 tensor across two devices. We can do so by providing the following ShardingSpecProto: + +``` +{ + device = [-1] // keys into device_map + device_map = {-1: [0, 1]} + sharded_dim =[] +} +``` + +It is also possible to mix splitting and broadcasting, consider the following ShardingSpecProto: + +``` +{ + device = [-1, -2] // keys into device_map + device_map = {-1: [0, 1], -2: [2, 3]} + sharded_dim =[ + { + axis = 0 + simple_sharding = + [ + { + num_shards = 2 + } + ] + } + ] +} +``` + +On device 0 and 1, the following 1x2 tensor is produced: `[[1,2]]` +On device 2 and 3, the following 1x2 tensor is produced: `[[2,3]]` + +#### Pipeline Parallelism + +Pipeline stages are represented as an optional integer value in a node's NodeConfigurationProto. It is a hint to the backend on how to run a model in a pipelined fashion across multiple devices. For example, consider the following diagram: + +``` +Nodes below have a pipeline id of 1: + +A -> B -> C -> D -> E + | Nodes below have a pipeline id of 2: + F -> G -> H -> I -> J -> K + +``` + +It is possible to have both pipeline and tensor parallel annotations in the same ONNX graph. + diff --git a/onnx/onnx.in.proto b/onnx/onnx.in.proto index d30e9393cc1..a0633237b7d 100644 --- a/onnx/onnx.in.proto +++ b/onnx/onnx.in.proto @@ -231,6 +231,63 @@ message NodeProto { // Named metadata values; keys should be distinct. repeated StringStringEntryProto metadata_props = 9; + + // Configuration of multi-device annotations. + repeated NodeConfigurationProto configuration = 10; +} + +// Multi-device configuration proto for NodeProto. +message NodeConfigurationProto { + // ID of the configuration. + string configuration_id = 1; + // Sharding spec for the node. + repeated ShardingSpecProto sharding_spec = 2; + // Pipeline stage of this node. + optional int pipeline_stage = 3; +} + +// ShardingSpecProto: This describes the sharding spec for a specific +// input/output of a node. +message ShardingSpecProto { + // Identifies the input/output of the node that is being sharded. + // It is called `logical tensor` in subsequent descriptions. + string tensor_name = 1; + + // The following is the list of devices across which the logical + // tensor is sharded or replicated. + repeated int64 device = 2; + + // Each element v in above field devices may represent either a + // device or a set of devices (when we want the same shard/tensor + // to be replicated across a subset of devices), as indicated by + // the following optional map. If the map contains an entry for v, + // then v represents a device group, and the map indicates the set + // of devices in that group. + optional map index_to_device_group_map = 3; + + // The following is the sharded-shape of the tensor, consisting of + // the sharding-spec for each axis of the tensor. + repeated ShardedDimProto sharded_dim = 4; +} + +// ShardedDimProto: This describes the sharding spec for a single +// axis of a sharded tensor. +message ShardedDimProto { + int32 axis = 1; // the axis this sharding corresponds to + // The common-case is described by a single instance of SimpleShardedDimProto + // We use multiple instances to handle cases produced when a sharded + // tensor is reshaped, fusing multiple axes into one. + repeated SimpleShardedDimProto simple_sharding = 2; +} + +// SimpleShardedDimProto: Indicates that N blocks are divided into M shards. +// Here, N is allowed to be symbolic, which M is required to be a constant. +message SimpledShardedDimProto { + oneof dim { + int64 dim_value = 1; + string dim_param = 2; + } + optional int32 num_shards = 3; } // Training information @@ -430,8 +487,22 @@ message ModelProto { // One FunctionProto can reference other FunctionProto in the model, however, recursive reference // is not allowed. repeated FunctionProto functions = 25; + + // Describes different target configurations for a multi-device use case. + // A model can describe multiple multi-device configurations for execution. + repeated ConfigurationProto configuration = 26; }; +// ConfigurationProto describes a multi-device configuration for a model. +message ConfigurationProto { + // Name of the configuration. + string name = 1; + // Name of the device. + string device = 2; + // Number of devices inside this configuration. + int32 num_devices = 3; +} + // StringStringEntryProto follows the pattern for cross-proto-version maps. // See https://developers.google.com/protocol-buffers/docs/proto3#maps message StringStringEntryProto {