# Chapter 8 page 332: 3,8,9
# 3
p=seq(0,1,0.01)
gini.index=2*p*(1-p)
class.error=1-pmax(p,1,-p)
cross.entropy=-(p*log(p)+(1-p)*log(1-p))
matplot(p,cbind(gini.index,class.error,cross.entropy),col=c("pink","yellow","blue"))

# 8
library(MASS)
## Warning: package 'MASS' was built under R version 4.1.3
library(ISLR)
## Warning: package 'ISLR' was built under R version 4.1.3
set.seed(1)
train=sample(1:nrow(Carseats),nrow(Carseats)/2)
Carseats.train=Carseats[train,]
Carseats.test=Carseats[-train,]
library(tree)
## Warning: package 'tree' was built under R version 4.1.3
## Registered S3 method overwritten by 'tree':
## method from
## print.tree cli
tree.carseats=tree(Sales~.,data=Carseats.train)
summary(tree.carseats)
##
## Regression tree:
## tree(formula = Sales ~ ., data = Carseats.train)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Age" "Advertising" "CompPrice"
## [6] "US"
## Number of terminal nodes: 18
## Residual mean deviance: 2.167 = 394.3 / 182
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -3.88200 -0.88200 -0.08712 0.00000 0.89590 4.09900
plot(tree.carseats)
text(tree.carseats,pretty=0)

yhat=predict(tree.carseats,newdata=Carseats.test)
mean((yhat-Carseats.test$Sales)^2)
## [1] 4.922039
cv.carseats=cv.tree(tree.carseats)
plot(cv.carseats$size,cv.carseats$dev,type="b")
tree.min=which.min(cv.carseats$dev)
points(tree.min,cv.carseats$dev[tree.min],col="pink",cex=2,pch=20)

prune.carseats=prune.tree(tree.carseats,best=8)
plot(prune.carseats)
text(prune.carseats,pretty=0)

yhat=predict(prune.carseats,newdata=Carseats.test)
mean((yhat-Carseats.test$Sales)^2)
## [1] 5.113254
# 9
set.seed(1)
train=sample(1:nrow(OJ),800)
OJ.train=OJ[train, ]
OJ.test=OJ[-train, ]
tree.oj=tree(Purchase~.,data=OJ.train)
summary(tree.oj)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = OJ.train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "SpecialCH" "ListPriceDiff"
## [5] "PctDiscMM"
## Number of terminal nodes: 9
## Residual mean deviance: 0.7432 = 587.8 / 791
## Misclassification error rate: 0.1588 = 127 / 800
plot(tree.oj)
text(tree.oj,pretty=0)

tree.pred=predict(tree.oj,OJ.test,type="class")
table(tree.pred,OJ.test$Purchase)
##
## tree.pred CH MM
## CH 160 38
## MM 8 64
cv.oj=cv.tree(tree.oj,FUN=prune.misclass)
cv.oj
## $size
## [1] 9 8 7 4 2 1
##
## $dev
## [1] 150 150 149 158 172 315
##
## $k
## [1] -Inf 0.000000 3.000000 4.333333 10.500000 151.000000
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(cv.oj$size,cv.oj$dev,type="b",xlab="Tree size",ylab="Deviance")

prune.oj=prune.misclass(tree.oj,best=2)
plot(prune.oj)
text(prune.oj,pretty=0)

summary(tree.oj)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = OJ.train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "SpecialCH" "ListPriceDiff"
## [5] "PctDiscMM"
## Number of terminal nodes: 9
## Residual mean deviance: 0.7432 = 587.8 / 791
## Misclassification error rate: 0.1588 = 127 / 800
summary(prune.oj)
##
## Classification tree:
## snip.tree(tree = tree.oj, nodes = 3:2)
## Variables actually used in tree construction:
## [1] "LoyalCH"
## Number of terminal nodes: 2
## Residual mean deviance: 0.9768 = 779.5 / 798
## Misclassification error rate: 0.205 = 164 / 800
prune.pred=predict(prune.oj,OJ.test,type="class")
table(prune.pred,OJ.test$Purchase)
##
## prune.pred CH MM
## CH 142 24
## MM 26 78