Hint: In a setting with two classes, pm1 = 1 ??? pm2. You could make this plot by hand, but it will be much easier to make in R.
pm <- seq(0, 1, 0.001)
# Classification error
class_error <- 1 - pmax(pm, 1-pm)
# Gini index
gini_index <- 2*pm*(1-pm)
# Cross entropy
cross_entropy <- -((pm*log(pm))+((1-pm)*log(1-pm)))
matplot(pm, cbind(class_error, gini_index, cross_entropy), col=c("black", "green", "orange"), ylab = "Split Criterion")
legend("topright",pch=19,col=c("black", "green", "orange"), legend=c("Classification_Error_Rate","Gini_Index","Cross-entropy"))
suppressMessages(library(ISLR))
suppressMessages(library(tree))
oj = OJ
names(oj) = tolower(names(oj))
set.seed(1000)
index = sample(1:nrow(oj), 800)
train = oj[index,]
test = oj[-index,]
purchase.test = oj$purchase[-index]
tree.oj = tree(purchase ~ ., train)
summary(tree.oj)
##
## Classification tree:
## tree(formula = purchase ~ ., data = train)
## Variables actually used in tree construction:
## [1] "loyalch" "pricediff" "weekofpurchase"
## Number of terminal nodes: 7
## Residual mean deviance: 0.7848 = 622.4 / 793
## Misclassification error rate: 0.175 = 140 / 800
tree.train.error = 17.5
tree.train.accuracy = 100-tree.train.error
tree.oj
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1069.000 CH ( 0.61125 0.38875 )
## 2) loyalch < 0.482389 297 319.600 MM ( 0.22896 0.77104 )
## 4) loyalch < 0.0356415 55 9.996 MM ( 0.01818 0.98182 ) *
## 5) loyalch > 0.0356415 242 285.500 MM ( 0.27686 0.72314 )
## 10) pricediff < 0.31 188 197.200 MM ( 0.21809 0.78191 )
## 20) weekofpurchase < 274.5 166 185.600 MM ( 0.24699 0.75301 ) *
## 21) weekofpurchase > 274.5 22 0.000 MM ( 0.00000 1.00000 ) *
## 11) pricediff > 0.31 54 74.790 MM ( 0.48148 0.51852 ) *
## 3) loyalch > 0.482389 503 447.300 CH ( 0.83698 0.16302 )
## 6) loyalch < 0.753545 235 284.500 CH ( 0.70638 0.29362 )
## 12) pricediff < 0.015 72 98.420 MM ( 0.43056 0.56944 ) *
## 13) pricediff > 0.015 163 149.500 CH ( 0.82822 0.17178 ) *
## 7) loyalch > 0.753545 268 104.000 CH ( 0.95149 0.04851 ) *
plot(tree.oj)
text(tree.oj, pretty=0)
pred = predict(tree.oj, test, type="class")
table(pred, purchase.test)
## purchase.test
## pred CH MM
## CH 123 12
## MM 41 94
tree.test.accuracy = round(mean(pred == purchase.test)*100,2)
tree.test.accuracy
## [1] 80.37
tree.test.error = round(mean(pred != purchase.test)*100,2)
tree.test.error
## [1] 19.63
set.seed(1000)
cv.oj = cv.tree(tree.oj, FUN = prune.misclass)
cv.oj
## $size
## [1] 7 4 2 1
##
## $dev
## [1] 157 157 158 311
##
## $k
## [1] -Inf 0 5 161
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(cv.oj$size, cv.oj$dev, type="b")
prune.oj = prune.misclass(tree.oj, best=2)
summary(prune.oj)
##
## Classification tree:
## snip.tree(tree = tree.oj, nodes = 2:3)
## Variables actually used in tree construction:
## [1] "loyalch"
## Number of terminal nodes: 2
## Residual mean deviance: 0.961 = 766.9 / 798
## Misclassification error rate: 0.1875 = 150 / 800
plot(prune.oj)
text(prune.oj, pretty=0)
prune.train.error = 18.75
prune.train.accuracy = 100-prune.train.error
prune.pred = predict(prune.oj, test, type="class")
table(prune.pred, purchase.test)
## purchase.test
## prune.pred CH MM
## CH 138 28
## MM 26 78
prune.test.accuracy = round(mean(prune.pred == purchase.test)*100,2)
prune.test.accuracy
## [1] 80
prune.test.error = round(mean(prune.pred != purchase.test)*100,2)
prune.test.error
## [1] 20
error.df = data.frame(Model_type = c("unPruned","Pruned w/ Term.nodes=2"), Training_Error_rate = c(tree.train.error,prune.train.error), Test_Error_rate = c(tree.test.error, prune.test.error))
error.df
## Model_type Training_Error_rate Test_Error_rate
## 1 unPruned 17.50 19.63
## 2 Pruned w/ Term.nodes=2 18.75 20.00