9.1 Introduction

If one had to choose a classification technique that performs well across a wide range of situations without requiring much effort from the analyst while being readily understandable by the consumer of the analysis, a strong contender would be the tree methodology developed by Breiman et al. (1984). We discuss this classification procedure first, then in later sections we show how the procedure can be extended to prediction of a numerical outcome. The program that Breiman et al. created to implement these procedures was called CART (Classification And Regression Trees). A related procedure is called C4.5.

What is a classification tree? Figure 9.1 shows a tree for classifying bank customers who receive a loan offer as either acceptors or nonacceptors, based on information such as their income, education level, and average credit card expenditure.

Example 1: Riding Mowers

Running a decision tree using rpart

An Introduction to Recursive Partitioning Using the RPART Routines can be found at: https://cran.r-project.org/web/packages/rpart/vignettes/longintro.pdf

Whereas the plot function vignette (rpart.plot) can be found at: http://www.milbo.org/rpart-plot/prp.pdf

Both contain numerous examples to practice different decision trees and create engaging plots!

library(rpart)
library(rpart.plot)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
# load the data
mower.df <- read.csv("~/Downloads/RidingMowers.csv")
glimpse(mower.df)
## Observations: 24
## Variables: 3
## $ Income    <dbl> 60.0, 85.5, 64.8, 61.5, 87.0, 110.1, 108.0, 82.8, 69.0…
## $ Lot_Size  <dbl> 18.4, 16.8, 21.6, 20.8, 23.6, 19.2, 17.6, 22.4, 20.0, …
## $ Ownership <fct> Owner, Owner, Owner, Owner, Owner, Owner, Owner, Owner…

This is an R Markdown document. Markdown is a simple formatting syntax for authoring HTML, PDF, and MS Word documents. For more details on using R Markdown see http://rmarkdown.rstudio.com.

When you click the Knit button a document will be generated that includes both content as well as the output of any embedded R code chunks within the document. You can embed an R code chunk like this:

#### Figure 9.7

# use rpart() to run a classification tree.
# define rpart.control() in rpart() to determine the depth of the tree.
class.tree <- rpart(Ownership ~ ., data = mower.df, 
                    control = rpart.control(maxdepth = 2, minsplit = 10), method = "class")
class.tree
## n= 24 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 24 12 Nonowner (0.5000000 0.5000000)  
##   2) Income< 59.7 8  1 Nonowner (0.8750000 0.1250000) *
##   3) Income>=59.7 16  5 Owner (0.3125000 0.6875000)  
##     6) Lot_Size< 19.8 9  4 Nonowner (0.5555556 0.4444444) *
##     7) Lot_Size>=19.8 7  0 Owner (0.0000000 1.0000000) *
help(rpart)
## plot tree
# use prp() to plot the tree. You can control plotting parameters such as color, shape, 
# and information displayed (which and where).
prp(class.tree, type = 1, extra = 1, split.font = 1, varlen = -10)  

help(prp)

Example 2: Banks

  1. Get the data loaded and remove some variables
bank.df <- read.csv("~/Downloads/UniversalBank.csv")
glimpse(bank.df)
## Observations: 5,000
## Variables: 14
## $ ID                 <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14…
## $ Age                <int> 25, 45, 39, 35, 35, 37, 53, 50, 35, 34, 65, 2…
## $ Experience         <int> 1, 19, 15, 9, 8, 13, 27, 24, 10, 9, 39, 5, 23…
## $ Income             <int> 49, 34, 11, 100, 45, 29, 72, 22, 81, 180, 105…
## $ ZIP.Code           <int> 91107, 90089, 94720, 94112, 91330, 92121, 917…
## $ Family             <int> 4, 3, 1, 1, 4, 4, 2, 1, 3, 1, 4, 3, 2, 4, 1, …
## $ CCAvg              <dbl> 1.6, 1.5, 1.0, 2.7, 1.0, 0.4, 1.5, 0.3, 0.6, …
## $ Education          <int> 1, 1, 1, 2, 2, 2, 2, 3, 2, 3, 3, 2, 3, 2, 1, …
## $ Mortgage           <int> 0, 0, 0, 0, 0, 155, 0, 0, 104, 0, 0, 0, 0, 0,…
## $ Personal.Loan      <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ Securities.Account <int> 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, …
## $ CD.Account         <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ Online             <int> 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, …
## $ CreditCard         <int> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
# bank.df <- bank.df[ , -c(1, 5)]  # Drop ID and zip code columns.
bank.df <- select(bank.df, -c(ID, ZIP.Code))
glimpse(bank.df)
## Observations: 5,000
## Variables: 12
## $ Age                <int> 25, 45, 39, 35, 35, 37, 53, 50, 35, 34, 65, 2…
## $ Experience         <int> 1, 19, 15, 9, 8, 13, 27, 24, 10, 9, 39, 5, 23…
## $ Income             <int> 49, 34, 11, 100, 45, 29, 72, 22, 81, 180, 105…
## $ Family             <int> 4, 3, 1, 1, 4, 4, 2, 1, 3, 1, 4, 3, 2, 4, 1, …
## $ CCAvg              <dbl> 1.6, 1.5, 1.0, 2.7, 1.0, 0.4, 1.5, 0.3, 0.6, …
## $ Education          <int> 1, 1, 1, 2, 2, 2, 2, 3, 2, 3, 3, 2, 3, 2, 1, …
## $ Mortgage           <int> 0, 0, 0, 0, 0, 155, 0, 0, 104, 0, 0, 0, 0, 0,…
## $ Personal.Loan      <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ Securities.Account <int> 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, …
## $ CD.Account         <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ Online             <int> 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, …
## $ CreditCard         <int> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
  1. Create partitions
  2. Run classification tree
# partition
set.seed(1)  
train.index <- sample(c(1:dim(bank.df)[1]), dim(bank.df)[1]*0.6)  
train.df <- bank.df[train.index, ]
valid.df <- bank.df[-train.index, ]

# classification tree
default.ct <- rpart(Personal.Loan ~ ., data = train.df, method = "class")
# plot tree
prp(default.ct, type = 1, extra = 1, under = TRUE, split.font = 1, varlen = -10)

#### Table 9.3
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
## Registered S3 methods overwritten by 'ggplot2':
##   method         from 
##   [.quosures     rlang
##   c.quosures     rlang
##   print.quosures rlang
## Create confusion matrix for training and validation data

# classify records in the validation data.
# set argument type = "class" in predict() to generate predicted class membership.
default.ct.point.pred.train <- predict(default.ct,train.df,type = "class")
# generate confusion matrix for training data
confusionMatrix(default.ct.point.pred.train, as.factor(train.df$Personal.Loan))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 2718   32
##          1    7  243
##                                           
##                Accuracy : 0.987           
##                  95% CI : (0.9823, 0.9907)
##     No Information Rate : 0.9083          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.9186          
##                                           
##  Mcnemar's Test P-Value : 0.0001215       
##                                           
##             Sensitivity : 0.9974          
##             Specificity : 0.8836          
##          Pos Pred Value : 0.9884          
##          Neg Pred Value : 0.9720          
##              Prevalence : 0.9083          
##          Detection Rate : 0.9060          
##    Detection Prevalence : 0.9167          
##       Balanced Accuracy : 0.9405          
##                                           
##        'Positive' Class : 0               
## 
#  Note:                Accuracy : 0.987           


# classify records in the validation data.
# set argument type = "class" in predict() to generate predicted class membership.
default.ct.point.pred.valid <- predict(default.ct,valid.df,type = "class")
# generate confusion matrix for training data
confusionMatrix(default.ct.point.pred.valid, as.factor(valid.df$Personal.Loan))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 1787   32
##          1    8  173
##                                           
##                Accuracy : 0.98            
##                  95% CI : (0.9729, 0.9857)
##     No Information Rate : 0.8975          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.8854          
##                                           
##  Mcnemar's Test P-Value : 0.0002762       
##                                           
##             Sensitivity : 0.9955          
##             Specificity : 0.8439          
##          Pos Pred Value : 0.9824          
##          Neg Pred Value : 0.9558          
##              Prevalence : 0.8975          
##          Detection Rate : 0.8935          
##    Detection Prevalence : 0.9095          
##       Balanced Accuracy : 0.9197          
##                                           
##        'Positive' Class : 0               
## 
# Note:                Accuracy : 0.9815          

Using Cross Validation

#### Table 9.4

# argument xval refers to the number of folds to use in rpart's built-in
# cross-validation procedure
# argument cp sets the smallest value for the complexity parameter.
cv.ct <- rpart(Personal.Loan ~ ., data = train.df, method = "class", 
               cp = 0.00001, minsplit = 5, xval = 5)
# use printcp() to print the table. 
printcp(cv.ct)
## 
## Classification tree:
## rpart(formula = Personal.Loan ~ ., data = train.df, method = "class", 
##     cp = 1e-05, minsplit = 5, xval = 5)
## 
## Variables actually used in tree construction:
## [1] Age        CCAvg      CD.Account Education  Family     Income    
## [7] Online    
## 
## Root node error: 275/3000 = 0.091667
## 
## n= 3000 
## 
##           CP nsplit rel error  xerror     xstd
## 1  0.3218182      0  1.000000 1.00000 0.057472
## 2  0.1454545      2  0.356364 0.45818 0.039952
## 3  0.0181818      3  0.210909 0.21818 0.027884
## 4  0.0169697      4  0.192727 0.21091 0.027425
## 5  0.0090909      7  0.141818 0.18182 0.025498
## 6  0.0072727      9  0.123636 0.16000 0.023943
## 7  0.0048485     12  0.101818 0.14545 0.022845
## 8  0.0036364     15  0.087273 0.17818 0.025246
## 9  0.0024242     22  0.061818 0.18182 0.025498
## 10 0.0018182     25  0.054545 0.19273 0.026238
## 11 0.0000100     27  0.050909 0.19273 0.026238

Using Pruning

#### Figure 9.12

# prune by lower cp
pruned.ct <- prune(cv.ct, 
                   cp = cv.ct$cptable[which.min(cv.ct$cptable[,"xerror"]),"CP"])
length(pruned.ct$frame$var[pruned.ct$frame$var == "<leaf>"])
## [1] 13
prp(pruned.ct, type = 1, extra = 1, split.font = 1, varlen = -10)  

#### Figure 9.13

set.seed(1)
cv.ct <- rpart(Personal.Loan ~ ., data = train.df, method = "class", cp = 0.00001, minsplit = 1, xval = 5)  # minsplit is the minimum number of observations in a node for a split to be attempted. xval is number K of folds in a K-fold cross-validation.
printcp(cv.ct)  # Print out the cp table of cross-validation errors. The R-squared for a regression tree is 1 minus rel error. xerror (or relative cross-validation error where "x" stands for "cross") is a scaled version of overall average of the 5 out-of-sample errors across the 5 folds.
## 
## Classification tree:
## rpart(formula = Personal.Loan ~ ., data = train.df, method = "class", 
##     cp = 1e-05, minsplit = 1, xval = 5)
## 
## Variables actually used in tree construction:
## [1] Age        CCAvg      CD.Account Education  Experience Family    
## [7] Income     Mortgage   Online    
## 
## Root node error: 275/3000 = 0.091667
## 
## n= 3000 
## 
##           CP nsplit rel error  xerror     xstd
## 1  0.3218182      0  1.000000 1.00000 0.057472
## 2  0.1454545      2  0.356364 0.38182 0.036604
## 3  0.0181818      3  0.210909 0.22182 0.028111
## 4  0.0169697      4  0.192727 0.20000 0.026720
## 5  0.0090909      7  0.141818 0.20364 0.026957
## 6  0.0072727      9  0.123636 0.19273 0.026238
## 7  0.0048485     12  0.101818 0.18545 0.025747
## 8  0.0036364     15  0.087273 0.19273 0.026238
## 9  0.0024242     29  0.036364 0.20000 0.026720
## 10 0.0018182     36  0.018182 0.19273 0.026238
## 11 0.0000100     46  0.000000 0.19273 0.026238
pruned.ct <- prune(cv.ct, cp = 0.0154639)
prp(pruned.ct, type = 1, extra = 1, under = TRUE, split.font = 1, varlen = -10, 
    box.col=ifelse(pruned.ct$frame$var == "<leaf>", 'gray', 'white'))