Load Libraries
library(ISLR2)
library(rpart)
library(randomForest)
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
library(gbm)
## Loaded gbm 2.2.2
## This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
library(BART)
## Loading required package: nlme
## Loading required package: survival
library(caret)
## Loading required package: ggplot2
##
## Attaching package: 'ggplot2'
## The following object is masked from 'package:randomForest':
##
## margin
## Loading required package: lattice
##
## Attaching package: 'caret'
## The following object is masked from 'package:survival':
##
## cluster
library(rattle)
## Loading required package: tibble
## Loading required package: bitops
## Rattle: A free graphical interface for data science with R.
## Version 5.5.1 Copyright (c) 2006-2021 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.
##
## Attaching package: 'rattle'
## The following object is masked from 'package:randomForest':
##
## importance
library(tree)
Consider the Gini index, classification error, and entropy in a simple classification setting with two classes. Create a single plot that displays each of these quantities as a function of ˆ pm1. The x-axis should display ˆ pm1, ranging from 0 to 1, and the y-axis should display the value of the Gini index, classification error, and entropy.
Hint: In a setting with two classes,ˆpm1 = 1−ˆpm2. You could make this plot by hand, but it will be much easier to make in R.
#sequence of values for pm1
pm1 <- seq(0, 1, 0.01)
pm2 <- 1 - pm1
#calculate impurity measures
gini <- 1 - pm1^2 - pm2^2
classification_error <- 1 - pmax(pm1, pm2)
entropy <- -(pm1 * log(pm1) + pm2 * log2(pm2))
entropy[is.nan(entropy)] <- 0
plot(pm1, gini, type = 'l', col = 'blue', lwd = 2,
ylim = c(0, 1))
lines(pm1, classification_error, col = 'red', lwd = 2)
lines(pm1, entropy, col = 'darkgreen', lwd = 2)
legend('topright', legend = c('Gini', 'Classification Error', 'Entropy'),
col = c('blue', 'red', 'darkgreen'), lwd = 2, cex = 0.8)
In the lab, a classification tree was applied to the Carseats data set after converting Sales into a qualitative response variable. Now we will seek to predict Sales using regression trees and related approaches, treating the response as a quantitative variable.
data(Carseats)
set.seed(42)
train_index <- createDataPartition(Carseats$Sales, p = 0.7, list = FALSE)
train_df <- Carseats[train_index, ]
test_df <- Carseats[-train_index, ]
tree_model <- rpart(Sales ~ ., data = train_df, method = 'anova')
rattle::fancyRpartPlot(tree_model, sub = " ", cex = 0.6)
Based on our plot of the regression tree, we an see that
ShelveLoc = Bad, Medium is the first variable our data is
split on. This indicates that its the most important predictors of
Sales in our model.
It then uses Price, and just keeps going on multiple
different variables after that, but our data is split in that first node
in splits of 79% and 21%.
In the end, our data is split into 17 bins, each of which will use the average to predict.
tree_model_preds <- predict(tree_model, newdata = test_df)
tree_mse <- mean((tree_model_preds - test_df$Sales)^2)
cat('Regression Tree MSE: ', tree_mse)
## Regression Tree MSE: 4.045978
printcp(tree_model)
##
## Regression tree:
## rpart(formula = Sales ~ ., data = train_df, method = "anova")
##
## Variables actually used in tree construction:
## [1] Advertising Age CompPrice Education Population Price
## [7] ShelveLoc
##
## Root node error: 2328.7/281 = 8.2872
##
## n= 281
##
## CP nsplit rel error xerror xstd
## 1 0.275002 0 1.00000 1.00608 0.083098
## 2 0.103378 1 0.72500 0.73540 0.059608
## 3 0.050273 2 0.62162 0.66093 0.054804
## 4 0.050228 3 0.57135 0.69899 0.056246
## 5 0.034056 4 0.52112 0.66312 0.055345
## 6 0.030912 5 0.48706 0.62839 0.051791
## 7 0.023292 7 0.42524 0.64131 0.053724
## 8 0.023285 8 0.40195 0.64856 0.054364
## 9 0.018886 9 0.37866 0.65210 0.054510
## 10 0.013324 10 0.35978 0.64510 0.052367
## 11 0.013317 11 0.34645 0.67176 0.052992
## 12 0.013276 12 0.33314 0.67176 0.052992
## 13 0.012766 13 0.31986 0.65902 0.052634
## 14 0.011850 14 0.30709 0.64886 0.051256
## 15 0.010261 15 0.29524 0.65041 0.051087
## 16 0.010000 16 0.28498 0.64784 0.051131
plotcp(tree_model)
At a nsplit of 5, our model’s xerror reaches a local
minimum, which is a good indicator of where to prune the tree.
pruned_tree <- prune(tree_model,
cp = tree_model$cptable[which.min(tree_model$cptable[,'xerror']), 'CP'])
rattle::fancyRpartPlot(pruned_tree, sub = ' ', cex = 0.7)
pruned_tree_preds <- predict(pruned_tree, newdata = test_df)
pruned_tree_mse <- mean((pruned_tree_preds - test_df$Sales)^2)
cat('Regression Tree MSE: ', pruned_tree_mse)
## Regression Tree MSE: 4.497167
MSE gets worse with pruning.
set.seed(42)
bag_model <- randomForest(Sales ~ ., data = train_df,
mtry = ncol(train_df) - 1,
importance = TRUE)
bag_model_preds <- predict(bag_model, newdata = test_df)
bag_model_mse <- mean((bag_model_preds - test_df$Sales)^2)
cat('Bag Model MSE: ', bag_model_mse)
## Bag Model MSE: 2.063321
Bagging is performed here because mtry = number of predictors except
Sales. This indicates that all predictors should be
considered for each split of the tree, or in other words, bagging should
be done.
varImpPlot(bag_model)
ShelveLoc and Price are the 2 most
important variables.
set.seed(42)
rf_model <- randomForest(Sales ~ ., data = train_df,
mtry = 8,
importance = TRUE)
rf_model_preds <- predict(rf_model, newdata = test_df)
rf_model_mse <- mean((rf_model_preds - test_df$Sales)^2)
cat('RF model MSE: ', rf_model_mse)
## RF model MSE: 2.037389
varImpPlot(rf_model)
Same as with our bagging model. ShelveLoc and
Price are our 2 most important variables. Then it is
CompPrice, Advertising, Age,
Income….
In this case, as we increased mtry from 4 to 8 by 1, the MSE kept decreasing. This indicates that by increasing the number of features considered at each split, the model was able to make better predictions on the test set. Our higher mtry essentially decreased the Randomness since we are giving each tree more features at each split.
xtrain <- train_df[, -which(names(train_df) == 'Sales')] #all cols except Sales
xtest <- test_df[, -which(names(test_df) == 'Sales')] #all cols except Sales
set.seed(42)
bart_model <-gbart(x.train = xtrain, y.train = train_df$Sales,
x.test = xtest)
## *****Calling gbart: type=1
## *****Data:
## data:n,p,np: 281, 14, 119
## y1,yn: 3.685623, 2.175623
## x1,x[n*p]: 111.000000, 1.000000
## xp1,xp[np*p]: 138.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 68 ... 1
## *****burn,nd,thin: 100,1000,1
## *****Prior:beta,alpha,tau,nu,lambda,offset: 2,0.95,0.287616,3,0.196244,7.53438
## *****sigma: 1.003722
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,14,0
## *****printevery: 100
##
## MCMC
## done 0 (out of 1100)
## done 100 (out of 1100)
## done 200 (out of 1100)
## done 300 (out of 1100)
## done 400 (out of 1100)
## done 500 (out of 1100)
## done 600 (out of 1100)
## done 700 (out of 1100)
## done 800 (out of 1100)
## done 900 (out of 1100)
## done 1000 (out of 1100)
## time: 3s
## trcnt,tecnt: 1000,1000
bart_model_preds <- bart_model$yhat.test.mean
bart_mse <- mean((bart_model_preds - test_df$Sales)^2)
cat('Bart Model MSE: ', bart_mse)
## Bart Model MSE: 1.472166
bart_model$varcount.mean[order(bart_model$varcount.mean, decreasing = T )]
## Price CompPrice ShelveLoc2 Age ShelveLoc1 US2
## 26.682 19.330 17.565 16.925 16.484 16.134
## US1 Income Urban1 Advertising Urban2 Population
## 16.086 15.923 15.877 15.653 15.279 15.245
## Education ShelveLoc3
## 14.980 14.609
We can see the order/ rank of variables by their mean usage across
all trees. Price is a strong driver of
Sales.
This problem involves the OJ data set which is part of the ISLR2 package.
data(OJ)
set.seed(42)
train_index <- createDataPartition(OJ$Purchase, p = 799/nrow(OJ), list = FALSE)
train_data <- OJ[train_index, ]
test_data <- OJ[-train_index, ]
oj_tree <- tree(Purchase ~ ., data = train_data)
summary(oj_tree)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = train_data)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "ListPriceDiff" "DiscMM"
## Number of terminal nodes: 8
## Residual mean deviance: 0.7567 = 599.3 / 792
## Misclassification error rate: 0.1675 = 134 / 800
Training Error rate: 0.1675
Terminal Nodes: 8
oj_tree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1070.00 CH ( 0.61000 0.39000 )
## 2) LoyalCH < 0.5036 348 413.70 MM ( 0.28161 0.71839 )
## 4) LoyalCH < 0.0616725 63 0.00 MM ( 0.00000 1.00000 ) *
## 5) LoyalCH > 0.0616725 285 366.80 MM ( 0.34386 0.65614 )
## 10) PriceDiff < 0.065 108 100.50 MM ( 0.17593 0.82407 ) *
## 11) PriceDiff > 0.065 177 243.30 MM ( 0.44633 0.55367 )
## 22) LoyalCH < 0.279374 60 65.19 MM ( 0.23333 0.76667 ) *
## 23) LoyalCH > 0.279374 117 160.70 CH ( 0.55556 0.44444 ) *
## 3) LoyalCH > 0.5036 452 361.40 CH ( 0.86283 0.13717 )
## 6) LoyalCH < 0.764572 192 220.20 CH ( 0.73958 0.26042 )
## 12) ListPriceDiff < 0.235 80 110.70 CH ( 0.52500 0.47500 )
## 24) DiscMM < 0.1 45 55.80 CH ( 0.68889 0.31111 ) *
## 25) DiscMM > 0.1 35 43.57 MM ( 0.31429 0.68571 ) *
## 13) ListPriceDiff > 0.235 112 76.27 CH ( 0.89286 0.10714 ) *
## 7) LoyalCH > 0.764572 260 97.26 CH ( 0.95385 0.04615 ) *
Leaf Node: 3
If LoyalCH > 0.5036, the tree predicts CH for the
majority of observations (86%).
It has a deviance of 361.40. Since it’s the first split, it has a higher deviance because that first split often does not perfectly separate the classes.
plot(oj_tree)
text(oj_tree)
It appears as though most of the data is split on the
LoyalCH variable. The first 2 levels of splits are all on
that variable. It’s likely that it is the most important variable, and
then using PriceDiff and ListPriceDiff to
further split the data. LoyalCH is used to split the data 4
times out of the 7.
oj_preds <- predict(oj_tree, newdata = test_data, type = 'class')
confusionMatrix(as.factor(oj_preds), test_data$Purchase)
## Confusion Matrix and Statistics
##
## Reference
## Prediction CH MM
## CH 144 20
## MM 21 85
##
## Accuracy : 0.8481
## 95% CI : (0.7997, 0.8888)
## No Information Rate : 0.6111
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.6811
##
## Mcnemar's Test P-Value : 1
##
## Sensitivity : 0.8727
## Specificity : 0.8095
## Pos Pred Value : 0.8780
## Neg Pred Value : 0.8019
## Prevalence : 0.6111
## Detection Rate : 0.5333
## Detection Prevalence : 0.6074
## Balanced Accuracy : 0.8411
##
## 'Positive' Class : CH
##
test_error <- mean(oj_preds != test_data$Purchase)
cat('OJ model test error: ', test_error)
## OJ model test error: 0.1518519
cv_oj_tree <- cv.tree(oj_tree, FUN = prune.misclass)
par(mfrow = c(1, 2))
plot(cv_oj_tree$size, cv_oj_tree$dev, type = 'b')
plot(cv_oj_tree$k, cv_oj_tree$dev, type = 'b')
par(mfrow = c(1, 1))
Tree size of 8 corresponds to the lowest cross-validated classification error rate.
best_size <- cv_oj_tree$size[which.min(cv_oj_tree$dev)]
cat('Optimal Tree Size: ', best_size)
## Optimal Tree Size: 8
We can’t produce a pruned tree with 5 terminal nodes because our cv tree is only able to do 1, 2 or 8.
cv_oj_tree$size
## [1] 8 2 1
# pruned_tree_oj <- prune.misclass(oj_tree, best = best_size)
#pruning is the same as normal tree.
pruned_tree_oj <- prune.misclass(oj_tree, best = 2)
plot(pruned_tree_oj)
text(pruned_tree_oj)
oj_tree_train_preds <- predict(oj_tree, newdata = train_data, type = 'class')
pruned_tree_oj_train_preds <- predict(pruned_tree_oj, newdata = train_data, type = 'class')
oj_tree_error_rate <- mean(oj_tree_train_preds != train_data$Purchase)
pruned_tree_oj_error_rate <- mean(pruned_tree_oj_train_preds != train_data$Purchase)
cat('Training error rate of pruned tree: ', pruned_tree_oj_error_rate, '\n')
## Training error rate of pruned tree: 0.2
cat('Training error rate of unpruned tree: ', oj_tree_error_rate)
## Training error rate of unpruned tree: 0.1675
The pruned tree’s error rate is higher.
# Predictions
pruned_pred <- predict(pruned_tree_oj, newdata = test_data, type = "class")
pruned_test_error <- mean(pruned_pred != test_data$Purchase)
cat("Unpruned Test Error:", test_error, "\n")
## Unpruned Test Error: 0.1518519
cat("Pruned Test Error:", pruned_test_error, "\n")
## Pruned Test Error: 0.2
Again, the Pruned tree error rate is higher.