Supervised Contrastive Learning

Review of paper by Prannay Khosla, Piotr Teterwak, Chen Wang et al, Google Research, 2020

The authors used contrastive loss, which has recently been shown to be very effective at learning deep neural network representations in the self-supervised setting, for supervised learning, and achieved better results than those obtained with the cross-entropy loss for ResNet-50 and ResNet-200.

What can we learn from this paper?

That using the standard cross-entropy loss may not be optimal when training deep neural networks for image classification tasks.

Prerequisites (to understand the paper, what does one need to be familiar with?)

  • Neural networks
  • Loss functions

Discussion

The main idea of the contrastive loss is that members of the same class in a classification task are expected to produce similar representations at the output of the neural network (before the final softmax layer), while members of different classes will result in fairly different representations. Thus, members of each class form a separate cluster in the representation space.

The contrastive loss is calculated by maximizing the dot products of normalized representations from images of the same class while minimizing the dot products of representations of different classes. The exact formula for the loss, adjusted for the supervised learning case from its unsupervised formulation (in which images of the same class are derived from various augmentations of the same image) and allowing for multiple instances of the same class, can be found in the paper.

At the first stage of training, the chosen network architecture (ResNet-50 and ResNet-200 were used in the paper) was trained on ImageNet using the contrastive loss. For its calculation, an extra fully-connected projection network with one hidden layer was added at the end of each ResNet to reduce the representation size from 2048 to 128. Two separate augmentations of each image were used as inputs.

After training with the contrastive loss, the projection network was discarded and a new randomly initialized fully connected layer was trained (with all other weights fixed) using the standard cross-entropy loss.

The resulting trained network showed a significant improvement in top-1 accuracy on ImageNet (about 2% higher than the authors’ implementation of cross-entropy loss training and better than other SOTA results with these architectures), while at the same time being noticeably more robust to image corruption as shown by Mean Corruption Error (mCE) results, and less sensitive to changes in training hyperparameters (optimizers, data augmentation techniques, etc.)

Many details of the suggested implementation are discussed in detail, and a Github repository is available. This paper has been generating a lot of interest among researchers since it was published a few weeks ago, and it definitely looks like a very promising new approach.

Original paper link

Github repository

Suggested reading

Leave a Reply