library(tidyverse)
library(openintro)
library(ISLR)
library(ISLR2)
library(tree)
library(rpart)
library(caret)
library(randomForest)
library(BART)
Hint: In a setting with two classes, \(\hat{p}_{m1}\) = 1 − \(\hat{p}_{m2}\). You could make this plot by
hand, but it will be much easier to make in R.
p <- seq(0, 1, 0.001)
gini.index <- 2 * p * (1 - p)
class.error <- 1 - pmax(p, 1 - p)
cross.entropy <- - (p * log(p) + (1 - p) * log(1 - p))
matplot(p,
cbind(gini.index, class.error, cross.entropy),
ylab = "Value",
col = c("blue", "gray", "red"),
type = "l",
lwd = 1)
legend("topright",
legend = c("Gini Index", "Classification Error", "Cross Entropy"),
col = c("blue", "gray", "red"),
lty = 1,
lwd = 1,
cex = 0.7)
##
## Regression tree:
## tree(formula = Sales ~ ., data = strain)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "Age" "Population"
## [6] "Education" "CompPrice" "Advertising"
## Number of terminal nodes: 18
## Residual mean deviance: 2.132 = 388.1 / 182
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -4.08000 -0.92870 0.06244 0.00000 0.87020 3.71700
## [1] 4.395357
The test MSE is 4.395357.
prune.car <- prune.tree(tree.seats, best = cv.seats$size[which.min(cv.seats$dev)])
plot(prune.car)
text(prune.car,pretty=0)## [1] 4.658628
Pruning the tree to 14 increased the test MSE from 4.395357 to 4.658628.
importance() function to determine which variables are most
important.
set.seed(123)
bag.seats <- randomForest(Sales~., data = strain, mtry = 10, ntree = 500,
importance = TRUE)
bagseat.pred <- predict(bag.seats, newdata = stest)
mean((bagseat.pred - stest$Sales)^2)## [1] 2.76144
## %IncMSE IncNodePurity
## CompPrice 20.3414969 158.911610
## Income 6.6237140 90.369331
## Advertising 5.7777253 72.793558
## Population -2.2001506 55.786278
## Price 44.3578602 380.255094
## ShelveLoc 48.3345635 387.886972
## Age 18.6296851 187.107660
## Education 2.6619834 55.987493
## Urban 0.9276070 8.152320
## US 0.4202302 5.900097
The test MSE is 2.76144. The most important variables are
ShelveLoc, Price, CompPrice, and
Age.
importance() function to determine which variables are most
important. Describe the effect of \(m\), the number of variables considered at
each split, on the error rate obtained.
mtry_vals <- c(3, 5, 7, 9)
for (m in mtry_vals) {
set.seed(123)
rf.seats <- randomForest(Sales ~ ., data = strain, mtry = m, ntree = 500, importance = TRUE)
pred <- predict(rf.seats, newdata = stest)
mse <- mean((pred - stest$Sales)^2)
cat("mtry =", m, "- Test MSE =", round(mse, 3), "\n")
cat("Variable Importance (based on %IncMSE):\n")
print(randomForest::importance(rf.seats)[, "%IncMSE"])
cat("------------------------------------------------------------\n")
}## mtry = 3 - Test MSE = 3.533
## Variable Importance (based on %IncMSE):
## CompPrice Income Advertising Population Price ShelveLoc
## 11.95200717 5.20336607 7.04990655 2.81702037 31.00971810 31.20191070
## Age Education Urban US
## 16.38791734 0.70334051 0.04221302 1.67637748
## ------------------------------------------------------------
## mtry = 5 - Test MSE = 3.017
## Variable Importance (based on %IncMSE):
## CompPrice Income Advertising Population Price ShelveLoc
## 15.6874802 6.1782093 7.4822778 0.1519680 35.6880806 42.1336259
## Age Education Urban US
## 18.3059283 1.8670124 -0.1467837 0.3849865
## ------------------------------------------------------------
## mtry = 7 - Test MSE = 2.832
## Variable Importance (based on %IncMSE):
## CompPrice Income Advertising Population Price ShelveLoc
## 18.1621394 6.5536181 5.5340650 -0.1318886 40.2052050 44.2028844
## Age Education Urban US
## 19.7420288 2.4063248 1.2543161 2.1656941
## ------------------------------------------------------------
## mtry = 9 - Test MSE = 2.746
## Variable Importance (based on %IncMSE):
## CompPrice Income Advertising Population Price ShelveLoc
## 21.05555659 6.65634770 6.89407868 -0.07229015 44.96551093 47.51448378
## Age Education Urban US
## 20.84240270 2.80047063 -2.15163575 1.90015802
## ------------------------------------------------------------
The test MSE decreases as \(m\)
increases. This suggests that increasing \(m\) reduces the bias of the model, leading
to better predictions on the test set. ShelveLoc,
Price, CompPrice, and Age
continue to have the most variable importance.
set.seed(123)
x_train <- strain[, setdiff(names(strain), "Sales")]
y_train <- strain$Sales
x_test <- stest[, setdiff(names(stest), "Sales")]
y_test <- stest$Sales
bart_model <- gbart(x.train = x_train, y.train = y_train, x.test = x_test)## *****Calling gbart: type=1
## *****Data:
## data:n,p,np: 200, 14, 200
## y1,yn: 3.230000, 3.070000
## x1,x[n*p]: 104.000000, 1.000000
## xp1,xp[np*p]: 138.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 63 ... 1
## *****burn,nd,thin: 100,1000,1
## *****Prior:beta,alpha,tau,nu,lambda,offset: 2,0.95,0.260569,3,0.191523,7.43
## *****sigma: 0.991574
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,14,0
## *****printevery: 100
##
## MCMC
## done 0 (out of 1100)
## done 100 (out of 1100)
## done 200 (out of 1100)
## done 300 (out of 1100)
## done 400 (out of 1100)
## done 500 (out of 1100)
## done 600 (out of 1100)
## done 700 (out of 1100)
## done 800 (out of 1100)
## done 900 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## trcnt,tecnt: 1000,1000
## [1] 1.622453
BART produces a test MSE of 1.622453.
OJ data set which is part of
the ISLR2 package.
summary() function to produce summary
statistics about the tree, and describe the results obtained. What is
the training error rate? How many terminal nodes does the tree
have?
##
## Classification tree:
## tree(formula = Purchase ~ ., data = OJtrain)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "SpecialCH" "ListPriceDiff"
## [5] "PctDiscMM"
## Number of terminal nodes: 9
## Residual mean deviance: 0.7432 = 587.8 / 791
## Misclassification error rate: 0.1588 = 127 / 800
5 variables were used to construct the tree (LoyalCH,
PriceDiff, SpecialCH,
ListPriceDiff and PctDiscMM). The training
error rate is 0.1588. There are 9 terminal nodes on the tree.
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1073.00 CH ( 0.60625 0.39375 )
## 2) LoyalCH < 0.5036 365 441.60 MM ( 0.29315 0.70685 )
## 4) LoyalCH < 0.280875 177 140.50 MM ( 0.13559 0.86441 )
## 8) LoyalCH < 0.0356415 59 10.14 MM ( 0.01695 0.98305 ) *
## 9) LoyalCH > 0.0356415 118 116.40 MM ( 0.19492 0.80508 ) *
## 5) LoyalCH > 0.280875 188 258.00 MM ( 0.44149 0.55851 )
## 10) PriceDiff < 0.05 79 84.79 MM ( 0.22785 0.77215 )
## 20) SpecialCH < 0.5 64 51.98 MM ( 0.14062 0.85938 ) *
## 21) SpecialCH > 0.5 15 20.19 CH ( 0.60000 0.40000 ) *
## 11) PriceDiff > 0.05 109 147.00 CH ( 0.59633 0.40367 ) *
## 3) LoyalCH > 0.5036 435 337.90 CH ( 0.86897 0.13103 )
## 6) LoyalCH < 0.764572 174 201.00 CH ( 0.73563 0.26437 )
## 12) ListPriceDiff < 0.235 72 99.81 MM ( 0.50000 0.50000 )
## 24) PctDiscMM < 0.196196 55 73.14 CH ( 0.61818 0.38182 ) *
## 25) PctDiscMM > 0.196196 17 12.32 MM ( 0.11765 0.88235 ) *
## 13) ListPriceDiff > 0.235 102 65.43 CH ( 0.90196 0.09804 ) *
## 7) LoyalCH > 0.764572 261 91.20 CH ( 0.95785 0.04215 ) *
Node 9 (LoyalCH) has 118 observations. It also shows
that it has a value of LoyalCH < 0.0356415. Over 80% of
the observations in this node take the value of MM and just under 20% of
the observations take the value of CH.
LoyalCH, SpecialCH, PriceDiff,
PctDiscMM, and ListPriceDiff are the most
important variables.
set.seed(1)
treeOJ.pred <- predict(tree.OJ, newdata = OJtest, type = "class")
confusionMatrix(OJtest$Purchase, treeOJ.pred)## Confusion Matrix and Statistics
##
## Reference
## Prediction CH MM
## CH 160 8
## MM 38 64
##
## Accuracy : 0.8296
## 95% CI : (0.7794, 0.8725)
## No Information Rate : 0.7333
## P-Value [Acc > NIR] : 0.0001259
##
## Kappa : 0.6154
##
## Mcnemar's Test P-Value : 1.904e-05
##
## Sensitivity : 0.8081
## Specificity : 0.8889
## Pos Pred Value : 0.9524
## Neg Pred Value : 0.6275
## Prevalence : 0.7333
## Detection Rate : 0.5926
## Detection Prevalence : 0.6222
## Balanced Accuracy : 0.8485
##
## 'Positive' Class : CH
##
## Accuracy
## 0.1703704
0.1703704 is the test error rate.
cv.tree() function to the training set in order to
determine the optimal tree size.
## $size
## [1] 9 8 7 4 2 1
##
## $dev
## [1] 145 145 146 146 167 315
##
## $k
## [1] -Inf 0.000000 3.000000 4.333333 10.500000 151.000000
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
## [1] 9
The tree size of 9 corresponds to the lowest cross-validated classification error rate. As this is the same as the number of terminal nodes from the original tree, the size suggested by cross-validation is not considered.
##
## Classification tree:
## tree(formula = Purchase ~ ., data = OJtrain)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "SpecialCH" "ListPriceDiff"
## [5] "PctDiscMM"
## Number of terminal nodes: 9
## Residual mean deviance: 0.7432 = 587.8 / 791
## Misclassification error rate: 0.1588 = 127 / 800
##
## Classification tree:
## snip.tree(tree = tree.OJ, nodes = c(4L, 12L, 5L))
## Variables actually used in tree construction:
## [1] "LoyalCH" "ListPriceDiff"
## Number of terminal nodes: 5
## Residual mean deviance: 0.8239 = 655 / 795
## Misclassification error rate: 0.205 = 164 / 800
The pruned tree has a higher training error rate (0.205) than the unpruned tree (0.1588).
set.seed(1)
treeOJ.pred <- predict(tree.OJ, newdata = OJtest, type = "class")
confusionMatrix(OJtest$Purchase, treeOJ.pred)## Confusion Matrix and Statistics
##
## Reference
## Prediction CH MM
## CH 160 8
## MM 38 64
##
## Accuracy : 0.8296
## 95% CI : (0.7794, 0.8725)
## No Information Rate : 0.7333
## P-Value [Acc > NIR] : 0.0001259
##
## Kappa : 0.6154
##
## Mcnemar's Test P-Value : 1.904e-05
##
## Sensitivity : 0.8081
## Specificity : 0.8889
## Pos Pred Value : 0.9524
## Neg Pred Value : 0.6275
## Prevalence : 0.7333
## Detection Rate : 0.5926
## Detection Prevalence : 0.6222
## Balanced Accuracy : 0.8485
##
## 'Positive' Class : CH
##
test_error_OJ <- 1 - confusionMatrix(OJtest$Purchase, treeOJ.pred)$overall['Accuracy']
cat("Test error rate for unpruned tree:", test_error_OJ, "\n")## Test error rate for unpruned tree: 0.1703704
pruneOJ.pred <- predict(prune.OJ, newdata = OJtest, type = "class")
confusionMatrix(OJtest$Purchase, pruneOJ.pred)## Confusion Matrix and Statistics
##
## Reference
## Prediction CH MM
## CH 136 32
## MM 21 81
##
## Accuracy : 0.8037
## 95% CI : (0.7512, 0.8494)
## No Information Rate : 0.5815
## P-Value [Acc > NIR] : 7.709e-15
##
## Kappa : 0.5911
##
## Mcnemar's Test P-Value : 0.1696
##
## Sensitivity : 0.8662
## Specificity : 0.7168
## Pos Pred Value : 0.8095
## Neg Pred Value : 0.7941
## Prevalence : 0.5815
## Detection Rate : 0.5037
## Detection Prevalence : 0.6222
## Balanced Accuracy : 0.7915
##
## 'Positive' Class : CH
##
test_error_prune <- 1 - confusionMatrix(OJtest$Purchase, pruneOJ.pred)$overall['Accuracy']
cat("Test error rate for pruned tree:", test_error_prune, "\n")## Test error rate for pruned tree: 0.1962963
The unpruned tree has a lower test error rate (0.1703704) than the pruned tree (0.1962963).