Consider the Gini index, classification error, and entropy in a simple classification setting with two classes. Create a single plot that displays each of these quantities as a function of \(\hat{p}_{m1}\). The x-axis should display \(\hat{p}_{m1}\), ranging from 0 to 1, and the y-axis should display the value of the Gini index, classification error, and entropy.
**Hint: In a setting with two classes, \(\hat{p}_{m1} = 1 - \hat{p}_{m2}\). You could make this plot by hand, but it will be much easier to make in R.
In the below three measures, \(\hat{p}_{mk}\) represents the proportion of training obsersation in the mth region that are from the kth class.
Classification Error Rate is defined as the fraction of training obserations that have a class that is not the most common class of the region -
\(E = 1 - max_k(\hat{p}_{mk})\)
Gini Index is a measure of total variance across K classes or node purity. The Gini Index takes on small values if all of the \(\hat{p}_{mk}\)’sd are close to zero or 1.
\(G = \sum_{k=1}^{K}\hat{p}_{mk}\)
Entropy is another measure of node purity and will take on a value near zero if the \(\hat{p}_{mk}\)’s are all near zero or near 1.
\(D = - \sum_{k=1}^{K}\hat{p}_{mk}log\hat{p}_{mk}\)
p = seq(0, 1, 0.01)
class.err = 1 - pmax(p, 1 - p)
gini.index = 2 * p * (1-p)
entropy = - (p * log(p) + (1 - p) * log(1 - p))
matplot(p, cbind(class.err, gini.index, entropy), col = c("cadetblue2", "salmon", "seagreen2" ), pch = 16, main = "Classification Tree Measures", xlab = "p-hat_mk values", ylab = "Splitting Criterion")
legend("bottom",pch=16, title = "Measures", col=c("cadetblue2", "salmon", "seagreen2"), legend=c("Classification Error", "Gini Index", "Entropy"), box.lty = 1)
In the lab, a classification tree was applied to the Carseats data set after converting Sales into a qualitative response variable. Now we will seek to predict Sales using regression trees and related approaches, treating the response as a quantitative variable.
library(ISLR)
library(caret)
library(tree)
library(rpart)
library(rattle)
library(randomForest)
library(gbm)
library(rpart.plot)
attach(Carseats)
summary(Carseats)
## Sales CompPrice Income Advertising
## Min. : 0.000 Min. : 77 Min. : 21.00 Min. : 0.000
## 1st Qu.: 5.390 1st Qu.:115 1st Qu.: 42.75 1st Qu.: 0.000
## Median : 7.490 Median :125 Median : 69.00 Median : 5.000
## Mean : 7.496 Mean :125 Mean : 68.66 Mean : 6.635
## 3rd Qu.: 9.320 3rd Qu.:135 3rd Qu.: 91.00 3rd Qu.:12.000
## Max. :16.270 Max. :175 Max. :120.00 Max. :29.000
## Population Price ShelveLoc Age Education
## Min. : 10.0 Min. : 24.0 Bad : 96 Min. :25.00 Min. :10.0
## 1st Qu.:139.0 1st Qu.:100.0 Good : 85 1st Qu.:39.75 1st Qu.:12.0
## Median :272.0 Median :117.0 Medium:219 Median :54.50 Median :14.0
## Mean :264.8 Mean :115.8 Mean :53.32 Mean :13.9
## 3rd Qu.:398.5 3rd Qu.:131.0 3rd Qu.:66.00 3rd Qu.:16.0
## Max. :509.0 Max. :191.0 Max. :80.00 Max. :18.0
## Urban US
## No :118 No :142
## Yes:282 Yes:258
##
##
##
##
str(Carseats)
## 'data.frame': 400 obs. of 11 variables:
## $ Sales : num 9.5 11.22 10.06 7.4 4.15 ...
## $ CompPrice : num 138 111 113 117 141 124 115 136 132 132 ...
## $ Income : num 73 48 35 100 64 113 105 81 110 113 ...
## $ Advertising: num 11 16 10 4 3 13 0 15 0 0 ...
## $ Population : num 276 260 269 466 340 501 45 425 108 131 ...
## $ Price : num 120 83 80 97 128 72 108 120 124 124 ...
## $ ShelveLoc : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
## $ Age : num 42 65 59 55 38 78 71 67 76 76 ...
## $ Education : num 17 10 12 14 13 16 15 10 10 17 ...
## $ Urban : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
## $ US : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
For this exercise, I will split the data on a 75/25 split using the caret function createDataPartition().
set.seed(12345)
carseat.inTrain <- createDataPartition(Carseats$Sales, p = 0.75, list = FALSE, times = 1)
carseat.train <- Carseats[carseat.inTrain,]
carseat.test <- Carseats[-carseat.inTrain,]
dim(carseat.train)
## [1] 301 11
dim(carseat.test)
## [1] 99 11
The summary of the Regression Trees details that variables ShelveLoc, Price, Age, Income, CompPrice, and Advertising were used in the tree construction. Additionally, the summary of the Regression Trees mentions that there are 17 Terminal Nodes in this tree and that the Residual Meam Deviance is 2.33. The Residual Meam Deviance is simply the sum of squared error for the tree.
Based on the plotted tree, it appears that the most important indicator of Sales is shelving location (ShelvLoc) as the first branch differentiates Good locations from Bad and Medium locations.
carseats.tree = tree(Sales~., data = carseat.train)
summary(carseats.tree)
##
## Regression tree:
## tree(formula = Sales ~ ., data = carseat.train)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Age" "Income" "CompPrice"
## [6] "Advertising"
## Number of terminal nodes: 17
## Residual mean deviance: 2.33 = 661.6 / 284
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -4.86400 -0.96450 0.02561 0.00000 0.95590 4.12500
plot(carseats.tree)
text(carseats.tree, pretty = 0)
For this Regression Tree Model, the Test MSE is 4.718901.
yhat = predict(carseats.tree, newdata = carseat.test)
mean((yhat - carseat.test$Sales)^2)
## [1] 4.718901
With Cross-Validation, the tree with 14 Terminal Nodes is selected.
set.seed(12345)
carseat.cv.tree = cv.tree(carseats.tree)
carseat.cv.tree
## $size
## [1] 17 15 14 13 12 11 10 9 7 6 5 4 3 2 1
##
## $dev
## [1] 1302.702 1313.222 1257.392 1268.506 1278.734 1273.294 1289.776 1313.307
## [9] 1335.252 1293.245 1475.498 1488.995 1463.871 1739.302 2243.539
##
## $k
## [1] -Inf 26.30625 29.65542 30.10488 32.08957 34.96283 37.58095
## [8] 41.65668 49.04988 67.91041 101.46757 108.79151 119.94385 290.87068
## [15] 517.57588
##
## $method
## [1] "deviance"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(carseat.cv.tree$size, carseat.cv.tree$dev, type = "b")
dev.min <- which.min(carseat.cv.tree$dev)
dev.min
## [1] 3
carseat.cv.tree$dev
## [1] 1302.702 1313.222 1257.392 1268.506 1278.734 1273.294 1289.776 1313.307
## [9] 1335.252 1293.245 1475.498 1488.995 1463.871 1739.302 2243.539
carseat.cv.tree$size
## [1] 17 15 14 13 12 11 10 9 7 6 5 4 3 2 1
For the Final Model, only utilizes 2 variables - ShelveLocGood and Price.
carseat.prune.tree = prune.tree(carseats.tree, best = 14)
plot(carseat.prune.tree)
text(carseat.prune.tree, pretty = 0)
The Test MSE for the Cross-Validated Final Tree Model is 4.899015. With pruning the tree, we do not improve the Test MSE as the unpruned tree obtained a lowest Test MSE of 4.718901.
cv.yhat = predict(carseat.prune.tree, newdata = carseat.test)
mean((cv.yhat - carseat.test$Sales)^2)
## [1] 4.899015
For this case, I will utilize the randomForest() function to perform bagging as bagging is a special case of a random forest with m = p. Rather than allowing R to default the setting of mtry, I will set it to equal to 10 to tell R to use every variable in the dataset.
With the Bagging Method, the model is able to explain 67.25% of the Variance of Sales on the Training Dataset. Additionally, we are able to see that ShelveLoc is the most importance variable in tree generated with the Bagging Method. The measures of the variable importance in the included table below is based upon the mean decrease of RSS on the out of bag samples when a given variable is excluded from the model.
set.seed(12345)
library(randomForest)
carseats.bag = randomForest(Sales ~ ., data = carseat.train, mtry = 10, importance = TRUE)
carseats.bag
##
## Call:
## randomForest(formula = Sales ~ ., data = carseat.train, mtry = 10, importance = TRUE)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 10
##
## Mean of squared residuals: 2.420585
## % Var explained: 67.25
importance(carseats.bag)
## %IncMSE IncNodePurity
## CompPrice 33.0793077 246.951534
## Income 9.2319869 125.288088
## Advertising 15.1560230 117.841555
## Population -2.6914261 81.154554
## Price 70.3054532 628.822089
## ShelveLoc 76.2680282 668.952116
## Age 22.1351604 222.162003
## Education 0.7707461 55.563048
## Urban -2.8709649 8.168139
## US 2.0388363 9.157165
With the the Bagging Method, we obtain a Test MSE of 3.201505. This Test MSE is better than the unpruned Regression Tree Model that we previously obtained.
bag.yhat= predict(carseats.bag, newdata = carseat.test)
mean((bag.yhat - carseat.test$Sales)^2)
## [1] 3.201505
With the Random Forest Method, the model is able to explain 62.09% of the Variance of Sales on the Training Dataset. Additionally, we are able to see that Price is the most importance variable in tree generated with Random Forest. The measures of the variable importance in the included table below is based upon the mean decrease of RSS on the out of bag samples when a given variable is excluded from the model.
set.seed(12345)
library(randomForest)
carseats.rf = randomForest(Sales ~ ., data = carseat.train, importance = TRUE, na.action = na.omit)
carseats.rf
##
## Call:
## randomForest(formula = Sales ~ ., data = carseat.train, importance = TRUE, na.action = na.omit)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 3
##
## Mean of squared residuals: 2.801971
## % Var explained: 62.09
importance(carseats.rf)
## %IncMSE IncNodePurity
## CompPrice 18.383237 217.97747
## Income 4.289270 163.86297
## Advertising 12.326383 173.31046
## Population -1.083167 150.67935
## Price 48.392095 520.07883
## ShelveLoc 45.881227 499.70560
## Age 14.309680 238.64574
## Education 1.433016 91.29355
## Urban -4.652921 18.78589
## US 5.682563 35.44425
With the Random Forest Model, we obtain a Test MSE of 3.565068. The Random Forest Test MSE is slightly higher than the Test MSE that we received when running the Bagging Method. For Random Forest, 3 Variables were tried at each split; whereas, the Bagged Model considered 10 variables at each split. In this case, the consideration of more variables resulted in a higher Test MSe.
rf.yhat= predict(carseats.rf, newdata = carseat.test)
mean((rf.yhat - carseat.test$Sales)^2)
## [1] 3.565068
detach(Carseats)
This problem involves the OJ data set which is part of the ISLR package.
library(ISLR)
summary(OJ)
## Purchase WeekofPurchase StoreID PriceCH PriceMM
## CH:653 Min. :227.0 Min. :1.00 Min. :1.690 Min. :1.690
## MM:417 1st Qu.:240.0 1st Qu.:2.00 1st Qu.:1.790 1st Qu.:1.990
## Median :257.0 Median :3.00 Median :1.860 Median :2.090
## Mean :254.4 Mean :3.96 Mean :1.867 Mean :2.085
## 3rd Qu.:268.0 3rd Qu.:7.00 3rd Qu.:1.990 3rd Qu.:2.180
## Max. :278.0 Max. :7.00 Max. :2.090 Max. :2.290
## DiscCH DiscMM SpecialCH SpecialMM
## Min. :0.00000 Min. :0.0000 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.00000 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.00000 Median :0.0000 Median :0.0000 Median :0.0000
## Mean :0.05186 Mean :0.1234 Mean :0.1477 Mean :0.1617
## 3rd Qu.:0.00000 3rd Qu.:0.2300 3rd Qu.:0.0000 3rd Qu.:0.0000
## Max. :0.50000 Max. :0.8000 Max. :1.0000 Max. :1.0000
## LoyalCH SalePriceMM SalePriceCH PriceDiff Store7
## Min. :0.000011 Min. :1.190 Min. :1.390 Min. :-0.6700 No :714
## 1st Qu.:0.325257 1st Qu.:1.690 1st Qu.:1.750 1st Qu.: 0.0000 Yes:356
## Median :0.600000 Median :2.090 Median :1.860 Median : 0.2300
## Mean :0.565782 Mean :1.962 Mean :1.816 Mean : 0.1465
## 3rd Qu.:0.850873 3rd Qu.:2.130 3rd Qu.:1.890 3rd Qu.: 0.3200
## Max. :0.999947 Max. :2.290 Max. :2.090 Max. : 0.6400
## PctDiscMM PctDiscCH ListPriceDiff STORE
## Min. :0.0000 Min. :0.00000 Min. :0.000 Min. :0.000
## 1st Qu.:0.0000 1st Qu.:0.00000 1st Qu.:0.140 1st Qu.:0.000
## Median :0.0000 Median :0.00000 Median :0.240 Median :2.000
## Mean :0.0593 Mean :0.02731 Mean :0.218 Mean :1.631
## 3rd Qu.:0.1127 3rd Qu.:0.00000 3rd Qu.:0.300 3rd Qu.:3.000
## Max. :0.4020 Max. :0.25269 Max. :0.440 Max. :4.000
str(OJ)
## 'data.frame': 1070 obs. of 18 variables:
## $ Purchase : Factor w/ 2 levels "CH","MM": 1 1 1 2 1 1 1 1 1 1 ...
## $ WeekofPurchase: num 237 239 245 227 228 230 232 234 235 238 ...
## $ StoreID : num 1 1 1 1 7 7 7 7 7 7 ...
## $ PriceCH : num 1.75 1.75 1.86 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceMM : num 1.99 1.99 2.09 1.69 1.69 1.99 1.99 1.99 1.99 1.99 ...
## $ DiscCH : num 0 0 0.17 0 0 0 0 0 0 0 ...
## $ DiscMM : num 0 0.3 0 0 0 0 0.4 0.4 0.4 0.4 ...
## $ SpecialCH : num 0 0 0 0 0 0 1 1 0 0 ...
## $ SpecialMM : num 0 1 0 0 0 1 1 0 0 0 ...
## $ LoyalCH : num 0.5 0.6 0.68 0.4 0.957 ...
## $ SalePriceMM : num 1.99 1.69 2.09 1.69 1.69 1.99 1.59 1.59 1.59 1.59 ...
## $ SalePriceCH : num 1.75 1.75 1.69 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceDiff : num 0.24 -0.06 0.4 0 0 0.3 -0.1 -0.16 -0.16 -0.16 ...
## $ Store7 : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 2 2 2 2 2 ...
## $ PctDiscMM : num 0 0.151 0 0 0 ...
## $ PctDiscCH : num 0 0 0.0914 0 0 ...
## $ ListPriceDiff : num 0.24 0.24 0.23 0 0 0.3 0.3 0.24 0.24 0.24 ...
## $ STORE : num 1 1 1 1 0 0 0 0 0 0 ...
set.seed(12345)
oj.inTrain <- createDataPartition(OJ$Purchase, p = 0.746, list = FALSE, times = 1)
oj.train <- OJ[oj.inTrain,]
oj.test <- OJ[-oj.inTrain,]
dim(oj.train)
## [1] 800 18
dim(oj.test)
## [1] 270 18
For this first tree model, I am fitting the categorical variable Purchase against all predictor variables in the dataset. For this tree, the variables that were used to construct the tree include LoyalCH, PriceDiff, and SalePriceMM. With the tree, we obtain a Training Error Rate of 16.5% or 0.165. The tree has 8 Terminal Nodes.
library(rpart)
library(rattle)
set.seed(12345)
oj.tree = tree(Purchase~., data = oj.train, method = "class")
summary(oj.tree)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = oj.train, method = "class")
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "SalePriceMM"
## Number of terminal nodes: 8
## Residual mean deviance: 0.7631 = 604.4 / 792
## Misclassification error rate: 0.165 = 132 / 800
For this problem, I will explain Terminal Node 7. We are able to identify that this node is a Terminal Model as the line is ended with an astericks.
The first split to get to Node 7 begins at the root variable of CH. If LoyalCH > 0.5036, then the observation will be tested to see if it is < 0.764879 or > 0.764879. If the observation is > 0.764879, the Orange Juice selection process would stop here and the observation will be categorized in Node 7 where the class of the variables are denotated as CH. In Node 7, there are a total of 251 observation where 95.219% of the observations have a true class of CH and 4.781% of the observations have a true class of MM.
oj.tree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1070.00 CH ( 0.61000 0.39000 )
## 2) LoyalCH < 0.5036 358 433.20 MM ( 0.29330 0.70670 )
## 4) LoyalCH < 0.276142 166 122.10 MM ( 0.12048 0.87952 )
## 8) LoyalCH < 0.0356415 56 0.00 MM ( 0.00000 1.00000 ) *
## 9) LoyalCH > 0.0356415 110 104.30 MM ( 0.18182 0.81818 ) *
## 5) LoyalCH > 0.276142 192 263.60 MM ( 0.44271 0.55729 )
## 10) PriceDiff < 0.05 73 76.78 MM ( 0.21918 0.78082 ) *
## 11) PriceDiff > 0.05 119 161.90 CH ( 0.57983 0.42017 ) *
## 3) LoyalCH > 0.5036 442 347.40 CH ( 0.86652 0.13348 )
## 6) LoyalCH < 0.764879 191 213.10 CH ( 0.75393 0.24607 )
## 12) PriceDiff < -0.165 31 37.35 MM ( 0.29032 0.70968 ) *
## 13) PriceDiff > -0.165 160 138.70 CH ( 0.84375 0.15625 )
## 26) SalePriceMM < 2.125 88 96.71 CH ( 0.76136 0.23864 ) *
## 27) SalePriceMM > 2.125 72 30.90 CH ( 0.94444 0.05556 ) *
## 7) LoyalCH > 0.764879 251 96.39 CH ( 0.95219 0.04781 ) *
In the plotted tree before, we can see that half of the terminal nodes classify observations as MM while the other half classify the observations as CH. The three important variables that assist in classifying the observations are LoyalCH, PriceDiff, and SalePriceMM. Most of the MM observations are classified on the left-hand side of the tree where LoyalCH < 0.5036. There is one exception on the left-hand side of the tree in which the terminal node results in a classification of CH. On the right-hand (LoyalCH > 0.5036), we can see that most of the terminal nodes classify the observations as CH with the exception of one terminal node which classifies the observations as MM.
plot(oj.tree)
text(oj.tree, pretty = TRUE)
The Test Error Rate for this Classification Tree is 15.93%.
oj.tree.pred = predict(oj.tree, newdata = oj.test, type = "class")
table(oj.tree.pred, oj.test$Purchase)
##
## oj.tree.pred CH MM
## CH 149 27
## MM 16 78
(27+16)/270
## [1] 0.1592593
set.seed(12345)
oj.cv.tree = cv.tree(oj.tree, FUN = prune.misclass)
oj.cv.tree
## $size
## [1] 8 6 4 2 1
##
## $dev
## [1] 148 148 169 176 312
##
## $k
## [1] -Inf 0.0 6.5 9.5 148.0
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(oj.cv.tree)
plot(oj.cv.tree$size, oj.cv.tree$dev, type = "b", xlab = "Tree Size", ylab = "Cross-Validation Error Rate")
Based on the above plot, it appears that a Tree with 6 Terminal Nodes results in the lowest Cross-Validation Error Rate.
oj.prune.tree = prune.misclass(oj.tree, best = 6)
summary(oj.prune.tree)
##
## Classification tree:
## snip.tree(tree = oj.tree, nodes = c(4L, 13L))
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff"
## Number of terminal nodes: 6
## Residual mean deviance: 0.7976 = 633.3 / 794
## Misclassification error rate: 0.165 = 132 / 800
plot(oj.prune.tree)
text(oj.prune.tree, pretty = 0)
The Training Error Rate of the Pruned Tree is 16.5%. The Training Error Rate for the Unpruned Tree is also 16.5%. There is no difference in the Training Error Rate between the two trees. The key difference between the two models is that the Pruned Tree leaves out the variable SalesPriceMM and has only 6 Terminal Nodes vs the 8 Terminal Nodes that the Unpruned Tree has.
The Test Error Rate for Pruned Tree: 15.93% The Test Error Rate for Unpruned Tree: 15.93%
There is no difference between the Test Error Rates for the Pruned and Unpruned Tree.
oj.prune.tree.pred = predict(oj.prune.tree, newdata = oj.test, type = "class")
table(oj.prune.tree.pred, oj.test$Purchase)
##
## oj.prune.tree.pred CH MM
## CH 149 27
## MM 16 78
(27+16)/270
## [1] 0.1592593