Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour

Motivation(s)

The supervised learning regime consists of minimizing a loss \(L(w)\) of the form

\[L(w) = \frac{1}{\left\vert \mathcal{X} \right\vert} \sum_{x \in \mathcal{X}} l(x, w).\]

Here \(w\) are the network weights, \(\mathcal{X}\) is a labeled training set, \(l(x, w)\) is the loss computed from samples \(x \in \mathcal{X}\) and their labels \(y\). A typical method to minimize this function is minibatch stochastic gradient descent (SGD):

\[w_{t + 1} = w_t - \eta \frac{1}{n} \sum_{x \in \mathcal{B}} \nabla l(w, w_t)\]

where \(\mathcal{B}\) is a minibatch sampled from \(\mathcal{X}\), \(n = \left\vert \mathcal{B} \right\vert\) is the minibatch size, \(\eta\) is the learning rate, and \(t\) is the iteration index. After \(k\) iterations the weights become

\[w_{t + k} = w_t + \eta \frac{1}{n} \sum_{j < k} \sum_{x \in \mathcal{B}_j} \nabla l(x, w_{t + j}).\]

A brute-force way to speed up training is to divide the batch over many workers. Consequently, a larger batch size is needed in order to fully utilize each worker. Taking a single step with the large minibatch \(\bigcup_j B_j\) of size \(kn\) and learning rate \(\hat{\eta}\) yields

\[\hat{w}_{t + 1} = w_t + \hat{\eta} \frac{1}{kn} \sum_{j < k} \sum_{x \in \mathcal{B}_j} \nabla l(x, w_t).\]

Assuming that \(\nabla l(x, w_t) \approx \nabla l(x, w_{t + j})\) for \(j < k\), setting \(\hat{\eta} = kn\) would yield \(\hat{w}_{t + 1} \approx w_{t + k}\). The above interpretation gives the intuition behind the linear scaling rule: when the minibatch size is multiplied by \(k\), multiply the learning rate by \(k\) while keeping the number of epochs and other hyperparameters unchanged.

The linear scaling rule is also applicable when using batch normalization (BN) with large minibatches. Recall that BN breaks the independence of each sample’s loss:

\[\begin{split}L(w) &= \frac{1}{\left\vert \mathcal{X}^n \right\vert} \sum_{\mathcal{B} \in \mathcal{X}^n} L(\mathcal{B}, w)\\ &= \frac{1}{\left\vert \mathcal{X}^n \right\vert} \sum_{\mathcal{B} \in \mathcal{X}^n} \frac{1}{n} \sum_{x \in \mathcal{B}} l_\mathcal{B}(x, w).\end{split}\]

Here \(\mathcal{X}^n\) denotes all distinct subsets with \(n\) elements in the power set of \(\mathcal{X}\). When viewing \(\mathcal{B}\) as a single sample in \(\mathcal{X}^n\), the loss of each \(\mathcal{B}\) is computed independently. Since changing the per-worker minibatch sample size \(n\) alters the underlying loss function \(L\), BN statistics should not be aggregated across all workers. Therefore, the case of large minibatch training with BN is analogous to the foregoing per-sample loss formulation: a total minibatch size of \(kn\) can be viewed as a minibatch of \(k\) samples with each sample \(\mathcal{B}_j\) independently selected from \(\mathcal{X}^n\). The aforementioned assumption now becomes \(\nabla L(\mathcal{B}_j, w_t) \approx \nabla L(\mathcal{B}_j, w_{t + j})\).

In practice, the above assumptions do not hold in the initial training epochs when the network is changing rapidly and negatively impacts the model accuracy. One strategy to mitigate this issue is a constant warmup phase: use a less aggressive constant learning rate for the first few training epochs. However, given a large enough \(k\), this constant warmup is not sufficient to solve the optimization problem. Furthermore, the abrupt transition from the low learning rate causes the training error to spike.

Proposed Solution(s)

The authors propose an alternative warmup that gradually ramps up the learning rate from a small to a large value. With a batch size of \(kn\), the authors start from a learning rate of \(\eta\) and increment it by a constant amount at each iteration such that it reaches \(\hat{\eta} = k \eta\) after \(5\) epochs. After the warmup phase, the training uses the original learning rate schedule.

Evaluation(s)

The authors verified on the 1000-way ImageNet classification that the linear scaling warmup yields training and validation error curves that closely match the small minibatch baseline. They used a single random shuffling of the training data (per epoch) that is divided amongst all \(k\) workers. The per-worker loss is normalized using \(\frac{1}{kn}\) because allreduce performs summing, not averaging.

Their infrastructure consists of 32 servers, each equipped with eight P100 GPUs and 3.2 TB of NVMe. Since backpropagation on a P100 takes 120ms for ResNet-50 (~15 Gbps peak bandwidth), interserver communication uses ConnectX-4 50 Gbps Ethernet network card with Wedge100 Ethernet switches. All models were trained for 90 epochs (one hour) using a local batch size of \(n = 32\). The reference learning rate \(\eta = 0.1 \frac{kn}{256}\) was reduced by \(\frac{1}{10}\) at the \({30}^\text{th}\), \({60}^\text{th}\), and \({80}^\text{th}\) epoch. For BN layers, they initialized \(\gamma = 1\) except for each residual block’s last BN where \(\gamma = 0\). This modification forces the forward/backward signal to initially propagate through the identity shortcut of ResNet, which seems to ease optimization at the start of training for both small and large minibatches.

The comparison between no warmup and gradual warmup suggests that large minibatch sizes are challenged by optimization difficulties in early training. The authors observed no generalization issues when transferring across datasets (from ImageNet to COCO) and across tasks (from classification to detection/segmentation) using models trained with large minibatches. However, when \(kn > 8192\), the training and validation error curves start to diverge towards lower accuracy.

Future Direction(s)

  • How does the accuracy change when batch normalization is replaced with smoothed layer normalization?

  • Why is large batch training sensitive to the initialization of the weights?

  • Does the three phase generalization dynamics still hold when training neural networks from scratch?

Question(s)

  • Scaling the cross-entropy loss is not equivalent to scaling the learning rate, but why would one want to boost the cross-entropy loss?

  • A momentum correction factor is needed when the learning rate delays the scaling of the momentum decay factor by one iteration. Why is the delayed scaling formulation useful?

Analysis

Optimization and generalization of large minibatch training using gradual warmup with linear scaling matches that of small minibatch training. Incidentally, the training error curves can be used as a reliable proxy for success well before training finishes.

Although the gradual warmup may seem like an arbitrary hack, there is precedent in doing so. [WVJ94] demonstrated that there are in general three distinct phases of learning. At the beginning, the network has hardly learned anything and is still very biased. This phase of training is dominated by the approximation error. As training continues, the approximation error will decrease at the cost of increasing the complexity error. If the network is trained long enough, the complexity error will dominate, which implies that early stopping is a mechanism to detect when phase three starts. Unfortunately, this analysis is for networks where only the output weights are being trained.

The authors assert that optimization difficulty is the main issue with large minibatch training, rather than poor generalization. They should have evaluated the sharpness of the local optimums because their results only hold when \(kn \leq 8\text{K}\) and [KMN+16] states that simply increasing the batch size will lead to a solution that has poor generalization.

One point that deserves mention is that the square root scaling rule that was justified theoretically in [Kri14] works poorly in practice.

The implementation details of gradient aggregation are very useful for future system designs. The authors claimed that collective communication was not a bottleneck for their allreduce implementation. Their scheme consists of three phases:

  1. For each server, buffers from its GPUs are summed into a single buffer.

  2. The resulting buffers are shared and summed across all servers.

  3. The final results are broadcasted to each GPU.

Phases (1) and (3) are handled by NCCL while the interserver allreduce uses the recursive halving and doubling algorithm, instead of the bucket (ring) algorithm. Given a buffer of \(b\) bytes and a cluster of \(p\) servers, both sends and receives \(2 \frac{p - 1}{p} b\) bytes of data. The former takes \(2 \log_2 p\) communication steps while the latter requires \(2 (p - 1)\) steps. A generalized version of the halving/doubling algorithm that supports non-power-of-two cluster size is the binary blocks algorithm: servers are partitioned into power-of-two blocks and two additional communication steps are used, one immediately after the intrablock reduce-scatter and one before the intrablock allgather.

A subsequent work [YGG17] demonstrates that gradual warmup with their LARS (Layer-wise Adaptive Rate Scaling) technique successfully scaled to \(kn = 16\text{K}\). They used Intel’s KNL processors (i.e. 1600 CPUs) to attain a top-1 accuracy of 75.3% in 31 minutes [YZH+17a]. However, scaling \(kn = 32\text{K}\) incurred a loss of 0.5% in accuracy. As an aside, [YZH+17a] is a superset of [YZH+17b].

Contrary to the results of [YGG17], the experiments in [CPS17] illustrate that LARS is unnecessary when scaling to \(kn = 32\text{K}\). [CPS17] asserts that the classical 3-step 10-fold decrease is not as good as polynomial decay with the power of one for the learning rate decay schedule. Furthermore, the weight decay (regularization term) matters with a large learning rate. They set the weight decay to \(0.00005\) throughout the warmup and for the majority of the training phase. During the last 5-7% (4-7 epochs) of training, the learning rate is decayed with a power of two polynomial, the weight decay doubled to \(0.0001\), and data augmentation is disabled. These techniques enabled ResNet-50 to be trained using \(kn = 32\text{K}\) without loss of accuracy. They achieved the highest top-1 accuracy with \(kn = 8\text{K}\) followed by \(kn = 16\text{K}\), both of which surpassed the previous state of the art. What is interesting is that at \(kn = 8\text{K}\), ResNet-50 attained 73% accuracy after 5 warmup epochs and 24 training epochs. For \(kn \geq 16\text{K}\), a longer warm-up period (7-10 epochs) is needed and a constant learning rate is used followed by linear decay. The accuracy degrades by 1% beyond \(kn = 32\text{K}\).

References

CPS17(1,2)

Valeriu Codreanu, Damian Podareanu, and Vikram Saletore. Scale out for large minibatch sgd: residual network training on imagenet-1k with improved accuracy and reduced time to train. arXiv preprint arXiv:1711.04291, 2017.

GDollarG+17

Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch sgd: training imagenet in 1 hour. arXiv preprint arXiv:1706.02677, 2017.

WVJ94

Changfeng Wang, Santosh S Venkatesh, and J Stephen Judd. Optimal stopping and effective machine complexity in learning. In Advances in neural information processing systems, 303–310. 1994.

YGG17(1,2)

Yang You, Igor Gitman, and Boris Ginsburg. Large batch training of convolutional networks. arXiv preprint, 2017.

YZH+17a(1,2)

Yang You, Zhao Zhang, C Hsieh, James Demmel, and Kurt Keutzer. Imagenet training in minutes. CoRR, abs/1709.05011, 2017.

YZH+17b

Yang You, Zhao Zhang, Cho-Jui Hsieh, James Demmel, and Kurt Keutzer. 100-epoch imagenet training with alexnet in 24 minutes. arXiv preprint, 2017.