library(ggplot2)
library(ISLR2)
library(tree)
library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
#install.packages("BART")
library(BART)
## Loading required package: nlme
## Loading required package: nnet
## Loading required package: survival
Consider the Gini index, classification error, and entropy in a simple classification setting with two classes. Create a single plot that displays each of these quantities as a function of pˆm1. The x-axis should display pˆm1, ranging from 0 to 1, and the y-axis should display the value of the Gini index, classification error, and entropy. Hint: In a setting with two classes, pˆm1 = 1 − pˆm2. You could make this plot by hand, but it will be much easier to make in R.
# p1 values ranging from 0 to 1 and p2 based on p1
p1 <- seq(0, 1, by = 0.01)
p2 <- 1 - p1
# Calculate Gini index
gini <- 2 * p1 * p2
# Calculate classification error
classification_error <- 1 - pmax(p1, p2)
# Calculate entropy
entropy <- -p1 * log(p1) - p2 * log(p2)
#combine
q3 <- data.frame(p1, gini, classification_error, entropy)
library(ggplot2)
ggplot(q3, aes(x = p1)) +
geom_line(aes(y = gini, color = "Gini"), size = 1) +
geom_line(aes(y = classification_error, color = "Classification Error"), size = 1) +
geom_line(aes(y = entropy, color = "Entropy"), size = 1) +
scale_color_manual(values = c("deeppink2", "deepskyblue3", "darkolivegreen4")) +
labs(x = "p1", y = "Value", color = "Measure") +
theme_minimal()
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Warning: Removed 2 rows containing missing values (`geom_line()`).
In the lab, a classification tree was applied to the Carseats data set after converting Sales into a qualitative response variable. Now we will seek to predict Sales using regression trees and related approaches, treating the response as a quantitative variable.
# 80/20 Split
set.seed(299)
index = sample(nrow(Carseats), 0.8*nrow(Carseats), replace = F)
carseat_train = Carseats[index,]
carseat_test = Carseats[-index,]
# Growing a tree
tree.carseats = tree(Sales ~ ., data = carseat_train)
summary(tree.carseats)
##
## Regression tree:
## tree(formula = Sales ~ ., data = carseat_train)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Age" "Income" "CompPrice"
## Number of terminal nodes: 15
## Residual mean deviance: 2.712 = 827.2 / 305
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -4.2100 -1.0900 -0.1153 0.0000 1.0740 5.5600
# Plot the tree
plot(tree.carseats)
text(tree.carseats, cex = 0.7)
#prediction
preds_carseats = predict(tree.carseats, newdata = carseat_test)
#compute MSE
mean((preds_carseats - carseat_test$Sales)^2)
## [1] 4.559933
The test MSE is 4.559933. Five variables were used to grow the tree: “ShelveLoc”,“Price” “Age”,“Income”, and “CompPrice”. You can see that there are 15 terminal nodes in the diagram.
set.seed(29)
cv.carseats = cv.tree(tree.carseats)
best_size = cv.carseats$size[which.min(cv.carseats$dev)]
best_size #returns 14
## [1] 14
# Plot the estimated test error rate
plot(cv.carseats$size, cv.carseats$dev, type = "b")
prune.carseats = prune.tree(tree.carseats,best=12)
preds_carseats_pruned = predict(prune.carseats, newdata=carseat_test)
mean((preds_carseats_pruned - carseat_test$Sales)^2)
## [1] 4.43234
The test MSE is 4.43234, compared to the unpruned MSE of 4.559. The pruned tree is barely better. You can see that I chose 12 terminal nodes based on the plot, though 14 terminal nodes i calculated as lowest deviance. Since 12 looks very close to 14 in terms of deviance, I chose to keep it simpler to avoid overfitting.
plot(prune.carseats)
text(prune.carseats, cex=.7)
# bagging is a special case of random forest where m = p (all features)
set.seed(29)
bag.carseats = randomForest(Sales ~ ., data = carseat_train, mtry = 10, importance = TRUE)
bag.carseats
##
## Call:
## randomForest(formula = Sales ~ ., data = carseat_train, mtry = 10, importance = TRUE)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 10
##
## Mean of squared residuals: 2.484017
## % Var explained: 69.13
#importance
importance(bag.carseats)
## %IncMSE IncNodePurity
## CompPrice 37.3400885 269.96766
## Income 11.7315196 138.52454
## Advertising 22.5506682 178.44597
## Population -2.3531224 69.61192
## Price 73.9002656 747.84858
## ShelveLoc 82.8804058 784.39183
## Age 22.7964285 227.21429
## Education 3.9945178 68.81012
## Urban -2.7669306 10.09019
## US 0.2676108 11.36861
preds_carseats_bag = predict(bag.carseats, newdata=carseat_test)
mean((preds_carseats_bag - carseat_test$Sales)^2)
## [1] 2.22704
The bagged model performs significantly better than the decision tree models with an MSE of 2.22704. By looking at importance, we see that ShelveLoc and Price are most important.
set.seed(29)
rf.carseats <- randomForest(Sales ~ ., data = carseat_train, importance = TRUE)
rf.carseats
##
## Call:
## randomForest(formula = Sales ~ ., data = carseat_train, importance = TRUE)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 3
##
## Mean of squared residuals: 2.909607
## % Var explained: 63.85
#importance
importance(rf.carseats)
## %IncMSE IncNodePurity
## CompPrice 16.217737 233.81378
## Income 6.261280 200.10340
## Advertising 18.007064 226.76061
## Population -3.117077 147.96513
## Price 46.326503 583.44896
## ShelveLoc 51.069175 588.26217
## Age 16.958881 285.91018
## Education 2.853488 106.69116
## Urban -2.331087 20.97769
## US 4.312647 38.84065
preds_carseats_rf = predict(rf.carseats, newdata=carseat_test)
mean((preds_carseats_rf - carseat_test$Sales)^2)
## [1] 2.601283
The random forest with the default mtry (sqare root of p) performed slightly worse than the bagged model. The MSE for the bagged model was 2.2 and the rf model MSE was 2.6. The same variables are important in this model as the bagged model: Shelveloc and Price.
set.seed(29)
rf.carseats2 <- randomForest(Sales ~ ., data = carseat_train, mtry=7, importance = TRUE)
preds_carseats_rf2 = predict(rf.carseats2, newdata=carseat_test)
mean((preds_carseats_rf2 - carseat_test$Sales)^2)
## [1] 2.207894
It looks like using mtry=7 beats the bagged model mean square error, though just barely.
# create matrices of predictors
x <- Carseats[,2:10]
y <- Carseats[,"Sales"]
xtrain <- x[index, ]
ytrain <- y[index]
xtest <- x[-index, ]
ytest <- y[-index]
set.seed(29)
bartfit_carseats <- gbart(xtrain, ytrain, x.test = xtest)
## *****Calling gbart: type=1
## *****Data:
## data:n,p,np: 320, 12, 80
## y1,yn: 4.163875, -1.856125
## x1,x[n*p]: 131.000000, 0.000000
## xp1,xp[np*p]: 138.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 69 ... 1
## *****burn,nd,thin: 100,1000,1
## *****Prior:beta,alpha,tau,nu,lambda,offset: 2,0.95,0.287616,3,0.200623,7.53613
## *****sigma: 1.014858
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,12,0
## *****printevery: 100
##
## MCMC
## done 0 (out of 1100)
## done 100 (out of 1100)
## done 200 (out of 1100)
## done 300 (out of 1100)
## done 400 (out of 1100)
## done 500 (out of 1100)
## done 600 (out of 1100)
## done 700 (out of 1100)
## done 800 (out of 1100)
## done 900 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## trcnt,tecnt: 1000,1000
yhat.bart.carseats <- bartfit_carseats$yhat.test.mean
mean((ytest - yhat.bart.carseats)^2)
## [1] 1.476019
# how many times each variable appears in the collection of trees
ord <- order(bartfit_carseats$varcount.mean, decreasing = T)
bartfit_carseats$varcount.mean[ord]
## Price CompPrice Age ShelveLoc2 Population Income
## 29.128 22.750 19.442 19.373 19.154 18.957
## Education ShelveLoc1 Urban2 Urban1 ShelveLoc3 Advertising
## 18.764 18.615 18.546 17.731 17.284 16.661
The mean square error is 1.476 using BART, which is the lowest MSE yet.
This problem involves the OJ data set which is part of the ISLR2 package.
# 80/20 Split
set.seed(299)
index = sample(nrow(OJ), 0.8*nrow(OJ), replace = F)
OJ_train = OJ[index,]
OJ_test = OJ[-index,]
# Growing a tree
set.seed(29)
tree.OJ = tree(Purchase ~ ., data = OJ_train)
summary(tree.OJ)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = OJ_train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "ListPriceDiff"
## Number of terminal nodes: 9
## Residual mean deviance: 0.6957 = 589.2 / 847
## Misclassification error rate: 0.1519 = 130 / 856
There are 9 terminal nodes and the training misclassification error rate is 15.19%.
tree.OJ
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 856 1143.00 CH ( 0.61215 0.38785 )
## 2) LoyalCH < 0.48285 331 356.70 MM ( 0.22961 0.77039 )
## 4) LoyalCH < 0.276142 186 122.70 MM ( 0.10215 0.89785 )
## 8) LoyalCH < 0.051325 69 10.45 MM ( 0.01449 0.98551 ) *
## 9) LoyalCH > 0.051325 117 100.50 MM ( 0.15385 0.84615 ) *
## 5) LoyalCH > 0.276142 145 194.30 MM ( 0.39310 0.60690 )
## 10) PriceDiff < 0.05 61 63.20 MM ( 0.21311 0.78689 ) *
## 11) PriceDiff > 0.05 84 116.30 CH ( 0.52381 0.47619 ) *
## 3) LoyalCH > 0.48285 525 437.70 CH ( 0.85333 0.14667 )
## 6) LoyalCH < 0.764572 242 281.60 CH ( 0.73140 0.26860 )
## 12) PriceDiff < 0.215 114 157.70 CH ( 0.52632 0.47368 )
## 24) PriceDiff < -0.165 32 33.62 MM ( 0.21875 0.78125 )
## 48) ListPriceDiff < 0.115 13 17.94 CH ( 0.53846 0.46154 ) *
## 49) ListPriceDiff > 0.115 19 0.00 MM ( 0.00000 1.00000 ) *
## 25) PriceDiff > -0.165 82 106.50 CH ( 0.64634 0.35366 ) *
## 13) PriceDiff > 0.215 128 75.02 CH ( 0.91406 0.08594 ) *
## 7) LoyalCH > 0.764572 283 99.34 CH ( 0.95760 0.04240 ) *
For the terminal node 8: This terminal node represents a subset of the data where LoyalCH < 0.051325.There are 69 observations in this node. Deviance is 10.45. If LoyalCH is less than 0.05, the outcome will be assigned to Minute Maid with a probability of 98%, and to CH with a probability of 1.4%.
plot(tree.OJ)
text(tree.OJ, cex=.7)
Looking at the diagram, we see that LoyalCH divides the tree into two
diverging branches. Most predictions on the left with low loyalCH are
MM, and most predictions on the right with high loyalCH are CH.
PriceDiff and ListPriceDiff also influence the predictions.
preds_OJ = predict(tree.OJ, newdata=OJ_test,type="class")
table(preds_OJ, OJ_test$Purchase)
##
## preds_OJ CH MM
## CH 115 36
## MM 14 49
1 - mean(preds_OJ == OJ_test$Purchase)
## [1] 0.2336449
The test error rate is 23%.
set.seed(29)
cv.OJ = cv.tree(tree.OJ)
best_size = cv.OJ$size[which.min(cv.OJ$dev)]
best_size #returns 8
## [1] 8
plot(cv.OJ$size, cv.OJ$dev, type = "b")
(h) Which tree size corresponds to the lowest cross-validated
classification error rate? 8 has the lowest deviance. However, looking
at the plot, 5 appears to be similar.
set.seed(29)
prune.OJ = prune.tree(tree.OJ,best=5)
preds_OJ_pruned = predict(prune.OJ, newdata=OJ_test, type="class")
table(preds_OJ_pruned, OJ_test$Purchase)
##
## preds_OJ_pruned CH MM
## CH 111 33
## MM 18 52
1 - mean(preds_OJ_pruned == OJ_test$Purchase)
## [1] 0.2383178
summary(tree.OJ)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = OJ_train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "ListPriceDiff"
## Number of terminal nodes: 9
## Residual mean deviance: 0.6957 = 589.2 / 847
## Misclassification error rate: 0.1519 = 130 / 856
summary(prune.OJ)
##
## Classification tree:
## snip.tree(tree = tree.OJ, nodes = c(4L, 5L, 12L))
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff"
## Number of terminal nodes: 5
## Residual mean deviance: 0.7627 = 649.1 / 851
## Misclassification error rate: 0.1787 = 153 / 856
The pruned tree has a higher training error rate (17.87% compared to 15.19%).
1 - mean(preds_OJ == OJ_test$Purchase)
## [1] 0.2336449
1 - mean(preds_OJ_pruned == OJ_test$Purchase)
## [1] 0.2383178
The test error rate is slightly higher for the pruned tree, though they are close.