This is an example exercise taken from the book “An Introduction to Statistical Learning with Applications in R” by Gareth James, Deniela Witten, Trever Hatie.
Here we will fit classification trees. First, we will load the necessary packages and the Carseats dataset.
require(tree)
require(ISLR)
We load the dataset as well.
data(Carseats)
head(Carseats)
## Sales CompPrice Income Advertising Population Price ShelveLoc Age
## 1 9.50 138 73 11 276 120 Bad 42
## 2 11.22 111 48 16 260 83 Good 65
## 3 10.06 113 35 10 269 80 Medium 59
## 4 7.40 117 100 4 466 97 Medium 55
## 5 4.15 141 64 3 340 128 Bad 38
## 6 10.81 124 113 13 501 72 Bad 78
## Education Urban US
## 1 17 Yes Yes
## 2 10 Yes Yes
## 3 12 Yes Yes
## 4 14 Yes Yes
## 5 13 Yes No
## 6 16 No Yes
str(Carseats)
## 'data.frame': 400 obs. of 11 variables:
## $ Sales : num 9.5 11.22 10.06 7.4 4.15 ...
## $ CompPrice : num 138 111 113 117 141 124 115 136 132 132 ...
## $ Income : num 73 48 35 100 64 113 105 81 110 113 ...
## $ Advertising: num 11 16 10 4 3 13 0 15 0 0 ...
## $ Population : num 276 260 269 466 340 501 45 425 108 131 ...
## $ Price : num 120 83 80 97 128 72 108 120 124 124 ...
## $ ShelveLoc : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
## $ Age : num 42 65 59 55 38 78 71 67 76 76 ...
## $ Education : num 17 10 12 14 13 16 15 10 10 17 ...
## $ Urban : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
## $ US : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
Since Sales is a continuous variable, so we begin by recoding it as a binary variable.
attach(Carseats)
High <- ifelse(Sales <= 8, "No", "Yes")
Carseats <- data.frame(Carseats, High)
head(Carseats)
## Sales CompPrice Income Advertising Population Price ShelveLoc Age
## 1 9.50 138 73 11 276 120 Bad 42
## 2 11.22 111 48 16 260 83 Good 65
## 3 10.06 113 35 10 269 80 Medium 59
## 4 7.40 117 100 4 466 97 Medium 55
## 5 4.15 141 64 3 340 128 Bad 38
## 6 10.81 124 113 13 501 72 Bad 78
## Education Urban US High
## 1 17 Yes Yes Yes
## 2 10 Yes Yes Yes
## 3 12 Yes Yes Yes
## 4 14 Yes Yes No
## 5 13 Yes No No
## 6 16 No Yes Yes
We now use the tree() function to fit a classification tree in order to predict High using all the variables but Sales.
tree.carseats <- tree(High ~. -Sales, Carseats)
summary(tree.carseats)
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
## [6] "Advertising" "Age" "US"
## Number of terminal nodes: 27
## Residual mean deviance: 0.458 = 171 / 373
## Misclassification error rate: 0.09 = 36 / 400
We will plot the tree structure.
plot(tree.carseats)
text(tree.carseats, pretty = 0)
It’s a little crowded looking tree. But the most important indicator of Sales appears to be shelving location, since the first branch differentiates Good locations from Bad and Medium locations.
If we just type the name of the tree object, R prints output corresponding to each branch of the tree. R displays the split criterion (e.g. Price < 92.5), the number of the observations in that branch, the deviance, the overall prediction for the branch (Yes or No), and the fraction of observations in that branch that take on values of Yes and No. Branches that lead to terminal nodes are indicated using asterisks.
tree.carseats
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 400 500 No ( 0.59 0.41 )
## 2) ShelveLoc: Bad,Medium 315 400 No ( 0.69 0.31 )
## 4) Price < 92.5 46 60 Yes ( 0.30 0.70 )
## 8) Income < 57 10 10 No ( 0.70 0.30 )
## 16) CompPrice < 110.5 5 0 No ( 1.00 0.00 ) *
## 17) CompPrice > 110.5 5 7 Yes ( 0.40 0.60 ) *
## 9) Income > 57 36 40 Yes ( 0.19 0.81 )
## 18) Population < 207.5 16 20 Yes ( 0.38 0.62 ) *
## 19) Population > 207.5 20 8 Yes ( 0.05 0.95 ) *
## 5) Price > 92.5 269 300 No ( 0.75 0.25 )
## 10) Advertising < 13.5 224 200 No ( 0.82 0.18 )
## 20) CompPrice < 124.5 96 40 No ( 0.94 0.06 )
## 40) Price < 106.5 38 30 No ( 0.84 0.16 )
## 80) Population < 177 12 20 No ( 0.58 0.42 )
## 160) Income < 60.5 6 0 No ( 1.00 0.00 ) *
## 161) Income > 60.5 6 5 Yes ( 0.17 0.83 ) *
## 81) Population > 177 26 8 No ( 0.96 0.04 ) *
## 41) Price > 106.5 58 0 No ( 1.00 0.00 ) *
## 21) CompPrice > 124.5 128 200 No ( 0.73 0.27 )
## 42) Price < 122.5 51 70 Yes ( 0.49 0.51 )
## 84) ShelveLoc: Bad 11 7 No ( 0.91 0.09 ) *
## 85) ShelveLoc: Medium 40 50 Yes ( 0.38 0.62 )
## 170) Price < 109.5 16 7 Yes ( 0.06 0.94 ) *
## 171) Price > 109.5 24 30 No ( 0.58 0.42 )
## 342) Age < 49.5 13 20 Yes ( 0.31 0.69 ) *
## 343) Age > 49.5 11 7 No ( 0.91 0.09 ) *
## 43) Price > 122.5 77 60 No ( 0.88 0.12 )
## 86) CompPrice < 147.5 58 20 No ( 0.97 0.03 ) *
## 87) CompPrice > 147.5 19 30 No ( 0.63 0.37 )
## 174) Price < 147 12 20 Yes ( 0.42 0.58 )
## 348) CompPrice < 152.5 7 6 Yes ( 0.14 0.86 ) *
## 349) CompPrice > 152.5 5 5 No ( 0.80 0.20 ) *
## 175) Price > 147 7 0 No ( 1.00 0.00 ) *
## 11) Advertising > 13.5 45 60 Yes ( 0.44 0.56 )
## 22) Age < 54.5 25 30 Yes ( 0.20 0.80 )
## 44) CompPrice < 130.5 14 20 Yes ( 0.36 0.64 )
## 88) Income < 100 9 10 No ( 0.56 0.44 ) *
## 89) Income > 100 5 0 Yes ( 0.00 1.00 ) *
## 45) CompPrice > 130.5 11 0 Yes ( 0.00 1.00 ) *
## 23) Age > 54.5 20 20 No ( 0.75 0.25 )
## 46) CompPrice < 122.5 10 0 No ( 1.00 0.00 ) *
## 47) CompPrice > 122.5 10 10 No ( 0.50 0.50 )
## 94) Price < 125 5 0 Yes ( 0.00 1.00 ) *
## 95) Price > 125 5 0 No ( 1.00 0.00 ) *
## 3) ShelveLoc: Good 85 90 Yes ( 0.22 0.78 )
## 6) Price < 135 68 50 Yes ( 0.12 0.88 )
## 12) US: No 17 20 Yes ( 0.35 0.65 )
## 24) Price < 109 8 0 Yes ( 0.00 1.00 ) *
## 25) Price > 109 9 10 No ( 0.67 0.33 ) *
## 13) US: Yes 51 20 Yes ( 0.04 0.96 ) *
## 7) Price > 135 17 20 No ( 0.65 0.35 )
## 14) Income < 46 6 0 No ( 1.00 0.00 ) *
## 15) Income > 46 11 20 Yes ( 0.45 0.55 ) *
In order to properly evluate the performance of a classification tree on the these data, we must estimate the test error rather than simply computing the training error. We split the observations into a training set and a test set, build the tree using the training set, and evaluate its performance on the test data.
set.seed(2)
train <- sample(1:nrow(Carseats), 200)
carseats.test <- Carseats[-train, ]
High.test <- High[-train]
tree.carseats.train <- tree(High ~. -Sales, Carseats, subset = train)
tree.pred <- predict(tree.carseats.train, carseats.test, type = "class")
table(tree.pred, High.test)
## High.test
## tree.pred No Yes
## No 86 27
## Yes 30 57
(86 + 57)/200
## [1] 0.715
This approach leads to correct predictions for around 71.5% of the locations in the test data set.
Next, we consider whether pruning the tree might lead to improved the results.
set.seed(3)
cv.carseats <- cv.tree(tree.carseats.train, FUN = prune.misclass)
names(cv.carseats)
## [1] "size" "dev" "k" "method"
cv.carseats
## $size
## [1] 19 17 14 13 9 7 3 2 1
##
## $dev
## [1] 55 55 53 52 50 56 69 65 80
##
## $k
## [1] -Inf 0.0000 0.6667 1.0000 1.7500 2.0000 4.2500 5.0000 23.0000
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
We plot the error rate as a function of both size and k.
par(mfrow = c(1, 2))
plot(cv.carseats$size, cv.carseats$dev, type = "b")
plot(cv.carseats$k, cv.carseats$dev, type = "b")
We now apply the prune.misclass() function in order to prune the tree to obtain the nine-node tree.
prune.carseats <- prune.misclass(tree.carseats.train, best = 9)
plot(prune.carseats)
text(prune.carseats, pretty = 0)
How well does this pruned tree perform on the test data set? Once again, we apply the predict() function.
tree.pred1 <- predict(prune.carseats, carseats.test, type = "class")
table(tree.pred1, High.test)
## High.test
## tree.pred1 No Yes
## No 94 24
## Yes 22 60
(94 + 60)/200
## [1] 0.77
Now 77% of the test observations are correctly classified, so not only has the pruning process produced a more interpretable tree, but it has also improved the classification accuracy.