Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity

Review of paper by William Fedus, Barret Zoph, and Noam Shazeer, Google Brain, 2021.

Modern deep learning models, especially in natural language processing, usually strive to achieve better accuracy by increasing the parameter size of the model (often combined with training on larger datasets), which comes at a huge computational cost. In this paper, in order to achieve better computational efficiency, the authors divide the fully connected layers in their Transformer model’s blocks into sets of many alternatives (experts), whereby only one expert is chosen for each given input in each layer. This provides the opportunity to increase the size of the model as desired (within the available memory constraints) by increasing the count of experts while maintaining a constant computational complexity per input token.

What can we learn from this paper?

That using the mixture-of-experts approach can create significantly more efficient state-of-the-art deep learning models.

Prerequisites (to better understand the paper, what should one be familiar with?)


Modern large-scale deep learning models take advantage of being trained on massive datasets and distilling relevant information from this data into their trained weights. Therefore, it stands to reason that during prediction time, only a small fraction of these weights is actually useful for any particular input, even though all of them need to be applied to the input data via matrix operations in order to obtain the prediction under the standard “one big model” framework.

Furthermore, although there is no complete understanding at the time of how deep neural networks achieve their high performance, it is reasonable to assume that during training a lot of computational effort is spent on making sure that the weights that are not relevant to a particular input are balanced in such a way that they do not negatively affect the predictive ability of the network.

To reduce all of this waste of computational power, especially since there is a growing concern about the carbon footprint of deep learning models, researchers have recently been looking at mixture-of-experts models, in which the input space of the network or a particular layer in it is divided into regions (either explicitly or by means of training a designated gating function), and a separate model or part of the network (called an expert) is trained for each region. Then, at prediction time, only the experts that are relevant for the current input are used to calculate the output.

The first model of this kind was proposed in 2017 by Shazeer (one of the authors of this paper) et al. In each expert layer of that model, multiple experts chosen by the top k values of a softmax gating layer were used since the authors believed at the time that a single expert model would not result in meaningful gradients to the gating function and thus would not be trainable. The Switch Transformer model, however, was successfully trained with just one expert per layer, thus simplifying the architecture and achieving better computational efficiency.

For training the new model, the authors used their own Mesh-Tensorflow framework (described in a 2018 paper), which is basically a version of Tensorflow specifically designed for distributed computing.

For maximizing efficiency, it is important that the load is approximately evenly distributed between the experts, which is achieved by an additional load balancing loss that achieves minimum when the load distribution is uniform. A maximum capacity is set for each expert, above which all tokens are passed directly to the next layer via a residual connection. This is not optimal in terms of prediction accuracy but limits the amount of required memory and computational power.

To ensure both training stability and efficiency, the authors used several additional techniques, such as converting all tensors to lower precision (bfloat16) to transfer them between devices more cheaply while increasing the precision back to float32 when training locally, scaling down the default Transformer initialization by a factor of 10 to reduce the variance early in training and thus improve stability, and applying high dropout values to the expert layers during fine-tuning in order to compensate for the large number of parameters in these layers that may lead to overfitting.

The Switch Transformer was pre-trained on the C4 corpus based on the Common Crawl dataset, just like the T5 text-to-text Transformer, for training which the C4 corpus was originally created, that was used for performance comparisons in this paper. To match the T5-Base model with 223 million parameters and the T5-Large model with 739 million parameters in terms of the required floating point operations (FLOPs) per sequence, a Switch-Base model and a Switch-Large model were developed with 7.4 billion and 26.3 billion parameters respectively.

Compared to the equivalent (in terms of FLOPs) T5 models, the fine-tuned Switch Transformer models have shown a significant improvement in performance across a variety of popular NLP datasets (GLUE, SQuAD, SuperGLUE, etc), presumably due to their higher parameter count that did not come at an increased computational cost.

Although the paper mentions trillion parameter models in its title, the main results are achieved with smaller models as described above. When scaling to hundreds of billions or over a trillion parameters, the models exhibited sporadic instability, which was apparently alleviated by reducing other dimensions such as width, depth, and the number of heads, but that came at a cost of reduced performance.

The authors have considered a number of other topics related to the new Switch Transformers such as their scaling properties, their distillation into smaller models (which significantly reduced prediction accuracy but still exceeded that of the same smaller model trained from scratch), multilingual learning, and various issues related to data, model, and expert parallelization between multiple devices.

While there is definitely a lot of work to be done to perfect the use of mixture-of-experts techniques in deep neural networks and thus reduce their computational load while preserving accuracy, this paper is an important first step that shows how this task can be achieved. As the authors point out in Appendix D, even a model with just two experts shows a noticeable computational benefit, which means this approach can be used by anyone with a regular GPU. For those who want to try modifying it for their use, the source code of the Switch Transformer is available.

Original paper link

Source code

Suggested further reading

Leave a Reply