ISLR Chapter 8: Tree-Based Methods

Problem 3.

Consider the Gini index, classification error, and entropy in a simple classification setting with two classes. Create a single plot that displays each of these quantities as a function of \(\hat{p}_{m1}\). The x-axis should display \(\hat{p}_{m1}\), ranging from 0 to 1, and the y-axis should display the value of the Gini index, classification error, and entropy.
Hint: In a setting with two classes, \(\hat{p}_{m1} = 1 − \hat{p}_{m2}\). You could make this plot by hand, but it will be much easier to make in R.

p <- seq(0, 1, 0.01)
class.err = 1 - pmax(p, 1 - p)
gini.index = 2 * p * (1-p)
entropy = - (p * log(p) + (1 - p) * log(1 - p))
matplot(p, cbind(class.err, gini.index, entropy), col = c("red", "blue", "green"), main = "Classification Tree Measures")
legend("bottom",pch=16, title = "Measures", col=c("red", "blue", "green"), legend=c("Classification Error", "Gini Index", "Entropy"), box.lty = 1)

Problem 8.

In the lab, a classification tree was applied to the Carseats data set after converting Sales into a qualitative response variable. Now we will seek to predict Sales using regression trees and related approaches, treating the response as a quantitative variable.

library(ISLR)
attach(Carseats)
summary(Carseats)
##      Sales          CompPrice       Income        Advertising    
##  Min.   : 0.000   Min.   : 77   Min.   : 21.00   Min.   : 0.000  
##  1st Qu.: 5.390   1st Qu.:115   1st Qu.: 42.75   1st Qu.: 0.000  
##  Median : 7.490   Median :125   Median : 69.00   Median : 5.000  
##  Mean   : 7.496   Mean   :125   Mean   : 68.66   Mean   : 6.635  
##  3rd Qu.: 9.320   3rd Qu.:135   3rd Qu.: 91.00   3rd Qu.:12.000  
##  Max.   :16.270   Max.   :175   Max.   :120.00   Max.   :29.000  
##    Population        Price        ShelveLoc        Age          Education   
##  Min.   : 10.0   Min.   : 24.0   Bad   : 96   Min.   :25.00   Min.   :10.0  
##  1st Qu.:139.0   1st Qu.:100.0   Good  : 85   1st Qu.:39.75   1st Qu.:12.0  
##  Median :272.0   Median :117.0   Medium:219   Median :54.50   Median :14.0  
##  Mean   :264.8   Mean   :115.8                Mean   :53.32   Mean   :13.9  
##  3rd Qu.:398.5   3rd Qu.:131.0                3rd Qu.:66.00   3rd Qu.:16.0  
##  Max.   :509.0   Max.   :191.0                Max.   :80.00   Max.   :18.0  
##  Urban       US     
##  No :118   No :142  
##  Yes:282   Yes:258  
##                     
##                     
##                     
## 
str(Carseats)
## 'data.frame':    400 obs. of  11 variables:
##  $ Sales      : num  9.5 11.22 10.06 7.4 4.15 ...
##  $ CompPrice  : num  138 111 113 117 141 124 115 136 132 132 ...
##  $ Income     : num  73 48 35 100 64 113 105 81 110 113 ...
##  $ Advertising: num  11 16 10 4 3 13 0 15 0 0 ...
##  $ Population : num  276 260 269 466 340 501 45 425 108 131 ...
##  $ Price      : num  120 83 80 97 128 72 108 120 124 124 ...
##  $ ShelveLoc  : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
##  $ Age        : num  42 65 59 55 38 78 71 67 76 76 ...
##  $ Education  : num  17 10 12 14 13 16 15 10 10 17 ...
##  $ Urban      : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
##  $ US         : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...

(a) Split the data set into a training set and a test set.

set.seed(1)

train <- sample(dim(Carseats)[1], dim(Carseats)[1]/2)
Carseats.train <- Carseats[train, ]
Carseats.test <- Carseats[-train, ]

(b) Fit a regression tree to the training set. Plot the tree, and interpret the results. What test MSE do you obtain?

library(tree)
carseats.tree <- tree(Sales ~ ., data = Carseats.train)
summary(carseats.tree)
## 
## Regression tree:
## tree(formula = Sales ~ ., data = Carseats.train)
## Variables actually used in tree construction:
## [1] "ShelveLoc"   "Price"       "Age"         "Advertising" "CompPrice"  
## [6] "US"         
## Number of terminal nodes:  18 
## Residual mean deviance:  2.167 = 394.3 / 182 
## Distribution of residuals:
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## -3.88200 -0.88200 -0.08712  0.00000  0.89590  4.09900
plot(carseats.tree)
text(carseats.tree, pretty = 0, cex =.65)

carseats.pred <- predict(carseats.tree, Carseats.test)
mean((carseats.pred - Carseats.test$Sales)^2)
## [1] 4.922039

Based on the regression tree model, the test MSE is 4.922039.

(c) Use cross-validation in order to determine the optimal level of tree complexity. Does pruning the tree improve the test MSE?

carseats.cv <- cv.tree(carseats.tree)
carseats.cv
## $size
##  [1] 18 17 16 15 14 13 12 11 10  8  7  6  5  4  3  2  1
## 
## $dev
##  [1]  831.3437  852.3639  868.6815  862.3400  862.3400  893.4641  911.2580
##  [8]  950.2691  955.2535 1039.1241 1066.6899 1125.0894 1205.5806 1273.2889
## [15] 1302.8607 1349.9273 1620.4687
## 
## $k
##  [1]      -Inf  16.99544  20.56322  25.01730  25.57104  28.01938  30.36962
##  [8]  31.56747  31.80816  40.75445  44.44673  52.57126  76.21881  99.59459
## [15] 116.69889 159.79501 337.60153
## 
## $method
## [1] "deviance"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"
plot(carseats.cv$size, carseats.cv$dev, type = 'b')

carseats.prune <-prune.tree(carseats.tree, best = 14)
plot(carseats.prune)
text(carseats.prune, pretty = 0, cex = .7)

pred.prune <- predict(carseats.prune, Carseats.test)
mean((Carseats.test$Sales - pred.prune)^2)
## [1] 5.013738

Based on the results above, pruning the tree increased the Test MSE to 5.013738.

(d) Use the bagging approach in order to analyze this data. What test MSE do you obtain? Use the importance() function to determine which variables are most important.

library(randomForest)
carseats.bag <- randomForest(Sales ~ ., data = Carseats.train, mtry = 10, ntree = 500, importance = T)
carseats.bag
## 
## Call:
##  randomForest(formula = Sales ~ ., data = Carseats.train, mtry = 10,      ntree = 500, importance = T) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 10
## 
##           Mean of squared residuals: 2.931324
##                     % Var explained: 62.72
pred.bag <- predict(carseats.bag, Carseats.test)
mse.bag <- mean((Carseats.test$Sales - pred.bag)^2)
mse.bag
## [1] 2.657296

Based on the bagging approach the MSE is decreased to 2.657296.

importance(carseats.bag)
##                 %IncMSE IncNodePurity
## CompPrice   23.07909904    171.185734
## Income       2.82081527     94.079825
## Advertising 11.43295625     99.098941
## Population  -3.92119532     59.818905
## Price       54.24314632    505.887016
## ShelveLoc   46.26912996    361.962753
## Age         14.24992212    159.740422
## Education   -0.07662320     46.738585
## Urban        0.08530119      8.453749
## US           4.34349223     15.157608

CompPrice, Price, and ShelveLoc are most important.

(e) Use random forests to analyze this data. What test MSE do you obtain? Use the importance() function to determine which variables are most important. Describe the effect of m, the number of variables considered at each split, on the error rate obtained.

carseats.rf <- randomForest(Sales ~ ., data = Carseats.train, mtry = 7, ntree = 500, importance = T)
carseats.rf
## 
## Call:
##  randomForest(formula = Sales ~ ., data = Carseats.train, mtry = 7,      ntree = 500, importance = T) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 7
## 
##           Mean of squared residuals: 2.886721
##                     % Var explained: 63.29
pred.rf <- predict(carseats.rf, Carseats.test)
mse.rf <- mean((Carseats.test$Sales - pred.rf)^2)
mse.rf
## [1] 2.624779

Based on the bagging approach the MSE is decreased to 2.624779.

importance(carseats.rf)
##                %IncMSE IncNodePurity
## CompPrice   22.7283830     166.31362
## Income       4.5009562      98.97324
## Advertising 12.8607660     103.27827
## Population   0.7406765      68.03609
## Price       51.5001132     479.69045
## ShelveLoc   44.1661519     368.90657
## Age         15.6557207     160.73614
## Education   -0.6632689      47.46421
## Urban        1.2130655       9.89366
## US           5.1049489      21.26560

CompPrice, Price, and ShelveLoc remain most important.

detach(Carseats)

Problem 9.

This problem involves the OJ data set which is part of the ISLR package.

attach(OJ)
summary(OJ)
##  Purchase WeekofPurchase     StoreID        PriceCH         PriceMM     
##  CH:653   Min.   :227.0   Min.   :1.00   Min.   :1.690   Min.   :1.690  
##  MM:417   1st Qu.:240.0   1st Qu.:2.00   1st Qu.:1.790   1st Qu.:1.990  
##           Median :257.0   Median :3.00   Median :1.860   Median :2.090  
##           Mean   :254.4   Mean   :3.96   Mean   :1.867   Mean   :2.085  
##           3rd Qu.:268.0   3rd Qu.:7.00   3rd Qu.:1.990   3rd Qu.:2.180  
##           Max.   :278.0   Max.   :7.00   Max.   :2.090   Max.   :2.290  
##      DiscCH            DiscMM         SpecialCH        SpecialMM     
##  Min.   :0.00000   Min.   :0.0000   Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:0.00000   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:0.0000  
##  Median :0.00000   Median :0.0000   Median :0.0000   Median :0.0000  
##  Mean   :0.05186   Mean   :0.1234   Mean   :0.1477   Mean   :0.1617  
##  3rd Qu.:0.00000   3rd Qu.:0.2300   3rd Qu.:0.0000   3rd Qu.:0.0000  
##  Max.   :0.50000   Max.   :0.8000   Max.   :1.0000   Max.   :1.0000  
##     LoyalCH          SalePriceMM     SalePriceCH      PriceDiff       Store7   
##  Min.   :0.000011   Min.   :1.190   Min.   :1.390   Min.   :-0.6700   No :714  
##  1st Qu.:0.325257   1st Qu.:1.690   1st Qu.:1.750   1st Qu.: 0.0000   Yes:356  
##  Median :0.600000   Median :2.090   Median :1.860   Median : 0.2300            
##  Mean   :0.565782   Mean   :1.962   Mean   :1.816   Mean   : 0.1465            
##  3rd Qu.:0.850873   3rd Qu.:2.130   3rd Qu.:1.890   3rd Qu.: 0.3200            
##  Max.   :0.999947   Max.   :2.290   Max.   :2.090   Max.   : 0.6400            
##    PctDiscMM        PctDiscCH       ListPriceDiff       STORE      
##  Min.   :0.0000   Min.   :0.00000   Min.   :0.000   Min.   :0.000  
##  1st Qu.:0.0000   1st Qu.:0.00000   1st Qu.:0.140   1st Qu.:0.000  
##  Median :0.0000   Median :0.00000   Median :0.240   Median :2.000  
##  Mean   :0.0593   Mean   :0.02731   Mean   :0.218   Mean   :1.631  
##  3rd Qu.:0.1127   3rd Qu.:0.00000   3rd Qu.:0.300   3rd Qu.:3.000  
##  Max.   :0.4020   Max.   :0.25269   Max.   :0.440   Max.   :4.000
str(OJ)
## 'data.frame':    1070 obs. of  18 variables:
##  $ Purchase      : Factor w/ 2 levels "CH","MM": 1 1 1 2 1 1 1 1 1 1 ...
##  $ WeekofPurchase: num  237 239 245 227 228 230 232 234 235 238 ...
##  $ StoreID       : num  1 1 1 1 7 7 7 7 7 7 ...
##  $ PriceCH       : num  1.75 1.75 1.86 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
##  $ PriceMM       : num  1.99 1.99 2.09 1.69 1.69 1.99 1.99 1.99 1.99 1.99 ...
##  $ DiscCH        : num  0 0 0.17 0 0 0 0 0 0 0 ...
##  $ DiscMM        : num  0 0.3 0 0 0 0 0.4 0.4 0.4 0.4 ...
##  $ SpecialCH     : num  0 0 0 0 0 0 1 1 0 0 ...
##  $ SpecialMM     : num  0 1 0 0 0 1 1 0 0 0 ...
##  $ LoyalCH       : num  0.5 0.6 0.68 0.4 0.957 ...
##  $ SalePriceMM   : num  1.99 1.69 2.09 1.69 1.69 1.99 1.59 1.59 1.59 1.59 ...
##  $ SalePriceCH   : num  1.75 1.75 1.69 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
##  $ PriceDiff     : num  0.24 -0.06 0.4 0 0 0.3 -0.1 -0.16 -0.16 -0.16 ...
##  $ Store7        : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 2 2 2 2 2 ...
##  $ PctDiscMM     : num  0 0.151 0 0 0 ...
##  $ PctDiscCH     : num  0 0 0.0914 0 0 ...
##  $ ListPriceDiff : num  0.24 0.24 0.23 0 0 0.3 0.3 0.24 0.24 0.24 ...
##  $ STORE         : num  1 1 1 1 0 0 0 0 0 0 ...

(a) Create a training set containing a random sample of 800 observations, and a test set containing the remaining observations.

set.seed(1)

train <- sample(dim(OJ)[1], 800)
oj.train <- OJ[train, ]
oj.test <- OJ[-train, ]

(b) Fit a tree to the training data, with Purchase as the response and the other variables as predictors. Use the summary() function to produce summary statistics about the tree, and describe the results obtained. What is the training error rate? How many terminal nodes does the tree have?

tree.oj = tree(Purchase ~ ., data = oj.train)
summary(tree.oj)
## 
## Classification tree:
## tree(formula = Purchase ~ ., data = oj.train)
## Variables actually used in tree construction:
## [1] "LoyalCH"       "PriceDiff"     "SpecialCH"     "ListPriceDiff"
## [5] "PctDiscMM"    
## Number of terminal nodes:  9 
## Residual mean deviance:  0.7432 = 587.8 / 791 
## Misclassification error rate: 0.1588 = 127 / 800

Based on the regression tree, only five variables were used, LoyalCH, PriceDiff, SpecialCH, ListPriceDiff, and PctDiscMM. It has 9 terminal nodes, and the training error rate or misclassification error is 0.1588.

(c) Type in the name of the tree object in order to get a detailed text output. Pick one of the terminal nodes, and interpret the information displayed.

tree.oj
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 800 1073.00 CH ( 0.60625 0.39375 )  
##    2) LoyalCH < 0.5036 365  441.60 MM ( 0.29315 0.70685 )  
##      4) LoyalCH < 0.280875 177  140.50 MM ( 0.13559 0.86441 )  
##        8) LoyalCH < 0.0356415 59   10.14 MM ( 0.01695 0.98305 ) *
##        9) LoyalCH > 0.0356415 118  116.40 MM ( 0.19492 0.80508 ) *
##      5) LoyalCH > 0.280875 188  258.00 MM ( 0.44149 0.55851 )  
##       10) PriceDiff < 0.05 79   84.79 MM ( 0.22785 0.77215 )  
##         20) SpecialCH < 0.5 64   51.98 MM ( 0.14062 0.85938 ) *
##         21) SpecialCH > 0.5 15   20.19 CH ( 0.60000 0.40000 ) *
##       11) PriceDiff > 0.05 109  147.00 CH ( 0.59633 0.40367 ) *
##    3) LoyalCH > 0.5036 435  337.90 CH ( 0.86897 0.13103 )  
##      6) LoyalCH < 0.764572 174  201.00 CH ( 0.73563 0.26437 )  
##       12) ListPriceDiff < 0.235 72   99.81 MM ( 0.50000 0.50000 )  
##         24) PctDiscMM < 0.196197 55   73.14 CH ( 0.61818 0.38182 ) *
##         25) PctDiscMM > 0.196197 17   12.32 MM ( 0.11765 0.88235 ) *
##       13) ListPriceDiff > 0.235 102   65.43 CH ( 0.90196 0.09804 ) *
##      7) LoyalCH > 0.764572 261   91.20 CH ( 0.95785 0.04215 ) *

Based on the output, terminal node 8 for the splitting variable LoyalCH has a splitting value of 0.0356 and 59 observations.

(d) Create a plot of the tree, and interpret the results.

plot(tree.oj)
text(tree.oj, pretty = 0, cex = .7)

Based on the plot, the most important variable is LoyalCH. If LoyalCH is less than 0.28, the tree predicts MM, for values greater than 0.76 the tree predicts CH, and for values of LoyalCH between those, the prediction determination will consider four other variables’ values.

(e) Predict the response on the test data, and produce a confusion matrix comparing the test labels to the predicted test labels. What is the test error rate?

pred.oj <- predict(tree.oj, oj.test, type = 'class')
table(oj.test$Purchase, pred.oj)
##     pred.oj
##       CH  MM
##   CH 160   8
##   MM  38  64

Based on the confusion matrix, the test error rate is .1704.

(f) Apply the cv.tree() function to the training set in order to determine the optimal tree size.

cv.oj <- cv.tree(tree.oj, FUN = prune.misclass)
cv.oj
## $size
## [1] 9 8 7 4 2 1
## 
## $dev
## [1] 150 150 149 158 172 315
## 
## $k
## [1]       -Inf   0.000000   3.000000   4.333333  10.500000 151.000000
## 
## $method
## [1] "misclass"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

(g) Produce a plot with tree size on the x-axis and cross-validated classification error rate on the y-axis.

plot(cv.oj$size, cv.oj$dev, type = 'b', xlab = "Tree Size", ylab = "Cross-Validated Error Rate")

(h) Which tree size corresponds to the lowest cross-validated classification error rate?
Based on the plot, the tree size with 7 terminal nodes appears to have the lowest cross-validation error rate.

(i) Produce a pruned tree corresponding to the optimal tree size obtained using cross-validation. If cross-validation does not lead to selection of a pruned tree, then create a pruned tree with five terminal nodes.

oj.prune <- prune.misclass(tree.oj, best = 7)
summary(oj.prune)
## 
## Classification tree:
## snip.tree(tree = tree.oj, nodes = c(4L, 10L))
## Variables actually used in tree construction:
## [1] "LoyalCH"       "PriceDiff"     "ListPriceDiff" "PctDiscMM"    
## Number of terminal nodes:  7 
## Residual mean deviance:  0.7748 = 614.4 / 793 
## Misclassification error rate: 0.1625 = 130 / 800
plot(oj.prune)
text(oj.prune, pretty = 0, cex = .7)

(j) Compare the training error rates between the pruned and unpruned trees. Which is higher?

summary(tree.oj)
## 
## Classification tree:
## tree(formula = Purchase ~ ., data = oj.train)
## Variables actually used in tree construction:
## [1] "LoyalCH"       "PriceDiff"     "SpecialCH"     "ListPriceDiff"
## [5] "PctDiscMM"    
## Number of terminal nodes:  9 
## Residual mean deviance:  0.7432 = 587.8 / 791 
## Misclassification error rate: 0.1588 = 127 / 800
summary(oj.prune)
## 
## Classification tree:
## snip.tree(tree = tree.oj, nodes = c(4L, 10L))
## Variables actually used in tree construction:
## [1] "LoyalCH"       "PriceDiff"     "ListPriceDiff" "PctDiscMM"    
## Number of terminal nodes:  7 
## Residual mean deviance:  0.7748 = 614.4 / 793 
## Misclassification error rate: 0.1625 = 130 / 800

Based on the results of the pruned and unpruned tree, the unpruned tree has a slightly lower training error rate of 0.1588 with 9 terminal nodes and 5 variables in the tree.

(k) Compare the test error rates between the pruned and unpruned trees. Which is higher?

oj.prune.pred = predict(oj.prune, newdata = oj.test, type = "class")
table(oj.prune.pred, oj.test$Purchase)
##              
## oj.prune.pred  CH  MM
##            CH 160  36
##            MM   8  66

Based on the table, this test error rate of the pruned tree is 0.163 which is slightly lower than the unpruned tree.

detach(OJ)