Random Forests in R
we will be going over what Random Forests are, what they are used for, and how to use them in an R environment.
Why do we need Random Forests?
You might be familiar with the concept of Decision Trees – a probabilistic predictive model which can be used to classify data in a wide array of applications. Decision Trees are created through observation of data points. A probabilistic model is created by observing the features present in each point labeled a certain class, and then associating probabilities to these features.
Decision Trees are very interesting because one can follow the structure created to understand how the class was inferred. However, this kind of model is not without its own problems. One of the main problems is what is called overfitting. Overfitting happens when the process of creating the tree makes it so that the tree is extremely ramified and complex – this means that the model will not generalize correctly.
This can mean that the data points are too varied, or maybe that there are too many features to be analyzed at once. However, if we cut down the number of data points or features, this might make our model worse. So, we would need another kind of solution for this problem.
What are Random Forests?
Random Forests are one of the proposed solutions. As one might infer from its name, Random Forests are composed of multiple Decision Trees. This makes them part of a family of models – that are composed of other models working in tandem – called ensemble learning models. The main concept behind Random Forests is that, if you partition the data that would be used to create a single decision tree into different parts, create one tree for each of these partitions, and then use a method to “average” the results of all of these different trees, you should end up with a better model. In the case of trees used for classification, this “average” is the mode of the set of trees in the forest. For regression, the “average” is the mean of the set of trees in the forest.
The main mechanism behind Random Forests is bagging, which is shorthand for bootstrap aggregating. Bagging is the concept of randomly sampling some data from a dataset, but with replacement. What this means in practice is that there is some amount of data that will be repeated in each partition, and some amount of data that will not be represented in the samples – about 63% of the unique examples are kept – this makes it so that the model generated for that bag is able to generalize better to some degree. Each partition of data of our training data for the Random Forest applies this concept.
You might be asking yourself what happens to the data that is not present in the “bags”. This data, aptly called Out-Of-Bag Data, serves as a kind of testing data for the generated model – which serves as validation that our model works!
Additionally, Random Forests are created using feature bagging as well, which makes it so that there are no problems of overfitting due to a large amount of features for a small amount of data. For example, if a few features are very strong predictors, they will be present in a large amount of “bags”, and these bags will become correlated. However, this also makes it so that the Random Forest itself does not focus only on what strongly predicts the data that it was fed, making the model generalize better. Traditionally, a dataset with a number 𝑓 of features will have ⌈𝑓⎯⎯⎯√2 ⌉ features in each partition.
Using Random Forests in R
Now that you know what Random Forests are, we can move on to use them in R. Conveniently enough, CRAN (R's library repository) has a library ready for us -- aptly named randomForest. However, we first need to install it. You can do that by running the code cell below.
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
We can now go ahead and create the model. For this example, we will be using the built-in iris dataset. Feel free to try using other datasets!
To create the model, we will use the randomForest function. It has a wide array of parameters for customization, but the simplest approach is just to provide it with a formula and the dataset to infer the probabilities from. This can be seen in the following code:
# Create the Random Forest model.
# The randomForest function accepts a "formula" structure as its main parameter. In this case, "Species" will be the variable
# to be predicted, while the others will be the predictors.
myLittleForest <- randomForest(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, data = iris)
# Print the summary of our model.
print(myLittleForest)
##
## Call:
## randomForest(formula = Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, data = iris)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 2
##
## OOB estimate of error rate: 4%
## Confusion matrix:
## setosa versicolor virginica class.error
## setosa 50 0 0 0.00
## versicolor 0 47 3 0.06
## virginica 0 3 47 0.06
Another statistic that can be quite informative is the importance of each predictor for the prediction of our data points. This can be done by using the importance function, which can be seen in the following code:
print(importance(myLittleForest, type=2))
## MeanDecreaseGini
## Sepal.Length 9.640125
## Sepal.Width 2.208160
## Petal.Length 43.513029
## Petal.Width 43.888339
In this case, it seems that the petal length of the flowers is the main difference between species (the larger the Gini coefficient is, the more different each data point is in terms of that variable).
Visualize the tree
library(party)
## Loading required package: grid
## Loading required package: mvtnorm
## Loading required package: modeltools
## Loading required package: stats4
## Loading required package: strucchange
## Loading required package: zoo
##
## Attaching package: 'zoo'
## The following objects are masked from 'package:base':
##
## as.Date, as.Date.numeric
## Loading required package: sandwich
cf <- cforest(Species~., data=iris)
Extract a tree and build a binary tree that can be plotted
pt <- prettytree(cf@ensemble[[1]], names(cf@data@get("input")))
nt <- new("BinaryTree")
nt@tree <- pt
nt@data <- cf@data
nt@responses <- cf@responses
plot(nt, type="simple")
END