p <- seq(0, 1, 0.01)
gini = 2*p*(1-p)
classerror = 1 - pmax(p, 1-p)
crossentropy = -(p*log(p)+(1-p)*log(1-p))
plot(NA,NA,xlim=c(0,1),ylim=c(0,1),xlab='p',ylab='f')
lines(p,gini,type='l', col='#4daf4a')
lines(p,classerror,col='#377eb8')
lines(p,crossentropy,col='#e41a1c')
legend(x='top',legend=c('Gini','Classification Error','Entropy'),
col=c('#4daf4a','#377eb8','#e41a1c'),lty=1,text.width = 0.25)
In the lab, a classification tree was applied to the Carseats data set after converting Sales into a qualitative response variable. Now we will seek to predict Sales using regression trees and related approaches, treating the response as a quantitative variable.
library(ISLR2)
## Warning: package 'ISLR2' was built under R version 4.4.2
library(tidyverse)
## Warning: package 'tidyverse' was built under R version 4.4.3
## Warning: package 'ggplot2' was built under R version 4.4.3
## Warning: package 'tidyr' was built under R version 4.4.3
## Warning: package 'dplyr' was built under R version 4.4.3
## Warning: package 'lubridate' was built under R version 4.4.3
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.1 ✔ tibble 3.2.1
## ✔ lubridate 1.9.4 ✔ tidyr 1.3.1
## ✔ purrr 1.0.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(MASS) # Boston data
##
## Attaching package: 'MASS'
##
## The following object is masked from 'package:dplyr':
##
## select
##
## The following object is masked from 'package:ISLR2':
##
## Boston
library(randomForest) # random forests
## Warning: package 'randomForest' was built under R version 4.4.3
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
##
## The following object is masked from 'package:dplyr':
##
## combine
##
## The following object is masked from 'package:ggplot2':
##
## margin
library(tree) # trees
## Warning: package 'tree' was built under R version 4.4.3
library(caret)
## Warning: package 'caret' was built under R version 4.4.3
## Loading required package: lattice
##
## Attaching package: 'caret'
##
## The following object is masked from 'package:purrr':
##
## lift
library(gridExtra)
## Warning: package 'gridExtra' was built under R version 4.4.3
##
## Attaching package: 'gridExtra'
##
## The following object is masked from 'package:randomForest':
##
## combine
##
## The following object is masked from 'package:dplyr':
##
## combine
library(gam)
## Warning: package 'gam' was built under R version 4.4.3
## Loading required package: splines
## Loading required package: foreach
##
## Attaching package: 'foreach'
##
## The following objects are masked from 'package:purrr':
##
## accumulate, when
##
## Loaded gam 1.22-5
attach(Carseats)
str(Carseats)
## 'data.frame': 400 obs. of 11 variables:
## $ Sales : num 9.5 11.22 10.06 7.4 4.15 ...
## $ CompPrice : num 138 111 113 117 141 124 115 136 132 132 ...
## $ Income : num 73 48 35 100 64 113 105 81 110 113 ...
## $ Advertising: num 11 16 10 4 3 13 0 15 0 0 ...
## $ Population : num 276 260 269 466 340 501 45 425 108 131 ...
## $ Price : num 120 83 80 97 128 72 108 120 124 124 ...
## $ ShelveLoc : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
## $ Age : num 42 65 59 55 38 78 71 67 76 76 ...
## $ Education : num 17 10 12 14 13 16 15 10 10 17 ...
## $ Urban : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
## $ US : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
set.seed(2025)
train.Index <- createDataPartition(Sales, p=0.8, list = FALSE)
train <- Carseats[train.Index,]
test <- Carseats[-train.Index,]
tree.fit <- tree(Sales ~ ., data = train)
summary(tree.fit)
##
## Regression tree:
## tree(formula = Sales ~ ., data = train)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Age" "Income" "CompPrice"
## [6] "Advertising"
## Number of terminal nodes: 19
## Residual mean deviance: 2.643 = 798.1 / 302
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -4.3790 -1.1670 0.0775 0.0000 1.2050 4.3010
plot(tree.fit)
text(tree.fit, pretty = 0, cex = 0.55)
tree.pred <- predict(tree.fit, newdata = test)
(mse <- mean((test$Sales - tree.pred) ^2))
## [1] 3.841789
The most important variables in the tree are ShelveLoc and Price. The test error rate is around 3.84.
set.seed(2025)
cv_tree_model <- cv.tree(tree.fit, K = 10)
data.frame(n_leaves = cv_tree_model$size, CV_RSS = cv_tree_model$dev) %>%
mutate(min_CV_RSS = as.numeric(min(CV_RSS) == CV_RSS)) %>%
ggplot(aes(x = n_leaves, y = CV_RSS)) +
geom_line(col = "grey55") +
geom_point(size = 2, aes(col = factor(min_CV_RSS))) +
scale_x_continuous(breaks = seq(1, 17, 2)) +
scale_y_continuous(labels = scales::comma_format()) +
scale_color_manual(values = c("deepskyblue3", "green")) +
theme(legend.position = "none") +
labs(title = "Carseats Dataset - Regression Tree",
subtitle = "Selecting the complexity parameter with cross-validation",
x = "Terminal Nodes",
y = "CV RSS")
From the plot, we can see that the optimal tree is the fully grown tree without pruning, since the best number of terminal nodes is 13. We verify that below.
which.min(cv_tree_model$dev)
## [1] 6
cv_tree_model$size[6]
## [1] 13
Now we check how the MSE differs by specifying best=13.
prune.model = prune.tree(tree.fit, best = 13)
prune.pred <- predict(prune.model, test)
mean((prune.pred - test$Sales)^2)
## [1] 3.841629
There is no difference in the test MSE between unpruned and pruned trees. The fully grown tree is the optimal tree in this case.
set.seed(2025)
rf.model <- randomForest(Sales ~ ., data = train, mtry = 10, ntree = 500, importance = T)
rf.model
##
## Call:
## randomForest(formula = Sales ~ ., data = train, mtry = 10, ntree = 500, importance = T)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 10
##
## Mean of squared residuals: 2.530131
## % Var explained: 69.35
rf.pred <- predict(rf.model, test)
(rf.mse = mean((test$Sales - rf.pred)^2))
## [1] 1.644367
importance(rf.model) |>
as.data.frame() %>%
rownames_to_column("varname") %>%
arrange(desc(IncNodePurity))
## varname %IncMSE IncNodePurity
## 1 ShelveLoc 79.7926818 816.425095
## 2 Price 69.9106181 737.900601
## 3 CompPrice 28.2415076 258.895130
## 4 Age 24.0131573 255.103750
## 5 Income 13.5697669 164.714906
## 6 Advertising 17.8783436 163.631355
## 7 Population -1.8123672 84.026687
## 8 Education -0.1807842 73.647902
## 9 US 6.4733304 24.361600
## 10 Urban 0.3599708 9.510284
The test error obtained after bagging is 1.644 which is less than what we obtained from the trees. The most important variables are ShelveLoc, Price, and CompPrice.
test_MSE <- c()
i <- 1
for (Mtry in 1:10) {
set.seed(2025)
rf_temp <- randomForest(Sales ~ ., data = train, mtry = Mtry, importance = T)
test_pred <- predict(rf_temp, test)
test_MSE[i] <- mean((test_pred - test$Sales)^2)
i <- i + 1
}
data.frame(mtry = 1:10, test_MSE = test_MSE) %>%
mutate(min_test_MSE = as.numeric(min(test_MSE) == test_MSE)) %>%
ggplot(aes(x = mtry, y = test_MSE)) +
geom_line(col = "grey55") +
geom_point(size = 2, aes(col = factor(min_test_MSE))) +
scale_x_continuous(breaks = seq(1, 10), minor_breaks = NULL) +
scale_color_manual(values = c("deepskyblue3", "green")) +
theme(legend.position = "none") +
labs(title = "Carseats Dataset - Random Forests",
subtitle = "Selecting 'mtry' using the test MSE",
x = "mtry",
y = "Test MSE")
By tuning the mtry parameters obtained the best random forest model with an mtry of 9
test_MSE[9]
## [1] 1.592166
and the test_MSE is 1.592
rf.model.1 <- randomForest(Sales ~ ., data = train, mtry = 9, importance = T)
importance(rf.model.1) |>
as.data.frame() %>%
rownames_to_column("varname") %>%
arrange(desc(IncNodePurity))
## varname %IncMSE IncNodePurity
## 1 ShelveLoc 81.6223752 819.459494
## 2 Price 66.1708476 738.717960
## 3 CompPrice 29.7620325 257.958231
## 4 Age 22.7082569 254.026790
## 5 Advertising 19.3180729 171.403524
## 6 Income 12.8022062 167.781751
## 7 Population 1.1874768 92.229784
## 8 Education -1.0015192 71.743042
## 9 US 4.9008990 24.928679
## 10 Urban -0.1782073 9.674885
The test error decreased after tuning the mtry parameter of the random forest model, with the best performance achieved at mtry = 9. The most important predictors were ShelveLoc and Price.
detach(Carseats)
attach(OJ)
str(OJ)
## 'data.frame': 1070 obs. of 18 variables:
## $ Purchase : Factor w/ 2 levels "CH","MM": 1 1 1 2 1 1 1 1 1 1 ...
## $ WeekofPurchase: num 237 239 245 227 228 230 232 234 235 238 ...
## $ StoreID : num 1 1 1 1 7 7 7 7 7 7 ...
## $ PriceCH : num 1.75 1.75 1.86 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceMM : num 1.99 1.99 2.09 1.69 1.69 1.99 1.99 1.99 1.99 1.99 ...
## $ DiscCH : num 0 0 0.17 0 0 0 0 0 0 0 ...
## $ DiscMM : num 0 0.3 0 0 0 0 0.4 0.4 0.4 0.4 ...
## $ SpecialCH : num 0 0 0 0 0 0 1 1 0 0 ...
## $ SpecialMM : num 0 1 0 0 0 1 1 0 0 0 ...
## $ LoyalCH : num 0.5 0.6 0.68 0.4 0.957 ...
## $ SalePriceMM : num 1.99 1.69 2.09 1.69 1.69 1.99 1.59 1.59 1.59 1.59 ...
## $ SalePriceCH : num 1.75 1.75 1.69 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceDiff : num 0.24 -0.06 0.4 0 0 0.3 -0.1 -0.16 -0.16 -0.16 ...
## $ Store7 : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 2 2 2 2 2 ...
## $ PctDiscMM : num 0 0.151 0 0 0 ...
## $ PctDiscCH : num 0 0 0.0914 0 0 ...
## $ ListPriceDiff : num 0.24 0.24 0.23 0 0 0.3 0.3 0.24 0.24 0.24 ...
## $ STORE : num 1 1 1 1 0 0 0 0 0 0 ...
set.seed(2025)
train.Index <- sample(nrow(OJ), 800)
train.OJ <- OJ[train.Index,]
test.OJ <- OJ[-train.Index,]
tree.fit.OJ <- tree(Purchase ~ ., data = train.OJ)
summary(tree.fit.OJ)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = train.OJ)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "PriceMM" "ListPriceDiff"
## [5] "PctDiscMM"
## Number of terminal nodes: 8
## Residual mean deviance: 0.7291 = 577.4 / 792
## Misclassification error rate: 0.1688 = 135 / 800
The tree uses only five variables LoyalCH, PriceDiff, PriceMM, ListPriceDiff and PctDiscMM. The training error rate is 135 / 800 = 0.1688 and the tree has 8 terminal nodes.
tree.fit.OJ
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1062.000 CH ( 0.62125 0.37875 )
## 2) LoyalCH < 0.48285 296 328.400 MM ( 0.24324 0.75676 )
## 4) LoyalCH < 0.136698 94 33.080 MM ( 0.04255 0.95745 ) *
## 5) LoyalCH > 0.136698 202 258.100 MM ( 0.33663 0.66337 )
## 10) PriceDiff < 0.05 88 80.360 MM ( 0.17045 0.82955 )
## 20) PriceMM < 2.11 63 69.160 MM ( 0.23810 0.76190 ) *
## 21) PriceMM > 2.11 25 0.000 MM ( 0.00000 1.00000 ) *
## 11) PriceDiff > 0.05 114 157.500 MM ( 0.46491 0.53509 ) *
## 3) LoyalCH > 0.48285 504 437.700 CH ( 0.84325 0.15675 )
## 6) LoyalCH < 0.764572 243 288.100 CH ( 0.72016 0.27984 )
## 12) ListPriceDiff < 0.235 96 133.000 MM ( 0.48958 0.51042 )
## 24) PctDiscMM < 0.196196 78 105.600 CH ( 0.58974 0.41026 ) *
## 25) PctDiscMM > 0.196196 18 7.724 MM ( 0.05556 0.94444 ) *
## 13) ListPriceDiff > 0.235 147 113.200 CH ( 0.87075 0.12925 ) *
## 7) LoyalCH > 0.764572 261 91.200 CH ( 0.95785 0.04215 ) *
We will interpret the terminal node at 11) PriceDiff > 0.05 114 157.500 MM ( 0.46491 0.53509 )*
The splitting variable is PriceDiff and the splitting value is 0.05. There are 114 points in the subtree below this node. The deviance for all points contained in the region below this node is 157.5 . The * in this line denotes that this is infact a terminal node. The prediction at this node is Sales = CH. About 46.4% of points in this node have MM as value of Sales. The remaining 53.5% has CH has value of Sales.
plot(tree.fit.OJ)
text(tree.fit.OJ, pretty = 0, cex = 0.55)
LoyalCH is the most important variable. The top 3 nodes contain LoyalCH. If LoyalCH < 0.1367, the tree predicts MM. If LoyalCH > 0.76, the tree predicts CH. For intermediate values of LoyalCH, the decision also depends on the value of the four additional variables.
pred.oj <- predict(tree.fit.OJ, test.OJ, type = 'class')
conf_mat = confusionMatrix(test.OJ$Purchase, pred.oj)
test_error_rate <- 1 - conf_mat$overall['Accuracy']
print(test_error_rate)
## Accuracy
## 0.1851852
The test error rate is 0.185.
cv_oj = cv.tree(tree.fit.OJ, FUN = prune.tree)
plot(cv_oj$size, cv_oj$dev, type = "b", xlab = "Tree Size", ylab = "Deviance")
which.min(cv_oj$dev)
## [1] 2
cv_oj$size[2]
## [1] 7
The tree size of 7 gives the lowest cross-validation error.
prune_oj = prune.tree(tree.fit.OJ, best = 7)
summary(prune_oj)
##
## Classification tree:
## snip.tree(tree = tree.fit.OJ, nodes = 10L)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "ListPriceDiff" "PctDiscMM"
## Number of terminal nodes: 7
## Residual mean deviance: 0.7423 = 588.6 / 793
## Misclassification error rate: 0.1688 = 135 / 800
There is no difference in the training error rates between the pruned and unpruned trees.
unpruned_pred = predict(tree.fit.OJ, test.OJ, type = "class")
unpruned_error = sum(test.OJ$Purchase != unpruned_pred)
unpruned_error/length(unpruned_pred)
## [1] 0.1851852
pruned_pred = predict(prune_oj, test.OJ, type = "class")
pruned_error = sum(test.OJ$Purchase != pruned_pred)
pruned_error/length(pruned_pred)
## [1] 0.1851852
There are no changes in the error rates since we used the same 9 nodes in the pruned tree as well.