library(ISLR2)
library(MASS)
library(class)
library(rpart)
library(tree)
library(randomForest)
library(BART)
Consider the Gini index, classification error, and entropy in a simple classification setting with two classes. Create a single plot that displays each of these quantities as a function of ˆpm1. The x-axis should display ˆpm1, ranging from 0 to 1, and the y-axis should display the value of the Gini index, classification error, and entropy.
Hint: In a setting with two classes, pˆm1 = 1 − pˆm2. You could make this plot by hand, but it will be much easier to make in R.
p = seq(0,1,0.01)
gini = 2*(1-p) * p
classerror = 1-pmax(p,1-p)
crossentropy = -(p*log(p)+(1-p)*log(1-p))
plot(NA,NA,xlim = c(0,1),ylim = c(0,1),xlab = 'p',ylab = 'Value')
lines(p,gini,col = 'green')
lines(p,classerror,col = 'blue')
lines(p,crossentropy,col = 'red')
legend(x = 'topright',legend=c('Gini','Classification error','Entropy'),
col =c ('green','blue','red'),lty = 1,text.width = 0.25)
In the lab, a classification tree was applied to the Carseats data set after converting Sales into a qualitative response variable. Now we will seek to predict Sales using regression trees and related approaches, treating the response as a quantitative variable.
attach(Carseats)
(a) Split the data set into a training set and a test set.
Normally, I’d split this 60/40 or 70/30, but we’ll do 50/50 to match the lab and other work posted online.
set.seed(1)
train=sample(1:nrow(Carseats),nrow(Carseats)/2)
test=-train
car.train = Carseats[train,]
car.test = Carseats[test,]
nrow(car.train)/nrow(Carseats)
## [1] 0.5
nrow(car.test)/nrow(Carseats)
## [1] 0.5
(b) Fit a regression tree to the training set. Plot the tree, and interpret the results. What test MSE do you obtain?
tree.carseats = tree(Sales~.,car.train)
summary(tree.carseats)
##
## Regression tree:
## tree(formula = Sales ~ ., data = car.train)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Age" "Advertising" "CompPrice"
## [6] "US"
## Number of terminal nodes: 18
## Residual mean deviance: 2.167 = 394.3 / 182
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -3.88200 -0.88200 -0.08712 0.00000 0.89590 4.09900
tree.pred=predict(tree.carseats,Carseats[-train,])
mean((tree.pred-Carseats[-train,'Sales'])^2)
## [1] 4.922039
MSE is listed above at ~ 4.922. And now to plot (we can adjust it a little to reduce the clutter by large text size)
plot(tree.carseats)
text(tree.carseats, cex = .7)
(c) Use cross-validation in order to determine the optimal level of tree complexity. Does pruning the tree improve the test MSE?
set.seed(1)
cv.car=cv.tree(tree.carseats)
plot(cv.car$size,cv.car$dev,xlab = "Size of Tree",ylab = "Deviance",type = "b")
plot(cv.car)
Having seen that the best tree included all the nodes, I went and checked other analysts work to see if they came to the same conclusion. The community appears split, some saying 17, some saying less; while they built their model the same, they had different cuts of the data for train and test and different seeds. As our model shows 17 nodes is correct, it is unlikely any amount of pruning could reduce the MSE. As we have a good dip at 15, I will check to see if the MSE improves at 15 anyhow.
prune.tree.carseats=prune.tree(tree.carseats,best=15)
plot(prune.tree.carseats)
text(prune.tree.carseats, cex=.7)
And to validate the MSE does not improve:
tree.pred=predict(prune.tree.carseats,Carseats[test,])
mean((tree.pred-Carseats[test,'Sales'])^2)
## [1] 4.924799
Admittedly, the MSE was a lot closer than I expected, but did not improve. Even in the other analyst’s work I researched, while they found ‘best’ prunes other than 17, it still did not improve the MSE.
(d) Use the bagging approach in order to analyze this data. What test MSE do you obtain? Use the importance() function to determine which variables are most important.
set.seed(1)
carseats.rf=randomForest(Sales~.,data=car.train, mtry = ncol(Carseats)-1,importance=T,ntree=100)
tree.pred=predict(carseats.rf,Carseats[test,])
mean((tree.pred-Carseats[test,'Sales'])^2)
## [1] 2.616711
MSE had a significant drop this way.
importance(carseats.rf)
## %IncMSE IncNodePurity
## CompPrice 11.4033686 169.11991
## Income 1.4751526 92.34927
## Advertising 5.3836500 96.15208
## Population -2.0667575 62.75983
## Price 27.2224905 492.32337
## ShelveLoc 22.7210914 363.27721
## Age 9.1141420 154.10317
## Education -0.6365854 46.57603
## Urban -0.6581541 10.55949
## US 2.1808079 16.20339
Price and ShelveLoc have the largest values
and are the most important variables.
(e) Use random forests to analyze this data. What test MSE do you obtain? Use the importance() function to determine which variables are most important. Describe the effect of m, the number of variables considered at each split, on the error rate obtained.
carseats.rf2 = randomForest(Sales ~ ., data = car.train, mtry = 5, ntree = 500,
importance = T)
rf.pred = predict(carseats.rf2, car.test)
mean((car.test$Sales - rf.pred)^2)
## [1] 2.704281
MSE got a little worse than bagging.
importance(carseats.rf2)
## %IncMSE IncNodePurity
## CompPrice 19.086069 160.09548
## Income 1.636380 114.64347
## Advertising 11.095354 110.54483
## Population -1.277371 79.95358
## Price 45.409780 453.69460
## ShelveLoc 40.346149 327.86347
## Age 12.143250 164.17656
## Education 0.833331 57.10072
## Urban 1.252085 10.94739
## US 4.485707 22.89812
The top variables are again Price and
ShelveLoc. Age is now number 3, where
CompPrice was number 3 before (they are very close).
(f) Now analyze the data using BART, and report your results.
x=Carseats[, 2:11]
y=Carseats[, "Sales"]
xtrain = x[train, ]
ytrain = y[train]
xtest = x[test, ]
ytest = y[test]
set.seed(1)
bartfit = gbart(xtrain , ytrain , x.test = xtest)
## *****Calling gbart: type=1
## *****Data:
## data:n,p,np: 200, 14, 200
## y1,yn: 2.781850, 1.091850
## x1,x[n*p]: 107.000000, 1.000000
## xp1,xp[np*p]: 111.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 63 ... 1
## *****burn,nd,thin: 100,1000,1
## *****Prior:beta,alpha,tau,nu,lambda,offset: 2,0.95,0.273474,3,0.23074,7.57815
## *****sigma: 1.088371
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,14,0
## *****printevery: 100
##
## MCMC
## done 0 (out of 1100)
## done 100 (out of 1100)
## done 200 (out of 1100)
## done 300 (out of 1100)
## done 400 (out of 1100)
## done 500 (out of 1100)
## done 600 (out of 1100)
## done 700 (out of 1100)
## done 800 (out of 1100)
## done 900 (out of 1100)
## done 1000 (out of 1100)
## time: 3s
## trcnt,tecnt: 1000,1000
Now we’ll compute the test error
yhat.bart = bartfit$yhat.test.mean
mean((ytest - yhat.bart)^2)
## [1] 1.450842
Using BART we dropped the MSE even further. Here is the number of times each variable appeared in the collection of trees:
ord = order (bartfit$varcount.mean , decreasing = T)
bartfit$varcount.mean[ord]
## Price CompPrice ShelveLoc2 US2 ShelveLoc1 US1
## 24.396 18.427 18.323 17.580 17.471 17.233
## Education Age Urban1 Urban2 Income Population
## 16.524 16.503 16.331 15.945 15.693 15.518
## ShelveLoc3 Advertising
## 15.440 13.818
detach(Carseats)
This problem involves the OJ data set which is part of the ISLR2 package.
attach(OJ)
(a) Create a training set containing a random sample of 800 observations, and a test set containing the remaining observations.
set.seed(1)
index=sample(1:nrow(OJ), 800)
train = OJ[index, ]
test = OJ[-index, ]
nrow(train)
## [1] 800
nrow(train) / nrow(OJ)
## [1] 0.7476636
nrow(test)
## [1] 270
nrow(test)/ nrow(OJ)
## [1] 0.2523364
Using 800 for train gives us about 75/25 split.
(b) Fit a tree to the training data, with Purchase as the response and the other variables as predictors. Use the summary() function to produce summary statistics about the tree, and describe the results obtained. What is the training error rate? How many terminal nodes does the tree have?
OJ.tree = tree(Purchase ~ ., train)
summary(OJ.tree)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "SpecialCH" "ListPriceDiff"
## [5] "PctDiscMM"
## Number of terminal nodes: 9
## Residual mean deviance: 0.7432 = 587.8 / 791
## Misclassification error rate: 0.1588 = 127 / 800
Using Summary, we see that there are 9 terminal nodes and a misclassification error of .1588
(c) Type in the name of the tree object in order to get a detailed text output. Pick one of the terminal nodes, and interpret the information displayed.
OJ.tree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1073.00 CH ( 0.60625 0.39375 )
## 2) LoyalCH < 0.5036 365 441.60 MM ( 0.29315 0.70685 )
## 4) LoyalCH < 0.280875 177 140.50 MM ( 0.13559 0.86441 )
## 8) LoyalCH < 0.0356415 59 10.14 MM ( 0.01695 0.98305 ) *
## 9) LoyalCH > 0.0356415 118 116.40 MM ( 0.19492 0.80508 ) *
## 5) LoyalCH > 0.280875 188 258.00 MM ( 0.44149 0.55851 )
## 10) PriceDiff < 0.05 79 84.79 MM ( 0.22785 0.77215 )
## 20) SpecialCH < 0.5 64 51.98 MM ( 0.14062 0.85938 ) *
## 21) SpecialCH > 0.5 15 20.19 CH ( 0.60000 0.40000 ) *
## 11) PriceDiff > 0.05 109 147.00 CH ( 0.59633 0.40367 ) *
## 3) LoyalCH > 0.5036 435 337.90 CH ( 0.86897 0.13103 )
## 6) LoyalCH < 0.764572 174 201.00 CH ( 0.73563 0.26437 )
## 12) ListPriceDiff < 0.235 72 99.81 MM ( 0.50000 0.50000 )
## 24) PctDiscMM < 0.196196 55 73.14 CH ( 0.61818 0.38182 ) *
## 25) PctDiscMM > 0.196196 17 12.32 MM ( 0.11765 0.88235 ) *
## 13) ListPriceDiff > 0.235 102 65.43 CH ( 0.90196 0.09804 ) *
## 7) LoyalCH > 0.764572 261 91.20 CH ( 0.95785 0.04215 ) *
The first terminal node listed is
LoyalCH < 0.0356415 59 10.14 MM ( 0.01695 0.98305 ).
The node is a child for the splits
LoyalCH < 0.5036 365 441.60 MM ( 0.29315 0.70685 ) and
LoyalCH < 0.280875 177 140.50 MM ( 0.13559 0.86441 ).
The root, root 800 1073.00 CH ( 0.60625 0.39375 ), means
there are 800 observations, the deviance is 1073, the overall prediction
is CH, and the split is 60.6% CH to 39.4% MM. At our terminal node,
there are 59 observations at our node.
(d) Create a plot of the tree, and interpret the results.
plot(OJ.tree)
text(OJ.tree, pretty = 0, cex = 0.7)
LoyalCH is at our top two layers, highlighting the importance of this variable. The node we evaluated previously is all the way to the bottom left. Coming down from the top, if you are less loyal to CH it sends you to the left. Then at the next level, if you were less loyal than .28, it sent you down the left lane again, where you’ll end up in one of two terminal nodes both purchasing MM. If you weren’t quite so MM loyal (above .28), you went to the right, where now the price difference comes into play. If you were more loyal to MM and went down the left at the start, but not so loyal you went to the right next, and the price difference was small, finally SpcialCH becomes the determining factor.
(e) Predict the response on the test data, and produce a confusion matrix comparing the test labels to the predicted test labels. What is the test error rate?
test_pred = predict(OJ.tree, test, type = "class")
table(test_pred, test_actual = test$Purchase)
## test_actual
## test_pred CH MM
## CH 160 38
## MM 8 64
(8+38)/(160+8+38+64)
## [1] 0.1703704
The test error rate is .1704
(f) Apply the cv.tree() function to the training set in order to determine the optimal tree size.
set.seed(1)
tree_model = OJ.tree
cv_tree_model = cv.tree(tree_model, K = 10, FUN = prune.misclass)
cv_tree_model
## $size
## [1] 9 8 7 4 2 1
##
## $dev
## [1] 145 145 146 146 167 315
##
## $k
## [1] -Inf 0.000000 3.000000 4.333333 10.500000 151.000000
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(cv_tree_model)
(g) Produce a plot with tree size on the x-axis and cross-validated classification error rate on the y-axis.
plot(cv_tree_model$size,cv_tree_model$dev,xlab="Size of the Tree",ylab="CV Deviance",type = "b")
points(4,min(cv_tree_model$dev),col="red")
points(9,min(cv_tree_model$dev),col="green")
Our tree is wanting go with size 9. However, when overlaying the value of CV Deviance at 9 with size 4, you can see there is very little difference. Seeing that a future problem will have us prune, we will prune at 4 for that problem because it will simplify our tree at very little cost.
(h) Which tree size corresponds to the lowest cross-validated classification error rate?
As stated above, 9; however, it is barely lowest compared to 8 and 4.
(i) Produce a pruned tree corresponding to the optimal tree size obtained using cross-validation. If cross-validation does not lead to selection of a pruned tree, then create a pruned tree with five terminal nodes.
pruned_tree_model <- prune.tree(tree_model, best = 4)
pruned_tree_model
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1073.0 CH ( 0.60625 0.39375 )
## 2) LoyalCH < 0.5036 365 441.6 MM ( 0.29315 0.70685 )
## 4) LoyalCH < 0.280875 177 140.5 MM ( 0.13559 0.86441 ) *
## 5) LoyalCH > 0.280875 188 258.0 MM ( 0.44149 0.55851 ) *
## 3) LoyalCH > 0.5036 435 337.9 CH ( 0.86897 0.13103 )
## 6) LoyalCH < 0.764572 174 201.0 CH ( 0.73563 0.26437 ) *
## 7) LoyalCH > 0.764572 261 91.2 CH ( 0.95785 0.04215 ) *
We used the 4 terminal nodes decided earlier in the problem. We could
have done 5 by changing best = 4 to
best = 5.
(j) Compare the training error rates between the pruned and unpruned trees. Which is higher?
Tree_Model:
mean(predict(tree_model, type = "class") != train$Purchase)
## [1] 0.15875
pruned_tree_model:
mean(predict(pruned_tree_model, type = "class") != train$Purchase)
## [1] 0.205
Our pruned tree did not have a better training error. Even if 4 was the best, it wouldn’t have necessarily corresponded to lower error.
(k) Compare the test error rates between the pruned and unpruned trees. Which is higher?
tree_model:
mean(predict(tree_model, type = "class", newdata = test) != test$Purchase)
## [1] 0.1703704
pruned_tree_model:
mean(predict(pruned_tree_model, type = "class", newdata = test) != test$Purchase)
## [1] 0.1851852
The error rate is still better for the full tree model. It would be worth further exploration to see if a tree with the “best” number of nodes would have resulted in a better error rate, but for us it did not.
detach(OJ)