Everything is not linear: the example of Random Forest

Linear regression is great. But unfortunately, not everything in nature is linear. If you drink alcohol, you get drunk. If you take your prescribed drugs, you are healthy. But if you do both at the same time, you will not be drunk and healthy, you will probably get very sick. This is an interaction. In general, we talk about interaction when there is a departure from linearity. There are many ways to try and capture interaction using statistical learning but today, I will focus on Random Forest. But before, I explain what a forest is I have to explain what a decision tree is.

“Erik – Prunus sp 02” by Zeynel Cebeci – Own work. Licensed under CC BY-SA 4.0 via Wikimedia Commons – https://commons.wikimedia.org/wiki/File:Erik_-_Prunus_sp_02.JPG#/media/File:Erik_-_Prunus_sp_02.JPG

The good people at www.r2d3.us did a great job of explaining what a decision tree is in a very visual way. So click here and go look at it. Also, subtle Star Wars reference.

Now you know what a decision tree is. Decision trees have several nice properties. They are easily interpretable and they can capture interaction. Indeed, if you take again our toy example, we have level of alcohol and number of pills taken as features and adverse outcome as label (what we want to predict). Let us imagine we do an experiment where we give some alcohol and some drugs to mice (Bioethics forbid us to do this thought experiment with humans). For each mouse, we note its blood alcohol content, the number of pills it took and then if it got an adverse outcome or not. With our tree growing algorithm, the data will first be split according to one of the two features for some cut-off value that separates well the adverse outcomes from the benign ones let’s say alcohol>1g/L. The second split for the part of the data that has alcohol>1g/L will be on our second variable let’s say pills>2. In this leaf corresponding to mix of alcohol and pills, we will find many adverse outcomes while in the other leaves there will be few. This tree is easily interpretable: for the leaf we considered, it means that if the mice has more than 1g/L and takes more than 2 pills then it has probability 0.8 (=12/15 mice with those features) of having an adverse outcome. For each leaf, we can make a similar statement using the data in the leaf to estimate the probability of adverse outcome.

A question left open by r2d3 is when to stop growing the tree to avoid overfitting. For single trees, some penalization techniques exist in order to prune a fully grown tree. A reference on this subject is Elements of statistical learning by Hastie, Tibshirani and Friedman pages 305-313 (pages 324-332 for the PDF count). This book is a general reference for statistical learning where many algorithms are explained including random forest.

Unfortunately, decision trees also have inconveniences. They are not so great at prediction. They have high variance: if you tweak the data by a little, you can end up with a very different tree. All the splits are conditional on the previous ones, so if the first split is different it will impact the entire tree.

“Forrest from Clérey in winter”. Licensed under CC BY-SA 3.0 via Wikimedia Commons – https://commons.wikimedia.org/wiki/File:Forrest_from_Cl%C3%A9rey_in_winter.JPG#/media/File:Forrest_from_Cl%C3%A9rey_in_winter.JPG

The random forest algorithm introduced by Leo Breiman in 2001 tries to take advantage of this variance. A forest is simply a collection of trees. Each tree in the forest is grown according to the same random algorithm. Instead of using the entire sample, each tree is grown using a bootstrap version of the sample. A bootstrap version of the sample is a sampling with replacement which means that some data point will be missing and some might be selected more than once. The way the tree is grown is also random: instead of selecting the feature with the best cut-off to split each node, a number $m_{try}$ of features will be selected at random and then among those features, the best cut-off is chosen. A node is not split if it has less than min_leaf observations. The parameters have default value that are widely used and in my experience hard to improve on.

For any new data point, we predict the label by aggregating the individual prediction of the trees: we take the mean for regression (quantitative label) or we do a majority vote for classification (binary label).

Random forest are not as interpretable as a single decision tree. However, the algorithm can be used to grade feature’s importance using permutation.

A nice additional feature is that random forest can be used to estimate the generalization error of the model using the training set. As we use a bootstrap version of the sample, some observations are left out when we grow a specific tree. We say that the observations are out-of-bag. For each observation, we can take all the trees for which it is out-of-bag and aggregate the predictions to get an out-of-bag estimate. The interest of this is that since we did not use the observation to obtain the estimate, overfitting should be avoided. This is not as satisfying as using test data to estimate generalization error.

Random forest is computationally efficient: each part of the randomization of the growth algorithm allows to save on computation time: using only a part of the data and looking for the best cut-off only in a few feature for each node splitting. More importantly, as each tree in the forest is independent from the rest, they can be grown in parallel.

Random forest is implemented in R and in the scikit-learn package for python. I was initially using the R implementation but had to move over to python as the implementation offers more flexibility in the way you aggregate the predictions. It also is probably faster in python.

To summarize, the idea is to average over many noisy estimators to get a better prediction. Random forest is popular for several reasons: it is easy to use as no parameter calibration is required, it achieves very good prediction performance, it is computationally efficient and implementation are available on popular programming language.

Theory

Let us now turn to our theoretical understanding of the algorithm. Compared to lasso, not much has been accomplished. The best result has been obtained recently by Erwan Scornet, Gérard Biau (who happens to be my PhD supervisor) and Jean-Philippe Vert in their article Consistency of random forest. Consistency means that when n the number of observations goes to infinity the estimate of the regression function converges in some sense to the true regression function. It is a property verified by many algorithms. An important theorem by Stone (1977) states that the nearest neighbor algorithm is consistent. For a new observation, the nearest neighbor algorithm takes the $k$ closest observations and aggregates their labels. A key difference between the two algorithms is that the nearest neighbor algorithm picks the observation he uses for prediction without looking at the labels. This makes consistency easier to prove. We expect that using the labels to pick the observation can only help achieve better prediction. But it makes proving it harder.

The article mentioned above proves consistency for random forest under somewhat strong assumptions. Nevertheless, this work is important because it keeps the use of labels in the way the trees are grown.