Batch Normalization Biases Deep Residual Networks Towards Shallow Paths

Review of paper by Soham De and Samuel L. Smith, Deepmind, 2020

This paper examines the theoretical reasons for using batch normalization in deep residual networks and suggests a simpler alternative solution.

What can we learn from this paper?

Batch normalization limits the flow through the main branch of each residual block (the left branch in the pictures below) in favor of skip connections (the right branch), which shortens the effective depth (the number of residual blocks between the input and the output of the network) to just tens instead of hundreds or thousands of layers and makes deep networks easier to train. A similar effect can be achieved by replacing batch normalization layers in each block with a single learnable scalar initialized at zero or a small value.

Prerequisites (to understand the paper, what does one need to be familiar with?)

  • Batch normalization
  • Residual neural networks

Motivation

To better understand the inner workings of residual architectures and explore a more computationally efficient alternative to batch normalization while preserving the accuracy of trained models.

Results

The paper expands upon an article by Veit et al, 2016, which established that deep ResNet-type architectures with skip connections behave like ensembles of many network paths of different lengths, and the shorter paths are dominant during training in terms of the flow of gradient. In the current paper, it is shown that this happens due to batch normalization (BN) layers in each residual block. The authors point out that deep networks that do not normalize residual paths, such as the original Transformer, are very difficult to train.

To present an alternative to batch normalization, the authors introduce SkipInit, which is a scalar trainable multiplier added to each residual branch (see the image) and initialized at zero or a small value (less than \(1/\sqrt{d}\) , where \(d\) is the number of residual blocks). The paper shows that networks with SkipInit show similar performance to those with batch normalization or Fixup initialization, a more complex way of avoiding BN layers, on CIFAR-10 and ImageNet. The performance deteriorates with increasing batch sizes, which happens because while SkipInit mirrors the normalization of BN layers, it does not provide the same loss conditioning benefit.

It is also noted that batch normalization has a regularizing effect on the network. When extra regularization is added to SkipInit via Dropout and extra biases, both top-1 and top-5 accuracy on ImageNet was slightly better compared to using BN layers for ResNet50-V2 architecture.

While the advantageousness of the new SkipInit technique will be decided by future research, I believe that this paper provides a lot of insight into the inner workings of residual networks, and reading it, together with some of the papers it references, can significantly improve understanding of deep residual architectures for those readers who are not already intimately familiar with them.

Original paper link

Further reading

Leave a Reply