Deep Gradient Compression: Reducing the communication bandwidth for distributed training (paper)
- Gradient exchange is costly and dwarfs saving of computation time
- Network bandwidth becomes a bottleneck for scaling up distributed training, even worse on mobiles (federated learning)
- Deep Gradient Compression (DGC) solves the communication bandwidth problem by compressing the gradients
- Async SGD accelerates the training by removing synchronization and updating params immediately once node completes backprop
- Gradient Quantization to low precision values can reduce communication bandwidth (1-bit SGD, QSGD, and TernGrad). DoReFa-Net uses 1-bit weights with 2-bit gradients
- Gradient Sparsification can be done in many ways:
- Threshold quantization to send gradients larger than predefined constant
- Choose a fixed proportion of positive and negative gradient updates
- Gradient dropping to sparsify gradients by a single threshold based on the absolute value (required adding layer normalization)
- Automatically tunes the compression rate depending on local gradient activity, and gained high compression ratio with negligible degradation of accuracy
- Compared to previous work, Deep Gradient Compression:
- Pushes gradient compression ratio
- Does not require altering model structure
- Results in no loss in accuracy
- Gradients larger than a threshold are transmitted to reduce bandwidth
- Rest of the gradients are accumulated locally
- These gradients become large enough to be transmitted over time
- Local gradient accumulation is equivalent to increasing batch size over time
Momentum correction and local gradient clipping mitigate the problem of accuracy loss brought in by sparse updates
- Momentum SGD:
- Cannot directly be applied in place of vanilla SGD (ignores discounting factor between sparse update intervals) and leads to the loss in convergence performance
- When the gradient sparsity is high, the update interval dramatically increases, and thus the significant side effect will harm the model performance
- To fix this, we locally accumulate the velocity instead of real gradient and the accumulated result is used for the subsequent sparsification and communication
- Local gradient clipping:
- Widely adopted to avoid exploding gradients
- Rescales the gradients whenever the sum of their L2-norms exceeds a threshold
- Accumulates gradients over iterations on each node independently and perform the gradient clipping locally before adding the current gradient to previous accumulation
- Scaling the threshold by
N^−1/2
, the current node’s fraction of the global threshold if allN
nodes had identical gradient distributions
- Updating of small gradients are stale and outdated
- Staleness can slow down convergence and degrade model performance
Momentum factor masking and warm-up training mitigate staleness
-
Momentum Factor Masking:
- Instead of searching for a new momentum coefficient (suggested in previous works), we simply apply the same mask to both the accumulated gradients and the momentum factor
- Mask stops the momentum for delayed gradients, preventing the stale momentum from carrying the weights in the wrong direction
-
Warm-up Training
- Gradients are more diverse and aggressive in the early stages of training and sparsifying them limits the range of variation of the model
- The remaining accumulated gradients may outweigh the latest gradients and misguide the optimization direction
- Using less aggressive learning rate helps slow down the changing speed of the neural network at the start of training
- Less aggressive gradient sparsity helps reduce the number of extreme gradients being delayed
- Exponentially increasing the gradient sparsity helps the training adapt to the gradients of larger sparsity
- Image classification tasks
- The learning curve of Gradient Dropping is worse than the baseline due to gradient staleness
- With momentum correction, the learning curve converges slightly faster, and the accuracy is much closer to the baseline
- With momentum factor masking and warm-up training techniques, gradient staleness is eliminated and the learning curve closely follows the baseline
- Deep Gradient Compression gives 75× better compression than Terngrad with no loss of accuracy. For ResNet-50, the compression ratio is slightly lower (277× vs. 597×) with a slight increase in accuracy
- Language modeling
- For language modeling, the training loss with Deep Gradient Compression closely match the baseline, so does the validation perplexity
- Deep Gradient Compression compresses the gradient by 462× with a slight reduction in perplexity
- Speech recognition
- The learning curves show the same improvement acquired from techniques in Deep Gradient Compression as for the image network
- The model trained with Deep Gradient Compression gains better recognition ability on both clean and noisy speech, even when gradients size is compressed by 608×
- Hierarchically calculating the threshold significantly reduces top-k selection time for gradients
- Deep Gradient Compression benefits even more when the communication-to-computation ratio of the model is higher and the network bandwidth is lower
- Deep Gradient Compression (DGC) compresses the gradient by 270-600× for a wide range of CNNs and RNNs