pm1 = seq(0, 1, 0.01)
pm2 = 1 - pm1
class_error = 1 - pmax(pm1, pm2)
gini = pm1*(1 - pm1) + pm2*(1 - pm2)
entropy = -pm1*log(pm1) - pm2*log(pm2)
df_proportion = data.frame(pm1,pm2,class_error,gini,entropy)
ggplot(data = df_proportion) +
geom_line(aes(x = pm1, y = class_error, col = 'Classification Error')) +
geom_line(aes(x = pm1, y = gini, col = 'Gini Index')) +
geom_line(aes(x = pm1, y = entropy, col = 'Entropy')) +
labs(y = 'Function Value', col = 'Function') +
theme_minimal()
## Warning: Removed 2 rows containing missing values or values outside the scale range
## (`geom_line()`).
OJ = ISLR2::OJ
set.seed(1)
index = sample(1:nrow(OJ), 800)
OJ_train = OJ[index,]
OJ_test = OJ[-index,]
oj_tree = tree(Purchase ~., data = OJ_train)
summary(oj_tree)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = OJ_train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "SpecialCH" "ListPriceDiff"
## [5] "PctDiscMM"
## Number of terminal nodes: 9
## Residual mean deviance: 0.7432 = 587.8 / 791
## Misclassification error rate: 0.1588 = 127 / 800
There is a training error rate of 15.88%, have 9 terminal nodes, and of the 17 predictors in the model only 5 were used to construct the tree. These are:
“LoyalCH” “PriceDiff” “SpecialCH” “ListPriceDiff” “PctDiscMM”
oj_tree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1073.00 CH ( 0.60625 0.39375 )
## 2) LoyalCH < 0.5036 365 441.60 MM ( 0.29315 0.70685 )
## 4) LoyalCH < 0.280875 177 140.50 MM ( 0.13559 0.86441 )
## 8) LoyalCH < 0.0356415 59 10.14 MM ( 0.01695 0.98305 ) *
## 9) LoyalCH > 0.0356415 118 116.40 MM ( 0.19492 0.80508 ) *
## 5) LoyalCH > 0.280875 188 258.00 MM ( 0.44149 0.55851 )
## 10) PriceDiff < 0.05 79 84.79 MM ( 0.22785 0.77215 )
## 20) SpecialCH < 0.5 64 51.98 MM ( 0.14062 0.85938 ) *
## 21) SpecialCH > 0.5 15 20.19 CH ( 0.60000 0.40000 ) *
## 11) PriceDiff > 0.05 109 147.00 CH ( 0.59633 0.40367 ) *
## 3) LoyalCH > 0.5036 435 337.90 CH ( 0.86897 0.13103 )
## 6) LoyalCH < 0.764572 174 201.00 CH ( 0.73563 0.26437 )
## 12) ListPriceDiff < 0.235 72 99.81 MM ( 0.50000 0.50000 )
## 24) PctDiscMM < 0.196196 55 73.14 CH ( 0.61818 0.38182 ) *
## 25) PctDiscMM > 0.196196 17 12.32 MM ( 0.11765 0.88235 ) *
## 13) ListPriceDiff > 0.235 102 65.43 CH ( 0.90196 0.09804 ) *
## 7) LoyalCH > 0.764572 261 91.20 CH ( 0.95785 0.04215 ) *
Node 8 is a terminal node because of the : 8) LoyalCH < 0.0356415 59 10.14 MM ( 0.01695 0.98305 ) 59 observations fall into this node and there is a deviance of 10.14 in the node. When the LoyalCH < 0.0356, then the model predicts that the purchase is likely to be Minute Maid with 98.3% confidence and a predicted purchase of Citrus Hill at 1.695% confidence. This suggests this is a pure node.
plot(oj_tree)
text(oj_tree, pretty = 0, cex = 0.6)
People who score very low in LoyalCH are predicted to purchase Minute Maid. A split occured due to Price Differences < 0.05 where people are more likely to purchase Citrus Hill. Another split occurs where SpecialCH < 0.5 where it is 50 50 if people purchase Minute Maid or Cirtus Hill. On the other side of the tree, people who have a a high LoyalCH score are predicted to purchase Citrus Hill as expected. A split occurs based on List Price Difference < 0.235 so if the difference is not too much loyal customers will stay with Citrus Hill. It splits again on Percent Discount Minute Maid where loyal customers will eventually purchase Minute Maid if the discount is large enough.
preds = predict(oj_tree, OJ_test, type = 'class')
table(preds, test_actual = OJ_test$Purchase)
## test_actual
## preds CH MM
## CH 160 38
## MM 8 64
mean(preds == OJ_test$Purchase)
## [1] 0.8296296
1 - mean(preds == OJ_test$Purchase)
## [1] 0.1703704
We get an error rate of 17.03% or 82.96% accuracy. Of our 270 test observations, we misclassify 46 of our predictions with the majority of these predictions being predicted as Citrus Hill but actually being Minute Maid. This could possibly be due to an imbalance between Citrus Hill and Minute Maid in the data.
set.seed(1)
oj_cv = cv.tree(oj_tree, K= 10, FUN = prune.misclass)
oj_cv
## $size
## [1] 9 8 7 4 2 1
##
## $dev
## [1] 145 145 146 146 167 315
##
## $k
## [1] -Inf 0.000000 3.000000 4.333333 10.500000 151.000000
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(oj_cv$size, oj_cv$dev/ nrow(OJ_train), type = 'b',
xlab = 'Tree Size', ylab = 'Error Rate')
Tree sizes 8 and 9 give us the lowest error rates of .145 but I will be choosing tree size 8 for simplicity.
oj_pruned = prune.tree(oj_tree, best = 8)
oj_pruned
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1073.00 CH ( 0.60625 0.39375 )
## 2) LoyalCH < 0.5036 365 441.60 MM ( 0.29315 0.70685 )
## 4) LoyalCH < 0.280875 177 140.50 MM ( 0.13559 0.86441 )
## 8) LoyalCH < 0.0356415 59 10.14 MM ( 0.01695 0.98305 ) *
## 9) LoyalCH > 0.0356415 118 116.40 MM ( 0.19492 0.80508 ) *
## 5) LoyalCH > 0.280875 188 258.00 MM ( 0.44149 0.55851 )
## 10) PriceDiff < 0.05 79 84.79 MM ( 0.22785 0.77215 ) *
## 11) PriceDiff > 0.05 109 147.00 CH ( 0.59633 0.40367 ) *
## 3) LoyalCH > 0.5036 435 337.90 CH ( 0.86897 0.13103 )
## 6) LoyalCH < 0.764572 174 201.00 CH ( 0.73563 0.26437 )
## 12) ListPriceDiff < 0.235 72 99.81 MM ( 0.50000 0.50000 )
## 24) PctDiscMM < 0.196196 55 73.14 CH ( 0.61818 0.38182 ) *
## 25) PctDiscMM > 0.196196 17 12.32 MM ( 0.11765 0.88235 ) *
## 13) ListPriceDiff > 0.235 102 65.43 CH ( 0.90196 0.09804 ) *
## 7) LoyalCH > 0.764572 261 91.20 CH ( 0.95785 0.04215 ) *
Unpruned:
mean(predict(oj_tree, type = 'class') != OJ_train$Purchase)
## [1] 0.15875
Pruned:
mean(predict(oj_pruned, type = 'class') != OJ_train$Purchase)
## [1] 0.1625
The pruned tree has a higher training error rate of 0.1625 compared to the unpruned rate of .159. This could have happened because we have decreased the flexibility / variance of the training set by decreasing the number of splits.
Unpruned:
mean(predict(oj_tree, type = 'class', newdata = OJ_test) != OJ_test$Purchase)
## [1] 0.1703704
Pruned:
mean(predict(oj_pruned, type = 'class', newdata = OJ_test) != OJ_test$Purchase)
## [1] 0.162963
Both test error rates increased compared to their training set, but the test error rate for the pruned tree was lower than the unpruned tree. The difference is not that large but the pruned tree is lower but if we were to increase the number of samples in the data the difference could become more pronouced. I could also change my seeds and the data could have the unpruned tree being the better model.