An Image is Worth 16×16 Words: Transformers For Image Recognition At Scale

Review of paper by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov et al, Google Research, 2020

This paper develops a novel way of using Transformer neural attention models for visual recognition tasks.

What can we learn from this paper?

That a non-convolutional, attention-based model can achieve state-of-the-art results in image classification.

Prerequisites (to better understand the paper, what should one be familiar with?)

  • Transformers and neural attention
  • Basics of Computer Vision

Discussion

Until this paper, all of the best performing modern computer vision models since AlexNet have been based on convolutional architectures. The ability of 2D convolutions to capture the local patterns in images resulted in the best models approaching 90% top-1 accuracy on ImageNet and performing excellently on a variety of visual tasks.

Recently, the most exciting innovations have been happening in the NLP domain, where the 2017 Transformer architecture based on neural attention serves as the basis for many new developments, including the most recent BigBird and Linformer models.

It is clear that having global attention could be beneficial for image tasks as well, since some information can only be obtained by comparing different, possibly quite distant parts of the image.

The reason attention models haven’t been doing better until now in computer vision lies both in the difficulty of scaling them (they scale as N2, so a full set of attention weights between pixels of a 1000×1000 image would have a million terms) and, perhaps more importantly, in the fact that, as opposed to words in a text, individual pixels in a picture are not very meaningful by themselves, so connecting them via attention does not accomplish much.

The new paper suggests the approach of using attention not on pixels, but instead on small patches of the image (perhaps 16×16 as in the title, although the optimal patch size would really depend on the dimensions and the contents of the images to which the model is applied).

Examining the picture (taken from the paper), one can see how the Vision Transformer operates. Each patch in the input image is flattened by using a linear projection matrix, and a positional embedding (a numerical value containing information about where the patch originally was in the image) is added to it. This is necessary since Transformers treat all inputs irrespective of their order, so having this positional information helps the model to properly evaluate the attention weights. An extra class token is concatenated to the inputs (labeled 0 in the image) as a placeholder for the class to be predicted in the classification task.

The Transformer encoder, similarly to the original 2017 version, consists of multiple blocks of attention, normalization, and fully-connected layers with residual (skip) connections, as shown in the right part of the picture. In each attention block, multiple heads can capture different patterns of connectivity. If you are interested in learning more about Transformers, I would recommend reading this excellent article by Jay Alammar.

The fully-connected MLP head at the output provides the desired class prediction. Of course, as always nowadays, the main model can be pre-trained on a large dataset of images, and then the final MLP head can be fine-tuned to a specific task via the standard transfer learning approach. One feature of the new model is that, while, according to the paper, it is more efficient than convolutional approaches in terms of achieving the same accuracy of prediction with less computation, its performance really seems to keep improving as it is trained on more and more data, more so than the other models. The authors of the paper have trained the Vision Transformer on a private Google JFT-300M dataset containing 300 million (!) images, which resulted in state-of-the-art accuracy on a number of benchmarks. One can hope that this pre-trained model will soon be released to the public so that we can all try it out.

It’s definitely exciting to see this new application of neural attention to the computer vision domain. Hopefully, a lot more progress will be achieved in the coming years based on this development!

Original paper link

Github repository from the authors (Jax/Flax)

Github repository not from the authors (PyTorch)

My review of this paper on Medium (a slightly expanded version of this review)

Suggested further reading

Leave a Reply