Elad Hoffer, Ron Banner et.al (2018)
Norm matters: efficient and accurate normalization schemes in deep networks

Paper Link: https://arxiv.org/pdf/1803.01814.pdf

BatchNormalization has been widely used over the years for faster convergence and accurate results. However, the reasons behind its merits remained unanswered, with several shortcomings that hindered its use for certain tasks (Timeseries, Reinforcement Learning, GAN).

Consequences of the scale invariance of Batch-Normalization

Denoting a channel weight vector with $w$ and $\hat{w} = w / ||w||_{2}$, channel input as $x$ and $BN$ for batch-norm we have

$$BN(||w||_{2}\hat{w}x) = BN(\hat{w}x).$$

This invariance to the weight vector norm means that a BN applied after a layer renders its norm irrelevant to the inputs of consecutive layers. This effects the weights of convolutional layer in the same way. The gradient in such case is scaled by $1/||w||_{2}$.

$$\frac{\partial BN(||w||_2 \hat{w}x)}{\partial (||w||_2 \hat{w})} = \frac{1}{||w||_2} \frac{\partial BN(\hat{w}x)}{\partial (\hat{w})}.$$

When a layer is rescaling invariant, the key feature of the weight vector is its direction.

During training, the weights are typically incremented through some variant of stochastic gradient
descent, according to the gradient of the loss at mini-batch $t$, with learning rate $\eta$

$$w_{t+1} = w_t{t} - \eta \nabla L_t (w_t)$$

Taking scalar product of the equation with itself and the gradient result from above and denote $\rho_t = ||w||_2$ we get

$$\rho^{2}_{t+1} = \rho^{2}_t - 2 \eta \hat{w_t}^T \nabla L_t (\hat{w_t}) + \eta^2 \rho_t^{-2} ||\nabla L (\hat{w})||^2$$

and therefore

$$\rho_{t+1} = \rho_t \sqrt{1 - 2 \eta \rho^{-2} \hat{w_t}^T \nabla (\hat{w_t}) + \eta^2 \rho_t^{-4} ||\nabla L(\hat{w_t})||^2} $$

$$ = \rho_t - \eta \rho_t^{-1} \hat{w_t}^T \nabla L (\hat{w_t}) + O(\eta^2)$$

This can be further simplified which results in the following equation,

$$\hat{w}_{t+1} = \hat{w}_t - \eta \rho_t^{-2} (I - \hat{w}_t \hat{w}_t^T) \nabla L (\hat{w}_t \rho_t) + O(\eta^2)$$

Therefore the step size of the weight direction is approximately proportional to

$$\hat{w}_{t+1} - \hat{w}_t \propto \frac{\eta}{||w_t||_2^2}$$


  • Weight decay is really not needed if learning rate is controlled responsibly. The authors show this via experiments (weight decay results can be mimicked by changing the learning rate).

  • Also BatchNorm, weight decay and learning are no more independent hyperparameters: one can be tuned in place of the other one, and results will be very nearly the same.


       author = {{Hoffer}, Elad and {Banner}, Ron and {Golan}, Itay and {Soudry}, Daniel},
        title = "{Norm matters: efficient and accurate normalization schemes in deep networks}",
      journal = {arXiv e-prints},
     keywords = {Statistics - Machine Learning, Computer Science - Machine Learning},
         year = "2018",
        month = "Mar",
          eid = {arXiv:1803.01814},
        pages = {arXiv:1803.01814},
archivePrefix = {arXiv},
       eprint = {1803.01814},
 primaryClass = {stat.ML},
       adsurl = {https://ui.adsabs.harvard.edu/\#abs/2018arXiv180301814H},
      adsnote = {Provided by the SAO/NASA Astrophysics Data System}