This is an example exercise taken from the book “An Introduction to Statistical Learning with Applications in R” by Gareth James, Deniela Witten, Trever Hatie.

Here we will fit classification trees. First, we will load the necessary packages and the Carseats dataset.

require(tree)
require(ISLR)

We load the dataset as well.

data(Carseats)
head(Carseats)
##   Sales CompPrice Income Advertising Population Price ShelveLoc Age
## 1  9.50       138     73          11        276   120       Bad  42
## 2 11.22       111     48          16        260    83      Good  65
## 3 10.06       113     35          10        269    80    Medium  59
## 4  7.40       117    100           4        466    97    Medium  55
## 5  4.15       141     64           3        340   128       Bad  38
## 6 10.81       124    113          13        501    72       Bad  78
##   Education Urban  US
## 1        17   Yes Yes
## 2        10   Yes Yes
## 3        12   Yes Yes
## 4        14   Yes Yes
## 5        13   Yes  No
## 6        16    No Yes
str(Carseats)
## 'data.frame':    400 obs. of  11 variables:
##  $ Sales      : num  9.5 11.22 10.06 7.4 4.15 ...
##  $ CompPrice  : num  138 111 113 117 141 124 115 136 132 132 ...
##  $ Income     : num  73 48 35 100 64 113 105 81 110 113 ...
##  $ Advertising: num  11 16 10 4 3 13 0 15 0 0 ...
##  $ Population : num  276 260 269 466 340 501 45 425 108 131 ...
##  $ Price      : num  120 83 80 97 128 72 108 120 124 124 ...
##  $ ShelveLoc  : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
##  $ Age        : num  42 65 59 55 38 78 71 67 76 76 ...
##  $ Education  : num  17 10 12 14 13 16 15 10 10 17 ...
##  $ Urban      : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
##  $ US         : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...

Since Sales is a continuous variable, so we begin by recoding it as a binary variable.

attach(Carseats)
High <- ifelse(Sales <= 8, "No", "Yes")
Carseats <- data.frame(Carseats, High)
head(Carseats)
##   Sales CompPrice Income Advertising Population Price ShelveLoc Age
## 1  9.50       138     73          11        276   120       Bad  42
## 2 11.22       111     48          16        260    83      Good  65
## 3 10.06       113     35          10        269    80    Medium  59
## 4  7.40       117    100           4        466    97    Medium  55
## 5  4.15       141     64           3        340   128       Bad  38
## 6 10.81       124    113          13        501    72       Bad  78
##   Education Urban  US High
## 1        17   Yes Yes  Yes
## 2        10   Yes Yes  Yes
## 3        12   Yes Yes  Yes
## 4        14   Yes Yes   No
## 5        13   Yes  No   No
## 6        16    No Yes  Yes

We now use the tree() function to fit a classification tree in order to predict High using all the variables but Sales.

tree.carseats <- tree(High ~. -Sales, Carseats)
summary(tree.carseats)
## 
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc"   "Price"       "Income"      "CompPrice"   "Population" 
## [6] "Advertising" "Age"         "US"         
## Number of terminal nodes:  27 
## Residual mean deviance:  0.458 = 171 / 373 
## Misclassification error rate: 0.09 = 36 / 400

We will plot the tree structure.

plot(tree.carseats)
text(tree.carseats, pretty = 0)

plot of chunk unnamed-chunk-5

It’s a little crowded looking tree. But the most important indicator of Sales appears to be shelving location, since the first branch differentiates Good locations from Bad and Medium locations.
If we just type the name of the tree object, R prints output corresponding to each branch of the tree. R displays the split criterion (e.g. Price < 92.5), the number of the observations in that branch, the deviance, the overall prediction for the branch (Yes or No), and the fraction of observations in that branch that take on values of Yes and No. Branches that lead to terminal nodes are indicated using asterisks.

tree.carseats
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##   1) root 400 500 No ( 0.59 0.41 )  
##     2) ShelveLoc: Bad,Medium 315 400 No ( 0.69 0.31 )  
##       4) Price < 92.5 46  60 Yes ( 0.30 0.70 )  
##         8) Income < 57 10  10 No ( 0.70 0.30 )  
##          16) CompPrice < 110.5 5   0 No ( 1.00 0.00 ) *
##          17) CompPrice > 110.5 5   7 Yes ( 0.40 0.60 ) *
##         9) Income > 57 36  40 Yes ( 0.19 0.81 )  
##          18) Population < 207.5 16  20 Yes ( 0.38 0.62 ) *
##          19) Population > 207.5 20   8 Yes ( 0.05 0.95 ) *
##       5) Price > 92.5 269 300 No ( 0.75 0.25 )  
##        10) Advertising < 13.5 224 200 No ( 0.82 0.18 )  
##          20) CompPrice < 124.5 96  40 No ( 0.94 0.06 )  
##            40) Price < 106.5 38  30 No ( 0.84 0.16 )  
##              80) Population < 177 12  20 No ( 0.58 0.42 )  
##               160) Income < 60.5 6   0 No ( 1.00 0.00 ) *
##               161) Income > 60.5 6   5 Yes ( 0.17 0.83 ) *
##              81) Population > 177 26   8 No ( 0.96 0.04 ) *
##            41) Price > 106.5 58   0 No ( 1.00 0.00 ) *
##          21) CompPrice > 124.5 128 200 No ( 0.73 0.27 )  
##            42) Price < 122.5 51  70 Yes ( 0.49 0.51 )  
##              84) ShelveLoc: Bad 11   7 No ( 0.91 0.09 ) *
##              85) ShelveLoc: Medium 40  50 Yes ( 0.38 0.62 )  
##               170) Price < 109.5 16   7 Yes ( 0.06 0.94 ) *
##               171) Price > 109.5 24  30 No ( 0.58 0.42 )  
##                 342) Age < 49.5 13  20 Yes ( 0.31 0.69 ) *
##                 343) Age > 49.5 11   7 No ( 0.91 0.09 ) *
##            43) Price > 122.5 77  60 No ( 0.88 0.12 )  
##              86) CompPrice < 147.5 58  20 No ( 0.97 0.03 ) *
##              87) CompPrice > 147.5 19  30 No ( 0.63 0.37 )  
##               174) Price < 147 12  20 Yes ( 0.42 0.58 )  
##                 348) CompPrice < 152.5 7   6 Yes ( 0.14 0.86 ) *
##                 349) CompPrice > 152.5 5   5 No ( 0.80 0.20 ) *
##               175) Price > 147 7   0 No ( 1.00 0.00 ) *
##        11) Advertising > 13.5 45  60 Yes ( 0.44 0.56 )  
##          22) Age < 54.5 25  30 Yes ( 0.20 0.80 )  
##            44) CompPrice < 130.5 14  20 Yes ( 0.36 0.64 )  
##              88) Income < 100 9  10 No ( 0.56 0.44 ) *
##              89) Income > 100 5   0 Yes ( 0.00 1.00 ) *
##            45) CompPrice > 130.5 11   0 Yes ( 0.00 1.00 ) *
##          23) Age > 54.5 20  20 No ( 0.75 0.25 )  
##            46) CompPrice < 122.5 10   0 No ( 1.00 0.00 ) *
##            47) CompPrice > 122.5 10  10 No ( 0.50 0.50 )  
##              94) Price < 125 5   0 Yes ( 0.00 1.00 ) *
##              95) Price > 125 5   0 No ( 1.00 0.00 ) *
##     3) ShelveLoc: Good 85  90 Yes ( 0.22 0.78 )  
##       6) Price < 135 68  50 Yes ( 0.12 0.88 )  
##        12) US: No 17  20 Yes ( 0.35 0.65 )  
##          24) Price < 109 8   0 Yes ( 0.00 1.00 ) *
##          25) Price > 109 9  10 No ( 0.67 0.33 ) *
##        13) US: Yes 51  20 Yes ( 0.04 0.96 ) *
##       7) Price > 135 17  20 No ( 0.65 0.35 )  
##        14) Income < 46 6   0 No ( 1.00 0.00 ) *
##        15) Income > 46 11  20 Yes ( 0.45 0.55 ) *

In order to properly evluate the performance of a classification tree on the these data, we must estimate the test error rather than simply computing the training error. We split the observations into a training set and a test set, build the tree using the training set, and evaluate its performance on the test data.

set.seed(2)
train <- sample(1:nrow(Carseats), 200)
carseats.test <- Carseats[-train, ]
High.test <- High[-train]
tree.carseats.train <- tree(High ~. -Sales, Carseats, subset = train)
tree.pred <- predict(tree.carseats.train, carseats.test, type = "class")
table(tree.pred, High.test)
##          High.test
## tree.pred No Yes
##       No  86  27
##       Yes 30  57
(86 + 57)/200
## [1] 0.715

This approach leads to correct predictions for around 71.5% of the locations in the test data set.
Next, we consider whether pruning the tree might lead to improved the results.

set.seed(3)
cv.carseats <- cv.tree(tree.carseats.train, FUN = prune.misclass)
names(cv.carseats)
## [1] "size"   "dev"    "k"      "method"
cv.carseats
## $size
## [1] 19 17 14 13  9  7  3  2  1
## 
## $dev
## [1] 55 55 53 52 50 56 69 65 80
## 
## $k
## [1]    -Inf  0.0000  0.6667  1.0000  1.7500  2.0000  4.2500  5.0000 23.0000
## 
## $method
## [1] "misclass"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

We plot the error rate as a function of both size and k.

par(mfrow = c(1, 2))
plot(cv.carseats$size, cv.carseats$dev, type = "b")
plot(cv.carseats$k, cv.carseats$dev, type = "b")

plot of chunk unnamed-chunk-9

We now apply the prune.misclass() function in order to prune the tree to obtain the nine-node tree.

prune.carseats <- prune.misclass(tree.carseats.train, best = 9)
plot(prune.carseats)
text(prune.carseats, pretty = 0)

plot of chunk unnamed-chunk-10

How well does this pruned tree perform on the test data set? Once again, we apply the predict() function.

tree.pred1 <- predict(prune.carseats, carseats.test, type = "class")
table(tree.pred1, High.test)
##           High.test
## tree.pred1 No Yes
##        No  94  24
##        Yes 22  60
(94 + 60)/200
## [1] 0.77

Now 77% of the test observations are correctly classified, so not only has the pruning process produced a more interpretable tree, but it has also improved the classification accuracy.