Chapter 08 (page 332): 3, 8, 9
Consider the Gini index, classification error, and entropy in a simple classification setting with two classes. Create a single plot that displays each of these quantities as a function of ˆpm1. The x axis should display ˆpm1, ranging from 0 to 1, and the y-axis should display the value of the Gini index, classification error, and entropy. Hint: In a setting with two classes, ˆpm1 = 1− ˆpm2. You could make this plot by hand, but it will be much easier to make in R.
p = seq(0, 1, 0.01)
Gini = 2 * p * (1 - p)
class.err = 1 - pmax(p, 1 - p)
entropy = -(p * log(p) + (1 - p) * log(1 - p))
matplot(p, cbind(Gini, class.err, entropy), col = c("red", "green", "blue"))
In the lab, a classification tree was applied to the Carseats data set after converting Sales into a qualitative response variable. Now we will seek to predict Sales using regression trees and related approaches, treating the response as a quantitative variable.
library(ISLR)
## Warning: package 'ISLR' was built under R version 3.6.3
data("Carseats")
train = sample(dim(Carseats)[1], dim(Carseats)[1]/2)
traincar = Carseats[train, ]
testcar = Carseats[-train, ]
MSe: 1.766985
library(tree)
## Warning: package 'tree' was built under R version 3.6.3
car.tree = tree(Sales ~ ., data = traincar)
summary(car.tree)
##
## Regression tree:
## tree(formula = Sales ~ ., data = traincar)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Age" "Income" "Population"
## [6] "Advertising" "CompPrice"
## Number of terminal nodes: 16
## Residual mean deviance: 2.36 = 434.3 / 184
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -4.84200 -0.88490 0.03929 0.00000 1.11400 3.94300
plot(car.tree)
text(car.tree, pretty = 0)
car.pred <- predict(car.tree, traincar)
mean((car.pred - traincar$Sales)^2)
## [1] 2.171461
Yes, the MSE has improved to 5.381467
set.seed(1)
cv.car = cv.tree(car.tree)
par(mfrow = c(1, 2))
plot(cv.car$size, cv.car$dev, type = "b")
plot(cv.car$k, cv.car$dev, type = "b")
library(tree)
prune.car <- prune.tree(car.tree, best = 11)
plot(prune.car)
text(prune.car, pretty = 0)
prune.pred = predict(prune.car, testcar)
mean((testcar$Sales - prune.pred)^2)
## [1] 4.9178
The MSE is lower at 3.080276. Price and ShelveLoc are the most important variables.
library(randomForest)
## Warning: package 'randomForest' was built under R version 3.6.3
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
set.seed(1)
bag.car = randomForest(Sales ~ ., data = traincar, mtry = 10, importance = TRUE)
bag.car
##
## Call:
## randomForest(formula = Sales ~ ., data = traincar, mtry = 10, importance = TRUE)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 10
##
## Mean of squared residuals: 2.687721
## % Var explained: 68.81
bag.predict = predict(bag.car, newdata = testcar)
mean((bag.predict - testcar$Sales)^2)
## [1] 2.407913
importance(bag.car)
## %IncMSE IncNodePurity
## CompPrice 23.5369516 157.087753
## Income 6.7126892 77.661018
## Advertising 18.3777112 138.771448
## Population 1.1615591 53.900099
## Price 55.0070628 529.010544
## ShelveLoc 55.2464535 470.325180
## Age 19.7020453 181.323744
## Education 1.6094616 42.338034
## Urban 0.9951195 9.188756
## US 2.2685034 9.481156
The MSE is even lower at 2.971218. Price and ShelveLoc are the most important variables.
set.seed(1)
rf.car = randomForest(Sales ~ ., data = traincar, mtry =5, importance = TRUE)
bag.car
##
## Call:
## randomForest(formula = Sales ~ ., data = traincar, mtry = 10, importance = TRUE)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 10
##
## Mean of squared residuals: 2.687721
## % Var explained: 68.81
rf.predict = predict(rf.car, newdata = testcar)
mean((rf.predict - testcar$Sales)^2)
## [1] 2.54431
importance(rf.car)
## %IncMSE IncNodePurity
## CompPrice 15.9608098 145.37607
## Income 4.8632695 99.63118
## Advertising 16.7668132 146.49785
## Population 1.8207667 83.73252
## Price 40.6098614 454.00011
## ShelveLoc 47.3691355 420.61094
## Age 20.2536012 219.09658
## Education -0.0273898 57.01777
## Urban -1.4761741 12.22659
## US 2.9978563 19.89020
library(ISLR)
data(OJ)
set.seed(1)
train = sample(dim(OJ)[1], 800)
oj.train = OJ[train, ]
oj.test = OJ[-train, ]
The error rate .1588, and there are 9 terminal nodes
oj.tree = tree(Purchase ~ ., data = oj.train)
summary(oj.tree)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = oj.train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "SpecialCH" "ListPriceDiff"
## [5] "PctDiscMM"
## Number of terminal nodes: 9
## Residual mean deviance: 0.7432 = 587.8 / 791
## Misclassification error rate: 0.1588 = 127 / 800
Node 20 denoted by the * tells us that this is a terminal node. The split criterion isSpecialCH < 0.5
oj.tree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1073.00 CH ( 0.60625 0.39375 )
## 2) LoyalCH < 0.5036 365 441.60 MM ( 0.29315 0.70685 )
## 4) LoyalCH < 0.280875 177 140.50 MM ( 0.13559 0.86441 )
## 8) LoyalCH < 0.0356415 59 10.14 MM ( 0.01695 0.98305 ) *
## 9) LoyalCH > 0.0356415 118 116.40 MM ( 0.19492 0.80508 ) *
## 5) LoyalCH > 0.280875 188 258.00 MM ( 0.44149 0.55851 )
## 10) PriceDiff < 0.05 79 84.79 MM ( 0.22785 0.77215 )
## 20) SpecialCH < 0.5 64 51.98 MM ( 0.14062 0.85938 ) *
## 21) SpecialCH > 0.5 15 20.19 CH ( 0.60000 0.40000 ) *
## 11) PriceDiff > 0.05 109 147.00 CH ( 0.59633 0.40367 ) *
## 3) LoyalCH > 0.5036 435 337.90 CH ( 0.86897 0.13103 )
## 6) LoyalCH < 0.764572 174 201.00 CH ( 0.73563 0.26437 )
## 12) ListPriceDiff < 0.235 72 99.81 MM ( 0.50000 0.50000 )
## 24) PctDiscMM < 0.196197 55 73.14 CH ( 0.61818 0.38182 ) *
## 25) PctDiscMM > 0.196197 17 12.32 MM ( 0.11765 0.88235 ) *
## 13) ListPriceDiff > 0.235 102 65.43 CH ( 0.90196 0.09804 ) *
## 7) LoyalCH > 0.764572 261 91.20 CH ( 0.95785 0.04215 ) *
Loyal CH is the most important variable of the tree. If LoyalCH<0.28, the tree predicts MM. If LoyalCH>0.76, the tree predicts CH.
plot(oj.tree)
text(oj.tree, pretty = 0)
pred.oj = predict(oj.tree, oj.test, type = "class")
table(oj.test$Purchase, pred.oj)
## pred.oj
## CH MM
## CH 160 8
## MM 38 64
pred.unprune = predict(oj.tree, oj.test, type = "class")
misclass.unprune = sum(oj.test$Purchase !=pred.unprune)
misclass.unprune/length(pred.unprune)
## [1] 0.1703704
oj.cv = cv.tree(oj.tree, FUN = prune.tree)
oj.cv
## $size
## [1] 9 8 7 6 5 4 3 2 1
##
## $dev
## [1] 685.6493 698.8799 702.8083 702.8083 714.1093 725.4734 780.2099
## [8] 790.0301 1074.2062
##
## $k
## [1] -Inf 12.62207 13.94616 14.35384 26.21539 35.74964 43.07317
## [8] 45.67120 293.15784
##
## $method
## [1] "deviance"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(oj.cv$size, oj.cv$dev, type = "b", xlab = "Tree size", ylab = "Deviance")
Size of 6th gives lowest cross-validation error.
oj.prune = prune.tree(oj.tree, best = 6)
summary(oj.prune)
##
## Classification tree:
## snip.tree(tree = oj.tree, nodes = c(10L, 4L, 12L))
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff" "ListPriceDiff"
## Number of terminal nodes: 6
## Residual mean deviance: 0.7919 = 628.8 / 794
## Misclassification error rate: 0.1788 = 143 / 800
pruned is higher.
pred.prune = predict(oj.prune, oj.test, type = "class")
misclass.prune = sum(oj.test$Purchase != pred.prune)
misclass.prune/length(pred.prune)
## [1] 0.1851852