Question 3

Consider the Gini index, classification error, and cross-entropy in a simple classification setting with two classes. Create a single plot that displays each of these quantities as a function of pm1. The x- axis should display pm1, ranging from 0 to 1, and the y-axis should display the value of the Gini index, classification error, and entropy. Hint: In a setting with two classes, Pm1 = 1- pm2. You could make this plot by hand, but it will be much easier to make in R.

p=seq(0,1,0.0001)
#Gini
G=2*p*(1-p)
#Classification Error
E=1-pmax(p,1-p)
#Entropy
D=-(p*log(p) + (1-p)*log(1-p))

plot(p,D, col="red",ylab="")
lines(p,E,col='green')
lines(p,G,col='blue')
legend(0.3,0.15,c("Entropy", "Missclassification","Gini"),lty=c(1,1,1),lwd=c(2.5,2.5,2.5),col=c('red','green','blue'))

##Question 8

In the lab, a classification tree was applied to the Carseats data set after converting Sales into a qualitative response variable. Now we will seek to predict Sales using regression trees and related approaches, treating the response as a quantitative variable.

library(ISLR)
## Warning: package 'ISLR' was built under R version 4.4.3
library(rpart)
library(caret)
## Loading required package: ggplot2
## Warning: package 'ggplot2' was built under R version 4.4.3
## Loading required package: lattice
library(randomForest)
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
library(BART)
## Warning: package 'BART' was built under R version 4.4.3
## Loading required package: nlme
## Loading required package: survival
## 
## Attaching package: 'survival'
## The following object is masked from 'package:caret':
## 
##     cluster
#Load the the data 
attach(Carseats)
#Make a copy of the data
df=Carseats
sum(is.na(df))
## [1] 0
str(df)
## 'data.frame':    400 obs. of  11 variables:
##  $ Sales      : num  9.5 11.22 10.06 7.4 4.15 ...
##  $ CompPrice  : num  138 111 113 117 141 124 115 136 132 132 ...
##  $ Income     : num  73 48 35 100 64 113 105 81 110 113 ...
##  $ Advertising: num  11 16 10 4 3 13 0 15 0 0 ...
##  $ Population : num  276 260 269 466 340 501 45 425 108 131 ...
##  $ Price      : num  120 83 80 97 128 72 108 120 124 124 ...
##  $ ShelveLoc  : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
##  $ Age        : num  42 65 59 55 38 78 71 67 76 76 ...
##  $ Education  : num  17 10 12 14 13 16 15 10 10 17 ...
##  $ Urban      : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
##  $ US         : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...

(a) Split the data set into a training set and a test set.

set.seed(123) # for reproducibility
trainIndex <- createDataPartition(df$Sales, p = 0.7, list = FALSE)
trainingData <- df[trainIndex, ]
testingData <- df[-trainIndex, ]

cat("Training data dimensions:", dim(trainingData), "\n")
## Training data dimensions: 281 11

(b) Fit a regression tree to the training set. Plot the tree, and interpret the results. What test MSE do you obtain?

##fit the tree on the training data
tree.carseats=rpart(Sales~., data=trainingData, method="anova", control=rpart.control(minsplit=15, cp=0.01))

summary(tree.carseats)
## Call:
## rpart(formula = Sales ~ ., data = trainingData, method = "anova", 
##     control = rpart.control(minsplit = 15, cp = 0.01))
##   n= 281 
## 
##            CP nsplit rel error    xerror       xstd
## 1  0.26851707      0 1.0000000 1.0111972 0.08289838
## 2  0.09667155      1 0.7314829 0.7470703 0.05902195
## 3  0.04877915      2 0.6348114 0.6995455 0.05510559
## 4  0.04108607      3 0.5860322 0.6649484 0.05538292
## 5  0.03346422      4 0.5449462 0.6181919 0.05017698
## 6  0.03053826      6 0.4780177 0.6139168 0.05000140
## 7  0.03031023      7 0.4474795 0.6103279 0.05196480
## 8  0.02611699      9 0.3868590 0.6079532 0.05137699
## 9  0.02491877     10 0.3607420 0.5897843 0.05087016
## 10 0.01930149     11 0.3358232 0.5779932 0.04771566
## 11 0.01325635     12 0.3165217 0.5619368 0.04586024
## 12 0.01046366     13 0.3032654 0.5398775 0.04530565
## 13 0.01000000     14 0.2928017 0.5566052 0.04721160
## 
## Variable importance
##   ShelveLoc       Price   CompPrice         Age Advertising  Population 
##          37          28          15           9           4           3 
##      Income   Education 
##           3           1 
## 
## Node number 1: 281 observations,    complexity param=0.2685171
##   mean=7.492456, MSE=7.690786 
##   left son=2 (222 obs) right son=3 (59 obs)
##   Primary splits:
##       ShelveLoc   splits as  LRL,       improve=0.26851710, (0 missing)
##       Price       < 94.5  to the right, improve=0.12196070, (0 missing)
##       Advertising < 13.5  to the left,  improve=0.06940121, (0 missing)
##       Age         < 61.5  to the right, improve=0.05387884, (0 missing)
##       Income      < 68.5  to the left,  improve=0.03495081, (0 missing)
##   Surrogate splits:
##       Income      < 21.5  to the right, agree=0.794, adj=0.017, (0 split)
##       Advertising < 19.5  to the left,  agree=0.794, adj=0.017, (0 split)
##       Price       < 168.5 to the left,  agree=0.794, adj=0.017, (0 split)
## 
## Node number 2: 222 observations,    complexity param=0.09667155
##   mean=6.751622, MSE=5.636254 
##   left son=4 (146 obs) right son=5 (76 obs)
##   Primary splits:
##       Price       < 105.5 to the right, improve=0.16696760, (0 missing)
##       ShelveLoc   splits as  L-R,       improve=0.12140820, (0 missing)
##       Advertising < 7.5   to the left,  improve=0.07412538, (0 missing)
##       Age         < 68.5  to the right, improve=0.06763535, (0 missing)
##       Income      < 98.5  to the left,  improve=0.05174565, (0 missing)
##   Surrogate splits:
##       CompPrice  < 116.5 to the right, agree=0.752, adj=0.276, (0 split)
##       Income     < 23    to the right, agree=0.667, adj=0.026, (0 split)
##       Population < 17.5  to the right, agree=0.662, adj=0.013, (0 split)
## 
## Node number 3: 59 observations,    complexity param=0.04108607
##   mean=10.28, MSE=5.585888 
##   left son=6 (42 obs) right son=7 (17 obs)
##   Primary splits:
##       Price       < 107.5 to the right, improve=0.26941850, (0 missing)
##       Advertising < 13.5  to the left,  improve=0.13064220, (0 missing)
##       Age         < 61.5  to the right, improve=0.11225400, (0 missing)
##       Population  < 345.5 to the left,  improve=0.09959193, (0 missing)
##       Education   < 11.5  to the right, improve=0.09094690, (0 missing)
##   Surrogate splits:
##       Population < 457.5 to the left,  agree=0.746, adj=0.118, (0 split)
##       CompPrice  < 105.5 to the right, agree=0.729, adj=0.059, (0 split)
##       Income     < 30    to the right, agree=0.729, adj=0.059, (0 split)
##       Age        < 73.5  to the left,  agree=0.729, adj=0.059, (0 split)
##       Education  < 17.5  to the left,  agree=0.729, adj=0.059, (0 split)
## 
## Node number 4: 146 observations,    complexity param=0.04877915
##   mean=6.051712, MSE=4.593309 
##   left son=8 (44 obs) right son=9 (102 obs)
##   Primary splits:
##       ShelveLoc   splits as  L-R,       improve=0.15719290, (0 missing)
##       CompPrice   < 124.5 to the left,  improve=0.13356190, (0 missing)
##       Price       < 137.5 to the right, improve=0.11499490, (0 missing)
##       Advertising < 10.5  to the left,  improve=0.07920236, (0 missing)
##       Age         < 64.5  to the right, improve=0.07657450, (0 missing)
## 
## Node number 5: 76 observations,    complexity param=0.03346422
##   mean=8.096184, MSE=4.890887 
##   left son=10 (61 obs) right son=11 (15 obs)
##   Primary splits:
##       Price       < 80.5  to the right, improve=0.1769107, (0 missing)
##       CompPrice   < 118.5 to the left,  improve=0.1655515, (0 missing)
##       Age         < 68.5  to the right, improve=0.1432999, (0 missing)
##       Income      < 102.5 to the left,  improve=0.1384654, (0 missing)
##       Advertising < 9.5   to the left,  improve=0.1365648, (0 missing)
##   Surrogate splits:
##       CompPrice  < 98.5  to the right, agree=0.868, adj=0.333, (0 split)
##       Population < 34    to the right, agree=0.829, adj=0.133, (0 split)
##       Education  < 17.5  to the left,  agree=0.816, adj=0.067, (0 split)
## 
## Node number 6: 42 observations,    complexity param=0.03053826
##   mean=9.499524, MSE=4.6455 
##   left son=12 (34 obs) right son=13 (8 obs)
##   Primary splits:
##       Advertising < 13.5  to the left,  improve=0.3382515, (0 missing)
##       Price       < 142.5 to the right, improve=0.2187682, (0 missing)
##       Population  < 345.5 to the left,  improve=0.1840546, (0 missing)
##       CompPrice   < 121.5 to the left,  improve=0.1746411, (0 missing)
##       US          splits as  LR,        improve=0.1348331, (0 missing)
##   Surrogate splits:
##       Population < 345.5 to the left,  agree=0.833, adj=0.125, (0 split)
## 
## Node number 7: 17 observations
##   mean=12.20824, MSE=2.686167 
## 
## Node number 8: 44 observations,    complexity param=0.02491877
##   mean=4.757955, MSE=3.30423 
##   left son=16 (37 obs) right son=17 (7 obs)
##   Primary splits:
##       CompPrice  < 144   to the left,  improve=0.37040830, (0 missing)
##       Age        < 60    to the right, improve=0.14999280, (0 missing)
##       Price      < 143.5 to the right, improve=0.14515880, (0 missing)
##       Population < 283   to the left,  improve=0.09732505, (0 missing)
##       Income     < 108   to the left,  improve=0.07934779, (0 missing)
## 
## Node number 9: 102 observations,    complexity param=0.03031023
##   mean=6.609804, MSE=4.11588 
##   left son=18 (24 obs) right son=19 (78 obs)
##   Primary splits:
##       Price       < 137.5 to the right, improve=0.1395049, (0 missing)
##       CompPrice   < 115.5 to the left,  improve=0.1392917, (0 missing)
##       Advertising < 1.5   to the left,  improve=0.1264752, (0 missing)
##       Income      < 57.5  to the left,  improve=0.1142586, (0 missing)
##       Age         < 41.5  to the right, improve=0.1097995, (0 missing)
##   Surrogate splits:
##       CompPrice < 146.5 to the right, agree=0.784, adj=0.083, (0 split)
## 
## Node number 10: 61 observations,    complexity param=0.03346422
##   mean=7.634918, MSE=4.433897 
##   left son=20 (19 obs) right son=21 (42 obs)
##   Primary splits:
##       Age         < 63.5  to the right, improve=0.29164580, (0 missing)
##       CompPrice   < 118.5 to the left,  improve=0.27933670, (0 missing)
##       ShelveLoc   splits as  L-R,       improve=0.15817220, (0 missing)
##       Advertising < 6.5   to the left,  improve=0.12382570, (0 missing)
##       Income      < 100.5 to the left,  improve=0.08312164, (0 missing)
##   Surrogate splits:
##       Income    < 23    to the left,  agree=0.738, adj=0.158, (0 split)
##       Price     < 82    to the left,  agree=0.721, adj=0.105, (0 split)
##       CompPrice < 98.5  to the left,  agree=0.705, adj=0.053, (0 split)
##       Education < 17.5  to the right, agree=0.705, adj=0.053, (0 split)
##       US        splits as  LR,        agree=0.705, adj=0.053, (0 split)
## 
## Node number 11: 15 observations
##   mean=9.972, MSE=2.365376 
## 
## Node number 12: 34 observations,    complexity param=0.01046366
##   mean=8.891471, MSE=3.404365 
##   left son=24 (8 obs) right son=25 (26 obs)
##   Primary splits:
##       Price      < 142.5 to the right, improve=0.1953644, (0 missing)
##       Income     < 47    to the left,  improve=0.1566892, (0 missing)
##       Population < 309   to the left,  improve=0.1249829, (0 missing)
##       Age        < 63.5  to the right, improve=0.1084243, (0 missing)
##       CompPrice  < 121.5 to the left,  improve=0.1077935, (0 missing)
##   Surrogate splits:
##       CompPrice < 153.5 to the right, agree=0.824, adj=0.250, (0 split)
##       Age       < 73.5  to the right, agree=0.824, adj=0.250, (0 split)
##       Income    < 99.5  to the right, agree=0.794, adj=0.125, (0 split)
## 
## Node number 13: 8 observations
##   mean=12.08375, MSE=1.670748 
## 
## Node number 16: 37 observations
##   mean=4.276757, MSE=2.061233 
## 
## Node number 17: 7 observations
##   mean=7.301429, MSE=2.181184 
## 
## Node number 18: 24 observations,    complexity param=0.02611699
##   mean=5.24375, MSE=4.893423 
##   left son=36 (8 obs) right son=37 (16 obs)
##   Primary splits:
##       Age         < 62.5  to the right, improve=0.4805916, (0 missing)
##       US          splits as  LR,        improve=0.3063699, (0 missing)
##       CompPrice   < 147.5 to the left,  improve=0.2591207, (0 missing)
##       Income      < 62.5  to the left,  improve=0.2576809, (0 missing)
##       Advertising < 3     to the left,  improve=0.2189801, (0 missing)
##   Surrogate splits:
##       Population < 42.5  to the left,  agree=0.750, adj=0.250, (0 split)
##       CompPrice  < 126   to the left,  agree=0.708, adj=0.125, (0 split)
## 
## Node number 19: 78 observations,    complexity param=0.03031023
##   mean=7.030128, MSE=3.125778 
##   left son=38 (32 obs) right son=39 (46 obs)
##   Primary splits:
##       CompPrice   < 124.5 to the left,  improve=0.29711820, (0 missing)
##       Advertising < 11.5  to the left,  improve=0.15046820, (0 missing)
##       Age         < 49.5  to the right, improve=0.12311780, (0 missing)
##       Income      < 57    to the left,  improve=0.11554350, (0 missing)
##       Population  < 417   to the left,  improve=0.06191788, (0 missing)
##   Surrogate splits:
##       Price      < 111.5 to the left,  agree=0.718, adj=0.312, (0 split)
##       Income     < 46.5  to the left,  agree=0.654, adj=0.156, (0 split)
##       Population < 64    to the left,  agree=0.654, adj=0.156, (0 split)
##       Age        < 31    to the left,  agree=0.654, adj=0.156, (0 split)
## 
## Node number 20: 19 observations
##   mean=5.944211, MSE=3.075056 
## 
## Node number 21: 42 observations,    complexity param=0.01930149
##   mean=8.399762, MSE=3.170498 
##   left son=42 (16 obs) right son=43 (26 obs)
##   Primary splits:
##       CompPrice   < 112.5 to the left,  improve=0.31325010, (0 missing)
##       Advertising < 9.5   to the left,  improve=0.21887860, (0 missing)
##       ShelveLoc   splits as  L-R,       improve=0.12205630, (0 missing)
##       Urban       splits as  RL,        improve=0.11541270, (0 missing)
##       Price       < 102.5 to the left,  improve=0.06975165, (0 missing)
##   Surrogate splits:
##       Age        < 60.5  to the right, agree=0.690, adj=0.187, (0 split)
##       Income     < 72.5  to the right, agree=0.667, adj=0.125, (0 split)
##       Population < 183.5 to the right, agree=0.643, adj=0.063, (0 split)
## 
## Node number 24: 8 observations
##   mean=7.42125, MSE=2.730511 
## 
## Node number 25: 26 observations
##   mean=9.343846, MSE=2.74197 
## 
## Node number 36: 8 observations
##   mean=3.075, MSE=3.5504 
## 
## Node number 37: 16 observations
##   mean=6.328125, MSE=2.037328 
## 
## Node number 38: 32 observations
##   mean=5.874687, MSE=2.063987 
## 
## Node number 39: 46 observations,    complexity param=0.01325635
##   mean=7.833913, MSE=2.289619 
##   left son=78 (19 obs) right son=79 (27 obs)
##   Primary splits:
##       Price       < 127   to the right, improve=0.2720068, (0 missing)
##       Advertising < 10.5  to the left,  improve=0.1473440, (0 missing)
##       Population  < 327.5 to the left,  improve=0.1430243, (0 missing)
##       Income      < 33    to the left,  improve=0.1171523, (0 missing)
##       CompPrice   < 138.5 to the left,  improve=0.1077796, (0 missing)
##   Surrogate splits:
##       Income     < 33    to the left,  agree=0.674, adj=0.211, (0 split)
##       Education  < 16.5  to the right, agree=0.652, adj=0.158, (0 split)
##       Age        < 53    to the right, agree=0.630, adj=0.105, (0 split)
##       CompPrice  < 145.5 to the right, agree=0.609, adj=0.053, (0 split)
##       Population < 145.5 to the left,  agree=0.609, adj=0.053, (0 split)
## 
## Node number 42: 16 observations
##   mean=7.129375, MSE=2.605056 
## 
## Node number 43: 26 observations
##   mean=9.181538, MSE=1.914128 
## 
## Node number 78: 19 observations
##   mean=6.893158, MSE=2.573832 
## 
## Node number 79: 27 observations
##   mean=8.495926, MSE=1.028565
#Plot the tree
library(rattle)
## Warning: package 'rattle' was built under R version 4.4.3
## Loading required package: tibble
## Loading required package: bitops
## Rattle: A free graphical interface for data science with R.
## Version 5.5.1 Copyright (c) 2006-2021 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.
## 
## Attaching package: 'rattle'
## The following object is masked from 'package:randomForest':
## 
##     importance
fancyRpartPlot(tree.carseats, sub = "", cex = 0.6)

The most important part of the tree appears to be the Shelving location, since the first branch differentiates Good locations from Bad and Medium locations.

#What test MSE do you obtain?
# Predict on test data
predictions <- predict(tree.carseats, newdata = testingData)

# Calculate Test MSE
mse <- mean((predictions - testingData$Sales)^2)
cat("Test MSE:", mse, "\n")
## Test MSE: 4.638469

(c) Use cross-validation in order to determine the optimal level of tree complexity. Does pruning the tree improve the test MSE?

#Using cross-validation
printcp(tree.carseats)
## 
## Regression tree:
## rpart(formula = Sales ~ ., data = trainingData, method = "anova", 
##     control = rpart.control(minsplit = 15, cp = 0.01))
## 
## Variables actually used in tree construction:
## [1] Advertising Age         CompPrice   Price       ShelveLoc  
## 
## Root node error: 2161.1/281 = 7.6908
## 
## n= 281 
## 
##          CP nsplit rel error  xerror     xstd
## 1  0.268517      0   1.00000 1.01120 0.082898
## 2  0.096672      1   0.73148 0.74707 0.059022
## 3  0.048779      2   0.63481 0.69955 0.055106
## 4  0.041086      3   0.58603 0.66495 0.055383
## 5  0.033464      4   0.54495 0.61819 0.050177
## 6  0.030538      6   0.47802 0.61392 0.050001
## 7  0.030310      7   0.44748 0.61033 0.051965
## 8  0.026117      9   0.38686 0.60795 0.051377
## 9  0.024919     10   0.36074 0.58978 0.050870
## 10 0.019301     11   0.33582 0.57799 0.047716
## 11 0.013256     12   0.31652 0.56194 0.045860
## 12 0.010464     13   0.30327 0.53988 0.045306
## 13 0.010000     14   0.29280 0.55661 0.047212
plotcp(tree.carseats)

In this case, i think the tree with 13 nodes is best at a cp value of 0.010000 since it’s cross-validated error is 0.29280.

tree.carseats$cptable[which.min(tree.carseats$cptable[,"xerror"]),"CP"]
## [1] 0.01046366
# pruning the tree 
carseats.prune=prune(tree.carseats,cp=tree.carseats$cptable[which.min(tree.carseats$cptable[,"xerror"]),"CP"])
fancyRpartPlot(carseats.prune, sub = "", cex = 0.6, uniform=TRUE, main="Pruned Regression Tree")

#Does pruning the tree improve the test MSE?
# Predict using pruned tree
predictions_pruned <- predict(carseats.prune, newdata = testingData)

# Compute test MSE for pruned tree
mse_pruned <- mean((predictions_pruned - testingData$Sales)^2)

# Print the result
cat("Test MSE (original tree):", mse, "\n")
## Test MSE (original tree): 4.638469
cat("Test MSE (pruned tree):", mse_pruned, "\n")
## Test MSE (pruned tree): 4.820072

Pruning the tree did not improve the test MSE in this case.

(d) Use the bagging approach in order to analyze this data. What test MSE do you obtain? Use the importance() function to determine which variables are most important.

##Split into training and test
inTrain=createDataPartition(df$Sales,p=0.5,list=FALSE)
train=df[inTrain,]
test=df[-inTrain,]
# Set seed for reproducibility
set.seed(123)

# Number of predictors
p <- ncol(train) - 1  # subtract 1 for the response variable

# Fit bagged model (mtry = p)
bagging.model <- randomForest(Sales ~ ., data = train, mtry = p, importance = TRUE)

# Predict on test data
bagging.pred <- predict(bagging.model, newdata = test)

# Compute test MSE
mse_bagging <- mean((bagging.pred - test$Sales)^2)
cat("Test MSE for the Bagging:", mse_bagging, "\n")
## Test MSE for the Bagging: 2.86157
# Use the importance() function to determine which variables are most important
library(randomForest)

# Show variable importance
randomForest::importance(bagging.model)
##                %IncMSE IncNodePurity
## CompPrice   20.6512551    129.600105
## Income       7.5062813     71.974410
## Advertising 17.1874846    131.326252
## Population   3.8035368     88.503247
## Price       55.2342436    435.411234
## ShelveLoc   56.9337181    447.490435
## Age         14.0201967     96.064386
## Education   -0.7150610     38.142492
## Urban       -0.5990693      4.951484
## US           3.1852843      7.125072
# Plot variable importance
varImpPlot(bagging.model)

(e) Use random forests to analyze this data. What test MSE do you obtain? Use the importance() function to determine which variables are most important. Describe the effect of m, the number of variables considered at each split, on the error rate obtained.

##fit random forest
set.seed(123)
Carseats.rf=train(Sales~.,data=train,method='rf',trControl = trainControl("cv", number = 10),importance = TRUE)

##best tuning parameter
Carseats.rf$bestTune
##   mtry
## 3   11
set.seed(123)
##final model
Carseats.rf$finalModel
## 
## Call:
##  randomForest(x = x, y = y, mtry = param$mtry, importance = TRUE) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 11
## 
##           Mean of squared residuals: 2.742871
##                     % Var explained: 63.33

By default 500 trees are trained. The optimal number of variables sampled at each split is 11

#Use the importance to determine which variables are most important
randomForest::importance(Carseats.rf$finalModel)
##                     %IncMSE IncNodePurity
## CompPrice       22.70396742    132.621290
## Income           6.24288585     71.691182
## Advertising     17.17014169    129.390252
## Population       5.28984739     90.545716
## Price           51.11405427    446.249324
## ShelveLocGood   57.60366192    346.575965
## ShelveLocMedium 27.43929884    107.195435
## Age             11.71714823     93.236072
## Education        2.17564082     38.588585
## UrbanYes        -0.13924360      5.416655
## USYes           -0.01240649      6.032990
plot(varImp(Carseats.rf))

# What test MSE do you obtain?
rf.pred <- predict(Carseats.rf, newdata=test)
mse_rf <- mean((rf.pred - test$Sales)^2)
cat("Test MSE:", mse_rf, "\n")
## Test MSE: 2.951448

The optimal number of variables considered at each split, mtry, was found to be 6 using 10-fold cross-validation. This value led to the lowest cross-validated error and a test MSE of 2.82, with the model explaining 61.6% of the variance in Sales.

(f) Now analyze the data using BART, and report your results.

#Prepare the data
# Extract X and y for training and test
x.train <- train[, setdiff(names(train), "Sales")]
y.train <- train$Sales

x.test <- test[, setdiff(names(test), "Sales")]
y.test <- test$Sales
#Fit the BART model
set.seed(123)
bart.model <- wbart(x.train, y.train, x.test)
## *****Into main of wbart
## *****Data:
## data:n,p,np: 201, 14, 199
## y1,yn: 3.754328, 2.244328
## x1,x[n*p]: 111.000000, 1.000000
## xp1,xp[np*p]: 138.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 62 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,0.251200,3.000000,0.197037
## *****sigma: 1.005748
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,14,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 100
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1100)
## done 100 (out of 1100)
## done 200 (out of 1100)
## done 300 (out of 1100)
## done 400 (out of 1100)
## done 500 (out of 1100)
## done 600 (out of 1100)
## done 700 (out of 1100)
## done 800 (out of 1100)
## done 900 (out of 1100)
## done 1000 (out of 1100)
## time: 5s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
#Evaluate test MSE
# Predicted values on test data
yhat.test <- bart.model$yhat.test.mean

# Test MSE
mse.bart <- mean((yhat.test - y.test)^2)
cat("Test MSE (BART):", mse.bart, "\n")
## Test MSE (BART): 1.454383
#Variable importance plot
varcount <- colMeans(bart.model$varcount)
barplot(sort(varcount, decreasing = TRUE), 
        main = "Variable Importance (BART)",
        col = "steelblue", las = 2, 
        ylab = "Avg. Number of Splits")

9. This problem involves the OJ data set which is part of the ISLR2 package.

library(ISLR2)
## Warning: package 'ISLR2' was built under R version 4.4.3
## 
## Attaching package: 'ISLR2'
## The following objects are masked from 'package:ISLR':
## 
##     Auto, Credit
#Load the data
attach(OJ)
d1 = OJ

# Check for missing values
sum(is.na(d1))
## [1] 0
#See data structure
str(d1)
## 'data.frame':    1070 obs. of  18 variables:
##  $ Purchase      : Factor w/ 2 levels "CH","MM": 1 1 1 2 1 1 1 1 1 1 ...
##  $ WeekofPurchase: num  237 239 245 227 228 230 232 234 235 238 ...
##  $ StoreID       : num  1 1 1 1 7 7 7 7 7 7 ...
##  $ PriceCH       : num  1.75 1.75 1.86 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
##  $ PriceMM       : num  1.99 1.99 2.09 1.69 1.69 1.99 1.99 1.99 1.99 1.99 ...
##  $ DiscCH        : num  0 0 0.17 0 0 0 0 0 0 0 ...
##  $ DiscMM        : num  0 0.3 0 0 0 0 0.4 0.4 0.4 0.4 ...
##  $ SpecialCH     : num  0 0 0 0 0 0 1 1 0 0 ...
##  $ SpecialMM     : num  0 1 0 0 0 1 1 0 0 0 ...
##  $ LoyalCH       : num  0.5 0.6 0.68 0.4 0.957 ...
##  $ SalePriceMM   : num  1.99 1.69 2.09 1.69 1.69 1.99 1.59 1.59 1.59 1.59 ...
##  $ SalePriceCH   : num  1.75 1.75 1.69 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
##  $ PriceDiff     : num  0.24 -0.06 0.4 0 0 0.3 -0.1 -0.16 -0.16 -0.16 ...
##  $ Store7        : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 2 2 2 2 2 ...
##  $ PctDiscMM     : num  0 0.151 0 0 0 ...
##  $ PctDiscCH     : num  0 0 0.0914 0 0 ...
##  $ ListPriceDiff : num  0.24 0.24 0.23 0 0 0.3 0.3 0.24 0.24 0.24 ...
##  $ STORE         : num  1 1 1 1 0 0 0 0 0 0 ...

(a)Create a training set containing a random sample of 800 observations, and a test set containing the remaining observations.

# Number of observations to use in training
train_indices <- sample(1:nrow(d1), 800)

# Create training and test sets
OJ_train <- OJ[train_indices, ]
OJ_test <- OJ[-train_indices, ]

(b) Fit a tree to the training data, with Purchase as the response and the other variables as predictors. Use the summary() function to produce summary statistics about the tree, and describe the results obtained. What is the training error rate? How many terminal nodes does the tree have?

library(tree)
## Warning: package 'tree' was built under R version 4.4.3
set.seed(123)
# Fit the decision tree
tree.oj <- tree(Purchase~., data=OJ_train)

# Summary statistics
summary(tree.oj)
## 
## Classification tree:
## tree(formula = Purchase ~ ., data = OJ_train)
## Variables actually used in tree construction:
## [1] "LoyalCH"     "SalePriceMM" "PriceMM"     "PriceDiff"  
## Number of terminal nodes:  8 
## Residual mean deviance:  0.7239 = 573.4 / 792 
## Misclassification error rate: 0.1662 = 133 / 800

There are 7 terminal nodes

#What is the training error rate?
#Predictions
predict_oj <- predict(tree.oj, OJ_train, type="class")
#Training error
training_error <- mean(predict_oj!= OJ_train$Purchase)
cat("Traing Error is:", training_error, "\n")
## Traing Error is: 0.16625

(c) Type in the name of the tree object in order to get a detailed text output. Pick one of the terminal nodes, and interpret the information displayed.

set.seed(123)
tree.oj
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 800 1062.00 CH ( 0.62125 0.37875 )  
##    2) LoyalCH < 0.482304 300  326.00 MM ( 0.23333 0.76667 )  
##      4) LoyalCH < 0.13004 96   39.28 MM ( 0.05208 0.94792 ) *
##      5) LoyalCH > 0.13004 204  255.30 MM ( 0.31863 0.68137 )  
##       10) SalePriceMM < 2.04 115  112.30 MM ( 0.19130 0.80870 )  
##         20) PriceMM < 2.11 88   98.97 MM ( 0.25000 0.75000 ) *
##         21) PriceMM > 2.11 27    0.00 MM ( 0.00000 1.00000 ) *
##       11) SalePriceMM > 2.04 89  123.30 MM ( 0.48315 0.51685 ) *
##    3) LoyalCH > 0.482304 500  415.70 CH ( 0.85400 0.14600 )  
##      6) LoyalCH < 0.754144 237  276.50 CH ( 0.72996 0.27004 )  
##       12) PriceDiff < 0.265 144  192.50 CH ( 0.61111 0.38889 )  
##         24) PriceDiff < -0.35 16   15.44 MM ( 0.18750 0.81250 ) *
##         25) PriceDiff > -0.35 128  163.40 CH ( 0.66406 0.33594 ) *
##       13) PriceDiff > 0.265 93   54.54 CH ( 0.91398 0.08602 ) *
##      7) LoyalCH > 0.754144 263   78.44 CH ( 0.96578 0.03422 ) *

Node 13 is a terminal node with 166 observations where LoyalCH is between 0.48285 and 0.764572 and PriceDiff is greater than 0.145. The predicted class is Citrus Hill (CH), with an 83.1% probability of choosing CH and 16.9% for Minute Maid (MM). The node’s deviance is 150.70, and it does not split further.

(d) Create a plot of the tree, and interpret the results

library(tree)
plot(tree.oj)
text(tree.oj, pretty = 0)

The most important predictor of Purchase appears to be loyalCH

(e) Predict the response on the test data, and produce a confusion matrix comparing the test labels to the predicted test labels. What is the test error rate?

set.seed(123)
#Predict the response on the test data
predict_test <- predict(tree.oj, newdata=OJ_test, type="class" )
#Test error rate
mse_oj <- mean(predict_test!=OJ_test$Purchase)
cat("Test Error Rate is:", mse_oj, "\n")
## Test Error Rate is: 0.2111111
#Confusion matrix
confusion_matrix <- table(Predicted = predict_test, Actual = OJ_test$Purchase)
print(confusion_matrix)
##          Actual
## Predicted  CH  MM
##        CH 131  32
##        MM  25  82

(f) Apply the cv.tree() function to the training set in order to determine the optimal tree size.

# Perform cross-validation to determine optimal tree size
cv_oj <- cv.tree(tree.oj, FUN = prune.misclass)

# Print the cross-validation results
print(cv_oj)
## $size
## [1] 8 5 2 1
## 
## $dev
## [1] 154 154 152 303
## 
## $k
## [1]       -Inf   0.000000   3.333333 160.000000
## 
## $method
## [1] "misclass"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

The smallest deviance (156) occurs for tree sizes 7 and 5, meaning both perform equally well in terms of classification error.The optimal tree size can be either of the two

(g) Produce a plot with tree size on the x-axis and cross-validated classification error rate on the y-axis.

# Create the cross-validation plot
plot(cv_oj$size, cv_oj$dev, type = "b", pch = 19, col = "blue",
     xlab = "Tree Size (Number of Terminal Nodes)",
     ylab = "CV Misclassification Error",
     main = "Cross-Validation: Tree Size vs Error Rate")

(h) Which tree size corresponds to the lowest cross-validated classification error rate?
The lowest cross-validated classification error rate corresponds to tree size = 5 and tree size = 7, both yielding the same minimal error (156 misclassifications).

However, since smaller trees are preferred for simplicity and interpretability, the optimal choice would be the tree with size 5.

(i) Produce a pruned tree corresponding to the optimal tree size obtained using cross-validation. If cross-validation does not lead to selection of a pruned tree, then create a pruned tree with five terminal nodes.

# Prune the tree 
pruned_oj <- prune.tree(tree.oj, best = 5)

# Plot the pruned tree
plot(pruned_oj)
text(pruned_oj, pretty = 0)

(j) Compare the training error rates between the pruned and unpruned trees. Which is higher?

# Predictions on training data for unpruned tree
pred_unpruned <- predict(tree.oj, newdata = OJ_train, type = "class")
train_error_unpruned <- mean(pred_unpruned != OJ_train$Purchase)

# Predictions on training data for pruned tree
pred_pruned <- predict(pruned_oj, newdata = OJ_train, type = "class")
train_error_pruned <- mean(pred_pruned != OJ_train$Purchase)

# Print results
cat("Training error rate (Unpruned tree):", train_error_unpruned, "\n")
## Training error rate (Unpruned tree): 0.16625
cat("Training error rate (Pruned tree):", train_error_pruned, "\n")
## Training error rate (Pruned tree): 0.17875

The pruned tree has a slightly higher training error than the unpruned tree.

(k) Compare the test error rates between the pruned and unpruned trees. Which is higher?

# Predict on test data using unpruned tree
pred_test_unpruned <- predict(tree.oj, newdata = OJ_test, type = "class")
test_error_unpruned <- mean(pred_test_unpruned != OJ_test$Purchase)

# Predict on test data using pruned tree
pred_test_pruned <- predict(pruned_oj, newdata = OJ_test, type = "class")
test_error_pruned <- mean(pred_test_pruned != OJ_test$Purchase)

# Print the test error rates
cat("Test error rate (Unpruned tree):", test_error_unpruned, "\n")
## Test error rate (Unpruned tree): 0.2111111
cat("Test error rate (Pruned tree):", test_error_pruned, "\n")
## Test error rate (Pruned tree): 0.2259259

The pruned tree has a slightly higher test error rate compared to the unpruned tree. In conclusion, pruning the tree did not help reduce the misclassification error rate in this case.