Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

Motivation(s)

Stochastic gradient descent is an effective way of training deep networks, assuming the learning rate schedule and the initial values for the model parameters are appropriately tuned. Hyperparameter tuning is further exacerbated by the change in the distributions of layers’ inputs, even though the training and test data come from the same distribution. The change in the distribution of network activations due to the change in network parameters during training is called internal covariate shift.

To combat internal covariate shift, one solution is to linearly transform the inputs to each layer such that the inputs are decorrelated and have zero means with unit variances. Consider a layer \(x = u + b\) with input \(u\) and bias \(b\). Whitening the inputs yields \(\newcommand{\E}[2][]{\mathop{\mathrm{E}_{#1}}\left[ #2 \right]} \hat{x} = x - \E{x}\) where \(\mathcal{X} = \left\{ x_1, \ldots, x_N \right\}\) and \(\E{x} = \frac{1}{N} \sum_{i = 1}^N x_i\). The corresponding gradient descent update is \(\Delta b = -\frac{\partial l}{\partial b}\), and the next iteration will be

\[u + (b + \Delta b) - \E{u + \left( b + \Delta b \right)} = x - \E{x} + \Delta b - \E{\Delta b}.\]

Suppose the gradient descent step defines

\[\frac{\partial l}{\partial b} = \frac{\partial l}{\partial x} \frac{\partial x}{\partial b} = \frac{\partial l}{\partial x} \qquad \text{instead of} \qquad \frac{\partial l}{\partial b} = \frac{\partial l}{\partial \hat{x}} \frac{\partial \hat{x}}{\partial b} = \frac{\partial l}{\partial \hat{x}} \left( 1 - \frac{\partial \E{x}}{\partial b} \right).\]

The combination of the update to \(b\) and subsequent change in normalization leads to no change in the output of the layer because

\[\E{\Delta b} = \E{-\frac{\partial l}{\partial x}} = \Delta b.\]

As the training continues, \(b\) grows indefinitely while the loss remains fixed. To address this issue when \(\hat{x} = f(\mathbf{x}, \mathcal{X})\) is an arbitrary transformation, backpropagation needs to account for the Jacobian \(\frac{\partial \hat{x}}{\partial \mathcal{X}}\) after every parameter update.

Proposed Solution(s)

Instead of whitening the features over the inputs and outputs jointly, the authors propose batch normalization (BN): independently transform each scalar feature to have zero mean and unit variance. Rather than using the entire training set \(\mathcal{X} = \left\{ \mathbf{x}_1, \ldots, \mathbf{x}_N \right\}\) to produce the desired statistics, each mini-batch \(\mathcal{B}_j = \left\{ \mathbf{x}_{j + 1}, \ldots, \mathbf{x}_{j + m} \right\}\) generates estimates of the sample mean and biased sample variance of each activation.

For a layer with input \(\mathbf{x} = \left( x^{(1)}, \ldots, x^{(d)} \right)\), normalize each dimension

\[\newcommand{\Var}[2][]{\mathop{\mathrm{Var}_{#1}}\left[ #2 \right]} \hat{x}^{(k)} = \frac{ x^{(k)} - \E[\mathcal{B}_j]{x^{(k)}} }{ \sqrt{\Var[\mathcal{B}_j]{x^{(k)}} + \epsilon} }.\]

To ensure the transformation can represent the identity mapping, the authors introduce a pair of learned parameters for each activation \(x^{(k)}\) to scale and shift the normalized value such that

\[y^{(k)} = \gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}.\]

BN makes backpropagation through a layer invariant to the scale of its parameters. Such normalization have been demonstrated to speed up convergence [LeCunBOMuller98]:

  • Nonzero mean in the input variables creates a very large eigenvalue (i.e. large condition number), which translates to inputs having a large variation in spread along different directions of the input space.

  • If the inputs are correlated, enforcing unit variance will not make the error surface spherical, but it will reduce its eccentricity.

To make the inference step deterministic, the normalization step for a particular activation

\[\hat{x}^{(k)} = \frac{ x^{(k)} - \E[\mathcal{X}]{x^{(k)}} }{ \sqrt{\Var[\mathcal{X}]{x^{(k)}} + \epsilon} }\]

uses the entire training population statistics. Since the true population mean is unknown, it is estimated as the sample mean

\[\begin{split}\E[\mathcal{X}]{x^{(k)}} &\gets \frac{1}{N} \sum_i x_i^{(k)}\\ &= \frac{1}{N / m} \sum_j \frac{1}{m} \sum_{\mathbf{x}_i \in \mathcal{B}_j} x_i^{(k)}\\ &= \frac{1}{N / m} \sum_j \mu_{\mathcal{B}_j}.\end{split}\]

Given the sample mean, the unbiased sample variance can be obtained using Bessel’s correction factor \(\frac{m}{m - 1}\). Assuming the underlying variance of each mini-batch is the same, it is estimated as the pooled variance

\[\begin{split}\Var[\mathcal{X}]{x^{(k)}} &\gets \frac{ \sum_j \left( \left\vert \mathcal{B}_j \right\vert - 1 \right) \frac{m}{m - 1} \Var[\mathcal{B}_j]{x^{(k)}} }{ \sum_j \left( \left\vert \mathcal{B}_j \right\vert - 1 \right) }\\ &= \frac{1}{N / m} \sum_j \frac{m}{m - 1} \sigma_{\mathcal{B}_j}^2.\end{split}\]

Evaluation(s)

The authors studied the evolution of input distributions over the course of training on the MNIST dataset. BN helps the network train faster and achieve higher accuracy by making the distribution more stable and reducing the internal covariate shift.

For a large dataset like ImageNet, adding BN on top of Inception yielded modest benefits. In order to reduce the number of epochs by an order of magnitude while increasing accuracy by three percent, the network and its hyperparameters need to be modified as follows:

  • Increase the learning rates.

  • Remove dropout.

  • Reduce \(L_2\) weight regularization.

  • Accelerate the learning rate decay.

  • Remove Local Response Normalization.

  • Shuffle training examples more thoroughly to prevent the same examples from always appearing in a mini-batch together.

Future Direction(s)

  • Does Bessel’s correction matter, or will the learned weights account for it?

  • What happens to a network’s convergence when larger weights lead to smaller gradients?

  • The authors conjecture that BN may lead the layer Jacobians to have singular values close to one. How?

Question(s)

  • To what extent did dropout improve the accuracy of Modified BN-Inception in single-network classification?

  • Since a Gaussian variable has the largest entropy among all random variables of equal variance, how did the kurtosis and negentropy of each dimension compare in terms of nongaussianity?

  • I am not convinced of interpreting batch renormalization as a null-space projection. Since \(r\) and \(d\) gradually increases over many training steps, batch renormalization is essentially scaling the gradients. A more reasonable hypothesis is that this hack tries to imitate how neurons amplify the signal it receives.

Analysis

BN is an effective technique to reduce internal covariate shift and achieve a stable distribution of activation values throughout training. Its effectiveness diminishes when the training mini-batches are small, or do not consist of independent samples [Iof17]. For small mini-batches, the estimates of the mean and variance become less accurate. Furthermore, it is common to bias the mini-batch sampling to include sets of examples that are known to be related. Alas, the model learns to predict labels for images that come in a set, where each image has a counterpart with the same label. This does not directly translate to classifying images individually.

During inference, the upper layers (whose inputs are normalized using the mini-batch) are trained on representations different from those computed in inference (whose inputs are normalized using the population statistics). A naive solution is to use the moving averages \(\mu\) and \(\sigma^2\) to perform the normalization during training. This causes the model parameters to blow up because the gradient optimization and the normalization counteracts each other. To handle the preceding issues, [Iof17] proposes batch renormalization as an extension to BN that allows a model with a batch size less than 32 to train faster and achieve a higher accuracy. However, the dependencies between training cases make batch renormalization inappropriate for recurrent neural networks and online learning tasks. Moreover, there is uncertainty as to whether a batch size of 32 is problematic in the future.

[IS15] asserts that BN should be applied before the nonlinearity since that is where matching the first and second moments is more likely to result in a stable distribution. Yet, the authors’ actual implementation oddly applies BN after ReLU. In addition, the authors claimed dropout can be removed or weighted less, but only presented evidence of the latter. They should have presented the accuracy of their top ensemble network without dropout because their differential over the previous state of the art is barely 1%.

An alternative technique one should be wary of is weight normalization (WN). WN decouples the direction of the weights from their norm by reparameterizing the optimization problem [SK16]. Even though WN achieves better and faster training accuracy, the final test accuracy is significantly lower than BN [GG17]. The CIFAR-10 results presented in [SK16] is practically insignificant because its seventeen layer network does not need normalization to attain a final accuracy comparable to that of a normalized network. A more serious issue is WN’s assumption that its weight matrices are approximately orthogonal. This is invalid because each gradient update increases correlations between different neurons. In consequence, WN is prone to overfitting even with dropout and weight decay.

Notes

Batch Normalizing Transform

Given a mini-batch \(\mathcal{B} = \left\{ x_1, \ldots, x_m \right\}\) for a particular activation, the normalizing transform consists of

\[\begin{split}\DeclareMathOperator{\BN}{\mathrm{BN}} y_i &\gets \gamma \hat{x}_i + \beta \equiv \BN_{\gamma, \beta}(x_i)\\ \hat{x}_i &\gets \frac{x_i - \mu_{\mathcal{B}}}{\sigma_{\mathcal{B}}}\\ \mu_{\mathcal{B}} &\gets \frac{1}{m} \sum_{i = 1}^m x_i\\ \sigma_{\mathcal{B}} &\gets \sqrt{\epsilon + m^{-1} \sum_{i = 1}^m (x_i - \mu_{\mathcal{B}})^2}\end{split}\]

where \(\epsilon\) is a constant. The gradients with respect to the parameters of the BN transform are

\[\begin{split}\frac{\partial l}{\partial \gamma} &= \sum_{i = 1}^m \frac{\partial l}{\partial y_i} \frac{\partial y_i}{\partial \gamma} = \sum_i \frac{\partial l}{\partial y_i} \hat{x}_i \\\\ \frac{\partial l}{\partial \beta} &= \sum_{i = 1}^m \frac{\partial l}{\partial y_i} \frac{\partial y_i}{\partial \beta} = \sum_i \frac{\partial l}{\partial y_i} \\\\ \frac{\partial l}{\partial \sigma_{\mathcal{B}}} &= \sum_{i = 1}^m \frac{\partial l}{\partial \hat{x}_i} \frac{\partial \hat{x}_i}{\partial \sigma_{\mathcal{B}}} = \sum_{i = 1}^m \frac{\partial l}{\partial \hat{x}_i} (x_i - \mu_{\mathcal{B}}) \frac{-1}{\sigma_{\mathcal{B}}^2} \\\\ \frac{\partial l}{\partial \hat{x}_i} &= \frac{\partial l}{\partial y_i} \frac{\partial y_i}{\partial \hat{x}_i} = \frac{\partial l}{\partial y_i} \gamma \\\\ \frac{\partial l}{\partial \mu_{\mathcal{B}}} &= \sum_{i = 1}^m \frac{\partial l}{\partial \hat{x}_i} \frac{\partial \hat{x}_i}{\partial \mu_{\mathcal{B}}} = \sum_{i = 1}^m \frac{\partial l}{\partial \hat{x}_i} \frac{-1}{\sigma_{\mathcal{B}}}\end{split}\]
\[\begin{split}\frac{\partial \hat{x}_i}{\partial \mu_{\mathcal{B}}} &= \frac{-1}{\sigma_{\mathcal{B}}} + \left( x_i - \mu_{\mathcal{B}} \right) \frac{-1}{\sigma_{\mathcal{B}}^2} \frac{1}{2 \sigma_{\mathcal{B}}} \frac{2}{m} \sum_{j = 1}^m (x_j - \mu_{\mathcal{B}}) (-1)\\ &= \frac{-1}{\sigma_{\mathcal{B}}} + \frac{\partial \hat{x}_i}{\partial \sigma_{\mathcal{B}}} \frac{1}{\sigma_{\mathcal{B}}} \left( \mu_{\mathcal{B}} - \frac{1}{m} \sum_{j = 1}^m x_j \right)\\ &= \frac{-1}{\sigma_{\mathcal{B}}}\end{split}\]
\[\begin{split}\frac{\partial l}{\partial x_i} &= \sum_{j = 1}^m \frac{\partial l}{\partial y_j} \frac{\partial y_j}{\partial \hat{x}_j} \frac{\partial \hat{x}_j}{\partial x_i}\\ &= \sum_{j = 1}^m \frac{\partial l}{\partial \hat{x}_j} \left[ \left( \mathbb{I}_i(j) - \frac{1}{m} \right) \sigma_{\mathcal{B}}^{-1} + (x_j - \mu_{\mathcal{B}}) \frac{-1}{\sigma_{\mathcal{B}}^2} \frac{1}{2 \sigma_{\mathcal{B}}} \frac{2}{m} (x_i - \mu_{\mathcal{B}}) \right]\\ &= \frac{\partial l}{\partial \hat{x}_i} \frac{1}{\sigma_{\mathcal{B}}} + \frac{\partial l}{\partial \mu_{\mathcal{B}}} \frac{1}{m} + \frac{\partial l}{\partial \sigma_{\mathcal{B}}} \frac{x_i - \mu_{\mathcal{B}}}{m \sigma_{\mathcal{B}}}\end{split}\]

Batch-Normalized Convolutional Networks

Recall that fully-connected and convolutional layers can be formulated as an affine transformation followed by an element-wise nonlinearity:

\[\mathbf{z} = \mathop{g}(\mathbf{x}) = \mathop{g}(\mathbf{W} \mathbf{u} + \mathbf{b})\]

where \(\mathbf{W}\) and \(\mathbf{b}\) are learned parameters of the model, and \(\mathop{g}\) is the nonlinearity. The BN transform is applied as

\[\BN(\mathbf{x}) = \BN(\mathbf{W} \mathbf{u} + \mathbf{b}) = \BN(\mathbf{W} \mathbf{u})\]

where \(\mathbf{b}\) is subsumed by \(\beta\). The authors reason that BN should be applied to \(\mathbf{x}\) instead of \(\mathbf{u}\) because the former is more Gaussian than the latter. This phenomenon follows from the Central Limit Theorem, which states that the distribution of a sum of independent random variables tends toward a Gaussian distribution [Lug][HyvarinenO00].

For convolutional layers, \(\mathcal{B}\) is the set of all values in a feature map across both the elements of a mini-batch and spatial locations. For a mini-batch of size \(m\) and feature maps of size \(p \times q\), the effective mini-batch size is \(m′ = | \mathcal{B} | = mpq\). The pair of parameters \(\gamma^{(k)}\) and \(\beta^{(k)}\) is now per feature map rather than per activation.

Batch Renormalization

Let \(\mu\) be an estimate of the mean of \(x\), and \(\sigma\) be an estimate of its standard deviation, computed as

\[\mu \gets \mu + \alpha (\mu_{\mathcal{B}} - \mu) \qquad \text{and} \qquad \sigma \gets \sigma + \alpha (\sigma_{\mathcal{B}} - \sigma)\]

where \(\alpha\) is a hyperparameter to tune so that the averages are exponentially-decayed. The result of normalizing \(x\) using the mini-batch statistics is related to its moving averages by an affine transform

\[\hat{x} \gets \frac{x - \mu_{\mathcal{B}}}{\sigma_{\mathcal{B}}} r + d = \frac{x - \mu}{\sigma} \qquad \text{where} \qquad r = \frac{\sigma_{\mathcal{B}}}{\sigma}, d = \frac{\mu_{\mathcal{B}} - \mu}{\sigma}.\]

Batch normalization is a special case that sets \(r = 1\) and \(d = 0\). Renormalization treats \(r\) and \(d\) as hyperparameters to tune. One scheme used in [Iof17] defines

\[\begin{split}\DeclareMathOperator{\stopgrad}{\text{stop_gradient}} \newcommand{\clip}[2][]{\mathop{\mathrm{clip}_{#1}}\left( #2 \right)} r &\gets \stopgrad\left( \clip[1 / r_\max, r_\max]{\frac{\sigma_{\mathcal{B}}}{\sigma}} \right)\\ d &\gets \stopgrad\left( \clip[-d_\max, d_\max]{\frac{\mu_{\mathcal{B}} - \mu}{\sigma}} \right).\end{split}\]

Here \(r_\max = 1\) and \(d_\max = 0\) initially, and each has its own non-decreasing relaxation akin to the learning rate schedule. The gradients with respect to the parameters of batch renormalization are

\[\begin{split}\frac{\partial l}{\partial \sigma_{\mathcal{B}}} &= \sum_{i = 1}^m \frac{\partial l}{\partial \hat{x}_i} \frac{\partial \hat{x}_i}{\partial \sigma_{\mathcal{B}}} = \sum_{i = 1}^m \frac{\partial l}{\partial \hat{x}_i} (x_i - \mu_{\mathcal{B}}) \frac{-r}{\sigma_{\mathcal{B}}^2} \\\\ \frac{\partial l}{\partial \mu_{\mathcal{B}}} &= \sum_{i = 1}^m \frac{\partial l}{\partial \hat{x}_i} \frac{\partial \hat{x}_i}{\partial \mu_{\mathcal{B}}} = \sum_{i = 1}^m \frac{\partial l}{\partial \hat{x}_i} \frac{-r}{\sigma_{\mathcal{B}}}\end{split}\]
\[\begin{split}\frac{\partial l}{\partial x_i} &= \sum_{j = 1}^m \frac{\partial l}{\partial \hat{x}_j} \left[ \left( \mathbb{I}_i(j) - \frac{1}{m} \right) \frac{r}{\sigma_{\mathcal{B}}} + (x_j - \mu_{\mathcal{B}}) \frac{-r}{\sigma_{\mathcal{B}}^2} \frac{1}{2 \sigma_{\mathcal{B}}} \frac{2}{m} (x_i - \mu_{\mathcal{B}}) \right]\\ &= \frac{\partial l}{\partial \hat{x}_i} \frac{r}{\sigma_{\mathcal{B}}} + \frac{\partial l}{\partial \mu_{\mathcal{B}}} \frac{1}{m} + \frac{\partial l}{\partial \sigma_{\mathcal{B}}} \frac{x_i - \mu_{\mathcal{B}}}{m \sigma_{\mathcal{B}}}\end{split}\]

The other gradients are the same as batch normalization, and the inference step now uses the moving averages instead of the population statistics.

References

GG17

Igor Gitman and Boris Ginsburg. Comparison of batch normalization and weight normalization algorithms for the large-scale image classification. arXiv preprint arXiv:1709.08145, 2017.

HyvarinenO00

Aapo Hyvärinen and Erkki Oja. Independent component analysis: algorithms and applications. Neural networks, 13(4):411–430, 2000.

Iof17(1,2,3)

Sergey Ioffe. Batch renormalization: towards reducing minibatch dependence in batch-normalized models. arXiv preprint arXiv:1702.03275, 2017.

IS15

Sergey Ioffe and Christian Szegedy. Batch normalization: accelerating deep network training by reducing internal covariate shift. In International Conference on Machine Learning, 448–456. 2015.

Lug

Michael Lugo. A proof of the central limit theorem. https://www.stat.berkeley.edu/ mlugo/stat134-f11/clt-proof.pdf. Accessed on 2017-12-25.

SK16(1,2)

Tim Salimans and Diederik P Kingma. Weight normalization: a simple reparameterization to accelerate training of deep neural networks. In Advances in Neural Information Processing Systems, 901–909. 2016.