If one had to choose a classification technique that performs well across a wide range of situations without requiring much effort from the analyst while being readily understandable by the consumer of the analysis, a strong contender would be the tree methodology developed by Breiman et al. (1984). We discuss this classification procedure first, then in later sections we show how the procedure can be extended to prediction of a numerical outcome. The program that Breiman et al. created to implement these procedures was called CART (Classification And Regression Trees). A related procedure is called C4.5.
What is a classification tree? Figure 9.1 shows a tree for classifying bank customers who receive a loan offer as either acceptors or nonacceptors, based on information such as their income, education level, and average credit card expenditure.
An Introduction to Recursive Partitioning Using the RPART Routines can be found at: https://cran.r-project.org/web/packages/rpart/vignettes/longintro.pdf
Whereas the plot function vignette (rpart.plot) can be found at: http://www.milbo.org/rpart-plot/prp.pdf
Both contain numerous examples to practice different decision trees and create engaging plots!
library(rpart)
library(rpart.plot)
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
# load the data
mower.df <- read.csv("~/Box/Teaching (jmmejia@iu.edu)/2020 - K-513/Public Files/Public Data/RidingMowers.csv")
glimpse(mower.df)
## Observations: 24
## Variables: 3
## $ Income <dbl> 60.0, 85.5, 64.8, 61.5, 87.0, 110.1, 108.0, 82.8, 69.0, 93.…
## $ Lot_Size <dbl> 18.4, 16.8, 21.6, 20.8, 23.6, 19.2, 17.6, 22.4, 20.0, 20.8,…
## $ Ownership <fct> Owner, Owner, Owner, Owner, Owner, Owner, Owner, Owner, Own…
#### Figure 9.7
# use rpart() to run a classification tree.
# define rpart.control() in rpart() to determine the depth of the tree.
class.tree <- rpart(Ownership ~ ., data = mower.df,
control = rpart.control(maxdepth = 2, minsplit = 10), method = "class")
class.tree
## n= 24
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 24 12 Nonowner (0.5000000 0.5000000)
## 2) Income< 59.7 8 1 Nonowner (0.8750000 0.1250000) *
## 3) Income>=59.7 16 5 Owner (0.3125000 0.6875000)
## 6) Lot_Size< 19.8 9 4 Nonowner (0.5555556 0.4444444) *
## 7) Lot_Size>=19.8 7 0 Owner (0.0000000 1.0000000) *
help(rpart)
## plot tree
# use prp() to plot the tree. You can control plotting parameters such as color, shape,
# and information displayed (which and where).
prp(class.tree, type = 1, extra = 1, split.font = 1, varlen = -10)
help(prp)
bank.df <- read.csv("~/Box/Teaching (jmmejia@iu.edu)/2020 - K-513/Public Files/Public Data/UniversalBank.csv")
glimpse(bank.df)
## Observations: 5,000
## Variables: 14
## $ ID <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,…
## $ Age <int> 25, 45, 39, 35, 35, 37, 53, 50, 35, 34, 65, 29, 48…
## $ Experience <int> 1, 19, 15, 9, 8, 13, 27, 24, 10, 9, 39, 5, 23, 32,…
## $ Income <int> 49, 34, 11, 100, 45, 29, 72, 22, 81, 180, 105, 45,…
## $ ZIP.Code <int> 91107, 90089, 94720, 94112, 91330, 92121, 91711, 9…
## $ Family <int> 4, 3, 1, 1, 4, 4, 2, 1, 3, 1, 4, 3, 2, 4, 1, 1, 4,…
## $ CCAvg <dbl> 1.6, 1.5, 1.0, 2.7, 1.0, 0.4, 1.5, 0.3, 0.6, 8.9, …
## $ Education <int> 1, 1, 1, 2, 2, 2, 2, 3, 2, 3, 3, 2, 3, 2, 1, 3, 3,…
## $ Mortgage <int> 0, 0, 0, 0, 0, 155, 0, 0, 104, 0, 0, 0, 0, 0, 0, 0…
## $ Personal.Loan <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,…
## $ Securities.Account <int> 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0,…
## $ CD.Account <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ Online <int> 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0,…
## $ CreditCard <int> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0,…
# bank.df <- bank.df[ , -c(1, 5)] # Drop ID and zip code columns.
bank.df <- select(bank.df, -c(ID, ZIP.Code))
glimpse(bank.df)
## Observations: 5,000
## Variables: 12
## $ Age <int> 25, 45, 39, 35, 35, 37, 53, 50, 35, 34, 65, 29, 48…
## $ Experience <int> 1, 19, 15, 9, 8, 13, 27, 24, 10, 9, 39, 5, 23, 32,…
## $ Income <int> 49, 34, 11, 100, 45, 29, 72, 22, 81, 180, 105, 45,…
## $ Family <int> 4, 3, 1, 1, 4, 4, 2, 1, 3, 1, 4, 3, 2, 4, 1, 1, 4,…
## $ CCAvg <dbl> 1.6, 1.5, 1.0, 2.7, 1.0, 0.4, 1.5, 0.3, 0.6, 8.9, …
## $ Education <int> 1, 1, 1, 2, 2, 2, 2, 3, 2, 3, 3, 2, 3, 2, 1, 3, 3,…
## $ Mortgage <int> 0, 0, 0, 0, 0, 155, 0, 0, 104, 0, 0, 0, 0, 0, 0, 0…
## $ Personal.Loan <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,…
## $ Securities.Account <int> 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0,…
## $ CD.Account <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ Online <int> 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0,…
## $ CreditCard <int> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0,…
# partition
set.seed(1)
train.index <- sample(c(1:dim(bank.df)[1]), dim(bank.df)[1]*0.6)
train.df <- bank.df[train.index, ]
valid.df <- bank.df[-train.index, ]
# classification tree
default.ct <- rpart(Personal.Loan ~ ., data = train.df, method = "class")
# plot tree
prp(default.ct, type = 1, extra = 1, under = TRUE, split.font = 1, varlen = -10)
#### Table 9.3
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
## Create confusion matrix for training and validation data
# classify records in the validation data.
# set argument type = "class" in predict() to generate predicted class membership.
default.ct.point.pred.train <- predict(default.ct,train.df,type = "class")
# generate confusion matrix for training data
confusionMatrix(default.ct.point.pred.train, as.factor(train.df$Personal.Loan))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 2718 32
## 1 7 243
##
## Accuracy : 0.987
## 95% CI : (0.9823, 0.9907)
## No Information Rate : 0.9083
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.9186
##
## Mcnemar's Test P-Value : 0.0001215
##
## Sensitivity : 0.9974
## Specificity : 0.8836
## Pos Pred Value : 0.9884
## Neg Pred Value : 0.9720
## Prevalence : 0.9083
## Detection Rate : 0.9060
## Detection Prevalence : 0.9167
## Balanced Accuracy : 0.9405
##
## 'Positive' Class : 0
##
# Note: Accuracy : 0.987
# classify records in the validation data.
# set argument type = "class" in predict() to generate predicted class membership.
default.ct.point.pred.valid <- predict(default.ct,valid.df,type = "class")
# generate confusion matrix for training data
confusionMatrix(default.ct.point.pred.valid, as.factor(valid.df$Personal.Loan))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 1787 32
## 1 8 173
##
## Accuracy : 0.98
## 95% CI : (0.9729, 0.9857)
## No Information Rate : 0.8975
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.8854
##
## Mcnemar's Test P-Value : 0.0002762
##
## Sensitivity : 0.9955
## Specificity : 0.8439
## Pos Pred Value : 0.9824
## Neg Pred Value : 0.9558
## Prevalence : 0.8975
## Detection Rate : 0.8935
## Detection Prevalence : 0.9095
## Balanced Accuracy : 0.9197
##
## 'Positive' Class : 0
##
# Note: Accuracy : 0.9815
#### Table 9.4
# argument xval refers to the number of folds to use in rpart's built-in
# cross-validation procedure
# argument cp sets the smallest value for the complexity parameter.
cv.ct <- rpart(Personal.Loan ~ ., data = train.df, method = "class",
cp = 0.00001, minsplit = 5, xval = 5)
# use printcp() to print the table.
printcp(cv.ct)
##
## Classification tree:
## rpart(formula = Personal.Loan ~ ., data = train.df, method = "class",
## cp = 1e-05, minsplit = 5, xval = 5)
##
## Variables actually used in tree construction:
## [1] Age CCAvg CD.Account Education Family Income Online
##
## Root node error: 275/3000 = 0.091667
##
## n= 3000
##
## CP nsplit rel error xerror xstd
## 1 0.3218182 0 1.000000 1.00000 0.057472
## 2 0.1454545 2 0.356364 0.45818 0.039952
## 3 0.0181818 3 0.210909 0.21818 0.027884
## 4 0.0169697 4 0.192727 0.21091 0.027425
## 5 0.0090909 7 0.141818 0.18182 0.025498
## 6 0.0072727 9 0.123636 0.16000 0.023943
## 7 0.0048485 12 0.101818 0.14545 0.022845
## 8 0.0036364 15 0.087273 0.17818 0.025246
## 9 0.0024242 22 0.061818 0.18182 0.025498
## 10 0.0018182 25 0.054545 0.19273 0.026238
## 11 0.0000100 27 0.050909 0.19273 0.026238
#### Figure 9.12
# prune by lower cp
help(prune.rpart)
pruned.ct <- prune(cv.ct,
cp = cv.ct$cptable[which.min(cv.ct$cptable[,"xerror"]),"CP"])
length(pruned.ct$frame$var[pruned.ct$frame$var == "<leaf>"])
## [1] 13
prp(pruned.ct, type = 1, extra = 1, split.font = 1, varlen = -10)
#### Figure 9.13
set.seed(1)
cv.ct <- rpart(Personal.Loan ~ ., data = train.df, method = "class", cp = 0.00001, minsplit = 1, xval = 5) # minsplit is the minimum number of observations in a node for a split to be attempted. xval is number K of folds in a K-fold cross-validation.
printcp(cv.ct) # Print out the cp table of cross-validation errors. The R-squared for a regression tree is 1 minus rel error. xerror (or relative cross-validation error where "x" stands for "cross") is a scaled version of overall average of the 5 out-of-sample errors across the 5 folds.
##
## Classification tree:
## rpart(formula = Personal.Loan ~ ., data = train.df, method = "class",
## cp = 1e-05, minsplit = 1, xval = 5)
##
## Variables actually used in tree construction:
## [1] Age CCAvg CD.Account Education Experience Family Income
## [8] Mortgage Online
##
## Root node error: 275/3000 = 0.091667
##
## n= 3000
##
## CP nsplit rel error xerror xstd
## 1 0.3218182 0 1.000000 1.00000 0.057472
## 2 0.1454545 2 0.356364 0.38182 0.036604
## 3 0.0181818 3 0.210909 0.22182 0.028111
## 4 0.0169697 4 0.192727 0.20000 0.026720
## 5 0.0090909 7 0.141818 0.20364 0.026957
## 6 0.0072727 9 0.123636 0.19273 0.026238
## 7 0.0048485 12 0.101818 0.18545 0.025747
## 8 0.0036364 15 0.087273 0.19273 0.026238
## 9 0.0024242 29 0.036364 0.20000 0.026720
## 10 0.0018182 36 0.018182 0.19273 0.026238
## 11 0.0000100 46 0.000000 0.19273 0.026238
pruned.ct <- prune(cv.ct, cp = 0.0154639)
prp(pruned.ct, type = 1, extra = 1, under = TRUE, split.font = 1, varlen = -10,
box.col=ifelse(pruned.ct$frame$var == "<leaf>", 'gray', 'white'))
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
## The following object is masked from 'package:dplyr':
##
## combine
## random forest
rf <- randomForest(as.factor(Personal.Loan) ~ ., data = train.df, ntree = 500,
mtry = 4, nodesize = 5, importance = TRUE)
## variable importance plot
varImpPlot(rf, type = 1)
## confusion matrix
rf.pred <- predict(rf, valid.df)
confusionMatrix(rf.pred, factor(valid.df$Personal.Loan))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 1791 28
## 1 4 177
##
## Accuracy : 0.984
## 95% CI : (0.9775, 0.989)
## No Information Rate : 0.8975
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.9083
##
## Mcnemar's Test P-Value : 4.785e-05
##
## Sensitivity : 0.9978
## Specificity : 0.8634
## Pos Pred Value : 0.9846
## Neg Pred Value : 0.9779
## Prevalence : 0.8975
## Detection Rate : 0.8955
## Detection Prevalence : 0.9095
## Balanced Accuracy : 0.9306
##
## 'Positive' Class : 0
##