Hint: In a setting with two classes, p^m1 = 1 − p^m2. You could make this plot by hand, but it will be much easier to make in R.
pm1=seq(0,1,0.01)
gini.index= 2*pm1*(1-pm1)
c.error= 1-pmax(pm1,1-pm1)
crossentropy= -(pm1*log(pm1)+(1-pm1)*log(1-pm1))
plot(NA,NA,xlim=c(0,1),ylim=c(0,1),xlab='pm1',ylab='f')
lines(pm1,gini.index,type='l',ls = 2)
lines(pm1,c.error,col='blue',ls =2 )
lines(pm1,crossentropy,col='red',ls = 2)
legend(x='top',legend=c('gini.index','classification error','cross entropy'),
col=c('black','blue','red'),lty=1,text.width = 0.22)
data <- OJ
str(data)
## 'data.frame': 1070 obs. of 18 variables:
## $ Purchase : Factor w/ 2 levels "CH","MM": 1 1 1 2 1 1 1 1 1 1 ...
## $ WeekofPurchase: num 237 239 245 227 228 230 232 234 235 238 ...
## $ StoreID : num 1 1 1 1 7 7 7 7 7 7 ...
## $ PriceCH : num 1.75 1.75 1.86 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceMM : num 1.99 1.99 2.09 1.69 1.69 1.99 1.99 1.99 1.99 1.99 ...
## $ DiscCH : num 0 0 0.17 0 0 0 0 0 0 0 ...
## $ DiscMM : num 0 0.3 0 0 0 0 0.4 0.4 0.4 0.4 ...
## $ SpecialCH : num 0 0 0 0 0 0 1 1 0 0 ...
## $ SpecialMM : num 0 1 0 0 0 1 1 0 0 0 ...
## $ LoyalCH : num 0.5 0.6 0.68 0.4 0.957 ...
## $ SalePriceMM : num 1.99 1.69 2.09 1.69 1.69 1.99 1.59 1.59 1.59 1.59 ...
## $ SalePriceCH : num 1.75 1.75 1.69 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceDiff : num 0.24 -0.06 0.4 0 0 0.3 -0.1 -0.16 -0.16 -0.16 ...
## $ Store7 : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 2 2 2 2 2 ...
## $ PctDiscMM : num 0 0.151 0 0 0 ...
## $ PctDiscCH : num 0 0 0.0914 0 0 ...
## $ ListPriceDiff : num 0.24 0.24 0.23 0 0 0.3 0.3 0.24 0.24 0.24 ...
## $ STORE : num 1 1 1 1 0 0 0 0 0 0 ...
set.seed(123)
index = sample(nrow(data),800)
train.set = data[index,]
test.set = data[-index,]
##fit the tree
oj.fit=rpart(Purchase ~., data=train.set,
method= "class",
control=rpart.control(minsplit=15, cp=0.01))
summary(oj.fit)
## Call:
## rpart(formula = Purchase ~ ., data = train.set, method = "class",
## control = rpart.control(minsplit = 15, cp = 0.01))
## n= 800
##
## CP nsplit rel error xerror xstd
## 1 0.49201278 0 1.0000000 1.0000000 0.04410089
## 2 0.03514377 1 0.5079872 0.5335463 0.03672576
## 3 0.02555911 2 0.4728435 0.5335463 0.03672576
## 4 0.01277955 4 0.4217252 0.4504792 0.03443205
## 5 0.01000000 7 0.3833866 0.4728435 0.03508854
##
## Variable importance
## LoyalCH StoreID PriceDiff SalePriceMM WeekofPurchase
## 45 9 9 6 6
## PriceMM DiscMM PctDiscMM PriceCH ListPriceDiff
## 5 5 4 4 3
## SalePriceCH STORE SpecialCH
## 2 1 1
##
## Node number 1: 800 observations, complexity param=0.4920128
## predicted class=CH expected loss=0.39125 P(node) =1
## class counts: 487 313
## probabilities: 0.609 0.391
## left son=2 (450 obs) right son=3 (350 obs)
## Primary splits:
## LoyalCH < 0.5036 to the right, improve=134.49530, (0 missing)
## StoreID < 3.5 to the right, improve= 40.88655, (0 missing)
## STORE < 0.5 to the left, improve= 20.84871, (0 missing)
## Store7 splits as RL, improve= 20.84871, (0 missing)
## PriceDiff < 0.015 to the right, improve= 19.14298, (0 missing)
## Surrogate splits:
## StoreID < 3.5 to the right, agree=0.660, adj=0.223, (0 split)
## WeekofPurchase < 246.5 to the right, agree=0.625, adj=0.143, (0 split)
## PriceCH < 1.825 to the right, agree=0.600, adj=0.086, (0 split)
## PriceMM < 1.89 to the right, agree=0.596, adj=0.077, (0 split)
## ListPriceDiff < 0.035 to the right, agree=0.581, adj=0.043, (0 split)
##
## Node number 2: 450 observations, complexity param=0.03514377
## predicted class=CH expected loss=0.1355556 P(node) =0.5625
## class counts: 389 61
## probabilities: 0.864 0.136
## left son=4 (423 obs) right son=5 (27 obs)
## Primary splits:
## PriceDiff < -0.39 to the right, improve=18.543390, (0 missing)
## DiscMM < 0.72 to the left, improve= 9.309254, (0 missing)
## SalePriceMM < 1.435 to the right, improve= 9.309254, (0 missing)
## PctDiscMM < 0.3342595 to the left, improve= 9.309254, (0 missing)
## LoyalCH < 0.7645725 to the right, improve= 8.822549, (0 missing)
## Surrogate splits:
## DiscMM < 0.72 to the left, agree=0.967, adj=0.444, (0 split)
## SalePriceMM < 1.435 to the right, agree=0.967, adj=0.444, (0 split)
## PctDiscMM < 0.3342595 to the left, agree=0.967, adj=0.444, (0 split)
## SalePriceCH < 2.075 to the left, agree=0.949, adj=0.148, (0 split)
##
## Node number 3: 350 observations, complexity param=0.02555911
## predicted class=MM expected loss=0.28 P(node) =0.4375
## class counts: 98 252
## probabilities: 0.280 0.720
## left son=6 (180 obs) right son=7 (170 obs)
## Primary splits:
## LoyalCH < 0.2761415 to the right, improve=14.991900, (0 missing)
## StoreID < 3.5 to the right, improve= 6.562913, (0 missing)
## Store7 splits as RL, improve= 4.617311, (0 missing)
## STORE < 0.5 to the left, improve= 4.617311, (0 missing)
## SpecialCH < 0.5 to the right, improve= 4.512108, (0 missing)
## Surrogate splits:
## STORE < 1.5 to the left, agree=0.629, adj=0.235, (0 split)
## StoreID < 1.5 to the left, agree=0.589, adj=0.153, (0 split)
## PriceCH < 1.875 to the left, agree=0.589, adj=0.153, (0 split)
## SalePriceCH < 1.875 to the left, agree=0.586, adj=0.147, (0 split)
## SalePriceMM < 1.84 to the left, agree=0.571, adj=0.118, (0 split)
##
## Node number 4: 423 observations
## predicted class=CH expected loss=0.09929078 P(node) =0.52875
## class counts: 381 42
## probabilities: 0.901 0.099
##
## Node number 5: 27 observations
## predicted class=MM expected loss=0.2962963 P(node) =0.03375
## class counts: 8 19
## probabilities: 0.296 0.704
##
## Node number 6: 180 observations, complexity param=0.02555911
## predicted class=MM expected loss=0.4222222 P(node) =0.225
## class counts: 76 104
## probabilities: 0.422 0.578
## left son=12 (106 obs) right son=13 (74 obs)
## Primary splits:
## PriceDiff < 0.05 to the right, improve=12.110850, (0 missing)
## SalePriceMM < 2.04 to the right, improve=11.572070, (0 missing)
## DiscMM < 0.25 to the left, improve= 5.760121, (0 missing)
## PctDiscMM < 0.1345485 to the left, improve= 5.760121, (0 missing)
## ListPriceDiff < 0.18 to the right, improve= 5.597236, (0 missing)
## Surrogate splits:
## SalePriceMM < 1.94 to the right, agree=0.933, adj=0.838, (0 split)
## DiscMM < 0.08 to the left, agree=0.822, adj=0.568, (0 split)
## PctDiscMM < 0.038887 to the left, agree=0.822, adj=0.568, (0 split)
## ListPriceDiff < 0.135 to the right, agree=0.800, adj=0.514, (0 split)
## PriceMM < 2.04 to the right, agree=0.783, adj=0.473, (0 split)
##
## Node number 7: 170 observations
## predicted class=MM expected loss=0.1294118 P(node) =0.2125
## class counts: 22 148
## probabilities: 0.129 0.871
##
## Node number 12: 106 observations, complexity param=0.01277955
## predicted class=CH expected loss=0.4245283 P(node) =0.1325
## class counts: 61 45
## probabilities: 0.575 0.425
## left son=24 (8 obs) right son=25 (98 obs)
## Primary splits:
## LoyalCH < 0.3084325 to the left, improve=3.118983, (0 missing)
## WeekofPurchase < 247.5 to the right, improve=2.489639, (0 missing)
## SpecialMM < 0.5 to the left, improve=2.454538, (0 missing)
## PriceCH < 1.755 to the right, improve=2.048863, (0 missing)
## PriceMM < 2.04 to the right, improve=1.514675, (0 missing)
##
## Node number 13: 74 observations
## predicted class=MM expected loss=0.2027027 P(node) =0.0925
## class counts: 15 59
## probabilities: 0.203 0.797
##
## Node number 24: 8 observations
## predicted class=CH expected loss=0 P(node) =0.01
## class counts: 8 0
## probabilities: 1.000 0.000
##
## Node number 25: 98 observations, complexity param=0.01277955
## predicted class=CH expected loss=0.4591837 P(node) =0.1225
## class counts: 53 45
## probabilities: 0.541 0.459
## left son=50 (46 obs) right son=51 (52 obs)
## Primary splits:
## LoyalCH < 0.442144 to the right, improve=3.071463, (0 missing)
## WeekofPurchase < 248.5 to the right, improve=2.208454, (0 missing)
## SpecialMM < 0.5 to the left, improve=2.011796, (0 missing)
## STORE < 0.5 to the left, improve=1.624324, (0 missing)
## StoreID < 5.5 to the right, improve=1.624324, (0 missing)
## Surrogate splits:
## WeekofPurchase < 255 to the left, agree=0.622, adj=0.196, (0 split)
## SalePriceCH < 1.755 to the right, agree=0.571, adj=0.087, (0 split)
## STORE < 2.5 to the right, agree=0.571, adj=0.087, (0 split)
## PriceMM < 2.205 to the right, agree=0.561, adj=0.065, (0 split)
## DiscCH < 0.115 to the left, agree=0.561, adj=0.065, (0 split)
##
## Node number 50: 46 observations
## predicted class=CH expected loss=0.326087 P(node) =0.0575
## class counts: 31 15
## probabilities: 0.674 0.326
##
## Node number 51: 52 observations, complexity param=0.01277955
## predicted class=MM expected loss=0.4230769 P(node) =0.065
## class counts: 22 30
## probabilities: 0.423 0.577
## left son=102 (8 obs) right son=103 (44 obs)
## Primary splits:
## SpecialCH < 0.5 to the right, improve=2.020979, (0 missing)
## STORE < 1.5 to the left, improve=1.724009, (0 missing)
## SpecialMM < 0.5 to the left, improve=1.680070, (0 missing)
## WeekofPurchase < 245 to the right, improve=1.384615, (0 missing)
## StoreID < 5.5 to the right, improve=1.319751, (0 missing)
## Surrogate splits:
## DiscCH < 0.27 to the right, agree=0.942, adj=0.625, (0 split)
## SalePriceCH < 1.54 to the left, agree=0.942, adj=0.625, (0 split)
## PctDiscCH < 0.149059 to the right, agree=0.942, adj=0.625, (0 split)
## SalePriceMM < 1.64 to the left, agree=0.923, adj=0.500, (0 split)
## DiscMM < 0.42 to the right, agree=0.904, adj=0.375, (0 split)
##
## Node number 102: 8 observations
## predicted class=CH expected loss=0.25 P(node) =0.01
## class counts: 6 2
## probabilities: 0.750 0.250
##
## Node number 103: 44 observations
## predicted class=MM expected loss=0.3636364 P(node) =0.055
## class counts: 16 28
## probabilities: 0.364 0.636
The minimum training error rate was 0.3833866 that was noted as cp = 0.01 value. The major split has happened at LoyalCH variable so it is the most important variable in the data set. Total number of terminal nodes observed was 103.
oj.fit
## n= 800
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 313 CH (0.60875000 0.39125000)
## 2) LoyalCH>=0.5036 450 61 CH (0.86444444 0.13555556)
## 4) PriceDiff>=-0.39 423 42 CH (0.90070922 0.09929078) *
## 5) PriceDiff< -0.39 27 8 MM (0.29629630 0.70370370) *
## 3) LoyalCH< 0.5036 350 98 MM (0.28000000 0.72000000)
## 6) LoyalCH>=0.2761415 180 76 MM (0.42222222 0.57777778)
## 12) PriceDiff>=0.05 106 45 CH (0.57547170 0.42452830)
## 24) LoyalCH< 0.3084325 8 0 CH (1.00000000 0.00000000) *
## 25) LoyalCH>=0.3084325 98 45 CH (0.54081633 0.45918367)
## 50) LoyalCH>=0.442144 46 15 CH (0.67391304 0.32608696) *
## 51) LoyalCH< 0.442144 52 22 MM (0.42307692 0.57692308)
## 102) SpecialCH>=0.5 8 2 CH (0.75000000 0.25000000) *
## 103) SpecialCH< 0.5 44 16 MM (0.36363636 0.63636364) *
## 13) PriceDiff< 0.05 74 15 MM (0.20270270 0.79729730) *
## 7) LoyalCH< 0.2761415 170 22 MM (0.12941176 0.87058824) *
Branches that lead to terminal nodes are indicated using asterisks.The tree has 8 branches that are leading to the terminal nodes.The output is also showing on criterion the split is done. When we see the node 4 the split criterion it used was PriceDiff>=-0.39 and for node 2 LoyalCH>=0.5036. The numbers that are displayed after the split indicates the number of observations in that branch.
plotcp(oj.fit)
representation to the cross validated error summary. there is a group of CP values we can select the one which has the lowest cross-validated error and we can further use it for pruning. In this case the tree with 4 nodes could be the best at a cp value of 0.0127796
pred.data <-predict(oj.fit,test.set)
y.hat <-ifelse(pred.data[,1] >= 0.5,'CH','MM') # I considered 0.5 as the cut-off we can also find out the optimal cut-off and can do the confusion matrix accordingly.
confusionMatrix(as.factor(y.hat),as.factor(test.set$Purchase))
## Confusion Matrix and Statistics
##
## Reference
## Prediction CH MM
## CH 141 24
## MM 25 80
##
## Accuracy : 0.8185
## 95% CI : (0.7673, 0.8626)
## No Information Rate : 0.6148
## P-Value [Acc > NIR] : 3.407e-13
##
## Kappa : 0.6175
##
## Mcnemar's Test P-Value : 1
##
## Sensitivity : 0.8494
## Specificity : 0.7692
## Pos Pred Value : 0.8545
## Neg Pred Value : 0.7619
## Prevalence : 0.6148
## Detection Rate : 0.5222
## Detection Prevalence : 0.6111
## Balanced Accuracy : 0.8093
##
## 'Positive' Class : CH
##
#test error rate
# FN + FP/total number of observations in the test data
# FN falsely predicted negative
# FP falsely predicted positive
(test.error = (25+24)/ncol(test)) # 4.454545
## [1] 4.454545
library(tree)
## Warning: package 'tree' was built under R version 4.0.5
## Registered S3 method overwritten by 'tree':
## method from
## print.tree cli
set.seed(123)
cv_tree <- tree(Purchase ~., data=train.set)
set.seed(123)
(op <- cv.tree(cv_tree,,prune.tree))
## $size
## [1] 8 7 6 5 4 3 2 1
##
## $dev
## [1] 736.9128 711.0227 703.7038 727.7832 727.7832 753.5331 792.3054
## [8] 1072.5980
##
## $k
## [1] -Inf 12.03823 14.92474 25.76707 26.02613 38.91686 50.61655
## [8] 298.68751
##
## $method
## [1] "deviance"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(op)
the tree we can see that size 6 has the lowest deviance. Deviance means a measure of the error remaining in the tree after construction.So size 6 could be the optimal size. ##### (g) Produce a plot with tree size on the x-axis and cross-validated classification error rate on the y-axis
set.seed(123)
oj.fit2=tree(Purchase~.,data=train.set) # another way of building a tree before we used rpart.
OJ.tree.cv = cv.tree(oj.fit2,K = 10,FUN = prune.misclass)
plot(OJ.tree.cv)
Tree size 5 corresponds to the lowest croos-validated classification error rate.
library(kableExtra) #for kable
##
## Attaching package: 'kableExtra'
## The following object is masked from 'package:dplyr':
##
## group_rows
oj.fit2=prune.misclass(oj.fit2,best = 5)
OJ.pred.train=predict(oj.fit2,train.set,type = 'class')
table(train.set[,'Purchase'],OJ.pred.train)
## OJ.pred.train
## CH MM
## CH 442 45
## MM 87 226
table(train.set[,'Purchase'],OJ.pred.train)/nrow(train.set)
## OJ.pred.train
## CH MM
## CH 0.55250 0.05625
## MM 0.10875 0.28250
OJ.pred.test=predict(oj.fit2,test.set,type = 'class')
table(test.set[,'Purchase'],OJ.pred.test)
## OJ.pred.test
## CH MM
## CH 150 16
## MM 34 70
table(test.set[,'Purchase'],OJ.pred.test)/nrow(test.set)
## OJ.pred.test
## CH MM
## CH 0.55555556 0.05925926
## MM 0.12592593 0.25925926
plot(oj.fit2)
text(oj.fit2)
summary(oj.fit2)
##
## Classification tree:
## snip.tree(tree = oj.fit2, nodes = c(4L, 7L))
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff"
## Number of terminal nodes: 5
## Residual mean deviance: 0.826 = 656.6 / 795
## Misclassification error rate: 0.165 = 132 / 800
The unpruned tree got test error rate around 44% while the pruned tree with 2 leaf nodes achieves a misclassification test rate of 58% . This represents a 14% improvement over an unpruned tree for the test data. However, the training error for the unpruned tree is 0.3833866 i.e., 38% approx which is lower. This is quite evident that the decision trees overfit the data.