Gradient Starvation: A Learning Proclivity in Neural Networks

Review of paper by Mohammad Pezeshki1,2, Sekou-Oumar Kaba1,3, Yoshua Bengio1,2, et al, 1Mila, 2 Université de Montréal, 3McGill University, 2020

In this paper, the authors examine in detail the phenomenon of gradient starvation, which was originally introduced by the same research group in 2018, for neural networks trained with the common cross-entropy loss. Gradient starvation occurs when the presence of easy-to-learn features in a dataset prevents the learning of other equally informative features, which may lead to a lack of robustness in the trained models that rely only on these few features. The authors propose a new Spectral Decoupling regularization method to combat this problem.

What can we learn from this paper?

That special care, such as using the suggested Spectral Decoupling regularization, should be taken to ensure that a trained neural network with cross-entropy loss relies on all useful features when making predictions, and not just on the easiest ones to learn.

Prerequisites (to better understand the paper, what should one be familiar with?)

Discussion

An example of gradient starvation is shown in the picture, where a simple neural network classifier has been trained with cross-entropy loss to distinguish between two classes (black lines show the decision boundaries). It can be seen that when the classes are linearly separable (on the right), the decision boundary becomes close to a straight line, ignoring any other features of the distribution of the inputs and making the prediction a lot less robust to small errors.

The reason gradient starvation happens in classification is that once a particular sample has been classified correctly using the easiest-to-learn features, it no longer contributes at all to the error function and its gradient, thus preventing further learning via gradient descent from it. The authors observe that training longer, changing the optimization algorithm, or using different existing regularizers such as weight decay, dropout, or batch normalization, do not alleviate the problem.

In order to mathematically analyze the issue, the authors used the recently developed neural tangent kernel framework. Under this framework (originally applied to the case of squared-error loss function), it is observed that an overparameterized neural network does not significantly change its weights during training (since, due to a large number of weights, even a small change in them is sufficient to approximate the desired outputs). If this is true, then the output of the network can be approximated by a Taylor series expansion around the initial weight values, with only the first (linear) term being significant. Thus, the network can be considered linear with respect to its weights, and the training outputs (as can be proven mathematically under certain overparameterization assumptions) converge with exponential decay towards the target outputs, with the decay parameter matrix being the neural tangent kernel. The speed of the convergence is then determined by the largest eigenvalue of this matrix.

Since, instead of the squared-error loss (which is often used in regression tasks), the authors consider the cross-entropy loss (normally used for classification), the resulting differential equation for the convergence of the network is no longer linear. Using mathematical tricks (including changing the variables via the Legendre transform), the paper arrives at its expression (19), and under certain simplifying assumptions shows that an increase in the strength of each feature reduces the response of the network to other features, thus preventing their learning by causing gradient starvation.

While the assumptions above may not always be true, the key insight is that, when weight decay regularization is used, different input features end up coupled with each other in the dynamical equation of the system, which is likely to cause some of them to dominate others.

To resolve this problem, the paper suggests a novel Spectral Decoupling regularization technique, in which the weight decay term in the loss function is replaced by an L2 penalty on the network’s logits (outputs before the final sigmoid). With this regularization, the features are uncoupled, so gradient starvation should not occur.

To verify that the new regularization approach indeed works to prevent gradient starvation, the authors consider several tasks for which it is known to occur, including the “two moons” problem as in the picture above, colored MNIST with spurious color bias, and the CelebA celebrity faces dataset with gender bias. In all cases, using the new regularization term significantly improved generalization, apparently preventing gradient starvation. The paper cites several other papers where a similar effect was achieved, but each time by providing the model with some additional information about the features or input data. Spectral Decoupling does not need any such information (which is important, since in practice it may not be available), instead learning from all features and then using this to generalize better.

While it seems clear that the new approach prevents gradient starvation, one question that comes to mind is whether Spectral Decoupling is guaranteed to provide the same degree of regularization as other methods such as weight decay. Since the regularization term is limited to the logits, it seems likely, unless there is a mathematical reason to the contrary, that a sufficiently overparameterized network might be able to accommodate this term while staying largely unregularized and overfitting to the training set. Hopefully, this will be addressed in future works.

Original paper link

The authors’ blog post about the paper

Github repository

Suggested further reading

Leave a Reply