p=seq(0,1,0.0001)
#Gini
G=2*p*(1-p)
#Classification Error
E=1-pmax(p,1-p)
#Entropy
D=-(p*log(p) + (1-p)*log(1-p))
plot(p,D, col="red",ylab="")
lines(p,E,col='green')
lines(p,G,col='blue')
legend(0.3,0.15,c("Entropy", "Missclassification","Gini"),lty=c(1,1,1),lwd=c(2.5,2.5,2.5),col=c('red','green','blue'))
# Question 8: In the lab, a classification tree was applied to the
Carseats data set after converting Sales into a qualitative response
variable. Now we will seek to predict Sales using regression trees and
related approaches, treating the response as a quantitative
variable.
set.seed(123)
library(ISLR2)
## Warning: package 'ISLR2' was built under R version 4.3.3
library(tree)
## Warning: package 'tree' was built under R version 4.3.3
library(randomForest)
## Warning: package 'randomForest' was built under R version 4.3.3
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
library(BART)
## Warning: package 'BART' was built under R version 4.3.3
## Loading required package: nlme
## Loading required package: survival
## Warning: package 'survival' was built under R version 4.3.3
library(dbarts)
## Warning: package 'dbarts' was built under R version 4.3.3
data(Carseats)
str(Carseats)
## 'data.frame': 400 obs. of 11 variables:
## $ Sales : num 9.5 11.22 10.06 7.4 4.15 ...
## $ CompPrice : num 138 111 113 117 141 124 115 136 132 132 ...
## $ Income : num 73 48 35 100 64 113 105 81 110 113 ...
## $ Advertising: num 11 16 10 4 3 13 0 15 0 0 ...
## $ Population : num 276 260 269 466 340 501 45 425 108 131 ...
## $ Price : num 120 83 80 97 128 72 108 120 124 124 ...
## $ ShelveLoc : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
## $ Age : num 42 65 59 55 38 78 71 67 76 76 ...
## $ Education : num 17 10 12 14 13 16 15 10 10 17 ...
## $ Urban : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
## $ US : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
train_index <- sample(1:nrow(Carseats), size = 0.7 * nrow(Carseats))
train_data <- Carseats[train_index, ]
test_data <- Carseats[-train_index, ]
tree_carseats <- tree(Sales ~ ., data = train_data)
# View summary
summary(tree_carseats)
##
## Regression tree:
## tree(formula = Sales ~ ., data = train_data)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "CompPrice" "Age" "Advertising"
## Number of terminal nodes: 19
## Residual mean deviance: 2.373 = 619.2 / 261
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -4.1570 -1.0160 0.1123 0.0000 0.8903 4.0310
plot(tree_carseats)
text(tree_carseats, pretty = 0,cex = 0.7)
# Predict on test set
yhat <- predict(tree_carseats, newdata = test_data)
# Calculate test MSE
mse_tree <- mean((yhat - test_data$Sales)^2)
cat("Test MSE (Regression Tree):", mse_tree, "\n")
## Test MSE (Regression Tree): 3.602818
The regression tree produced showed that the top 5 predictors of Sales are ShelveLoc, Price, CompPrice, Age, and Advertising. The test MSE for this model is approximately 3.60 which is the estimate error on the new data.
cv_tree_carseats <- cv.tree(tree_carseats)
# Plot cross-validation results
plot(cv_tree_carseats$size, cv_tree_carseats$dev, type = "b",
xlab = "Tree Size (Number of Terminal Nodes)",
ylab = "Deviance (CV Error)")
best.size <- cv_tree_carseats$size[which.min(cv_tree_carseats$dev)]
cat("Optimal tree size:", best.size, "\n")
## Optimal tree size: 19
pruned_tree <- prune.tree(tree_carseats, best = best.size)
# Plot pruned tree
plot(pruned_tree)
text(pruned_tree, pretty = 0, cex = 0.7)
# Predict on test set using pruned tree
yhat.pruned <- predict(pruned_tree, newdata = test_data)
# Calculate test MSE
mse_pruned <- mean((yhat.pruned - test_data$Sales)^2)
cat("Test MSE (Pruned Tree):", mse_pruned, "\n")
## Test MSE (Pruned Tree): 3.602818
cat("Test MSE (Original Tree):", mse_tree, "\n")
## Test MSE (Original Tree): 3.602818
Pruning did not improve the MSE and actually increased the MSE. This indicates the original tree was optimal size.
bag_carseats <- randomForest(Sales ~ ., data = train_data, mtry = ncol(train_data) - 1, importance = TRUE)
bag_pred <- predict(bag_carseats, newdata = test_data)
mse_bag <- mean((bag_pred - test_data$Sales)^2)
cat("Test MSE (Bagging):", mse_bag, "\n")
## Test MSE (Bagging): 2.278865
importance(bag_carseats)
## %IncMSE IncNodePurity
## CompPrice 35.7999670 262.298264
## Income 7.8716191 119.388933
## Advertising 20.8395402 152.464416
## Population 2.1760765 66.209944
## Price 63.0665019 621.449901
## ShelveLoc 72.4630954 685.134197
## Age 19.1347198 177.449419
## Education 4.3626732 66.736646
## Urban 0.3123883 9.295726
## US 1.3827536 7.787507
The test MSE obtained was 2.26 The important features were ShelveLoc,Price, CompPrice, Advertising, and Age.
p <- ncol(train_data) - 1 # subtract 1 for response variable
rf.carseats <- randomForest(Sales ~ ., data = train_data, mtry = floor(sqrt(p)), importance = TRUE)
rf.pred <- predict(rf.carseats, newdata = test_data)
mse_rf <- mean((rf.pred - test_data$Sales)^2)
cat("Test MSE (Random Forest):", mse_rf, "\n")
## Test MSE (Random Forest): 2.704399
importance(rf.carseats)
## %IncMSE IncNodePurity
## CompPrice 16.7196294 230.02523
## Income 6.2625691 168.90108
## Advertising 16.2090797 183.37645
## Population -1.3086954 129.80261
## Price 42.9458404 488.50809
## ShelveLoc 46.8917226 513.90664
## Age 17.9735745 257.21081
## Education 1.2049797 103.30595
## Urban 0.8486326 19.34474
## US 4.3444583 25.30593
mtry_vals <- 1:p
test_mse_vals <- numeric(length(mtry_vals))
for (i in seq_along(mtry_vals)) {
rf_temp <- randomForest(Sales ~ ., data = train_data, mtry = mtry_vals[i])
pred_temp <- predict(rf_temp, newdata = test_data)
test_mse_vals[i] <- mean((pred_temp - test_data$Sales)^2)
}
# Plot mtry vs test MSE
plot(mtry_vals, test_mse_vals, type = "b",
xlab = "mtry (Number of Variables Tried at Each Split)",
ylab = "Test MSE",
main = "Effect of mtry on Test Error")
The Test MSE is 2.70 for the random forest model. The important
variables are ShelveLoc,Price,Advertising,CompPrice, and Age. The effect
of mtry on Test error is that as the number of variables tried
increased, the test MSE went down. The error rate is minimized around m
=5 which confirms a balance between bias and variance in this model.
# Response variable (vector)
y_train <- train_data$Sales
y_test <- test_data$Sales
# Predictor matrix (remove Sales)
x_train <- data.matrix(train_data[, setdiff(names(train_data), "Sales")])
x_test <- data.matrix(test_data[, setdiff(names(test_data), "Sales")])
bart.model <- bart(x.train = x_train, y.train = y_train,
x.test = x_test, verbose = FALSE)
bart.pred <- bart.model$yhat.test.mean
# Compute Test MSE
mse_bart <- mean((y_test - bart.pred)^2)
cat("Test MSE (BART):", mse_bart, "\n")
## Test MSE (BART): 1.627498
The test MSE for the BART model is the lowest with 1.61 making this the best model out of the approaches done.
library(rpart)
## Warning: package 'rpart' was built under R version 4.3.3
library(rpart.plot)
## Warning: package 'rpart.plot' was built under R version 4.3.3
set.seed(1)
data(OJ)
n <- nrow(OJ)
# Indices for training set (70%)
train_indices <- sample(1:n, size = floor(0.7 * n))
# Create training and test sets
train_data <- OJ[train_indices, ]
test_data <- OJ[-train_indices, ]
tree_model <- rpart(Purchase ~ ., data = train_data, method = "class")
# Predict on training set
train_pred <- predict(tree_model, train_data, type = "class")
# Confusion matrix and training error rate
train_cm <- table(train_data$Purchase, train_pred)
train_error <- 1 - sum(diag(train_cm)) / sum(train_cm)
train_cm
## train_pred
## CH MM
## CH 419 41
## MM 64 225
cat("Training error rate:", round(train_error, 4), "\n")
## Training error rate: 0.1402
Training error rate is 0.1402 or 14.02%
rpart.plot(tree_model, extra = 104)
cat("Number of terminal nodes:", sum(tree_model$frame$var == "<leaf>"), "\n")
## Number of terminal nodes: 9
The classification tree first splits at LoyalCH which indicates this is the most important predictor. subsequent splits occur in other predictors such as PriceDiff,Store ID, WeekofPurchase,SalesPriceMM, and SpecialCH. There are 9 terminal nodes.
summary(tree_model)
## Call:
## rpart(formula = Purchase ~ ., data = train_data, method = "class")
## n= 749
##
## CP nsplit rel error xerror xstd
## 1 0.49134948 0 1.0000000 1.0000000 0.04609874
## 2 0.02076125 1 0.5086505 0.5259516 0.03808642
## 3 0.01903114 3 0.4671280 0.5467128 0.03863523
## 4 0.01730104 6 0.3944637 0.4982699 0.03731816
## 5 0.01384083 7 0.3771626 0.4775087 0.03671313
## 6 0.01000000 8 0.3633218 0.4740484 0.03660979
##
## Variable importance
## LoyalCH PriceDiff SalePriceMM PriceMM StoreID
## 43 9 8 6 6
## DiscMM PctDiscMM WeekofPurchase ListPriceDiff PriceCH
## 5 5 5 3 3
## SpecialCH STORE Store7 SalePriceCH SpecialMM
## 2 2 1 1 1
##
## Node number 1: 749 observations, complexity param=0.4913495
## predicted class=CH expected loss=0.3858478 P(node) =1
## class counts: 460 289
## probabilities: 0.614 0.386
## left son=2 (461 obs) right son=3 (288 obs)
## Primary splits:
## LoyalCH < 0.48285 to the right, improve=121.74400, (0 missing)
## StoreID < 3.5 to the right, improve= 44.51769, (0 missing)
## Store7 splits as RL, improve= 27.82627, (0 missing)
## STORE < 0.5 to the left, improve= 27.82627, (0 missing)
## PriceDiff < 0.015 to the right, improve= 23.14246, (0 missing)
## Surrogate splits:
## StoreID < 3.5 to the right, agree=0.656, adj=0.104, (0 split)
## PriceMM < 1.89 to the right, agree=0.632, adj=0.042, (0 split)
## WeekofPurchase < 232.5 to the right, agree=0.625, adj=0.024, (0 split)
## PriceCH < 1.72 to the right, agree=0.625, adj=0.024, (0 split)
## ListPriceDiff < 0.035 to the right, agree=0.625, adj=0.024, (0 split)
##
## Node number 2: 461 observations, complexity param=0.01903114
## predicted class=CH expected loss=0.1605206 P(node) =0.6154873
## class counts: 387 74
## probabilities: 0.839 0.161
## left son=4 (245 obs) right son=5 (216 obs)
## Primary splits:
## LoyalCH < 0.7645725 to the right, improve=14.985200, (0 missing)
## PriceDiff < 0.015 to the right, improve=13.895340, (0 missing)
## SalePriceMM < 1.84 to the right, improve=10.603240, (0 missing)
## ListPriceDiff < 0.255 to the right, improve= 8.359879, (0 missing)
## DiscMM < 0.15 to the left, improve= 6.437776, (0 missing)
## Surrogate splits:
## WeekofPurchase < 257.5 to the right, agree=0.605, adj=0.157, (0 split)
## PriceMM < 2.04 to the right, agree=0.597, adj=0.139, (0 split)
## SalePriceMM < 1.84 to the right, agree=0.594, adj=0.134, (0 split)
## PriceCH < 1.825 to the right, agree=0.590, adj=0.125, (0 split)
## PriceDiff < 0.015 to the right, agree=0.588, adj=0.120, (0 split)
##
## Node number 3: 288 observations, complexity param=0.02076125
## predicted class=MM expected loss=0.2534722 P(node) =0.3845127
## class counts: 73 215
## probabilities: 0.253 0.747
## left son=6 (128 obs) right son=7 (160 obs)
## Primary splits:
## LoyalCH < 0.280875 to the right, improve=9.683681, (0 missing)
## StoreID < 3.5 to the right, improve=7.378759, (0 missing)
## Store7 splits as RL, improve=6.889547, (0 missing)
## STORE < 0.5 to the left, improve=6.889547, (0 missing)
## PriceDiff < 0.49 to the right, improve=6.357341, (0 missing)
## Surrogate splits:
## STORE < 1.5 to the left, agree=0.622, adj=0.148, (0 split)
## StoreID < 3.5 to the right, agree=0.611, adj=0.125, (0 split)
## SalePriceCH < 1.775 to the left, agree=0.590, adj=0.078, (0 split)
## PriceDiff < 0.325 to the right, agree=0.587, adj=0.070, (0 split)
## WeekofPurchase < 275.5 to the right, agree=0.580, adj=0.055, (0 split)
##
## Node number 4: 245 observations
## predicted class=CH expected loss=0.04081633 P(node) =0.3271028
## class counts: 235 10
## probabilities: 0.959 0.041
##
## Node number 5: 216 observations, complexity param=0.01903114
## predicted class=CH expected loss=0.2962963 P(node) =0.2883845
## class counts: 152 64
## probabilities: 0.704 0.296
## left son=10 (141 obs) right son=11 (75 obs)
## Primary splits:
## PriceDiff < 0.015 to the right, improve=17.636060, (0 missing)
## ListPriceDiff < 0.235 to the right, improve=16.794240, (0 missing)
## SalePriceMM < 1.84 to the right, improve=12.779700, (0 missing)
## DiscMM < 0.15 to the left, improve= 8.545958, (0 missing)
## PctDiscMM < 0.0729725 to the left, improve= 8.545958, (0 missing)
## Surrogate splits:
## SalePriceMM < 1.84 to the right, agree=0.958, adj=0.880, (0 split)
## PctDiscMM < 0.1155095 to the left, agree=0.884, adj=0.667, (0 split)
## DiscMM < 0.15 to the left, agree=0.870, adj=0.627, (0 split)
## PriceMM < 2.04 to the right, agree=0.801, adj=0.427, (0 split)
## ListPriceDiff < 0.18 to the right, agree=0.792, adj=0.400, (0 split)
##
## Node number 6: 128 observations, complexity param=0.02076125
## predicted class=MM expected loss=0.3984375 P(node) =0.1708945
## class counts: 51 77
## probabilities: 0.398 0.602
## left son=12 (56 obs) right son=13 (72 obs)
## Primary splits:
## SalePriceMM < 2.04 to the right, improve=8.672867, (0 missing)
## PriceDiff < 0.05 to the right, improve=5.506200, (0 missing)
## SpecialCH < 0.5 to the right, improve=4.715265, (0 missing)
## STORE < 0.5 to the left, improve=4.380208, (0 missing)
## StoreID < 5.5 to the right, improve=4.380208, (0 missing)
## Surrogate splits:
## PriceDiff < 0.135 to the right, agree=0.898, adj=0.768, (0 split)
## PriceMM < 2.04 to the right, agree=0.805, adj=0.554, (0 split)
## DiscMM < 0.08 to the left, agree=0.781, adj=0.500, (0 split)
## PctDiscMM < 0.038887 to the left, agree=0.781, adj=0.500, (0 split)
## WeekofPurchase < 244 to the right, agree=0.742, adj=0.411, (0 split)
##
## Node number 7: 160 observations
## predicted class=MM expected loss=0.1375 P(node) =0.2136182
## class counts: 22 138
## probabilities: 0.138 0.862
##
## Node number 10: 141 observations
## predicted class=CH expected loss=0.1489362 P(node) =0.188251
## class counts: 120 21
## probabilities: 0.851 0.149
##
## Node number 11: 75 observations, complexity param=0.01903114
## predicted class=MM expected loss=0.4266667 P(node) =0.1001335
## class counts: 32 43
## probabilities: 0.427 0.573
## left son=22 (38 obs) right son=23 (37 obs)
## Primary splits:
## StoreID < 3.5 to the right, improve=6.468582, (0 missing)
## ListPriceDiff < 0.235 to the right, improve=4.800357, (0 missing)
## WeekofPurchase < 240.5 to the left, improve=4.321538, (0 missing)
## DiscMM < 0.47 to the left, improve=3.226667, (0 missing)
## PctDiscMM < 0.227263 to the left, improve=3.226667, (0 missing)
## Surrogate splits:
## Store7 splits as RL, agree=0.840, adj=0.676, (0 split)
## STORE < 0.5 to the left, agree=0.840, adj=0.676, (0 split)
## WeekofPurchase < 238 to the left, agree=0.760, adj=0.514, (0 split)
## SpecialMM < 0.5 to the left, agree=0.733, adj=0.459, (0 split)
## PriceCH < 1.825 to the left, agree=0.680, adj=0.351, (0 split)
##
## Node number 12: 56 observations
## predicted class=CH expected loss=0.3928571 P(node) =0.07476636
## class counts: 34 22
## probabilities: 0.607 0.393
##
## Node number 13: 72 observations, complexity param=0.01730104
## predicted class=MM expected loss=0.2361111 P(node) =0.09612817
## class counts: 17 55
## probabilities: 0.236 0.764
## left son=26 (11 obs) right son=27 (61 obs)
## Primary splits:
## SpecialCH < 0.5 to the right, improve=6.264324, (0 missing)
## StoreID < 3.5 to the right, improve=1.895859, (0 missing)
## PriceDiff < -0.24 to the right, improve=1.768832, (0 missing)
## Store7 splits as RL, improve=1.612379, (0 missing)
## STORE < 0.5 to the left, improve=1.612379, (0 missing)
## Surrogate splits:
## DiscCH < 0.25 to the right, agree=0.875, adj=0.182, (0 split)
## SalePriceCH < 1.49 to the left, agree=0.875, adj=0.182, (0 split)
## PctDiscCH < 0.1366045 to the right, agree=0.875, adj=0.182, (0 split)
##
## Node number 22: 38 observations, complexity param=0.01384083
## predicted class=CH expected loss=0.3684211 P(node) =0.05073431
## class counts: 24 14
## probabilities: 0.632 0.368
## left son=44 (30 obs) right son=45 (8 obs)
## Primary splits:
## WeekofPurchase < 272.5 to the left, improve=2.950877, (0 missing)
## PriceCH < 1.89 to the left, improve=1.455639, (0 missing)
## PriceMM < 2.04 to the left, improve=1.455639, (0 missing)
## LoyalCH < 0.5039495 to the right, improve=1.455639, (0 missing)
## SalePriceCH < 1.89 to the left, improve=1.455639, (0 missing)
## Surrogate splits:
## PriceCH < 1.89 to the left, agree=0.947, adj=0.75, (0 split)
## PriceMM < 2.04 to the left, agree=0.947, adj=0.75, (0 split)
## SalePriceCH < 1.89 to the left, agree=0.947, adj=0.75, (0 split)
## PriceDiff < -0.25 to the right, agree=0.947, adj=0.75, (0 split)
## DiscMM < 0.47 to the left, agree=0.895, adj=0.50, (0 split)
##
## Node number 23: 37 observations
## predicted class=MM expected loss=0.2162162 P(node) =0.0493992
## class counts: 8 29
## probabilities: 0.216 0.784
##
## Node number 26: 11 observations
## predicted class=CH expected loss=0.2727273 P(node) =0.01468625
## class counts: 8 3
## probabilities: 0.727 0.273
##
## Node number 27: 61 observations
## predicted class=MM expected loss=0.147541 P(node) =0.08144192
## class counts: 9 52
## probabilities: 0.148 0.852
##
## Node number 44: 30 observations
## predicted class=CH expected loss=0.2666667 P(node) =0.0400534
## class counts: 22 8
## probabilities: 0.733 0.267
##
## Node number 45: 8 observations
## predicted class=MM expected loss=0.25 P(node) =0.01068091
## class counts: 2 6
## probabilities: 0.250 0.750
print(tree_model)
## n= 749
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 749 289 CH (0.61415220 0.38584780)
## 2) LoyalCH>=0.48285 461 74 CH (0.83947939 0.16052061)
## 4) LoyalCH>=0.7645725 245 10 CH (0.95918367 0.04081633) *
## 5) LoyalCH< 0.7645725 216 64 CH (0.70370370 0.29629630)
## 10) PriceDiff>=0.015 141 21 CH (0.85106383 0.14893617) *
## 11) PriceDiff< 0.015 75 32 MM (0.42666667 0.57333333)
## 22) StoreID>=3.5 38 14 CH (0.63157895 0.36842105)
## 44) WeekofPurchase< 272.5 30 8 CH (0.73333333 0.26666667) *
## 45) WeekofPurchase>=272.5 8 2 MM (0.25000000 0.75000000) *
## 23) StoreID< 3.5 37 8 MM (0.21621622 0.78378378) *
## 3) LoyalCH< 0.48285 288 73 MM (0.25347222 0.74652778)
## 6) LoyalCH>=0.280875 128 51 MM (0.39843750 0.60156250)
## 12) SalePriceMM>=2.04 56 22 CH (0.60714286 0.39285714) *
## 13) SalePriceMM< 2.04 72 17 MM (0.23611111 0.76388889)
## 26) SpecialCH>=0.5 11 3 CH (0.72727273 0.27272727) *
## 27) SpecialCH< 0.5 61 9 MM (0.14754098 0.85245902) *
## 7) LoyalCH< 0.280875 160 22 MM (0.13750000 0.86250000) *
Node 45 represents a small segment of 8 customers who are moderately loyal to CH, with small price differences, shopped at store ID 4 or higher, and made a purchase late in the year.
test_pred <- predict(tree_model, test_data, type = "class")
test_cm <- table(test_data$Purchase, test_pred)
test_error <- 1 - sum(diag(test_cm)) / sum(test_cm)
cat("Test error rate:", round(test_error, 4), "\n")
## Test error rate: 0.1869
The test error rate is 0.1869 or 18.69%.
library(tree)
oj_tree <- tree(Purchase ~ ., data = train_data)
# Cross-validation
cv_oj <- cv.tree(oj_tree, FUN = prune.misclass)
print(cv_oj)
## $size
## [1] 8 5 2 1
##
## $dev
## [1] 142 140 143 289
##
## $k
## [1] -Inf 4.000000 5.666667 142.000000
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
# optimal size
optimal_size <- cv_oj$size[which.min(cv_oj$dev)]
cat("Optimal tree size:", optimal_size, "\n")
## Optimal tree size: 5
Optimal tree size is 5.
# Plot tree size vs. cross-validated error rate
plot(cv_oj$size, cv_oj$dev, type = "b",
xlab = "Tree Size (Number of Terminal Nodes)",
ylab = "CV Classification Error",
main = "Cross-Validated Error vs. Tree Size")
optimal_index <- which.min(cv_oj$dev)
optimal_size <- cv_oj$size[optimal_index]
lowest_error <- cv_oj$dev[optimal_index]
cat("Optimal tree size:", optimal_size, "\n")
## Optimal tree size: 5
cat("Lowest cross-validated classification error rate:", round(lowest_error, 4), "\n")
## Lowest cross-validated classification error rate: 140
Tree size 5 corresponds to the lowest CV classification error rate.
# Prune to optimal size
if (optimal_size == length(unique(oj_tree$frame$var))) {
pruned_tree <- prune.misclass(oj_tree, best = 5)
cat("Cross-validation did not suggest pruning. Pruned to 5 nodes instead.\n")
} else {
pruned_tree <- prune.misclass(oj_tree, best = optimal_size)
cat("Pruned tree to optimal size:", optimal_size, "\n")
}
## Cross-validation did not suggest pruning. Pruned to 5 nodes instead.
# Plot the pruned tree
plot(pruned_tree)
text(pruned_tree, pretty = 0)
# Predictions on training data (unpruned)
train_pred_unpruned <- predict(oj_tree, train_data, type = "class")
train_cm_unpruned <- table(train_data$Purchase, train_pred_unpruned)
train_error_unpruned <- 1 - sum(diag(train_cm_unpruned)) / sum(train_cm_unpruned)
# Predictions on training data (pruned)
train_pred_pruned <- predict(pruned_tree, train_data, type = "class")
train_cm_pruned <- table(train_data$Purchase, train_pred_pruned)
train_error_pruned <- 1 - sum(diag(train_cm_pruned)) / sum(train_cm_pruned)
cat("Training error rate (Unpruned tree):", round(train_error_unpruned, 4), "\n")
## Training error rate (Unpruned tree): 0.1575
cat("Training error rate (Pruned tree):", round(train_error_pruned, 4), "\n")
## Training error rate (Pruned tree): 0.1736
The pruned tree is higher with results of 0.1736 vs the unpruned tree of 0.1575
# Predictions on test data (unpruned)
test_pred_unpruned <- predict(oj_tree, test_data, type = "class")
test_cm_unpruned <- table(test_data$Purchase, test_pred_unpruned)
test_error_unpruned <- 1 - sum(diag(test_cm_unpruned)) / sum(test_cm_unpruned)
# Predictions on test data (pruned)
test_pred_pruned <- predict(pruned_tree, test_data, type = "class")
test_cm_pruned <- table(test_data$Purchase, test_pred_pruned)
test_error_pruned <- 1 - sum(diag(test_cm_pruned)) / sum(test_cm_pruned)
cat("Test error rate (Unpruned tree):", round(test_error_unpruned, 4), "\n")
## Test error rate (Unpruned tree): 0.19
cat("Test error rate (Pruned tree):", round(test_error_pruned, 4), "\n")
## Test error rate (Pruned tree): 0.1776
The test error rate of the unpruned tree is higher with results of 0.19 vs the pruned test error rate of 0.1776