9.1 Introduction

If one had to choose a classification technique that performs well across a wide range of situations without requiring much effort from the analyst while being readily understandable by the consumer of the analysis, a strong contender would be the tree methodology developed by Breiman et al. (1984). We discuss this classification procedure first, then in later sections we show how the procedure can be extended to prediction of a numerical outcome. The program that Breiman et al. created to implement these procedures was called CART (Classification And Regression Trees). A related procedure is called C4.5.

What is a classification tree? Figure 9.1 shows a tree for classifying bank customers who receive a loan offer as either acceptors or nonacceptors, based on information such as their income, education level, and average credit card expenditure.

Example 1: Riding Mowers

Running a decision tree using rpart

An Introduction to Recursive Partitioning Using the RPART Routines can be found at: https://cran.r-project.org/web/packages/rpart/vignettes/longintro.pdf

Whereas the plot function vignette (rpart.plot) can be found at: http://www.milbo.org/rpart-plot/prp.pdf

Both contain numerous examples to practice different decision trees and create engaging plots!

library(rpart)
library(rpart.plot)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
# load the data
mower.df <- read.csv("~/Box/Teaching (jmmejia@iu.edu)/2020 - K-513/Public Files/Public Data/RidingMowers.csv")
glimpse(mower.df)
## Observations: 24
## Variables: 3
## $ Income    <dbl> 60.0, 85.5, 64.8, 61.5, 87.0, 110.1, 108.0, 82.8, 69.0, 93.…
## $ Lot_Size  <dbl> 18.4, 16.8, 21.6, 20.8, 23.6, 19.2, 17.6, 22.4, 20.0, 20.8,…
## $ Ownership <fct> Owner, Owner, Owner, Owner, Owner, Owner, Owner, Owner, Own…
#### Figure 9.7

# use rpart() to run a classification tree.
# define rpart.control() in rpart() to determine the depth of the tree.
class.tree <- rpart(Ownership ~ ., data = mower.df, 
                    control = rpart.control(maxdepth = 2, minsplit = 10), method = "class")
class.tree
## n= 24 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 24 12 Nonowner (0.5000000 0.5000000)  
##   2) Income< 59.7 8  1 Nonowner (0.8750000 0.1250000) *
##   3) Income>=59.7 16  5 Owner (0.3125000 0.6875000)  
##     6) Lot_Size< 19.8 9  4 Nonowner (0.5555556 0.4444444) *
##     7) Lot_Size>=19.8 7  0 Owner (0.0000000 1.0000000) *
help(rpart)
## plot tree
# use prp() to plot the tree. You can control plotting parameters such as color, shape, 
# and information displayed (which and where).
prp(class.tree, type = 1, extra = 1, split.font = 1, varlen = -10)  

help(prp)

Example 2: Banks

  1. Get the data loaded and remove some variables
bank.df <- read.csv("~/Box/Teaching (jmmejia@iu.edu)/2020 - K-513/Public Files/Public Data/UniversalBank.csv")
glimpse(bank.df)
## Observations: 5,000
## Variables: 14
## $ ID                 <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,…
## $ Age                <int> 25, 45, 39, 35, 35, 37, 53, 50, 35, 34, 65, 29, 48…
## $ Experience         <int> 1, 19, 15, 9, 8, 13, 27, 24, 10, 9, 39, 5, 23, 32,…
## $ Income             <int> 49, 34, 11, 100, 45, 29, 72, 22, 81, 180, 105, 45,…
## $ ZIP.Code           <int> 91107, 90089, 94720, 94112, 91330, 92121, 91711, 9…
## $ Family             <int> 4, 3, 1, 1, 4, 4, 2, 1, 3, 1, 4, 3, 2, 4, 1, 1, 4,…
## $ CCAvg              <dbl> 1.6, 1.5, 1.0, 2.7, 1.0, 0.4, 1.5, 0.3, 0.6, 8.9, …
## $ Education          <int> 1, 1, 1, 2, 2, 2, 2, 3, 2, 3, 3, 2, 3, 2, 1, 3, 3,…
## $ Mortgage           <int> 0, 0, 0, 0, 0, 155, 0, 0, 104, 0, 0, 0, 0, 0, 0, 0…
## $ Personal.Loan      <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,…
## $ Securities.Account <int> 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0,…
## $ CD.Account         <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ Online             <int> 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0,…
## $ CreditCard         <int> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0,…
# bank.df <- bank.df[ , -c(1, 5)]  # Drop ID and zip code columns.
bank.df <- select(bank.df, -c(ID, ZIP.Code))
glimpse(bank.df)
## Observations: 5,000
## Variables: 12
## $ Age                <int> 25, 45, 39, 35, 35, 37, 53, 50, 35, 34, 65, 29, 48…
## $ Experience         <int> 1, 19, 15, 9, 8, 13, 27, 24, 10, 9, 39, 5, 23, 32,…
## $ Income             <int> 49, 34, 11, 100, 45, 29, 72, 22, 81, 180, 105, 45,…
## $ Family             <int> 4, 3, 1, 1, 4, 4, 2, 1, 3, 1, 4, 3, 2, 4, 1, 1, 4,…
## $ CCAvg              <dbl> 1.6, 1.5, 1.0, 2.7, 1.0, 0.4, 1.5, 0.3, 0.6, 8.9, …
## $ Education          <int> 1, 1, 1, 2, 2, 2, 2, 3, 2, 3, 3, 2, 3, 2, 1, 3, 3,…
## $ Mortgage           <int> 0, 0, 0, 0, 0, 155, 0, 0, 104, 0, 0, 0, 0, 0, 0, 0…
## $ Personal.Loan      <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,…
## $ Securities.Account <int> 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0,…
## $ CD.Account         <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ Online             <int> 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0,…
## $ CreditCard         <int> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0,…
  1. Create partitions
  2. Run classification tree
# partition
set.seed(1)  
train.index <- sample(c(1:dim(bank.df)[1]), dim(bank.df)[1]*0.6)  
train.df <- bank.df[train.index, ]
valid.df <- bank.df[-train.index, ]

# classification tree
default.ct <- rpart(Personal.Loan ~ ., data = train.df, method = "class")
# plot tree
prp(default.ct, type = 1, extra = 1, under = TRUE, split.font = 1, varlen = -10)

#### Table 9.3
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
## Create confusion matrix for training and validation data

# classify records in the validation data.
# set argument type = "class" in predict() to generate predicted class membership.
default.ct.point.pred.train <- predict(default.ct,train.df,type = "class")
# generate confusion matrix for training data
confusionMatrix(default.ct.point.pred.train, as.factor(train.df$Personal.Loan))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 2718   32
##          1    7  243
##                                           
##                Accuracy : 0.987           
##                  95% CI : (0.9823, 0.9907)
##     No Information Rate : 0.9083          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.9186          
##                                           
##  Mcnemar's Test P-Value : 0.0001215       
##                                           
##             Sensitivity : 0.9974          
##             Specificity : 0.8836          
##          Pos Pred Value : 0.9884          
##          Neg Pred Value : 0.9720          
##              Prevalence : 0.9083          
##          Detection Rate : 0.9060          
##    Detection Prevalence : 0.9167          
##       Balanced Accuracy : 0.9405          
##                                           
##        'Positive' Class : 0               
## 
#  Note:                Accuracy : 0.987           


# classify records in the validation data.
# set argument type = "class" in predict() to generate predicted class membership.
default.ct.point.pred.valid <- predict(default.ct,valid.df,type = "class")
# generate confusion matrix for training data
confusionMatrix(default.ct.point.pred.valid, as.factor(valid.df$Personal.Loan))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 1787   32
##          1    8  173
##                                           
##                Accuracy : 0.98            
##                  95% CI : (0.9729, 0.9857)
##     No Information Rate : 0.8975          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.8854          
##                                           
##  Mcnemar's Test P-Value : 0.0002762       
##                                           
##             Sensitivity : 0.9955          
##             Specificity : 0.8439          
##          Pos Pred Value : 0.9824          
##          Neg Pred Value : 0.9558          
##              Prevalence : 0.8975          
##          Detection Rate : 0.8935          
##    Detection Prevalence : 0.9095          
##       Balanced Accuracy : 0.9197          
##                                           
##        'Positive' Class : 0               
## 
# Note:                Accuracy : 0.9815          

Using Cross Validation

#### Table 9.4

# argument xval refers to the number of folds to use in rpart's built-in
# cross-validation procedure
# argument cp sets the smallest value for the complexity parameter.
cv.ct <- rpart(Personal.Loan ~ ., data = train.df, method = "class", 
               cp = 0.00001, minsplit = 5, xval = 5)
# use printcp() to print the table. 
printcp(cv.ct)
## 
## Classification tree:
## rpart(formula = Personal.Loan ~ ., data = train.df, method = "class", 
##     cp = 1e-05, minsplit = 5, xval = 5)
## 
## Variables actually used in tree construction:
## [1] Age        CCAvg      CD.Account Education  Family     Income     Online    
## 
## Root node error: 275/3000 = 0.091667
## 
## n= 3000 
## 
##           CP nsplit rel error  xerror     xstd
## 1  0.3218182      0  1.000000 1.00000 0.057472
## 2  0.1454545      2  0.356364 0.45818 0.039952
## 3  0.0181818      3  0.210909 0.21818 0.027884
## 4  0.0169697      4  0.192727 0.21091 0.027425
## 5  0.0090909      7  0.141818 0.18182 0.025498
## 6  0.0072727      9  0.123636 0.16000 0.023943
## 7  0.0048485     12  0.101818 0.14545 0.022845
## 8  0.0036364     15  0.087273 0.17818 0.025246
## 9  0.0024242     22  0.061818 0.18182 0.025498
## 10 0.0018182     25  0.054545 0.19273 0.026238
## 11 0.0000100     27  0.050909 0.19273 0.026238

Using Pruning

#### Figure 9.12

# prune by lower cp
help(prune.rpart)
pruned.ct <- prune(cv.ct, 
                   cp = cv.ct$cptable[which.min(cv.ct$cptable[,"xerror"]),"CP"])
length(pruned.ct$frame$var[pruned.ct$frame$var == "<leaf>"])
## [1] 13
prp(pruned.ct, type = 1, extra = 1, split.font = 1, varlen = -10)  

#### Figure 9.13

set.seed(1)
cv.ct <- rpart(Personal.Loan ~ ., data = train.df, method = "class", cp = 0.00001, minsplit = 1, xval = 5)  # minsplit is the minimum number of observations in a node for a split to be attempted. xval is number K of folds in a K-fold cross-validation.
printcp(cv.ct)  # Print out the cp table of cross-validation errors. The R-squared for a regression tree is 1 minus rel error. xerror (or relative cross-validation error where "x" stands for "cross") is a scaled version of overall average of the 5 out-of-sample errors across the 5 folds.
## 
## Classification tree:
## rpart(formula = Personal.Loan ~ ., data = train.df, method = "class", 
##     cp = 1e-05, minsplit = 1, xval = 5)
## 
## Variables actually used in tree construction:
## [1] Age        CCAvg      CD.Account Education  Experience Family     Income    
## [8] Mortgage   Online    
## 
## Root node error: 275/3000 = 0.091667
## 
## n= 3000 
## 
##           CP nsplit rel error  xerror     xstd
## 1  0.3218182      0  1.000000 1.00000 0.057472
## 2  0.1454545      2  0.356364 0.38182 0.036604
## 3  0.0181818      3  0.210909 0.22182 0.028111
## 4  0.0169697      4  0.192727 0.20000 0.026720
## 5  0.0090909      7  0.141818 0.20364 0.026957
## 6  0.0072727      9  0.123636 0.19273 0.026238
## 7  0.0048485     12  0.101818 0.18545 0.025747
## 8  0.0036364     15  0.087273 0.19273 0.026238
## 9  0.0024242     29  0.036364 0.20000 0.026720
## 10 0.0018182     36  0.018182 0.19273 0.026238
## 11 0.0000100     46  0.000000 0.19273 0.026238
pruned.ct <- prune(cv.ct, cp = 0.0154639)
prp(pruned.ct, type = 1, extra = 1, under = TRUE, split.font = 1, varlen = -10, 
    box.col=ifelse(pruned.ct$frame$var == "<leaf>", 'gray', 'white')) 

Trying a random forest

 library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
## The following object is masked from 'package:dplyr':
## 
##     combine
## random forest
rf <- randomForest(as.factor(Personal.Loan) ~ ., data = train.df, ntree = 500, 
    mtry = 4, nodesize = 5, importance = TRUE)  

## variable importance plot
varImpPlot(rf, type = 1)

## confusion matrix
rf.pred <- predict(rf, valid.df)
confusionMatrix(rf.pred, factor(valid.df$Personal.Loan))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 1791   28
##          1    4  177
##                                          
##                Accuracy : 0.984          
##                  95% CI : (0.9775, 0.989)
##     No Information Rate : 0.8975         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.9083         
##                                          
##  Mcnemar's Test P-Value : 4.785e-05      
##                                          
##             Sensitivity : 0.9978         
##             Specificity : 0.8634         
##          Pos Pred Value : 0.9846         
##          Neg Pred Value : 0.9779         
##              Prevalence : 0.8975         
##          Detection Rate : 0.8955         
##    Detection Prevalence : 0.9095         
##       Balanced Accuracy : 0.9306         
##                                          
##        'Positive' Class : 0              
##