Consider the Gini index, classification error, and entropy in a simple classification setting with two classes. Create a single plot that displays each of these quantities as a function of \(\hat{p}_{m1}\). The x-axis should display \(\hat{p}_{m1}\), ranging from 0 to 1, and the y-axis should display the value of the Gini index, classification error, and entropy.
Hint: In a setting with two classes, \(\hat{p}_{m1} = 1 − \hat{p}_{m2}\). You could make this plot by hand, but it will be much easier to make in R.
p <- seq(0, 1, 0.01)
class.err = 1 - pmax(p, 1 - p)
gini.index = 2 * p * (1-p)
entropy = - (p * log(p) + (1 - p) * log(1 - p))
matplot(p, cbind(class.err, gini.index, entropy), col = c("red", "blue", "green"), main = "Classification Tree Measures")
legend("bottom",pch=16, title = "Measures", col=c("red", "blue", "green"), legend=c("Classification Error", "Gini Index", "Entropy"), box.lty = 1)
In the lab, a classification tree was applied to the Carseats data set after converting Sales into a qualitative response variable. Now we will seek to predict Sales using regression trees and related approaches, treating the response as a quantitative variable.
library(ISLR)
attach(Carseats)
summary(Carseats)
## Sales CompPrice Income Advertising
## Min. : 0.000 Min. : 77 Min. : 21.00 Min. : 0.000
## 1st Qu.: 5.390 1st Qu.:115 1st Qu.: 42.75 1st Qu.: 0.000
## Median : 7.490 Median :125 Median : 69.00 Median : 5.000
## Mean : 7.496 Mean :125 Mean : 68.66 Mean : 6.635
## 3rd Qu.: 9.320 3rd Qu.:135 3rd Qu.: 91.00 3rd Qu.:12.000
## Max. :16.270 Max. :175 Max. :120.00 Max. :29.000
## Population Price ShelveLoc Age Education
## Min. : 10.0 Min. : 24.0 Bad : 96 Min. :25.00 Min. :10.0
## 1st Qu.:139.0 1st Qu.:100.0 Good : 85 1st Qu.:39.75 1st Qu.:12.0
## Median :272.0 Median :117.0 Medium:219 Median :54.50 Median :14.0
## Mean :264.8 Mean :115.8 Mean :53.32 Mean :13.9
## 3rd Qu.:398.5 3rd Qu.:131.0 3rd Qu.:66.00 3rd Qu.:16.0
## Max. :509.0 Max. :191.0 Max. :80.00 Max. :18.0
## Urban US
## No :118 No :142
## Yes:282 Yes:258
##
##
##
##
str(Carseats)
## 'data.frame': 400 obs. of 11 variables:
## $ Sales : num 9.5 11.22 10.06 7.4 4.15 ...
## $ CompPrice : num 138 111 113 117 141 124 115 136 132 132 ...
## $ Income : num 73 48 35 100 64 113 105 81 110 113 ...
## $ Advertising: num 11 16 10 4 3 13 0 15 0 0 ...
## $ Population : num 276 260 269 466 340 501 45 425 108 131 ...
## $ Price : num 120 83 80 97 128 72 108 120 124 124 ...
## $ ShelveLoc : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
## $ Age : num 42 65 59 55 38 78 71 67 76 76 ...
## $ Education : num 17 10 12 14 13 16 15 10 10 17 ...
## $ Urban : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
## $ US : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
(a) Split the data set into a training set and a test set.
set.seed(1)
train <- sample(dim(Carseats)[1], dim(Carseats)[1]/2)
Carseats.train <- Carseats[train, ]
Carseats.test <- Carseats[-train, ]
(b) Fit a regression tree to the training set. Plot the tree, and interpret the results. What test MSE do you obtain?
library(tree)
carseats.tree <- tree(Sales ~ ., data = Carseats.train)
summary(carseats.tree)
##
## Regression tree:
## tree(formula = Sales ~ ., data = Carseats.train)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Age" "Advertising" "CompPrice"
## [6] "US"
## Number of terminal nodes: 18
## Residual mean deviance: 2.167 = 394.3 / 182
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -3.88200 -0.88200 -0.08712 0.00000 0.89590 4.09900
plot(carseats.tree)
text(carseats.tree, pretty = 0, cex =.65)
carseats.pred <- predict(carseats.tree, Carseats.test)
mean((carseats.pred - Carseats.test$Sales)^2)
## [1] 4.922039
Based on the regression tree model, the test MSE is 4.922039.
(c) Use cross-validation in order to determine the optimal level of tree complexity. Does pruning the tree improve the test MSE?
carseats.cv <- cv.tree(carseats.tree)
carseats.cv
## $size
## [1] 18 17 16 15 14 13 12 11 10 8 7 6 5 4 3 2 1
##
## $dev
## [1] 831.3437 852.3639 868.6815 862.3400 862.3400 893.4641 911.2580
## [8] 950.2691 955.2535 1039.1241 1066.6899 1125.0894 1205.5806 1273.2889
## [15] 1302.8607 1349.9273 1620.4687
##
## $k
## [1] -Inf 16.99544 20.56322 25.01730 25.57104 28.01938 30.36962
## [8] 31.56747 31.80816 40.75445 44.44673 52.57126 76.21881 99.59459
## [15] 116.69889 159.79501 337.60153
##
## $method
## [1] "deviance"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(carseats.cv$size, carseats.cv$dev, type = 'b')
carseats.prune <-prune.tree(carseats.tree, best = 14)
plot(carseats.prune)
text(carseats.prune, pretty = 0, cex = .7)
pred.prune <- predict(carseats.prune, Carseats.test)
mean((Carseats.test$Sales - pred.prune)^2)
## [1] 5.013738
Based on the results above, pruning the tree increased the Test MSE to 5.013738.
(d) Use the bagging approach in order to analyze this data. What test MSE do you obtain? Use the importance() function to determine which variables are most important.
library(randomForest)
carseats.bag <- randomForest(Sales ~ ., data = Carseats.train, mtry = 10, ntree = 500, importance = T)
carseats.bag
##
## Call:
## randomForest(formula = Sales ~ ., data = Carseats.train, mtry = 10, ntree = 500, importance = T)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 10
##
## Mean of squared residuals: 2.931324
## % Var explained: 62.72
pred.bag <- predict(carseats.bag, Carseats.test)
mse.bag <- mean((Carseats.test$Sales - pred.bag)^2)
mse.bag
## [1] 2.657296
Based on the bagging approach the MSE is decreased to 2.657296.
importance(carseats.bag)
## %IncMSE IncNodePurity
## CompPrice 23.07909904 171.185734
## Income 2.82081527 94.079825
## Advertising 11.43295625 99.098941
## Population -3.92119532 59.818905
## Price 54.24314632 505.887016
## ShelveLoc 46.26912996 361.962753
## Age 14.24992212 159.740422
## Education -0.07662320 46.738585
## Urban 0.08530119 8.453749
## US 4.34349223 15.157608
CompPrice, Price, and ShelveLoc are most important.
(e) Use random forests to analyze this data. What test MSE do you obtain? Use the importance() function to determine which variables are most important. Describe the effect of m, the number of variables considered at each split, on the error rate obtained.
carseats.rf <- randomForest(Sales ~ ., data = Carseats.train, mtry = 7, ntree = 500, importance = T)
carseats.rf
##
## Call:
## randomForest(formula = Sales ~ ., data = Carseats.train, mtry = 7, ntree = 500, importance = T)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 7
##
## Mean of squared residuals: 2.886721
## % Var explained: 63.29
pred.rf <- predict(carseats.rf, Carseats.test)
mse.rf <- mean((Carseats.test$Sales - pred.rf)^2)
mse.rf
## [1] 2.624779
Based on the bagging approach the MSE is decreased to 2.624779.
importance(carseats.rf)
## %IncMSE IncNodePurity
## CompPrice 22.7283830 166.31362
## Income 4.5009562 98.97324
## Advertising 12.8607660 103.27827
## Population 0.7406765 68.03609
## Price 51.5001132 479.69045
## ShelveLoc 44.1661519 368.90657
## Age 15.6557207 160.73614
## Education -0.6632689 47.46421
## Urban 1.2130655 9.89366
## US 5.1049489 21.26560
CompPrice, Price, and ShelveLoc remain most important.
detach(Carseats)
This problem involves the OJ data set which is part of the ISLR package.
attach(OJ)
summary(OJ)
## Purchase WeekofPurchase StoreID PriceCH PriceMM
## CH:653 Min. :227.0 Min. :1.00 Min. :1.690 Min. :1.690
## MM:417 1st Qu.:240.0 1st Qu.:2.00 1st Qu.:1.790 1st Qu.:1.990
## Median :257.0 Median :3.00 Median :1.860 Median :2.090
## Mean :254.4 Mean :3.96 Mean :1.867 Mean :2.085
## 3rd Qu.:268.0 3rd Qu.:7.00 3rd Qu.:1.990 3rd Qu.:2.180
## Max. :278.0 Max. :7.00 Max. :2.090 Max. :2.290
## DiscCH DiscMM SpecialCH SpecialMM
## Min. :0.00000 Min. :0.0000 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.00000 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.00000 Median :0.0000 Median :0.0000 Median :0.0000
## Mean :0.05186 Mean :0.1234 Mean :0.1477 Mean :0.1617
## 3rd Qu.:0.00000 3rd Qu.:0.2300 3rd Qu.:0.0000 3rd Qu.:0.0000
## Max. :0.50000 Max. :0.8000 Max. :1.0000 Max. :1.0000
## LoyalCH SalePriceMM SalePriceCH PriceDiff Store7
## Min. :0.000011 Min. :1.190 Min. :1.390 Min. :-0.6700 No :714
## 1st Qu.:0.325257 1st Qu.:1.690 1st Qu.:1.750 1st Qu.: 0.0000 Yes:356
## Median :0.600000 Median :2.090 Median :1.860 Median : 0.2300
## Mean :0.565782 Mean :1.962 Mean :1.816 Mean : 0.1465
## 3rd Qu.:0.850873 3rd Qu.:2.130 3rd Qu.:1.890 3rd Qu.: 0.3200
## Max. :0.999947 Max. :2.290 Max. :2.090 Max. : 0.6400
## PctDiscMM PctDiscCH ListPriceDiff STORE
## Min. :0.0000 Min. :0.00000 Min. :0.000 Min. :0.000
## 1st Qu.:0.0000 1st Qu.:0.00000 1st Qu.:0.140 1st Qu.:0.000
## Median :0.0000 Median :0.00000 Median :0.240 Median :2.000
## Mean :0.0593 Mean :0.02731 Mean :0.218 Mean :1.631
## 3rd Qu.:0.1127 3rd Qu.:0.00000 3rd Qu.:0.300 3rd Qu.:3.000
## Max. :0.4020 Max. :0.25269 Max. :0.440 Max. :4.000
str(OJ)
## 'data.frame': 1070 obs. of 18 variables:
## $ Purchase : Factor w/ 2 levels "CH","MM": 1 1 1 2 1 1 1 1 1 1 ...
## $ WeekofPurchase: num 237 239 245 227 228 230 232 234 235 238 ...
## $ StoreID : num 1 1 1 1 7 7 7 7 7 7 ...
## $ PriceCH : num 1.75 1.75 1.86 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceMM : num 1.99 1.99 2.09 1.69 1.69 1.99 1.99 1.99 1.99 1.99 ...
## $ DiscCH : num 0 0 0.17 0 0 0 0 0 0 0 ...
## $ DiscMM : num 0 0.3 0 0 0 0 0.4 0.4 0.4 0.4 ...
## $ SpecialCH : num 0 0 0 0 0 0 1 1 0 0 ...
## $ SpecialMM : num 0 1 0 0 0 1 1 0 0 0 ...
## $ LoyalCH : num 0.5 0.6 0.68 0.4 0.957 ...
## $ SalePriceMM : num 1.99 1.69 2.09 1.69 1.69 1.99 1.59 1.59 1.59 1.59 ...
## $ SalePriceCH : num 1.75 1.75 1.69 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceDiff : num 0.24 -0.06 0.4 0 0 0.3 -0.1 -0.16 -0.16 -0.16 ...
## $ Store7 : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 2 2 2 2 2 ...
## $ PctDiscMM : num 0 0.151 0 0 0 ...
## $ PctDiscCH : num 0 0 0.0914 0 0 ...
## $ ListPriceDiff : num 0.24 0.24 0.23 0 0 0.3 0.3 0.24 0.24 0.24 ...
## $ STORE : num 1 1 1 1 0 0 0 0 0 0 ...
(a) Create a training set containing a random sample of 800 observations, and a test set containing the remaining observations.
set.seed(1)
train <- sample(dim(OJ)[1], 800)
oj.train <- OJ[train, ]
oj.test <- OJ[-train, ]
(b) Fit a tree to the training data, with Purchase as the response and the other variables as predictors. Use the summary() function to produce summary statistics about the tree, and describe the results obtained. What is the training error rate? How many terminal nodes does the tree have?
tree.oj = tree(Purchase ~ ., data = oj.train)
summary(tree.oj)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = oj.train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "SpecialCH" "ListPriceDiff"
## [5] "PctDiscMM"
## Number of terminal nodes: 9
## Residual mean deviance: 0.7432 = 587.8 / 791
## Misclassification error rate: 0.1588 = 127 / 800
Based on the regression tree, only five variables were used, LoyalCH, PriceDiff, SpecialCH, ListPriceDiff, and PctDiscMM. It has 9 terminal nodes, and the training error rate or misclassification error is 0.1588.
(c) Type in the name of the tree object in order to get a detailed text output. Pick one of the terminal nodes, and interpret the information displayed.
tree.oj
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1073.00 CH ( 0.60625 0.39375 )
## 2) LoyalCH < 0.5036 365 441.60 MM ( 0.29315 0.70685 )
## 4) LoyalCH < 0.280875 177 140.50 MM ( 0.13559 0.86441 )
## 8) LoyalCH < 0.0356415 59 10.14 MM ( 0.01695 0.98305 ) *
## 9) LoyalCH > 0.0356415 118 116.40 MM ( 0.19492 0.80508 ) *
## 5) LoyalCH > 0.280875 188 258.00 MM ( 0.44149 0.55851 )
## 10) PriceDiff < 0.05 79 84.79 MM ( 0.22785 0.77215 )
## 20) SpecialCH < 0.5 64 51.98 MM ( 0.14062 0.85938 ) *
## 21) SpecialCH > 0.5 15 20.19 CH ( 0.60000 0.40000 ) *
## 11) PriceDiff > 0.05 109 147.00 CH ( 0.59633 0.40367 ) *
## 3) LoyalCH > 0.5036 435 337.90 CH ( 0.86897 0.13103 )
## 6) LoyalCH < 0.764572 174 201.00 CH ( 0.73563 0.26437 )
## 12) ListPriceDiff < 0.235 72 99.81 MM ( 0.50000 0.50000 )
## 24) PctDiscMM < 0.196197 55 73.14 CH ( 0.61818 0.38182 ) *
## 25) PctDiscMM > 0.196197 17 12.32 MM ( 0.11765 0.88235 ) *
## 13) ListPriceDiff > 0.235 102 65.43 CH ( 0.90196 0.09804 ) *
## 7) LoyalCH > 0.764572 261 91.20 CH ( 0.95785 0.04215 ) *
Based on the output, terminal node 8 for the splitting variable LoyalCH has a splitting value of 0.0356 and 59 observations.
(d) Create a plot of the tree, and interpret the results.
plot(tree.oj)
text(tree.oj, pretty = 0, cex = .7)
Based on the plot, the most important variable is LoyalCH. If LoyalCH is less than 0.28, the tree predicts MM, for values greater than 0.76 the tree predicts CH, and for values of LoyalCH between those, the prediction determination will consider four other variables’ values.
(e) Predict the response on the test data, and produce a confusion matrix comparing the test labels to the predicted test labels. What is the test error rate?
pred.oj <- predict(tree.oj, oj.test, type = 'class')
table(oj.test$Purchase, pred.oj)
## pred.oj
## CH MM
## CH 160 8
## MM 38 64
Based on the confusion matrix, the test error rate is .1704.
(f) Apply the cv.tree() function to the training set in order to determine the optimal tree size.
cv.oj <- cv.tree(tree.oj, FUN = prune.misclass)
cv.oj
## $size
## [1] 9 8 7 4 2 1
##
## $dev
## [1] 150 150 149 158 172 315
##
## $k
## [1] -Inf 0.000000 3.000000 4.333333 10.500000 151.000000
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
(g) Produce a plot with tree size on the x-axis and cross-validated classification error rate on the y-axis.
plot(cv.oj$size, cv.oj$dev, type = 'b', xlab = "Tree Size", ylab = "Cross-Validated Error Rate")
(h) Which tree size corresponds to the lowest cross-validated classification error rate?
Based on the plot, the tree size with 7 terminal nodes appears to have the lowest cross-validation error rate.
(i) Produce a pruned tree corresponding to the optimal tree size obtained using cross-validation. If cross-validation does not lead to selection of a pruned tree, then create a pruned tree with five terminal nodes.
oj.prune <- prune.misclass(tree.oj, best = 7)
summary(oj.prune)
##
## Classification tree:
## snip.tree(tree = tree.oj, nodes = c(4L, 10L))
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "ListPriceDiff" "PctDiscMM"
## Number of terminal nodes: 7
## Residual mean deviance: 0.7748 = 614.4 / 793
## Misclassification error rate: 0.1625 = 130 / 800
plot(oj.prune)
text(oj.prune, pretty = 0, cex = .7)
(j) Compare the training error rates between the pruned and unpruned trees. Which is higher?
summary(tree.oj)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = oj.train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "SpecialCH" "ListPriceDiff"
## [5] "PctDiscMM"
## Number of terminal nodes: 9
## Residual mean deviance: 0.7432 = 587.8 / 791
## Misclassification error rate: 0.1588 = 127 / 800
summary(oj.prune)
##
## Classification tree:
## snip.tree(tree = tree.oj, nodes = c(4L, 10L))
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "ListPriceDiff" "PctDiscMM"
## Number of terminal nodes: 7
## Residual mean deviance: 0.7748 = 614.4 / 793
## Misclassification error rate: 0.1625 = 130 / 800
Based on the results of the pruned and unpruned tree, the unpruned tree has a slightly lower training error rate of 0.1588 with 9 terminal nodes and 5 variables in the tree.
(k) Compare the test error rates between the pruned and unpruned trees. Which is higher?
oj.prune.pred = predict(oj.prune, newdata = oj.test, type = "class")
table(oj.prune.pred, oj.test$Purchase)
##
## oj.prune.pred CH MM
## CH 160 36
## MM 8 66
Based on the table, this test error rate of the pruned tree is 0.163 which is slightly lower than the unpruned tree.
detach(OJ)