Attention Augmented Differentiable Forest for Tabular Data

Review of paper by Yingshi Chen, Xiamen University, 2020

The author has developed a new “differentiable forest”-type neural network framework for predictions on tabular data that has some similarity to the recently suggested NODE architecture and employs squeeze-and-excitation “tree attention blocks” (TABs) to show performance superior to gradient boosted decision trees (e.g. XGBoost, LightGBM, Catboost) on a number of benchmarks.

What can we learn from this paper?

That neural networks with decision tree-like blocks and an added attention mechanism can be efficiently used on tabular data.

Prerequisites (to better understand the paper, what should one be familiar with?)

Discussion

Neural networks have performed extremely well in recent years on a number of supervised and unsupervised ML tasks, especially in the areas of computer vision, audio, natural language processing, and other sequential data, efficiently taking advantage of the structure of their inputs and of the associated inductive biases to advance the state of the art compared to traditional machine learning techniques such as, for example, support vector machines or decision trees.

However, in the case of tabular data, which frequently lacks any particular predetermined structure, may have unknown correlations, and/or be sparse, deep neural networks are often inferior to other techniques, especially gradient boosted decision trees (GBDTs). This “unreasonable ineffectiveness of deep learning on tabular data” is discussed in more detail in this article.

In the last few years, there has been an effort to develop better neural-network-based methods to compete with and improve upon the results of GBDTs on tabular inputs. One somewhat general approach, which is discussed in the above-mentioned article, is the recent TabNet, which uses a sparse learnable mask on the input features and is even capable of self-supervised learning.

Another technique called NODE (Neural Oblivious Decision Ensembles), which is closer to the current paper, was developed in 2019 by Yandex, the Russian search engine company that has also developed Catboost. This technique uses differentiable neural trees, in which the differentiability is achieved by using a soft splitting function at each node of the tree such as softmax, sparsemax, or the more general α-entmax (the latter includes both the more traditional softmax function when α is 1, and the sparsemax function when α is 2, but more often uses intermediate values of α such as 1.5; using α > 1 helps promote sparsity in features) instead of hard splitting using the Heaviside step function as in decision trees. Just like Catboost, it utilizes oblivious decision trees, meaning trees that use the same splitting feature and threshold in all internal nodes of the same depth (and thus are oblivious to the particular input and the corresponding path within the decision tree to the current node).

Since a single differentiable tree is a weak learner, predictions from many trees are merged together in a random forest-like fashion, averaged over all trees.

In the author’s previous paper, they generalize over the NODE architecture, still using the α-entmax splitting function but dispensing with the requirement of oblivious trees; that is, in the author’s version, the nodes in each layer do not have to share the same gating function for all inputs.

The current paper, as it seems, removes the α-entmax in favor of a simpler softmax splitting function, and adds attention between trees in the form of a tree attention block (TAB). The tree attention block uses a squeeze-and-excitation operation similar to that in SENet. The squeeze operation calculates the average of the output of each tree, while the excitation part calculates the attention weights via fully connected layers. These attention weights are then applied to the predictions of each tree.

The author compares the performance of the new technique against several gradient boosted decision tree models, as well as NODE and mGBDT, using a number of common tabular datasets such as Higgs, Click, Year Prediction, Microsoft, Yahoo, and EPSILON. The results look quite good in comparison, improving upon the performance of other algorithms in a number of cases.

It is worth noting that the results for all methods are given with default (untuned) hyperparameters. The previous papers also give results with tuned hyperparameters; it would be interesting to see how the suggested model performs in that case.

The authors also look at the memory usage of the new differentiable forest and find it to be higher for smaller datasets but smaller for larger datasets, staying fairly consistent as opposed to other models whose memory consumption tends to be more proportional to the size of the dataset.

Overall, although the suggested approach is still in its development and can likely be improved, it seems like a very interesting paper, especially considering that it was written by a graduate student, not a big company like Google or Yandex. Hopefully, the model will be further developed and refined to become one of the tools commonly used by machine learning practitioners for tabular datasets.

Original paper link

Github repository

Suggested reading

Leave a Reply