Consider the Gini index, classification error, and entropy in a
simple classification setting with two classes. Create a single plot
that displays each of these quantities as a function of ˆpm1. The x-axis
should display ˆpm1, ranging from 0 to 1, and the y-axis should display
the value of the Gini index, classification error, and entropy.
Hint: In a setting with two classes, pˆm1 = 1 − pˆm2. You could make
this plot by hand, but it will be much easier to make in R.
PosProb = seq(0.0, 1.0, .01)
NegProb = 1 - PosProb
ClassError = seq(0,1,.01)
ClassError[PosProb > .5] = 1 - PosProb[PosProb > .5]
gini = PosProb * (1 - PosProb) + NegProb*(1-NegProb)
ent = -PosProb*log(PosProb) - NegProb*log(NegProb)
plot(PosProb, ent, typ = "l", xlab = "Probability", ylab = "Value")
lines(PosProb, gini, col = "blue")
lines(PosProb, ClassError, col = "red")
In the lab, a classification tree was applied to the
Carseats data set after converting Sales into
a qualitative response variable. Now we will seek to predict
Sales using regression trees and related approaches,
treating the response as a quantitative variable.
Q8(a) Split the data set into a training set and a
test set.
A8(a)
library(ISLR)
library(rpart)
library(caret)
set.seed(1)
attach(Carseats)
library(tree)
inTrain=createDataPartition(Carseats$Sales,p=.75,list=FALSE)
train=Carseats[inTrain,]
test=Carseats[-inTrain, ]
Q8(b) Fit a regression tree to the training set.
Plot the tree, and interpret the results. What test MSE do you
obtain?
A8(b) Using tree() function my results are
Variables actually used in tree construction: ShelveLoc,
Price, Age, Advertising,
CompPrice, and US. It resulted in 18 terminal
nodes ans a test MSE is 4.9358.
Course material Labs method using tree()
function
set.seed(1)
tree.car = tree(Sales ~ ., data = train)
dev.new(width=5, height=24, unit="in")
plot(tree.car)
text(tree.car, pretty = 0, cex=.55)
tree.pred = predict(tree.car, test)
mean((test$Sales - tree.pred) ^ 2)
## [1] 4.935788
summary(tree.car)
##
## Regression tree:
## tree(formula = Sales ~ ., data = train)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "Age" "Advertising"
## [6] "CompPrice"
## Number of terminal nodes: 16
## Residual mean deviance: 2.295 = 654 / 285
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -3.5490 -1.0160 -0.1187 0.0000 0.9648 3.5840
Q8(c) Use cross-validation in order to determine the
optimal level of tree complexity. Does pruning the tree improve the test
MSE?
A8(c) Optimal level of trees is 12 with cross
validation. With pruning the test MSE is 4.9708. Pruning did not help
improve the test MSE.
set.seed(1)
cv.car=cv.tree(tree.car)
plot(cv.car$size, cv.car$dev,xlab = "Size of Tree",ylab = "Deviance")
summary(cv.car)
## Length Class Mode
## size 15 -none- numeric
## dev 15 -none- numeric
## k 15 -none- numeric
## method 1 -none- character
prune.car=prune.tree(tree.car, best=12)
summary(prune.car)
##
## Regression tree:
## snip.tree(tree = tree.car, nodes = c(9L, 10L, 45L))
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "Advertising" "CompPrice"
## Number of terminal nodes: 12
## Residual mean deviance: 2.634 = 761.3 / 289
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -4.993 -1.028 -0.093 0.000 0.902 4.372
plot(prune.car)
text(prune.car)
prune.car
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 301 2343.00 7.554
## 2) ShelveLoc: Bad,Medium 236 1322.00 6.817
## 4) Price < 94.5 40 167.70 9.232
## 8) Income < 56 8 32.57 7.019 *
## 9) Income > 56 32 86.16 9.786 *
## 5) Price > 94.5 196 873.70 6.324
## 10) ShelveLoc: Bad 60 173.50 5.128 *
## 11) ShelveLoc: Medium 136 576.40 6.852
## 22) Advertising < 6.5 75 270.50 6.033
## 44) Price < 127 45 95.50 6.726 *
## 45) Price > 127 30 120.90 4.993 *
## 23) Advertising > 6.5 61 193.60 7.860
## 46) Price < 127 37 94.60 8.538
## 92) CompPrice < 121.5 12 21.30 6.961 *
## 93) CompPrice > 121.5 25 29.12 9.295 *
## 47) Price > 127 24 55.76 6.815 *
## 3) ShelveLoc: Good 65 427.20 10.230
## 6) Price < 142.5 54 251.40 10.920
## 12) Price < 99.5 15 44.25 12.870 *
## 13) Price > 99.5 39 128.30 10.170
## 26) Advertising < 0.5 11 17.59 8.369 *
## 27) Advertising > 0.5 28 61.00 10.880 *
## 7) Price > 142.5 11 23.61 6.838 *
prune.pred=predict(prune.car, test)
mean((test$Sales - prune.pred) ^ 2)
## [1] 4.970822
Q8(d) Use the bagging approach in order to analyze
this data. What test MSE do you obtain? Use the
importance() function to determine which variables are most
important.
A8(d) After using the bagging approach my test MSE is
3.6088 This test MSE is lower than the test MSE with the logistic tree
and the pruned tree. The most important predictors are
Price, ShelveLoc which have a greater
importance over the other variables.
library(randomForest)
set.seed(1)
bag.carseats = randomForest(Sales ~ ., data = train, ntree = 500,
importance = TRUE)
bag.pred = predict(bag.carseats, newdata = test)
bag.mse = mean((bag.pred - test$Sales)^2)
bag.mse
## [1] 3.608817
importance(bag.carseats)
## %IncMSE IncNodePurity
## CompPrice 14.8699983 196.12444
## Income 5.2932823 164.18484
## Advertising 16.0552041 205.98253
## Population -2.8993949 134.26420
## Price 46.0090231 581.67363
## ShelveLoc 53.9203166 549.45519
## Age 16.0708797 236.32475
## Education 1.5229240 102.94677
## Urban -0.8830921 20.33000
## US 3.9642113 32.40947
varImpPlot (bag.carseats)
Q8(e) Use random forests to analyze this data. What
test MSE do you obtain? Use the importance() function to
determine which variables are most important. Describe the effect of m,
the number of variables considered at each split, on the error rate
obtained.
A8(e) randomForest provided me with the lowest test MSE
so far. The importance of M is that it relies on the reflects the number
of variables randomly sampled as candidates at each split. By default
for a a regression is is typically p(predictors)/3. For an mtry of 12 I
get test MSE of 3.1626 and with a mtry of 7 I get a test MSE of 3.273323
The affect of m is that when we change M the MSE varies by a small
amount, top predictors of importance are Price, ShelveLoc, CompPrice,
and Advertising.
set.seed(1)
randomforest.carseats = randomForest(Sales ~ ., data = train, mtry = 7, ntree = 500, importance = TRUE)
randomforest.predict = predict(randomforest.carseats, test)
randomforset.mse = mean((test$Sales - randomforest.predict)^2)
randomforset.mse
## [1] 3.273323
set.seed(1)
randomforest.carseats2 = randomForest(Sales ~ ., data = train, mtry = 12, ntree = 500, importance = TRUE)
randomforest.predict2 = predict(randomforest.carseats2, test)
randomforset.mse2 = mean((test$Sales - randomforest.predict2)^2)
randomforset.mse2
## [1] 3.162588
importance(randomforest.carseats2)
## %IncMSE IncNodePurity
## CompPrice 30.326955 215.44587
## Income 9.084429 104.67827
## Advertising 24.695116 183.50245
## Population 1.369278 70.96353
## Price 68.468514 742.34978
## ShelveLoc 80.392723 695.38606
## Age 19.959245 183.40868
## Education 3.721950 60.19692
## Urban -1.279238 10.64221
## US 4.147063 15.31079
Q8(f) Now analyze the data using BART, and report
your results
A8(f) The MSE using BART is 1.641
library(BART)
set.seed(1)
bart.car = gbart(train[, 2:11], train$Sales,
x.test = test[, 2:11])
## *****Calling gbart: type=1
## *****Data:
## data:n,p,np: 301, 14, 99
## y1,yn: 1.945947, 2.155947
## x1,x[n*p]: 138.000000, 1.000000
## xp1,xp[np*p]: 136.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 68 ... 1
## *****burn,nd,thin: 100,1000,1
## *****Prior:beta,alpha,tau,nu,lambda,offset: 2,0.95,0.287616,3,0.19762,7.55405
## *****sigma: 1.007235
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,14,0
## *****printevery: 100
##
## MCMC
## done 0 (out of 1100)
## done 100 (out of 1100)
## done 200 (out of 1100)
## done 300 (out of 1100)
## done 400 (out of 1100)
## done 500 (out of 1100)
## done 600 (out of 1100)
## done 700 (out of 1100)
## done 800 (out of 1100)
## done 900 (out of 1100)
## done 1000 (out of 1100)
## time: 3s
## trcnt,tecnt: 1000,1000
bart.car$varcount.mean
## CompPrice Income Advertising Population Price ShelveLoc1
## 18.759 14.019 14.747 14.044 25.431 16.830
## ShelveLoc2 ShelveLoc3 Age Education Urban1 Urban2
## 18.585 15.455 17.088 14.607 14.627 16.170
## US1 US2
## 16.000 17.670
mean((test$Sales - bart.car$yhat.test.mean)^2)
## [1] 1.641002
detach(Carseats)
This problem involves the OJ data set which is part of
the ISLR2 package
library(ISLR2)
attach(OJ)
Q9(a) Create a training set containing a random
sample of 800 observations, and a test set containing the remaining
observations.
A9(a)
set.seed(1)
inTrainOJ=sample(nrow(OJ), 800)
trainOJ=OJ[inTrain,]
testOJ=OJ[-inTrain, ]
Q9(b) Fit a tree to the training data, with
Purchase as the response and the other variables as
predictors. Use the summary() function to produce summary
statistics about the tree, and describe the results obtained. What is
the training error rate? How many terminal nodes does the tree
have?
A9(b) The training error listed in the summary as
misclassification error is .116, the tree has 18 terminal nodes, and the
variables used in the tree are LoyalCH,
SalePriceMM, StoreId,
WeekofPurchase, ListPriceDiff,
DiscMM, and STORE.
set.seed(1)
OJ_tree=tree(Purchase ~ ., data = trainOJ)
summary(OJ_tree)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = trainOJ)
## Variables actually used in tree construction:
## [1] "LoyalCH" "SalePriceMM" "StoreID" "WeekofPurchase"
## [5] "ListPriceDiff" "DiscMM" "STORE"
## Number of terminal nodes: 18
## Residual mean deviance: 0.5139 = 145.4 / 283
## Misclassification error rate: 0.1163 = 35 / 301
Q9(c) Type in the name of the tree object in order
to get a detailed text output. Pick one of the terminal nodes, and
interpret the information displayed.
A9(c) Node 18. The splitting value of this node for
SalePriceMM is 2.04. 42 observation in this node predict MM. Roughly
16.7% of the points in this node have the value of sales for choosing
CH, the remaining 83.3% predictions choose MM.
OJ_tree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 301 384.100 CH ( 0.66445 0.33555 )
## 2) LoyalCH < 0.705326 146 192.400 MM ( 0.36986 0.63014 )
## 4) LoyalCH < 0.495536 94 102.300 MM ( 0.23404 0.76596 )
## 8) LoyalCH < 0.067098 19 0.000 MM ( 0.00000 1.00000 ) *
## 9) LoyalCH > 0.067098 75 90.770 MM ( 0.29333 0.70667 )
## 18) SalePriceMM < 2.04 42 37.850 MM ( 0.16667 0.83333 )
## 36) StoreID < 3.5 24 8.314 MM ( 0.04167 0.95833 ) *
## 37) StoreID > 3.5 18 22.910 MM ( 0.33333 0.66667 )
## 74) SalePriceMM < 1.74 12 16.640 CH ( 0.50000 0.50000 ) *
## 75) SalePriceMM > 1.74 6 0.000 MM ( 0.00000 1.00000 ) *
## 19) SalePriceMM > 2.04 33 45.470 MM ( 0.45455 0.54545 )
## 38) LoyalCH < 0.308432 12 13.500 CH ( 0.75000 0.25000 )
## 76) WeekofPurchase < 257.5 5 0.000 CH ( 1.00000 0.00000 ) *
## 77) WeekofPurchase > 257.5 7 9.561 CH ( 0.57143 0.42857 ) *
## 39) LoyalCH > 0.308432 21 25.130 MM ( 0.28571 0.71429 )
## 78) WeekofPurchase < 265.5 12 6.884 MM ( 0.08333 0.91667 ) *
## 79) WeekofPurchase > 265.5 9 12.370 CH ( 0.55556 0.44444 ) *
## 5) LoyalCH > 0.495536 52 69.290 CH ( 0.61538 0.38462 )
## 10) ListPriceDiff < 0.165 16 17.990 MM ( 0.25000 0.75000 )
## 20) WeekofPurchase < 247 6 0.000 MM ( 0.00000 1.00000 ) *
## 21) WeekofPurchase > 247 10 13.460 MM ( 0.40000 0.60000 ) *
## 11) ListPriceDiff > 0.165 36 38.140 CH ( 0.77778 0.22222 )
## 22) DiscMM < 0.35 31 27.390 CH ( 0.83871 0.16129 )
## 44) SalePriceMM < 2.155 11 0.000 CH ( 1.00000 0.00000 ) *
## 45) SalePriceMM > 2.155 20 22.490 CH ( 0.75000 0.25000 )
## 90) LoyalCH < 0.6176 14 18.250 CH ( 0.64286 0.35714 ) *
## 91) LoyalCH > 0.6176 6 0.000 CH ( 1.00000 0.00000 ) *
## 23) DiscMM > 0.35 5 6.730 MM ( 0.40000 0.60000 ) *
## 3) LoyalCH > 0.705326 155 68.700 CH ( 0.94194 0.05806 )
## 6) STORE < 1.5 66 10.360 CH ( 0.98485 0.01515 )
## 12) LoyalCH < 0.994319 61 0.000 CH ( 1.00000 0.00000 ) *
## 13) LoyalCH > 0.994319 5 5.004 CH ( 0.80000 0.20000 ) *
## 7) STORE > 1.5 89 53.810 CH ( 0.91011 0.08989 )
## 14) WeekofPurchase < 269.5 64 48.230 CH ( 0.87500 0.12500 ) *
## 15) WeekofPurchase > 269.5 25 0.000 CH ( 1.00000 0.00000 ) *
Q9(d) Create a plot of the tree, and interpret the
results.
A9(d) 18 terminal nodes with many repeated nodes.
plot(OJ_tree, uniform=FALSE,
main="Classification Tree")
text(OJ_tree, pretty = 0,cex=0.55)
Q9(e) Predict the response on the test data, and
produce a confusion matrix comparing the test labels to the predicted
test labels. What is the test error rate?
A9(e) The test error rate is 24.967%
set.seed(1)
OJ.pred = predict(OJ_tree, testOJ, type = "class")
table(testOJ$Purchase, OJ.pred)
## OJ.pred
## CH MM
## CH 361 92
## MM 100 216
mean(OJ.pred!=testOJ$Purchase)
## [1] 0.2496749
confusionMatrix(testOJ$Purchase, OJ.pred)
## Confusion Matrix and Statistics
##
## Reference
## Prediction CH MM
## CH 361 92
## MM 100 216
##
## Accuracy : 0.7503
## 95% CI : (0.7182, 0.7806)
## No Information Rate : 0.5995
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.4823
##
## Mcnemar's Test P-Value : 0.6134
##
## Sensitivity : 0.7831
## Specificity : 0.7013
## Pos Pred Value : 0.7969
## Neg Pred Value : 0.6835
## Prevalence : 0.5995
## Detection Rate : 0.4694
## Detection Prevalence : 0.5891
## Balanced Accuracy : 0.7422
##
## 'Positive' Class : CH
##
Q9(f) Apply the cv.tree() function to
the training set in order to determine the optimal tree size.
A9(f) Optimal tree size would be 4, it has the smallest
misclassification error with a $dev of 54.
set.seed(1)
cv.OJ=cv.tree(OJ_tree, FUN=prune.misclass)
summary(cv.OJ)
## Length Class Mode
## size 7 -none- numeric
## dev 7 -none- numeric
## k 7 -none- numeric
## method 1 -none- character
names(cv.OJ)
## [1] "size" "dev" "k" "method"
cv.OJ
## $size
## [1] 18 9 7 4 3 2 1
##
## $dev
## [1] 56 56 58 54 68 69 103
##
## $k
## [1] -Inf 0 1 2 8 12 38
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
Q9(g) Produce a plot with tree size on the x-axis
and cross-validated classification error rate on the y-axis.
A9(g) Plot below:
plot(cv.OJ$size, cv.OJ$dev, xlab = 'Tree Size', ylab = 'Classification Error Rate')
Q9(h) Which tree size corresponds to the lowest
cross-validated classification error rate?
A9(h) Tree size 4 corresponds to the lowest
cross-validated classification error.
Q9(i) Produce a pruned tree corresponding to the
optimal tree size obtained using cross-validation. If cross-validation
does not lead to selection of a pruned tree, then create a pruned tree
with five terminal nodes.
A9(i) Pruned treated with 4 as optimal tree size.
set.seed(1)
prune.OJ=prune.tree(OJ_tree, best=4)
summary(prune.OJ)
##
## Classification tree:
## snip.tree(tree = OJ_tree, nodes = c(11L, 10L, 3L, 4L))
## Variables actually used in tree construction:
## [1] "LoyalCH" "ListPriceDiff"
## Number of terminal nodes: 4
## Residual mean deviance: 0.7647 = 227.1 / 297
## Misclassification error rate: 0.1429 = 43 / 301
Q9(j) Compare the training error rates between the
pruned and unpruned trees. Which is higher?
A9(j) Error rate for unpruned tree is 11.628% and the
error rate for the pruned tree is 14.286%. Pruned has a higher error
rate than unpruned. This is due to the unpruned tree overfitting.
set.seed(1)
unpruned.predict = predict(OJ_tree, newdata = trainOJ, type="class")
table(trainOJ$Purchase, unpruned.predict)
## unpruned.predict
## CH MM
## CH 190 10
## MM 25 76
mean(trainOJ$Purchase!=unpruned.predict)
## [1] 0.1162791
pruned.predict = predict(prune.OJ, newdata = trainOJ, type="class")
table(trainOJ$Purchase, pruned.predict)
## pruned.predict
## CH MM
## CH 174 26
## MM 17 84
mean(trainOJ$Purchase!=pruned.predict)
## [1] 0.1428571
Q9(k) Compare the test error rates between the
pruned and unpruned trees. Which is higher?
A9(k) The test error rate for the unpruned tree is is
24.967% (completed above in part A9(e)), and the test error rate for the
pruned tree is 19.896% making it more reliable, giving better
results.
test.prune.pred=predict(prune.OJ, testOJ, type='class')
table(testOJ$Purchase, test.prune.pred)
## test.prune.pred
## CH MM
## CH 354 99
## MM 54 262
mean(testOJ$Purchase!=test.prune.pred)
## [1] 0.1989597