On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima

Motivation(s)

The problem deep learning is trying to solve can be represented as

\[\min_{\mathbf{x} \in \mathbb{R}^n} f(\mathbf{x}) \triangleq \frac{1}{M} \sum_{i = 1}^M f_i(\mathbf{x})\]

where \(f_i\) is a loss function for data point \(i \in \{1, \ldots, M\}\) and \(\mathbf{x}\) is the vector of weights being optimized. One way to optimize this function is to use stochastic gradient descent (SGD):

\[\mathbf{x}_{k + 1} \gets \mathbf{x}_k + \alpha_k \left( \frac{1}{\left\vert B_k \right\vert} \sum_{i \in B_k} \nabla f_i(\mathbf{x}_k) \right)\]

where \(B_k \subset \{1, \ldots, M\}\) is the batch sampled from the data and \(\alpha_k\) is the learning rate. SGD and its variants are employed in a small-batch (SB) regime, where \(\left\vert B_k \right\vert \in \{32, 64, \ldots, 512 \} \ll M\). These methods can be interpreted as gradient descent using noisy gradients and have guarantees of

  • convergence to minimizers of strongly-convex functions and to stationary points for non-convex functions,

  • saddle-point avoidance, and

  • robustness to input data.

Several attempts at making the batch size larger resulted in poor generalization: the performance of the model on testing data sets is often worse when trained with large-batch (LB) methods as compared to small-batch methods. This generalization gap could be as high as 5% even for smaller networks.

Proposed Solution(s)

The authors observe that

The lack of generalization ability is due to the fact that LB methods tend to converge to sharp minimizers of the training function. These minimizers are characterized by a significant number of large positive eigenvalues in \(\nabla^2 f(\mathbf{x})\), and tend to generalize less well. In contrast, SB methods converge to flat minimizers characterized by having numerous small eigenvalues of \(\nabla^2 f(\mathbf{x})\). We have observed that the loss function landscape of deep neural networks is such that LB methods are attracted to regions with sharp minimizers and that, unlike SB methods, are unable to escape basins of attraction of these minimizers.

By inspection, a flat minimizer can be described with low precision, whereas a sharp minimizer requires high precision. Thus, one explanation for the generalization gap is the minimum description length (MDL) theory, which states that statistical models that require fewer bits to describe (i.e. are of low complexity) generalize better.

Given the prohibitive cost of computing the eigenvalues, the authors propose a \((\mathcal{C}_\epsilon, \mathbf{A})\)-sharpness measure at a given local minimizer. Here \(\mathcal{C}_\epsilon\) denotes a box around the solution over which the maximization of \(f\) is performed, and \(\mathbf{A} \in \mathbb{R}^{n \times p}\) represents a random manifold. The columns of \(\mathbf{A}\) are randomly generated, and \(p\) determines the dimension of the manifold. To make the measure invariant to problem dimension and sparsity, define the constraint set \(\mathcal{C}_\epsilon\) as

\[\mathcal{C}_\epsilon = \left\{ \mathbf{z} \in \mathbb{R}^p \colon -\epsilon \left( \left\vert \left( \mathbf{A}^\dagger \mathbf{x} \right)_i \right\vert + 1 \right) \leq z_i \leq \epsilon \left( \left\vert \left( \mathbf{A}^\dagger \mathbf{x} \right)_i \right\vert + 1 \right) \quad \forall i \in \left\{ 1, \ldots, p \right\} \right\}\]

where \(\mathbf{A}^\dagger\) indicates the pseudo-inverse of \(\mathbf{A}\), and \(\epsilon > 0\) controls the size of the box. The \((\mathcal{C}_\epsilon, \mathbf{A})\)-sharpness of \(f\) at \(\mathbf{x} \in \mathbb{R}^n\) is defined as

\[\mathop{\phi}_{\mathbf{x}, f}\left( \epsilon, \mathbf{A} \right) = \frac{ \max_{\mathbf{y} \in \mathcal{C}_\epsilon} f(\mathbf{x} + \mathbf{A} \mathbf{y}) - f(\mathbf{x}) }{ 1 + f(\mathbf{x}) } \times 100.\]

Evaluation(s)

The authors reproduced the generalization gap for several network configurations over datasets such as MNIST, TIMIT, CIFAR-10, and CIFAR-100. The training and testing accuracy curve for SB and LB methods demonstrate that the generalization gap is not due to over-fitting or over-training as commonly observed in statistics because these phenomena manifest themselves in the form of a testing accuracy curve that, at a certain iterate peaks, and then decays due to the model learning idiosyncrasies of the training data.

The authors empirically verified for the full-space \(\mathbf{A} = \mathbf{I}_n\) and random subspaces of \(p = 100\) dimensions that when \(\epsilon \in \left\{ 10^{-3}, 5 \times 10^{-4} \right\}\), the SB and LB regimes have at least an order of magnitude difference in sharpness values. This view is reinforced by the parametric plots that shows the LB minima are strikingly sharper than the SB minima in a one-dimensional manifold. The sharpness of the minimizer did not change with data augmentation and conservative training, even though the generalization gap decreased.

To help direct LB solutions toward flat minimizers, the authors investigated warm-starting the LB methods. Monitoring the ratio of \(\left\Vert \mathbf{x}^\ast_s - \mathbf{x}_0 \right\Vert_2\) and \(\left\Vert \mathbf{x}^\ast_l - \mathbf{x}_0 \right\Vert_2\) reveals that LB methods tend to be attracted to minimizers close to the starting point \(\mathbf{x}_0\), whereas SB methods move away and locate minimizers that are farther away. Both methods yield similar values of sharpness for near the initial point, but as the loss function reduces, the sharpness of the iterates corresponding to the LB method rapidly increases, whereas for the SB method the sharpness stays relatively constant initially and then reduces. This suggests SB methods have an exploration phase followed by convergence to a flat minimizer. Furthermore, both sharp and flat minimizers have very similar loss function values. Thus, after the SB method has ended its exploration phase and discovered a flat minimizer, the LB method is then able to converge towards it, leading to good testing accuracy.

Future Direction(s)

  • When fine-tuning a network, are LB methods safe to use?

Question(s)

  • Would reinitializing the dimensions that exhibit high sharpness enable LB training?

Analysis

Warm-started large-batch training is able to close the generalization gap. Here warm-starting means the training starts with small batches for a few epochs and then transitions to large batches. On a side note, [WM03] and [DPBB17] may seem relevant, but do not bother reading them. Each offer very little insight and is essentially a convoluted rehash of several papers.

The paper would be even more interesting if the authors looked into why robust training and adversarial training had zero effect on the generalization gap. Subsequent works have tried to use the sharpness measure as a way to capture the generalization behavior [NBMS17]. However, sharpness is sensitive to scaling of the parameters and is not a capacity control measure as it can be artificially changed by scaling the network. The statistical capacity of a model class in terms of the number of examples required to ensure generalization, i.e. that the population (or test error) is close to the training error, even when minimizing the training error. This roughly corresponds to the maximum number of examples on which one can obtain small training error even with random labels. When sharpness is considered in the context of PAC-Bayes analysis, the empirical results indicate sharpness could suffice as a complexity measure.

One of the interesting points is that the sharpness measure does not increase rapidly along all (or even most) directions. The authors observed that it rises steeply only along a small dimensional subspace (e.g. 5% of the whole space), and on most other directions, it is relatively flat. In addition, the sharpness measure is closely related to the spectrum of \(\nabla^2 f(\mathbf{x})\). When \(\epsilon\) is small enough, the largest eigenvalue of \(\nabla^2 f(\mathbf{x})\) is related to \(\phi\) for the full-space, and for the random subspaces, \(\phi\) approximates the Ritz value of \(\nabla^2 f(\mathbf{x})\) projected onto the column-space of \(\mathbf{A}\).

A later work [HHS17] proposes an alternative model to tackle the generalization gap phenomenon. They argue that the initial learning phase can be described using a high-dimensional “random walk on a random potential” process, with an “ultra-slow” logarithmic increase in the distance of the weights from their initialization. During the initial training phase, to reach a minima of “width” \(d\), the weight vector \(\mathbf{w}_t\) has to travel at least a distance \(d\). This takes about \(\exp d\) iterations. Thus, to reach wide (“flat”) minima, a high diffusion rate and a large number of training iterations (weight updates) are necessary.

Their experiments indicate that the distance between the current weight and the initialization point can be a good measure to decide upon when to decrease the learning rate. This is in contrast to current practice where the learning rate is decreased after validation error appears to reach a plateau for fear of overfitting. [HHS17] asserts that one should instead continue the optimization using the same learning rate even if the training error decreases while the validation plateaus. They propose a regime adaptation scheme where each time period of \(e\) epochs in the original regime is scaled by \(\frac{B_L}{B_S}\). Although this scheme yields better accuracy, training for more epochs defeats the purpose of scaling up batch sizes. The other proposals such as square root scaling \(\sqrt{\frac{B_L}{B_S}}\) (see derivation in the appendix) and Ghost Batch Normalization (GBN) were evaluated using a batch size no larger than 4096. Furthermore, the assessment lacked any comparisons against batch normalization and smoothed layer normalization.

References

DPBB17

Laurent Dinh, Razvan Pascanu, Samy Bengio, and Yoshua Bengio. Sharp minima can generalize for deep nets. arXiv preprint arXiv:1703.04933, 2017.

HHS17(1,2)

Elad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better: closing the generalization gap in large batch training of neural networks. In Advances in Neural Information Processing Systems, 1729–1739. 2017.

KMN+16

Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: generalization gap and sharp minima. arXiv preprint arXiv:1609.04836, 2016.

NBMS17

Behnam Neyshabur, Srinadh Bhojanapalli, David McAllester, and Nati Srebro. Exploring generalization in deep learning. In Advances in Neural Information Processing Systems, 5949–5958. 2017.

WM03

D Randall Wilson and Tony R Martinez. The general inefficiency of batch training for gradient descent learning. Neural Networks, 16(10):1429–1451, 2003.