Consider the Gini index, classification error, and entropy in a simple classification setting with two classes. Create a single plot that displays each of these quantities as a function of \(\hat{p}_{m1}\). The x-axis should display \(\hat{p}_{m1}\),ranging from 0 to 1, and the y-axis should display the value of the Gini index, classification error, and entropy.
Hint: In a setting with two classes, \(\hat{p}_{m1} = 1-\hat{p}_{m2}\). You could make this plot by hand, but it will be much easier to make in R.
A) Defining Classification Eror, Gini Index and Entropy which are the measures of the classification Error Rate
Classification Error Rate E:: When Class 1 is most common class: \(E = 1-\hat{p}_{m1}\) when (Class 1 is least common class): \(E = 1-\hat{p}_{m2} = 1-(1-\hat{p}_{m1})\)
Combining both scenarios we can write it as \(E = 1-max\{\hat{p}_{m2},\hat{p}_{m1}\}\)
Gini Index G::
\(G = \hat{p}_{m1}(1-\hat{p}_{m1})+\hat{p}_{m2}(1-\hat{p}_{m2})\)
Entropy D:: \(D = -\hat{p}_{m1}\log\hat{p}_{m1}-(1-\hat{p}_{m1})\log(1-\hat{p}_{m1})\)
p_class_1 <- seq(0, 1, 0.001)
p_class_2 <- 1 - p_class_1
classification_error_E <- 1 - pmax(p_class_1, p_class_2)
gini <- p_class_1*(1-p_class_1) + p_class_2*(1-p_class_2)
entropy <- -p_class_1*log(p_class_1) - p_class_2*log(p_class_2)
errors <- data.frame(p_class_1, p_class_2, classification_error_E, gini, entropy)
errors <- na.omit(errors)
errors %>% pivot_longer(cols = c(classification_error_E, gini, entropy), names_to = "ErrorType") %>% ggplot(aes(x = p_class_1, y = value, col = factor(ErrorType))) +geom_line() + labs(col = "Error Type",
y = "Value",
x = "Proportion (of class '1')")
In the lab, a classification tree was applied to the Carseats data set after converting Sales into a qualitative response variable. Now we will seek to predict Sales using regression trees and related approaches,treating the response as a quantitative variable.
(a) Split the data set into a training set and a test set.
set.seed(2)
df_carseats <- Carseats
##Split into training and test
inTrain=createDataPartition(df_carseats$Sales,p=0.7,list=FALSE)
train_cs=df_carseats[inTrain,]
train <- train_cs
test_cs=df_carseats[-inTrain,]
(b) Fit a regression tree to the training set. Plot the tree, and interpret the results. What test MSE do you obtain?
# Regression tree on training set.
tree_carseats = tree(Sales~.,data=train_cs)
summary(tree_carseats)
##
## Regression tree:
## tree(formula = Sales ~ ., data = train_cs)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Age" "Advertising" "CompPrice"
## [6] "US"
## Number of terminal nodes: 17
## Residual mean deviance: 2.444 = 645.2 / 264
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -3.8830 -0.9533 0.0352 0.0000 0.9613 3.9750
plot(tree_carseats)
text(tree_carseats,pretty=1)
# Test MSE.
cs_pred = predict(tree_carseats,test_cs)
cs_mse = mean((cs_pred-test_cs$Sales)^2)
cs_mse
## [1] 4.45589
Using rpart method to check the optimal nodes:
tree.carseats=rpart(Sales~.,data=train_cs, method="anova", control=rpart.control(minsplit=17, cp=0.01))
fancyRpartPlot(tree.carseats)
summary(tree.carseats)
## Call:
## rpart(formula = Sales ~ ., data = train_cs, method = "anova",
## control = rpart.control(minsplit = 17, cp = 0.01))
## n= 281
##
## CP nsplit rel error xerror xstd
## 1 0.24712187 0 1.0000000 1.0121632 0.08401072
## 2 0.10788188 1 0.7528781 0.7701805 0.06132119
## 3 0.06849690 2 0.6449963 0.7205982 0.05836172
## 4 0.03222861 3 0.5764994 0.6943517 0.05425237
## 5 0.03149356 4 0.5442707 0.7476479 0.05837820
## 6 0.03141147 5 0.5127772 0.7476479 0.05837820
## 7 0.02993960 7 0.4499542 0.7526464 0.05837650
## 8 0.02260164 9 0.3900750 0.7302173 0.05886077
## 9 0.01856731 10 0.3674734 0.7111268 0.06074441
## 10 0.01500930 11 0.3489061 0.7117272 0.06265321
## 11 0.01189172 12 0.3338968 0.6641162 0.05704285
## 12 0.01109414 13 0.3220051 0.6762581 0.05563152
## 13 0.01016251 14 0.3109109 0.6617373 0.05405029
## 14 0.01000000 15 0.3007484 0.6604523 0.05416623
##
## Variable importance
## Price ShelveLoc CompPrice Advertising Population Age
## 27 27 13 10 6 6
## US Income Education Urban
## 4 3 2 1
##
## Node number 1: 281 observations, complexity param=0.2471219
## mean=7.569146, MSE=7.916595
## left son=2 (220 obs) right son=3 (61 obs)
## Primary splits:
## ShelveLoc splits as LRL, improve=0.24712190, (0 missing)
## Price < 95 to the right, improve=0.16393610, (0 missing)
## Age < 61.5 to the right, improve=0.07614442, (0 missing)
## Advertising < 9.5 to the left, improve=0.06608259, (0 missing)
## US splits as LR, improve=0.03557244, (0 missing)
## Surrogate splits:
## Price < 168.5 to the left, agree=0.79, adj=0.033, (0 split)
##
## Node number 2: 220 observations, complexity param=0.1078819
## mean=6.832636, MSE=5.672099
## left son=4 (143 obs) right son=5 (77 obs)
## Primary splits:
## Price < 105.5 to the right, improve=0.19232100, (0 missing)
## Advertising < 7.5 to the left, improve=0.08436887, (0 missing)
## ShelveLoc splits as L-R, improve=0.06835498, (0 missing)
## Age < 61.5 to the right, improve=0.06033545, (0 missing)
## Income < 96.5 to the left, improve=0.04663377, (0 missing)
## Surrogate splits:
## CompPrice < 108.5 to the right, agree=0.745, adj=0.273, (0 split)
## Population < 506 to the left, agree=0.659, adj=0.026, (0 split)
##
## Node number 3: 61 observations, complexity param=0.0684969
## mean=10.22541, MSE=6.999399
## left son=6 (42 obs) right son=7 (19 obs)
## Primary splits:
## Price < 107.5 to the right, improve=0.35688240, (0 missing)
## Age < 61.5 to the right, improve=0.23528680, (0 missing)
## Education < 11.5 to the right, improve=0.19586070, (0 missing)
## Advertising < 13.5 to the left, improve=0.09837488, (0 missing)
## Population < 77 to the right, improve=0.06074540, (0 missing)
## Surrogate splits:
## Population < 77 to the right, agree=0.754, adj=0.211, (0 split)
## Education < 11.5 to the right, agree=0.754, adj=0.211, (0 split)
## CompPrice < 113.5 to the right, agree=0.738, adj=0.158, (0 split)
## Income < 29 to the right, agree=0.721, adj=0.105, (0 split)
##
## Node number 4: 143 observations, complexity param=0.03222861
## mean=6.066224, MSE=4.25466
## left son=8 (59 obs) right son=9 (84 obs)
## Primary splits:
## Price < 129.5 to the right, improve=0.11783800, (0 missing)
## Advertising < 10.5 to the left, improve=0.10798530, (0 missing)
## CompPrice < 124.5 to the left, improve=0.10007040, (0 missing)
## ShelveLoc splits as L-R, improve=0.09859884, (0 missing)
## Age < 50.5 to the right, improve=0.06782543, (0 missing)
## Surrogate splits:
## CompPrice < 135.5 to the right, agree=0.692, adj=0.254, (0 split)
## Income < 38 to the left, agree=0.615, adj=0.068, (0 split)
## Population < 487 to the right, agree=0.615, adj=0.068, (0 split)
## Education < 17.5 to the right, agree=0.608, adj=0.051, (0 split)
## Advertising < 20 to the right, agree=0.601, adj=0.034, (0 split)
##
## Node number 5: 77 observations, complexity param=0.03141147
## mean=8.255974, MSE=5.187731
## left son=10 (61 obs) right son=11 (16 obs)
## Primary splits:
## Price < 80.5 to the right, improve=0.1663795, (0 missing)
## Age < 54.5 to the right, improve=0.1510254, (0 missing)
## ShelveLoc splits as L-R, improve=0.1397998, (0 missing)
## CompPrice < 123.5 to the left, improve=0.1365667, (0 missing)
## Income < 102 to the left, improve=0.1335882, (0 missing)
## Surrogate splits:
## CompPrice < 98.5 to the right, agree=0.870, adj=0.375, (0 split)
## Population < 34 to the right, agree=0.831, adj=0.188, (0 split)
##
## Node number 6: 42 observations, complexity param=0.03149356
## mean=9.162381, MSE=5.138932
## left son=12 (10 obs) right son=13 (32 obs)
## Primary splits:
## US splits as LR, improve=0.3245969, (0 missing)
## Advertising < 13.5 to the left, improve=0.3203558, (0 missing)
## Price < 144 to the right, improve=0.2947885, (0 missing)
## Age < 61.5 to the right, improve=0.2062498, (0 missing)
## Population < 345.5 to the left, improve=0.1659646, (0 missing)
## Surrogate splits:
## Advertising < 1 to the left, agree=0.929, adj=0.7, (0 split)
## Age < 76.5 to the right, agree=0.810, adj=0.2, (0 split)
## Income < 32.5 to the left, agree=0.786, adj=0.1, (0 split)
##
## Node number 7: 19 observations
## mean=12.57526, MSE=3.092235
##
## Node number 8: 59 observations, complexity param=0.02260164
## mean=5.221356, MSE=3.780378
## left son=16 (42 obs) right son=17 (17 obs)
## Primary splits:
## CompPrice < 142 to the left, improve=0.22542260, (0 missing)
## ShelveLoc splits as L-R, improve=0.22141920, (0 missing)
## Age < 65 to the right, improve=0.11986160, (0 missing)
## Advertising < 15.5 to the left, improve=0.10363610, (0 missing)
## Income < 30.5 to the right, improve=0.03894448, (0 missing)
## Surrogate splits:
## Price < 153 to the left, agree=0.797, adj=0.294, (0 split)
## Income < 30.5 to the right, agree=0.763, adj=0.176, (0 split)
## Advertising < 23 to the left, agree=0.746, adj=0.118, (0 split)
## Age < 27.5 to the right, agree=0.746, adj=0.118, (0 split)
## Population < 129 to the right, agree=0.729, adj=0.059, (0 split)
##
## Node number 9: 84 observations, complexity param=0.0299396
## mean=6.659643, MSE=3.73428
## left son=18 (18 obs) right son=19 (66 obs)
## Primary splits:
## CompPrice < 115.5 to the left, improve=0.20905150, (0 missing)
## Advertising < 10.5 to the left, improve=0.18851150, (0 missing)
## Age < 52.5 to the right, improve=0.12014750, (0 missing)
## Income < 72 to the left, improve=0.06484556, (0 missing)
## ShelveLoc splits as L-R, improve=0.04568950, (0 missing)
## Surrogate splits:
## Population < 479 to the right, agree=0.845, adj=0.278, (0 split)
##
## Node number 10: 61 observations, complexity param=0.03141147
## mean=7.780164, MSE=4.604582
## left son=20 (25 obs) right son=21 (36 obs)
## Primary splits:
## Age < 60.5 to the right, improve=0.2609391, (0 missing)
## CompPrice < 123.5 to the left, improve=0.2180911, (0 missing)
## ShelveLoc splits as L-R, improve=0.1603595, (0 missing)
## Advertising < 6.5 to the left, improve=0.1155796, (0 missing)
## Population < 157 to the right, improve=0.1089737, (0 missing)
## Surrogate splits:
## Population < 250 to the right, agree=0.672, adj=0.20, (0 split)
## Education < 17.5 to the right, agree=0.623, adj=0.08, (0 split)
## Advertising < 19 to the right, agree=0.607, adj=0.04, (0 split)
## Price < 81.5 to the left, agree=0.607, adj=0.04, (0 split)
## ShelveLoc splits as L-R, agree=0.607, adj=0.04, (0 split)
##
## Node number 11: 16 observations
## mean=10.07, MSE=3.257162
##
## Node number 12: 10 observations
## mean=6.852, MSE=2.193696
##
## Node number 13: 32 observations, complexity param=0.01856731
## mean=9.884375, MSE=3.869962
## left son=26 (25 obs) right son=27 (7 obs)
## Primary splits:
## Advertising < 13.5 to the left, improve=0.33353170, (0 missing)
## Population < 345.5 to the left, improve=0.18606440, (0 missing)
## Price < 137.5 to the right, improve=0.17161020, (0 missing)
## Age < 61.5 to the right, improve=0.16952270, (0 missing)
## CompPrice < 120 to the left, improve=0.09643594, (0 missing)
## Surrogate splits:
## Population < 345.5 to the left, agree=0.844, adj=0.286, (0 split)
##
## Node number 16: 42 observations, complexity param=0.01109414
## mean=4.634048, MSE=2.690024
## left son=32 (16 obs) right son=33 (26 obs)
## Primary splits:
## ShelveLoc splits as L-R, improve=0.2184404, (0 missing)
## Age < 63.5 to the right, improve=0.1698168, (0 missing)
## CompPrice < 123 to the left, improve=0.1223636, (0 missing)
## Income < 89 to the left, improve=0.1123349, (0 missing)
## Advertising < 10.5 to the left, improve=0.1116564, (0 missing)
## Surrogate splits:
## Price < 130.5 to the left, agree=0.690, adj=0.187, (0 split)
## Advertising < 18 to the right, agree=0.667, adj=0.125, (0 split)
## Population < 148.5 to the left, agree=0.667, adj=0.125, (0 split)
## Age < 33.5 to the left, agree=0.667, adj=0.125, (0 split)
## CompPrice < 136.5 to the right, agree=0.643, adj=0.063, (0 split)
##
## Node number 17: 17 observations
## mean=6.672353, MSE=3.516618
##
## Node number 18: 18 observations
## mean=4.967778, MSE=2.431084
##
## Node number 19: 66 observations, complexity param=0.0299396
## mean=7.121061, MSE=3.096134
## left son=38 (47 obs) right son=39 (19 obs)
## Primary splits:
## Advertising < 10.5 to the left, improve=0.33095980, (0 missing)
## Age < 49.5 to the right, improve=0.17763200, (0 missing)
## ShelveLoc splits as L-R, improve=0.11599380, (0 missing)
## Population < 204 to the left, improve=0.10467630, (0 missing)
## CompPrice < 136 to the left, improve=0.09221327, (0 missing)
## Surrogate splits:
## Income < 115.5 to the left, agree=0.773, adj=0.211, (0 split)
## Population < 417 to the left, agree=0.773, adj=0.211, (0 split)
## Age < 30.5 to the right, agree=0.727, adj=0.053, (0 split)
##
## Node number 20: 25 observations
## mean=6.4648, MSE=3.115681
##
## Node number 21: 36 observations, complexity param=0.0150093
## mean=8.693611, MSE=3.60264
## left son=42 (18 obs) right son=43 (18 obs)
## Primary splits:
## Advertising < 6 to the left, improve=0.2574434, (0 missing)
## CompPrice < 117.5 to the left, improve=0.2198174, (0 missing)
## Age < 30 to the right, improve=0.1715830, (0 missing)
## US splits as LR, improve=0.1030712, (0 missing)
## ShelveLoc splits as L-R, improve=0.0934053, (0 missing)
## Surrogate splits:
## US splits as LR, agree=0.861, adj=0.722, (0 split)
## Urban splits as RL, agree=0.750, adj=0.500, (0 split)
## Income < 98 to the left, agree=0.639, adj=0.278, (0 split)
## Price < 95 to the left, agree=0.639, adj=0.278, (0 split)
## Population < 420 to the left, agree=0.611, adj=0.222, (0 split)
##
## Node number 26: 25 observations
## mean=9.2832, MSE=2.791598
##
## Node number 27: 7 observations
## mean=12.03143, MSE=1.820669
##
## Node number 32: 16 observations
## mean=3.656875, MSE=1.507946
##
## Node number 33: 26 observations
## mean=5.235385, MSE=2.46824
##
## Node number 38: 47 observations, complexity param=0.01016251
## mean=6.477447, MSE=2.221959
## left son=76 (39 obs) right son=77 (8 obs)
## Primary splits:
## Age < 37.5 to the right, improve=0.21647690, (0 missing)
## CompPrice < 142.5 to the left, improve=0.19124320, (0 missing)
## ShelveLoc splits as L-R, improve=0.18904880, (0 missing)
## Income < 54 to the right, improve=0.04120671, (0 missing)
## Education < 17.5 to the left, improve=0.04067719, (0 missing)
## Surrogate splits:
## CompPrice < 154 to the left, agree=0.872, adj=0.25, (0 split)
##
## Node number 39: 19 observations
## mean=8.713158, MSE=1.699095
##
## Node number 42: 18 observations, complexity param=0.01189172
## mean=7.730556, MSE=3.411883
## left son=84 (12 obs) right son=85 (6 obs)
## Primary splits:
## CompPrice < 126.5 to the left, improve=0.4307474, (0 missing)
## Education < 16.5 to the left, improve=0.3272920, (0 missing)
## Population < 156 to the right, improve=0.1870630, (0 missing)
## Price < 95.5 to the right, improve=0.1399644, (0 missing)
## ShelveLoc splits as L-R, improve=0.1035401, (0 missing)
## Surrogate splits:
## Age < 30.5 to the right, agree=0.778, adj=0.333, (0 split)
## Education < 15.5 to the left, agree=0.778, adj=0.333, (0 split)
## Population < 80 to the right, agree=0.722, adj=0.167, (0 split)
##
## Node number 43: 18 observations
## mean=9.656667, MSE=1.938444
##
## Node number 76: 39 observations
## mean=6.163333, MSE=1.688068
##
## Node number 77: 8 observations
## mean=8.00875, MSE=1.998786
##
## Node number 84: 12 observations
## mean=6.873333, MSE=1.929339
##
## Node number 85: 6 observations
## mean=9.445, MSE=1.967992
cs.pred = predict(tree.carseats,test_cs)
cs.mse = mean((cs.pred-test_cs$Sales)^2)
cs.mse
## [1] 4.621769
A)Based on the summary and the tree plot important predictors in tree construction are "ShelveLoc" "Price" "Age" "Advertising" "CompPrice" "US".
Most important or top ones are "ShelveLoc" "Price". Tree has 17 terminal nodes and
\(Regression\ Tree \ MSE =\) 4.4558903.
(c) Use cross-validation in order to determine the optimal level of tree complexity. Does pruning the tree improve the test MSE?
set.seed(2)
cv_carseats = cv.tree(tree_carseats)
# plot(cv.carseats$size,cv.carseats$dev,xlab="Terminal Nodes",ylab="CV Error",type="b")
data.frame(terminals = cv_carseats$size,cverror=cv_carseats$dev) %>%
ggplot(aes(x = terminals, y = cverror)) +
geom_line() +
geom_point(size = 2,aes(terminals))+
geom_point(size=4,aes(16,1325),col="blue")+
scale_x_continuous(breaks = seq(1, 17, 1))+
labs(y = "CV Error",
x = "Terminal Nodes")
A) Cross validation for the optimal terminal nodes resulted in 16, Now model the tree with best number of nodes:
prune_carseats = prune.tree(tree_carseats,best=16)
prune_pred = predict(prune_carseats,test_cs)
prune_mse = mean((prune_pred-test_cs$Sales)^2)
prune_mse
## [1] 4.407265
The test MSE for pruned tree is slightly better than the un pruned tree.$Pruned Tree MSE=$4.4072655.
(d) Use the bagging approach in order to analyze this data. What test MSE do you obtain? Use the importance() function to determine which variables are most important.
##use bagging approach via Random Forest
mtry <- ncol(train_cs)-1 #count all the predictors
tunegrid <- expand.grid(.mtry=mtry)
bag_rf=train(Sales~.,data=train,method='rf',trControl = trainControl("cv", number = 10),tuneGrid=tunegrid,importance = TRUE)
# important predictors
varImp(bag_rf)
## rf variable importance
##
## Overall
## ShelveLocGood 100.000
## Price 96.661
## CompPrice 45.324
## ShelveLocMedium 35.935
## Advertising 33.300
## Age 29.262
## Income 15.798
## USYes 14.748
## Population 5.330
## Education 2.036
## UrbanYes 0.000
##final model mse - validation/train MSE
# bag_rf$finalModel$mse
#test MSE
test_bag_mse <- mean((predict(bag_rf,newdata=test_cs)-test_cs$Sales)^2)
test_bag_mse
## [1] 2.30439
A) Bagging approach reduces the MSE by almost 50% compared to the regression trees.The most important predictors standout to be ShelveLoc and Price, which are same as the regression trees.
$Bagging test MSE=$2.3043904
(e) Use random forests to analyze this data. What test MSE do you obtain? Use the importance() function to determine which variables are most important. Describe the effect of m, the number of variables considered at each split, on the error rate obtained.
cs_rf=train(Sales~.,data=train,method='rf',trControl = trainControl("cv", number = 10),importance = TRUE)
#test rf MSE
cs_rf$finalModel
##
## Call:
## randomForest(x = x, y = y, mtry = min(param$mtry, ncol(x)), importance = TRUE)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 11
##
## Mean of squared residuals: 2.7107
## % Var explained: 65.76
cs_rf$bestTune
## mtry
## 3 11
varImp(cs_rf)
## rf variable importance
##
## Overall
## Price 100.000
## ShelveLocGood 99.188
## CompPrice 43.288
## ShelveLocMedium 33.365
## Advertising 32.767
## Age 27.657
## Income 12.885
## USYes 12.365
## Education 6.114
## Population 2.097
## UrbanYes 0.000
test_rf_mse_1 <- mean((predict(cs_rf,newdata=test_cs)-test_cs$Sales)^2)
test_rf_mse_1
## [1] 2.246459
m=p/2 = (5)
##use Random Forest mtry = p/2,sqrt(p),p/4
mtry <- ncol(train_cs)-1
tunegrid <- expand.grid(.mtry=mtry/2)
cs_rf=train(Sales~.,data=train,method='rf',trControl = trainControl("cv", number = 10),tuneGrid=tunegrid,importance = TRUE)
# important predictors
varImp(cs_rf)
## rf variable importance
##
## Overall
## ShelveLocGood 100.000
## Price 91.033
## Advertising 35.868
## CompPrice 34.796
## ShelveLocMedium 33.710
## Age 32.231
## USYes 15.670
## Income 15.505
## Education 6.131
## Population 3.373
## UrbanYes 0.000
#test rf MSE
test_rf_mse <- mean((predict(cs_rf,newdata=test_cs)-test_cs$Sales)^2)
test_rf_mse
## [1] 2.591315
m= p/3 ~ (3.33)
tunegrid <- expand.grid(.mtry=mtry/3)
cs_rf=train(Sales~.,data=train,method='rf',trControl = trainControl("cv", number = 10),tuneGrid=tunegrid,importance = TRUE)
#test rf MSE
test_rf_mse <- mean((predict(cs_rf,newdata=test_cs)-test_cs$Sales)^2)
test_rf_mse
## [1] 3.080292
m= p/4 ~ (2.5)
tunegrid <- expand.grid(.mtry=mtry/4)
cs_rf=train(Sales~.,data=train,method='rf',trControl = trainControl("cv", number = 10),tuneGrid=tunegrid,importance = TRUE)
#test rf MSE
test_rf_mse <- mean((predict(cs_rf,newdata=test_cs)-test_cs$Sales)^2)
test_rf_mse
## [1] 3.669904
m= sqrt(p)~(3.16)
tunegrid <- expand.grid(.mtry=sqrt(mtry))
cs_rf=train(Sales~.,data=train,method='rf',trControl = trainControl("cv", number = 10),tuneGrid=tunegrid,importance = TRUE)
#test rf MSE
test_rf_mse <- mean((predict(cs_rf,newdata=test_cs)-test_cs$Sales)^2)
test_rf_mse
## [1] 3.067804
m= p-1 ~ (9)
tunegrid <- expand.grid(.mtry=mtry-1)
cs_rf=train(Sales~.,data=train,method='rf',trControl = trainControl("cv", number = 10),tuneGrid=tunegrid,importance = TRUE)
#test rf MSE
test_rf_mse <- mean((predict(cs_rf,newdata=test_cs)-test_cs$Sales)^2)
test_rf_mse
## [1] 2.296652
A)Number of variables tried in each split has an inverse relationship with test MSE. Random Forrest with cross validation results in using m as 11 and the test MSE is slightly better than bagging. \(Test\ MSE\ Random\ Forrest\)= 2.2464588
This problem involves the OJ data set which is part of the ISLR package. (a) Create a training set containing a random sample of 800 observations, and a test set containing the remaining observations.
df_oj <- OJ
set.seed(2)
train_index <- sample(1:nrow(OJ), 800)
train_oj <- df_oj[train_index, ]
test_oj <- df_oj[-train_index, ]
(b) Fit a tree to the training data, with Purchase as the response and the other variables as predictors. Use the summary() function to produce summary statistics about the tree, and describe the results obtained. What is the training error rate? How many terminal nodes does the tree have?
tree.OJ = tree(Purchase~.,data=train_oj)
summary(tree.OJ)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = train_oj)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff"
## Number of terminal nodes: 9
## Residual mean deviance: 0.7009 = 554.4 / 791
## Misclassification error rate: 0.1588 = 127 / 800
\(Training\ Misclassification\ Error\ Rate::15.88\%\) Predictors used for tree construction are "LoyalCH" and "PriceDiff". Terminals Nodes this tree has are 9. Residual mean deviance, a measure of error remaining after the tree construction is high at 70% on this training data set.
(c) Type in the name of the tree object in order to get a detailed text output. Pick one of the terminal nodes, and interpret the information displayed.
tree.OJ
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1068.00 CH ( 0.61250 0.38750 )
## 2) LoyalCH < 0.5036 359 422.80 MM ( 0.27577 0.72423 )
## 4) LoyalCH < 0.280875 172 127.60 MM ( 0.12209 0.87791 )
## 8) LoyalCH < 0.035047 56 10.03 MM ( 0.01786 0.98214 ) *
## 9) LoyalCH > 0.035047 116 106.60 MM ( 0.17241 0.82759 ) *
## 5) LoyalCH > 0.280875 187 254.10 MM ( 0.41711 0.58289 )
## 10) PriceDiff < 0.05 73 71.36 MM ( 0.19178 0.80822 ) *
## 11) PriceDiff > 0.05 114 156.30 CH ( 0.56140 0.43860 ) *
## 3) LoyalCH > 0.5036 441 311.80 CH ( 0.88662 0.11338 )
## 6) LoyalCH < 0.737888 168 191.10 CH ( 0.74405 0.25595 )
## 12) PriceDiff < 0.265 93 125.00 CH ( 0.60215 0.39785 )
## 24) PriceDiff < -0.35 12 10.81 MM ( 0.16667 0.83333 ) *
## 25) PriceDiff > -0.35 81 103.10 CH ( 0.66667 0.33333 ) *
## 13) PriceDiff > 0.265 75 41.82 CH ( 0.92000 0.08000 ) *
## 7) LoyalCH > 0.737888 273 65.11 CH ( 0.97436 0.02564 )
## 14) PriceDiff < -0.39 11 12.89 CH ( 0.72727 0.27273 ) *
## 15) PriceDiff > -0.39 262 41.40 CH ( 0.98473 0.01527 ) *
A) Modeled Classification tree has 9 terminal nodes(indicated by "") and choosing terminal 9 and below is the trail we will analyze: 1) root 800 1068.00 CH ( 0.61250 0.38750 )
2) LoyalCH < 0.5036 359 422.80 MM ( 0.27577 0.72423 )
4) LoyalCH < 0.280875 172 127.60 MM ( 0.12209 0.87791 )
8) LoyalCH < 0.035047 56 10.03 MM ( 0.01786 0.98214 ) 9) LoyalCH > 0.035047 116 106.60 MM ( 0.17241 0.82759 ) *
Root node 1)root 800 1068.00 CH ( 0.61250 0.38750 ), has 800 observations with deviance of 1068 and overall prediction is CH and the proportions are 61.25% for CH and 38.75% for MM.
Root node further splits into 2 more with splits at: with LoyalCH = 0.5036 and another split at LoyalCH = 0.280875. So, Node 9 is the subset of purchases with LoyalCH < 0.280875 and LoyalCH > 0.035047 and prediction is MM with 17.24% proportion for CH and 82.759% for MM out of 116 observations.
(d) Create a plot of the tree, and interpret the results.
plot(tree.OJ)
text(tree.OJ,pretty=0)
A)LoyalCH(Customer brand loyalty for Citrus Hill) is the most important variable followed by PriceDiff. Left branch are those less loyal to CH brand and right branch is for those more loyal to CH.Splits on the right branch are driven by PriceDiff.
(e) Predict the response on the test data, and produce a confusion matrix comparing the test labels to the predicted test labels. What is the test error rate?
pred.OJ = predict(tree.OJ, newdata = test_oj, type = "class")
table(pred.OJ,test_oj$Purchase)
##
## pred.OJ CH MM
## CH 148 37
## MM 15 70
\[ Test\ Misclassification\ Rate:: \frac{15+37}{270}= 19.26\% (higher\ than\ training\ error\ rate ) \]
(f) Apply the cv.tree() function to the training set in order to determine the optimal tree size.
set.seed(2)
cv_oj = cv.tree(tree.OJ,FUN=prune.misclass)
cv_oj
## $size
## [1] 9 7 4 2 1
##
## $dev
## [1] 144 144 141 154 310
##
## $k
## [1] -Inf 0.000000 2.666667 7.000000 161.000000
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
A) Tree with 4 terminal nodes is the one with lowest CV error.
(g) Produce a plot with tree size on the x-axis and cross-validated classification error rate on the y-axis.
data.frame(terminals = cv_oj$size,cverror=cv_oj$dev) %>%
ggplot(aes(x = terminals, y = cverror)) +
geom_line() +
geom_point(size = 2,aes(terminals))+
geom_point(size=4,aes(4,141),col="blue")+
scale_x_continuous(breaks = seq(1, 9, 1))+
labs(title="Cross Validation with Pruning on misclassification rate", y = "CV Error",
x = "Terminal Nodes")
(h) Which tree size corresponds to the lowest cross-validated classification error rate? A) Tree with 4 terminal nodes is the one with lowest CV error rate.
(i) Produce a pruned tree corresponding to the optimal tree size obtained using cross-validation. If cross-validation does not lead to selection of a pruned tree, then create a pruned tree with five terminal nodes.
pruned_tree_oj <- prune.tree(tree.OJ, best = 4)
pruned_tree_oj
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1068.00 CH ( 0.61250 0.38750 )
## 2) LoyalCH < 0.5036 359 422.80 MM ( 0.27577 0.72423 )
## 4) LoyalCH < 0.280875 172 127.60 MM ( 0.12209 0.87791 ) *
## 5) LoyalCH > 0.280875 187 254.10 MM ( 0.41711 0.58289 ) *
## 3) LoyalCH > 0.5036 441 311.80 CH ( 0.88662 0.11338 )
## 6) LoyalCH < 0.737888 168 191.10 CH ( 0.74405 0.25595 ) *
## 7) LoyalCH > 0.737888 273 65.11 CH ( 0.97436 0.02564 ) *
(j) Compare the training error rates between the pruned and unpruned trees. Which is higher?
summary(pruned_tree_oj)
##
## Classification tree:
## snip.tree(tree = tree.OJ, nodes = c(7L, 4L, 6L, 5L))
## Variables actually used in tree construction:
## [1] "LoyalCH"
## Number of terminal nodes: 4
## Residual mean deviance: 0.8014 = 637.9 / 796
## Misclassification error rate: 0.1862 = 149 / 800
\(Training\ Misclassification\ Error\ Rate\ pruned\ tree::18.62\%\) \(Training\ Misclassification\ Error\ Rate\ unpruned\ tree::15.88\%\)
We notice that the pruned tree train error rate is higher than the unpruned tree.
(k) Compare the test error rates between the pruned and unpruned trees. Which is higher?
pred_prune = predict(pruned_tree_oj, newdata = test_oj, type = "class")
table(pred_prune,test_oj$Purchase)
##
## pred_prune CH MM
## CH 129 31
## MM 34 76
\[ Test\ Misclassification\ Rate\ Unpruned\ tree:: \frac{15+37}{270}= 19.26\% \] \[ Test\ Misclassification\ Rate\ Pruned\ tree:: \frac{34+31}{270}= 24.07\% (higher\ than\ test\ error\ rate\ of\ unpruned\ tree ) \] Error is higher for the pruned tree, this may be because the sample used for training is small.