library(tidyverse)
library(caret)
library(e1071)
library(glmnet)
library(MLmetrics)
library(caretEnsemble)
library(kernlab)

data(GermanCredit)

In a series of tutorials, we will be walking through the “caret” package in R for machine learning. We will start with the raw data, a preliminary exploration of it, and then pre-processing it. Later we will get into training and tuning the model itself and evaluating its performance, and finding ways to tweak it with a focus on situations of class imbalance. In this particular tutorial we will be using machine learning for classification purposes, and we will use the “GermanCredit” dataset from the “caret” package.

Throughout this series of tutorials, we will cover:

There are a few sources from which this tutorial draws influence and structure. The first is the GitHub documentation on “caret” from its creation, Max Kuhn. The second is a very well-written and comprehensive tutorial by author Selva Prabhakaran on Machine Learning Plus. Third is a helpful resource for dealing with class imbalance, as we often find with classification problems.

Here are some steps we’re going to take to get this data into a form more realistic to what you may experience in the real world…


General preparation of the dataset for machine learning

# Select variables
GermanCredit <- GermanCredit %>%
  dplyr::select(Class, Duration, Amount, Age, ResidenceDuration, NumberExistingCredits,
                NumberPeopleMaintenance, Telephone, ForeignWorker, Housing.Rent,
                Housing.Own, Housing.ForFree, Property.RealEstate,
                Property.Insurance, Property.CarOther, Property.Unknown) %>%
  dplyr::rename("EmploymentDuration" = "Duration")

# Simulate missing data for the variables Age and Employment Duration
n <- nrow(GermanCredit)
agePct <- 3
durationPct <- 7

# Generate rows that will hold missing data
set.seed(355)
ageMissingPctRows <- sample(1:n, round(agePct/100 * n, 0))

set.seed(355)
durationMissingPctRows <- sample(1:n, round(durationPct/100 * n, 0))

# Make values NA's
GermanCredit[ageMissingPctRows, "Age"] <- NA
GermanCredit[durationMissingPctRows, "EmploymentDuration"] <- NA

# Code certain variables as factors
GermanCredit <- GermanCredit %>%
  mutate(across(.cols = c("ResidenceDuration", "NumberExistingCredits",
                          "NumberPeopleMaintenance", "Telephone",
                          "ForeignWorker"), .fns = factor))

Let’s get a look at our dataset now:

summary(GermanCredit)
##   Class     EmploymentDuration     Amount           Age       
##  Bad :300   Min.   : 4.00      Min.   :  250   Min.   :19.00  
##  Good:700   1st Qu.:12.00      1st Qu.: 1366   1st Qu.:27.00  
##             Median :18.00      Median : 2320   Median :33.00  
##             Mean   :20.79      Mean   : 3271   Mean   :35.35  
##             3rd Qu.:24.00      3rd Qu.: 3972   3rd Qu.:41.75  
##             Max.   :72.00      Max.   :18424   Max.   :75.00  
##             NA's   :70                         NA's   :30     
##  ResidenceDuration NumberExistingCredits NumberPeopleMaintenance Telephone
##  1:130             1:633                 1:845                   0:404    
##  2:308             2:333                 2:155                   1:596    
##  3:149             3: 28                                                  
##  4:413             4:  6                                                  
##                                                                           
##                                                                           
##                                                                           
##  ForeignWorker  Housing.Rent    Housing.Own    Housing.ForFree
##  0: 37         Min.   :0.000   Min.   :0.000   Min.   :0.000  
##  1:963         1st Qu.:0.000   1st Qu.:0.000   1st Qu.:0.000  
##                Median :0.000   Median :1.000   Median :0.000  
##                Mean   :0.179   Mean   :0.713   Mean   :0.108  
##                3rd Qu.:0.000   3rd Qu.:1.000   3rd Qu.:0.000  
##                Max.   :1.000   Max.   :1.000   Max.   :1.000  
##                                                               
##  Property.RealEstate Property.Insurance Property.CarOther Property.Unknown
##  Min.   :0.000       Min.   :0.000      Min.   :0.000     Min.   :0.000   
##  1st Qu.:0.000       1st Qu.:0.000      1st Qu.:0.000     1st Qu.:0.000   
##  Median :0.000       Median :0.000      Median :0.000     Median :0.000   
##  Mean   :0.282       Mean   :0.232      Mean   :0.332     Mean   :0.154   
##  3rd Qu.:1.000       3rd Qu.:0.000      3rd Qu.:1.000     3rd Qu.:0.000   
##  Max.   :1.000       Max.   :1.000      Max.   :1.000     Max.   :1.000   
## 

“Class” is our response variable, and it has a class balance of 70/30. We now have a distribution of missing values for the EmploymentDuration and Age variables that we will address later, but the rest of our predictor variables are factors. Notice that they are coded in different ways. For example, “Telephone” and “ForeignWorker” are coded as 0 vs. 1 variables, but the variable “Housing” is divided into three components: “Housing.Rent”, “Housing.Own”, and “Housing.ForFree”. We will address this during the pre-processing process.


Visualization of feature distribution by class

Caret gives us the very useful featurePlot() function, which can help produce lattice graphs - that is, to observe the distribution of the predictors by the class variable when we have continuous variables. Let’s look at a couple examples of possible feature plots.

featurePlot(x = GermanCredit[,c("EmploymentDuration", "Age")],
            y = GermanCredit$Class,
            plot = "box")

featurePlot(x = GermanCredit[,c("EmploymentDuration", "Age")],
            y = GermanCredit$Class,
            plot = "density")

In the case of the “Property” variable which is coded numerically, we can do a similar procedure.

featurePlot(x = GermanCredit[,13:16],
            y = GermanCredit$Class,
            plot = "density")

Another very helpful function is nearZeroVar(), which can identify variables that have either one unique value (i.e. true “zero variance” predictors, or predictors that have very few unique values relative to the total sample size, or a very large ratio of frequency of the most common value to the next most common value). Let’s use it now with the default arguments as well as with a more extreme example:

nearZeroVar(GermanCredit, freqCut = 95/5, uniqueCut = 10)
## [1] 9
nearZeroVar(GermanCredit, freqCut = 80/20, uniqueCut = 10)
## [1]  7  9 10 12 16

This function returns the column indices of the variables that are thought to be near zero variation, based on this configuration. The variable being complained about in the first run of the function is the “ForeignWorker” variable. Actually, you can see that under the lower configuration, it complains about variables 10, 12, and 16 - but these are levels of the “Housing” and “Property” variables. Technically it should be complaining about the “NumberExistingCredits” variable, but doesn’t due to the way it is encoded!

Either way, we will drop the variable “ForeignWorker” due to the relative lack of variation, and will merge levels 2, 3, and 4 of the variable NumberExistingCredits.

GermanCredit <- dplyr::select(GermanCredit, -ForeignWorker)
GermanCredit$NumberExistingCredits <- fct_collapse(GermanCredit$NumberExistingCredits,
                                                   "2+" = c("2", "3", "4"))


Pre-processing: imputation of missing data, one-hot encoding, and normalization

Let’s move on to other pre-processing functions. The first thing that we will do is divide our data into two parts: training set and test set. Caret provides us the createDataPartition() function for this, which will allow us to partition based on the proportion from the response variable.

set.seed(355)
trainIndex <- createDataPartition(GermanCredit$Class, p = 0.7, list = FALSE)
trainingSet <- GermanCredit[trainIndex,]
testSet <- GermanCredit[-trainIndex,]

Let’s summarize the training set by itself.

summary(trainingSet)
##   Class     EmploymentDuration     Amount           Age       
##  Bad :210   Min.   : 4.00      Min.   :  250   Min.   :19.00  
##  Good:490   1st Qu.:12.00      1st Qu.: 1357   1st Qu.:27.00  
##             Median :18.00      Median : 2250   Median :33.00  
##             Mean   :20.76      Mean   : 3220   Mean   :35.52  
##             3rd Qu.:24.00      3rd Qu.: 3933   3rd Qu.:42.00  
##             Max.   :60.00      Max.   :18424   Max.   :75.00  
##             NA's   :45                         NA's   :24     
##  ResidenceDuration NumberExistingCredits NumberPeopleMaintenance Telephone
##  1: 91             1 :447                1:593                   0:274    
##  2:215             2+:253                2:107                   1:426    
##  3:102                                                                    
##  4:292                                                                    
##                                                                           
##                                                                           
##                                                                           
##   Housing.Rent     Housing.Own     Housing.ForFree  Property.RealEstate
##  Min.   :0.0000   Min.   :0.0000   Min.   :0.0000   Min.   :0.0000     
##  1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:0.0000     
##  Median :0.0000   Median :1.0000   Median :0.0000   Median :0.0000     
##  Mean   :0.1843   Mean   :0.7029   Mean   :0.1129   Mean   :0.2743     
##  3rd Qu.:0.0000   3rd Qu.:1.0000   3rd Qu.:0.0000   3rd Qu.:1.0000     
##  Max.   :1.0000   Max.   :1.0000   Max.   :1.0000   Max.   :1.0000     
##                                                                        
##  Property.Insurance Property.CarOther Property.Unknown
##  Min.   :0.0000     Min.   :0.0000    Min.   :0.00    
##  1st Qu.:0.0000     1st Qu.:0.0000    1st Qu.:0.00    
##  Median :0.0000     Median :0.0000    Median :0.00    
##  Mean   :0.2243     Mean   :0.3414    Mean   :0.16    
##  3rd Qu.:0.0000     3rd Qu.:1.0000    3rd Qu.:0.00    
##  Max.   :1.0000     Max.   :1.0000    Max.   :1.00    
## 

Next, we are going to pre-process the data. Your best friend for this process will be the preProcess() function. We will use this to impute missing data first.

The preProcess() function takes argument “method”, which has many different options for processing. For imputation, options are “knnImpute”, “bagImpute”, or “medianImpute”. Let’s use “bagImpute” on the training set. We will go back at the end and apply the same transformation to the testing data.

set.seed(355)
bagMissing <- preProcess(trainingSet, method = "bagImpute")
trainingSet <- predict(bagMissing, newdata = trainingSet)

Next, we will use what is known as “one hot encoding” to transform the dummy variables. Actually, “Housing” and “Property” are already in the exact format that we want! What we want to do is transform the other variables into the same format. The output will be a matrix of the predictors, which omits the response variable.

dummyModel <- dummyVars(Class ~ ., data = trainingSet)
trainingSetX <- as.data.frame(predict(dummyModel, newdata = trainingSet))

The next thing that we will do will be to transform these variables to be between 0 and 1. One of my preferred approaches, in the case where all of the predictors are continuous, is to standardize the variables into Z-scores. This can be done through a combination of “center” and “scale” to the “method” argument. Specifying “method = ‘range’”, however, will transform the variables to a 0-1 scale.

rangeModel <- preProcess(trainingSetX, method = "range")
trainingSetX <- predict(rangeModel, newdata = trainingSetX)

Now, we will make the final training set by adding this to our original response variable.

trainingSet <- cbind(trainingSet$Class, trainingSetX)
names(trainingSet)[1] <- "Class"

But remember, all we did was transform the training set. We need to transform the test set as well. We’ll use the same three procedures: the imputation of missing values using the “bagMissing” model object, the one-hot encoding using the “dummyModel” object, and the normalization using the “rangeModel” object.

testSet_imputed <- predict(bagMissing, testSet)
testSet_dummy <- predict(dummyModel, testSet_imputed)
testSet_range <- predict(rangeModel, testSet_dummy)
testSet_range <- data.frame(testSet_range)
testSet <- cbind(testSet$Class, testSet_range)
names(testSet) <- names(trainingSet)   # proactive step that can prevent errors later


Removing low information features

Next thing that we need to consider is low information features. If uninformative, useless features are included in the dataset, this will almost always lead to a decrease in model performance. Personally, I like to just let the domain knowledge take care of this part. However, one other option is Recursive Feature Elimination. The function to implement this is the rfe() function, with a control defined by the rfeControl() function.

Recursive Feature Elimination works by building many models of a type of machine learning method on the training set, and iteratively re-calculating the most important variables. At the end, the variables that were found important most often, across different subset sizes, can be reported. We will see a single example of a variable importance plot later; many algorithms provide methods for ranking features from most to least important. The method that will be used in our Recursive Feature Elimination approach here will be the Random Forest, and subset sizes (i.e. number of most important features to use) explicitly provided.

subsets <- c(1:5, 10, 15, 20)

set.seed(355)

rfeCtrl <- rfeControl(functions = rfFuncs,
                      method = "cv",
                      verbose = FALSE)

rfProfile <- rfe(x = trainingSet[,2:21], 
                y = trainingSet$Class, 
                sizes = subsets,
                rfeControl = rfeCtrl)

rfProfile
## 
## Recursive feature selection
## 
## Outer resampling method: Cross-Validated (10 fold) 
## 
## Resampling performance over subset size:
## 
##  Variables Accuracy  Kappa AccuracySD KappaSD Selected
##          1   0.7071 0.1022    0.02259 0.07996         
##          2   0.6843 0.1346    0.03836 0.10482         
##          3   0.6900 0.1102    0.03930 0.11867         
##          4   0.6829 0.1194    0.04249 0.13929         
##          5   0.7029 0.1527    0.03614 0.11216         
##         10   0.7129 0.1791    0.04688 0.15503         
##         15   0.7314 0.2344    0.04893 0.14366        *
##         20   0.7214 0.2093    0.05724 0.17351         
## 
## The top 5 variables (out of 15):
##    EmploymentDuration, Amount, Age, Property.RealEstate, Housing.Own

We will not be eliminating features in this example based on this; however, this is something that is well worth being aware of!


Visualization of feature importance

Next, we will be using the train() function to fit actual models, and later going in to examine performance. The train() function is an incredibly powerful function, that takes a control object which controls for tuning hyperparameters, cross-validation of the model, and selecting an optimal model.

To get an idea of the full scope of models that can be trained in “caret”, see the following list:

names(getModelInfo())
##   [1] "ada"                 "AdaBag"              "AdaBoost.M1"        
##   [4] "adaboost"            "amdai"               "ANFIS"              
##   [7] "avNNet"              "awnb"                "awtan"              
##  [10] "bag"                 "bagEarth"            "bagEarthGCV"        
##  [13] "bagFDA"              "bagFDAGCV"           "bam"                
##  [16] "bartMachine"         "bayesglm"            "binda"              
##  [19] "blackboost"          "blasso"              "blassoAveraged"     
##  [22] "bridge"              "brnn"                "BstLm"              
##  [25] "bstSm"               "bstTree"             "C5.0"               
##  [28] "C5.0Cost"            "C5.0Rules"           "C5.0Tree"           
##  [31] "cforest"             "chaid"               "CSimca"             
##  [34] "ctree"               "ctree2"              "cubist"             
##  [37] "dda"                 "deepboost"           "DENFIS"             
##  [40] "dnn"                 "dwdLinear"           "dwdPoly"            
##  [43] "dwdRadial"           "earth"               "elm"                
##  [46] "enet"                "evtree"              "extraTrees"         
##  [49] "fda"                 "FH.GBML"             "FIR.DM"             
##  [52] "foba"                "FRBCS.CHI"           "FRBCS.W"            
##  [55] "FS.HGD"              "gam"                 "gamboost"           
##  [58] "gamLoess"            "gamSpline"           "gaussprLinear"      
##  [61] "gaussprPoly"         "gaussprRadial"       "gbm_h2o"            
##  [64] "gbm"                 "gcvEarth"            "GFS.FR.MOGUL"       
##  [67] "GFS.LT.RS"           "GFS.THRIFT"          "glm.nb"             
##  [70] "glm"                 "glmboost"            "glmnet_h2o"         
##  [73] "glmnet"              "glmStepAIC"          "gpls"               
##  [76] "hda"                 "hdda"                "hdrda"              
##  [79] "HYFIS"               "icr"                 "J48"                
##  [82] "JRip"                "kernelpls"           "kknn"               
##  [85] "knn"                 "krlsPoly"            "krlsRadial"         
##  [88] "lars"                "lars2"               "lasso"              
##  [91] "lda"                 "lda2"                "leapBackward"       
##  [94] "leapForward"         "leapSeq"             "Linda"              
##  [97] "lm"                  "lmStepAIC"           "LMT"                
## [100] "loclda"              "logicBag"            "LogitBoost"         
## [103] "logreg"              "lssvmLinear"         "lssvmPoly"          
## [106] "lssvmRadial"         "lvq"                 "M5"                 
## [109] "M5Rules"             "manb"                "mda"                
## [112] "Mlda"                "mlp"                 "mlpKerasDecay"      
## [115] "mlpKerasDecayCost"   "mlpKerasDropout"     "mlpKerasDropoutCost"
## [118] "mlpML"               "mlpSGD"              "mlpWeightDecay"     
## [121] "mlpWeightDecayML"    "monmlp"              "msaenet"            
## [124] "multinom"            "mxnet"               "mxnetAdam"          
## [127] "naive_bayes"         "nb"                  "nbDiscrete"         
## [130] "nbSearch"            "neuralnet"           "nnet"               
## [133] "nnls"                "nodeHarvest"         "null"               
## [136] "OneR"                "ordinalNet"          "ordinalRF"          
## [139] "ORFlog"              "ORFpls"              "ORFridge"           
## [142] "ORFsvm"              "ownn"                "pam"                
## [145] "parRF"               "PART"                "partDSA"            
## [148] "pcaNNet"             "pcr"                 "pda"                
## [151] "pda2"                "penalized"           "PenalizedLDA"       
## [154] "plr"                 "pls"                 "plsRglm"            
## [157] "polr"                "ppr"                 "pre"                
## [160] "PRIM"                "protoclass"          "qda"                
## [163] "QdaCov"              "qrf"                 "qrnn"               
## [166] "randomGLM"           "ranger"              "rbf"                
## [169] "rbfDDA"              "Rborist"             "rda"                
## [172] "regLogistic"         "relaxo"              "rf"                 
## [175] "rFerns"              "RFlda"               "rfRules"            
## [178] "ridge"               "rlda"                "rlm"                
## [181] "rmda"                "rocc"                "rotationForest"     
## [184] "rotationForestCp"    "rpart"               "rpart1SE"           
## [187] "rpart2"              "rpartCost"           "rpartScore"         
## [190] "rqlasso"             "rqnc"                "RRF"                
## [193] "RRFglobal"           "rrlda"               "RSimca"             
## [196] "rvmLinear"           "rvmPoly"             "rvmRadial"          
## [199] "SBC"                 "sda"                 "sdwd"               
## [202] "simpls"              "SLAVE"               "slda"               
## [205] "smda"                "snn"                 "sparseLDA"          
## [208] "spikeslab"           "spls"                "stepLDA"            
## [211] "stepQDA"             "superpc"             "svmBoundrangeString"
## [214] "svmExpoString"       "svmLinear"           "svmLinear2"         
## [217] "svmLinear3"          "svmLinearWeights"    "svmLinearWeights2"  
## [220] "svmPoly"             "svmRadial"           "svmRadialCost"      
## [223] "svmRadialSigma"      "svmRadialWeights"    "svmSpectrumString"  
## [226] "tan"                 "tanSearch"           "treebag"            
## [229] "vbmpRadial"          "vglmAdjCat"          "vglmContRatio"      
## [232] "vglmCumulative"      "widekernelpls"       "WM"                 
## [235] "wsrf"                "xgbDART"             "xgbLinear"          
## [238] "xgbTree"             "xyf"

As you can see, there are MANY options. However, it should be noted that certain packages, on top of “caret”, must be installed and called at the beginning in order to use these. We will start by training a Random Forest classifier on our training set and looking at what we can do with it. Note that we are not setting the controls for the train() function - it will do this automatically. We will come back later and do so.

set.seed(355)
rf <- train(Class ~., data = trainingSet, method = "rf")
rf
## Random Forest 
## 
## 700 samples
##  20 predictor
##   2 classes: 'Bad', 'Good' 
## 
## No pre-processing
## Resampling: Bootstrapped (25 reps) 
## Summary of sample sizes: 700, 700, 700, 700, 700, 700, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.7111715  0.1123744
##   11    0.6985914  0.1905956
##   20    0.6941395  0.1895292
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.

Random Forest is one of several different classifiers that provides a metric of variable importance. Some others include linear models (where the absolute value of the t-statistic for each model parameter is used to rank variables by importance), partial least squares, recursive partitioning, bagged or boosted trees, or multivariate adaptive regression splines. As a next step, we will look at a plot of the variable importance for the Random Forest we just trained.

varimp_RF <- varImp(rf)
plot(varimp_RF, main = "German Credit Variable Importance (Random Forest)")

Some caution should be taken in interpreting these Random Forest variable importance plot results, because they tend to weight continuous variables higher than categorical variables. See my YouTube video on “When Should You Use Random Forests?” to see a more detailed example of this. It is fair however to note that Amount is viewed as a very disproportionately important feature.


Definitions of metrics of performance for classification problems: Sensitivity, Specificity, etc.

Let’s now use the random forest to predict on our test data. The predict() function here, by default, will return response predictions (i.e. a vector of “Good”s and “Bad”s). We will use it later to output probabilities instead.

fitted <- predict(rf, testSet)
fitted[1:10]
##  [1] Good Good Good Good Good Good Good Good Good Good
## Levels: Bad Good

Now we will create a confusion matrix:

confusionMatrix(reference = testSet$Class, data = fitted, mode = "everything", positive = "Good")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Bad Good
##       Bad    9    8
##       Good  81  202
##                                           
##                Accuracy : 0.7033          
##                  95% CI : (0.6481, 0.7545)
##     No Information Rate : 0.7             
##     P-Value [Acc > NIR] : 0.4782          
##                                           
##                   Kappa : 0.0806          
##                                           
##  Mcnemar's Test P-Value : 2.312e-14       
##                                           
##             Sensitivity : 0.9619          
##             Specificity : 0.1000          
##          Pos Pred Value : 0.7138          
##          Neg Pred Value : 0.5294          
##               Precision : 0.7138          
##                  Recall : 0.9619          
##                      F1 : 0.8195          
##              Prevalence : 0.7000          
##          Detection Rate : 0.6733          
##    Detection Prevalence : 0.9433          
##       Balanced Accuracy : 0.5310          
##                                           
##        'Positive' Class : Good            
## 

It is worth taking a moment to interpret some of this output.

As you can see here, the model has reasonable sensitivity and positive predictive value, but abysmal specificity. Such is life for practitioners of machine learning.

But next, we need to actually tune the hyperparameters here rather than just pass these directly into the train() function!

The train() function allows us to pass in a control to it. Part of this is the hyperparameter tuning. There are two ways to do this: through a tuning length, or through a tuning grid. The tuning length, specified through the “tuneLength” argument, represents the number of unique values that the “train” function will automatically choose between as the model trains. “tuneGrid” allows the user to specify which values for the hyperparameters that they want.

twoClassCtrl <- trainControl(
  method = "repeatedcv",
  number = 5,
  repeats = 5,
  savePredictions = "final",
  classProbs = T,
  summaryFunction = twoClassSummary
)

The “summaryFunction” argument to trCtrl is an important one. One can specify whether they would like to optimize across the tradeoff between Sensitivity and Specificity by specifying the “twoClassSummary” as the summary function. Alternatively, one can optimize across the tradeoff between Recall and Precision (that is, Sensitivity and Positive Predictive Value) by specifying “prSummary” as the summary function. Your knowledge of the domain should be the guiding force behind this determination. However, as a general rule, for very imbalanced datasets, Precision/Recall is a preferable tradeoff than Sensitivity/Specificity. This particular case is borderline because the class imbalance is 70/30 - it’s imbalanced but you’ll see a lot worse in the real world. We will hold it in place for now.

Now that we’ve created a training control, we’ll also use the “tuneLength” argument to the train() function.

set.seed(355)
rfTL <- train(Class ~., data = trainingSet, method = "rf", metric = "ROC", trControl = twoClassCtrl, tuneLength = 10)

fittedTL <- predict(rfTL, testSet)
confusionMatrix(reference = testSet$Class, data = fittedTL, mode = "everything", positive = "Good")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Bad Good
##       Bad   22   23
##       Good  68  187
##                                           
##                Accuracy : 0.6967          
##                  95% CI : (0.6412, 0.7482)
##     No Information Rate : 0.7             
##     P-Value [Acc > NIR] : 0.5781          
##                                           
##                   Kappa : 0.1574          
##                                           
##  Mcnemar's Test P-Value : 3.979e-06       
##                                           
##             Sensitivity : 0.8905          
##             Specificity : 0.2444          
##          Pos Pred Value : 0.7333          
##          Neg Pred Value : 0.4889          
##               Precision : 0.7333          
##                  Recall : 0.8905          
##                      F1 : 0.8043          
##              Prevalence : 0.7000          
##          Detection Rate : 0.6233          
##    Detection Prevalence : 0.8500          
##       Balanced Accuracy : 0.5675          
##                                           
##        'Positive' Class : Good            
## 

This is a little bit more balanced, but not exactly a huge improvement. Next, let’s make changes using tuneGrid. Before doing this, it may be a helpful to look at what parameters we are actually tuning.

modelLookup("rf")
##   model parameter                         label forReg forClass probModel
## 1    rf      mtry #Randomly Selected Predictors   TRUE     TRUE      TRUE

There is one parameter here to tune, and that is “mtry”, or the number of randomly selected predictors in the tree. Let’s take this for a spin…

rfGrid <- data.frame(mtry = c(3, 5, 7, 9, 10, 11, 12, 13, 15, 17, 19))  # tuneGrid requires a data frame input

set.seed(355)
rfTG <- train(Class ~., data = trainingSet, method = "rf", metric = "ROC", trControl = twoClassCtrl, tuneGrid = rfGrid)

fittedTG <- predict(rfTG, testSet)
confusionMatrix(reference = testSet$Class, data = fittedTG, mode = "everything", positive = "Good")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Bad Good
##       Bad   24   24
##       Good  66  186
##                                           
##                Accuracy : 0.7             
##                  95% CI : (0.6447, 0.7513)
##     No Information Rate : 0.7             
##     P-Value [Acc > NIR] : 0.5284          
##                                           
##                   Kappa : 0.1758          
##                                           
##  Mcnemar's Test P-Value : 1.548e-05       
##                                           
##             Sensitivity : 0.8857          
##             Specificity : 0.2667          
##          Pos Pred Value : 0.7381          
##          Neg Pred Value : 0.5000          
##               Precision : 0.7381          
##                  Recall : 0.8857          
##                      F1 : 0.8052          
##              Prevalence : 0.7000          
##          Detection Rate : 0.6200          
##    Detection Prevalence : 0.8400          
##       Balanced Accuracy : 0.5762          
##                                           
##        'Positive' Class : Good            
## 


Using non-standard sampling methods to correct for class imbalance

Next, another consideration we must have is sampling method. The two primary techniques for doing this are “down-sampling” and “up-sampling”, although “caret” does offer a couple hybrid approaches. These approaches are called “SMOTE” and “ROSE”, and require installation of the “DMwR” and “ROSE” packages, respectively.

The method for sampling can be specified in the control, so let’s try that next.

downCtrl <- trainControl(
  method = "boot",
  number = 5,
  savePredictions = "final",
  classProbs = T,
  summaryFunction = twoClassSummary,
  sampling = "down"
)

upCtrl <- trainControl(
  method = "boot",
  number = 5,
  savePredictions = "final",
  classProbs = T,
  summaryFunction = twoClassSummary,
  sampling = "up"
)

Now we’ll try looking at model performance again (for the down-sampled training case):

set.seed(355)
rfDown <- train(Class ~., data = trainingSet, method = "rf", metric = "ROC", trControl = downCtrl, tuneLength = 10)

fittedDown <- predict(rfDown, testSet)
confusionMatrix(reference = testSet$Class, data = fittedDown, mode = "everything", positive = "Good")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Bad Good
##       Bad   50   91
##       Good  40  119
##                                           
##                Accuracy : 0.5633          
##                  95% CI : (0.5052, 0.6203)
##     No Information Rate : 0.7             
##     P-Value [Acc > NIR] : 1               
##                                           
##                   Kappa : 0.1052          
##                                           
##  Mcnemar's Test P-Value : 1.251e-05       
##                                           
##             Sensitivity : 0.5667          
##             Specificity : 0.5556          
##          Pos Pred Value : 0.7484          
##          Neg Pred Value : 0.3546          
##               Precision : 0.7484          
##                  Recall : 0.5667          
##                      F1 : 0.6450          
##              Prevalence : 0.7000          
##          Detection Rate : 0.3967          
##    Detection Prevalence : 0.5300          
##       Balanced Accuracy : 0.5611          
##                                           
##        'Positive' Class : Good            
## 

Then let’s look at the model performance for the up-sampled training case:

set.seed(355)
rfUp <- train(Class ~., data = trainingSet, method = "rf", metric = "ROC", trControl = upCtrl, tuneLength = 10)

fittedUp <- predict(rfUp, testSet)
confusionMatrix(reference = testSet$Class, data = fittedUp, mode = "everything", positive = "Good")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Bad Good
##       Bad   33   46
##       Good  57  164
##                                           
##                Accuracy : 0.6567          
##                  95% CI : (0.5999, 0.7103)
##     No Information Rate : 0.7             
##     P-Value [Acc > NIR] : 0.9541          
##                                           
##                   Kappa : 0.153           
##                                           
##  Mcnemar's Test P-Value : 0.3245          
##                                           
##             Sensitivity : 0.7810          
##             Specificity : 0.3667          
##          Pos Pred Value : 0.7421          
##          Neg Pred Value : 0.4177          
##               Precision : 0.7421          
##                  Recall : 0.7810          
##                      F1 : 0.7610          
##              Prevalence : 0.7000          
##          Detection Rate : 0.5467          
##    Detection Prevalence : 0.7367          
##       Balanced Accuracy : 0.5738          
##                                           
##        'Positive' Class : Good            
## 


Altering boundaries for classifier thresholds

An entirely out of the box approach for attacking the class imbalance problem is, when we use the predict() function, to change these to probabilities. For example, if the algorithm assigns a probability of 33% or greater that an observation will fall into the negative class, we can classify it as Bad.

fittedProb <- predict(rfTL, testSet, type = "prob")
fittedProb <- fittedProb$Bad
fittedProb <- factor(ifelse(fittedProb >= 0.333, "Bad", "Good"))

Now we will create a confusion matrix:

confusionMatrix(reference = testSet$Class, data = fittedProb, mode = "everything", positive = "Good")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Bad Good
##       Bad   37   62
##       Good  53  148
##                                         
##                Accuracy : 0.6167        
##                  95% CI : (0.559, 0.672)
##     No Information Rate : 0.7           
##     P-Value [Acc > NIR] : 0.9992        
##                                         
##                   Kappa : 0.1127        
##                                         
##  Mcnemar's Test P-Value : 0.4557        
##                                         
##             Sensitivity : 0.7048        
##             Specificity : 0.4111        
##          Pos Pred Value : 0.7363        
##          Neg Pred Value : 0.3737        
##               Precision : 0.7363        
##                  Recall : 0.7048        
##                      F1 : 0.7202        
##              Prevalence : 0.7000        
##          Detection Rate : 0.4933        
##    Detection Prevalence : 0.6700        
##       Balanced Accuracy : 0.5579        
##                                         
##        'Positive' Class : Good          
## 

While this is not the most accurate approach that has been used so far, it is by far the most balanced with respect to Sensitivity and Specificity. This can be thought of somewhat analogously to treating the misclassification of bad creditors as good, to be twice as serious an error as misclassifying good creditors as bad; and in fact, the two errors not being equally bad will often be true when there is class imbalance. Note that this assessment (i.e. one error being 2X or 3X as bad as the other kind) should ALWAYS be informed primarily by knowledge of the underlying domain.


Training and resampling multiple models

Our next order of business will be to evaluate model performance across many different methods, instead of sticking strictly to just the Random Forest! We will compare the Random Forest to two other approaches: “glmnet” (the Elastic Net), and “svmRadial” (Support Vector Machines with a radial kernel). We will define a list of methods we want to use, and create an ensemble of training results using the caretList() functionality from the “caretEnsemble” package.

Specify the training protocol:

methodCtrl <- trainControl(
  method = "repeatedcv",
  number = 5,
  repeats = 5,
  savePredictions = "final",
  classProbs = T,
  summaryFunction = twoClassSummary
)

Specify list of methods:

methodList <- c("rf", "glmnet", "svmRadial")

Train the ensemble of models:

set.seed(355)
ensemble <- caretList(Class ~ ., data = trainingSet, metric = "ROC", trControl = methodCtrl, methodList = methodList)

Now we will compare model performance using the resamples() function:

resampledList <- resamples(ensemble)
summary(resampledList)
## 
## Call:
## summary.resamples(object = resampledList)
## 
## Models: rf, glmnet, svmRadial 
## Number of resamples: 25 
## 
## ROC 
##                Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## rf        0.6014334 0.6443149 0.6700680 0.6797959 0.7247328 0.7852284    0
## glmnet    0.5586735 0.6394558 0.6517250 0.6541108 0.6852527 0.7387026    0
## svmRadial 0.5971817 0.6409135 0.6639942 0.6744121 0.7045675 0.7534014    0
## 
## Sens 
##                 Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## rf        0.07142857 0.0952381 0.1190476 0.1304762 0.1666667 0.2142857    0
## glmnet    0.07142857 0.0952381 0.1190476 0.1190476 0.1428571 0.1904762    0
## svmRadial 0.00000000 0.0952381 0.1428571 0.1457143 0.1666667 0.3095238    0
## 
## Spec 
##                Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## rf        0.9489796 0.9591837 0.9795918 0.9730612 0.9795918 1.0000000    0
## glmnet    0.9285714 0.9489796 0.9693878 0.9669388 0.9795918 1.0000000    0
## svmRadial 0.9183673 0.9387755 0.9591837 0.9567347 0.9693878 0.9897959    0

Because the Random Forest is outperforming the other methods, from the standpoint of ROC, it is probably the preferred approach of these three.

Whew! That’s a lot that we’ve done in this tutorial series. Remember that the machine learning universe is a large one, and you can spend months on a particular classification problem! A good rule of thumb is to think about what the bare minimum type of output you can create in two days or so looks like. There can often be diminshing marginal returns the more time that is paid to these problems, in particular if you don’t know what you’re doing or aren’t working on the right problem. Now, hopefully, you have an understanding for how to use the “caret” package to solve machine learning problems you face in the real world!