Practical Random Forest and repeated cross validation in R

This document presents the steps in creating a classification model using random forest in R.



Importing packages required

The following code installs the packages required if they’re not yet available. Afterwards, they will be loaded in our current R session.

## (1) Define the packages that will be needed
packages <- c('dplyr', 'ggplot2', 'caret')

## (2) Install them if not yet installed
installed_packages <- packages %in% rownames(installed.packages())
if (any(installed_packages == FALSE)) {
  install.packages(packages[!installed_packages])
}

## (3) Load the packages into R session
invisible(lapply(packages, library, character.only = TRUE))



Loading the dataset

The dataset we’ll use is the built-in Iris dataset in R. We use data funtion to load it.

## Get the iris dataset
data("iris")

## View first few rows
head(iris)
##   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1          5.1         3.5          1.4         0.2  setosa
## 2          4.9         3.0          1.4         0.2  setosa
## 3          4.7         3.2          1.3         0.2  setosa
## 4          4.6         3.1          1.5         0.2  setosa
## 5          5.0         3.6          1.4         0.2  setosa
## 6          5.4         3.9          1.7         0.4  setosa



Creating repeated k-fold cross validation

We want the model to be ready for unseen data. To prepare it, we use repeated k-fold cross validation. As quoted from Machine Learning Mastery:

Cross-validation is primarily used in applied machine learning to estimate the skill of a machine learning model on unseen data. That is, to use a limited sample in order to estimate how the model is expected to perform in general when used to make predictions on data not used during the training of the model.

cross validation

The general procedure is like this.

  1. We shuffle the data by random.

  2. We split it into k-groups. Note that k is an arbitrary parameter. There’s no specific criteria to choose the value for k. Typically, it’s 5 or 10.

  3. For each unique group:

     1. Take the group as a hold out or test data set
    
     2. Take the remaining groups as a training data set
    
     3. Fit a model on the training set and evaluate it on the test set
    
     4. Retain the evaluation score and discard the model
  4. Summarize the skill of the model using the sample of model evaluation scores

Repeated k-validation is simply doing k-fold cross validation, but it repeats the process by n times. n is also an arbitrary number. Let’s do this in R using caret package.

## Set seed for reproducibility
set.seed(123)

## Define repeated cross validation with 5 folds and three repeats
repeat_cv <- trainControl(method='repeatedcv', number=5, repeats=3)



Splitting the dataset into training and testing sets

Our goal is to make a random forest model that is also ready for unseen data. So, we will split the data into training and testing sets. We use the training data to train the model, while we use the testing data to measure the model’ accuracy on data that is not seen by the model.

## Set seed for reproducibility
set.seed(123)

## Split the data so that we use 70% of it for training
train_index <- createDataPartition(y=iris$Species, p=0.7, list=FALSE)

## Subset the data
training_set <- iris[train_index, ]
testing_set <- iris[-train_index, ]



Creating a random forest model

Let’s now train a random forest model using the training data we had defined.

## Set seed for reproducibility
set.seed(123)

## Train a random forest model
forest <- train(
        
        # Formula. We are using all variables to predict Species
        Species~., 
        
        # Source of data; remove the Species variable
        data=training_set, 
        
        # `rf` method for random forest
        method='rf', 
        
        # Add repeated cross validation as trControl
        trControl=repeat_cv,
        
        # Accuracy to measure the performance of the model
        metric='Accuracy')

## Print out the details about the model
forest$finalModel
## 
## Call:
##  randomForest(x = x, y = y, mtry = min(param$mtry, ncol(x))) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 2
## 
##         OOB estimate of  error rate: 4.76%
## Confusion matrix:
##            setosa versicolor virginica class.error
## setosa         35          0         0  0.00000000
## versicolor      0         33         2  0.05714286
## virginica       0          3        32  0.08571429

This random forest model had 500 trees. The above output displayed the confusion matrix of the actual species of the training data and the predicted species by the random forest model. The diagonal represent the correct predictions. The model misclassified 2 versicolor as virginica, and 3 virginca as versicolor. Having a total of 5 prediction errors, we have an accuracy of 95.23%.



Check variable importance

We can extract the variable importance from the random forest model. Variable importance simply tells us how the variables helped the model in predicting the class of the data. Higher importance means that the variable is useful to the model. Let’s plot it.

## Get variable importance, and turn into a data frame
var_imp <- varImp(forest, scale=FALSE)$importance
var_imp <- data.frame(variables=row.names(var_imp), importance=var_imp$Overall)

## Create a plot of variable importance
var_imp %>%
        
        ## Sort the data by importance
        arrange(importance) %>%
        
        ## Create a ggplot object for aesthetic
        ggplot(aes(x=reorder(variables, importance), y=importance)) + 
        
        ## Plot the bar graph
        geom_bar(stat='identity') + 
        
        ## Flip the graph to make a horizontal bar plot
        coord_flip() + 
        
        ## Add x-axis label
        xlab('Variables') +
        
        ## Add a title
        labs(title='Random forest variable importance') + 
        
        ## Some layout for the plot
        theme_minimal() + 
        theme(axis.text = element_text(size = 10), 
              axis.title = element_text(size = 15), 
              plot.title = element_text(size = 20), 
              )

The above plot shows us that petal length and petal width are the most helpful variables in the model.



Performance on testing data

Let’s see how the model performs on testing data.

## Generate predictions
y_hats <- predict(
        
        ## Random forest object
        object=forest, 
        
        ## Data to use for predictions; remove the Species
        newdata=testing_set[, -5])

## Print the accuracy
accuracy <- mean(y_hats == testing_set$Species)*100
cat('Accuracy on testing data: ', round(accuracy, 2), '%',  sep='')
## Accuracy on testing data: 93.33%

The model did pretty well on the testing data having a 93.33% accuracy rate.