###Trees Based Models ###Regression Trees
library(ggplot2)
library(MASS)
library(dplyr)
##
## Attaching package: 'dplyr'
## The following object is masked from 'package:MASS':
##
## select
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(tree)
library(ISLR)
library(tree)
set.seed(1)
train<-sample(1:nrow(Boston), nrow(Boston)/2)
tree.boston<-tree(medv~., Boston, subset=train)
summary(tree.boston)
##
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "rm" "lstat" "crim" "age"
## Number of terminal nodes: 7
## Residual mean deviance: 10.38 = 2555 / 246
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -10.1800 -1.7770 -0.1775 0.0000 1.9230 16.5800
#Plotting a regression tree # plot
plot(tree.boston)
text(tree.boston, pretty=0)
#Using cross-validation to select complexity
# cross-validation for complexity
cv.boston<-cv.tree(tree.boston)
plot(cv.boston$size, cv.boston$dev, type='b')
# favors the 7 node tree
#Pruning the tree
# pruning to 5 nodes
prune.boston<-prune.tree(tree.boston, best=5)
plot(prune.boston)
text(prune.boston, pretty=0)
#Making predictions
# make predictions
yhat<-predict(tree.boston, newdata=Boston[-train,])
boston.test<-Boston[-train, "medv"]
plot(yhat, boston.test)
abline(0,1)
# mse
mean((yhat-boston.test)^2)
## [1] 35.28688
#Classification Trees
attach(Carseats)
#will be using the carseat data from the book and making a new variable to seperate high sales.
# CLASSIFICATION TREE
# make new variable if Sales is greater than 8
High<-ifelse(Sales<=8, "No", "Yes")
Carseats<-data.frame(Carseats, High)
attach(Carseats)
## The following object is masked _by_ .GlobalEnv:
##
## High
## The following objects are masked from Carseats (pos = 3):
##
## Advertising, Age, CompPrice, Education, Income, Population, Price,
## Sales, ShelveLoc, Urban, US
#Fitting a classification tree
# classification tree
tree.carseats<-tree(High~.-Sales, Carseats)
summary(tree.carseats)
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
## [6] "Advertising" "Age" "US"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400
#Plotting a classification tree
# plot
plot(tree.carseats)
text(tree.carseats, pretty=0)
#Displaying branches with text
# print the full tree with text
tree.carseats
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 400 541.500 No ( 0.59000 0.41000 )
## 2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )
## 4) Price < 92.5 46 56.530 Yes ( 0.30435 0.69565 )
## 8) Income < 57 10 12.220 No ( 0.70000 0.30000 )
## 16) CompPrice < 110.5 5 0.000 No ( 1.00000 0.00000 ) *
## 17) CompPrice > 110.5 5 6.730 Yes ( 0.40000 0.60000 ) *
## 9) Income > 57 36 35.470 Yes ( 0.19444 0.80556 )
## 18) Population < 207.5 16 21.170 Yes ( 0.37500 0.62500 ) *
## 19) Population > 207.5 20 7.941 Yes ( 0.05000 0.95000 ) *
## 5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )
## 10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )
## 20) CompPrice < 124.5 96 44.890 No ( 0.93750 0.06250 )
## 40) Price < 106.5 38 33.150 No ( 0.84211 0.15789 )
## 80) Population < 177 12 16.300 No ( 0.58333 0.41667 )
## 160) Income < 60.5 6 0.000 No ( 1.00000 0.00000 ) *
## 161) Income > 60.5 6 5.407 Yes ( 0.16667 0.83333 ) *
## 81) Population > 177 26 8.477 No ( 0.96154 0.03846 ) *
## 41) Price > 106.5 58 0.000 No ( 1.00000 0.00000 ) *
## 21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )
## 42) Price < 122.5 51 70.680 Yes ( 0.49020 0.50980 )
## 84) ShelveLoc: Bad 11 6.702 No ( 0.90909 0.09091 ) *
## 85) ShelveLoc: Medium 40 52.930 Yes ( 0.37500 0.62500 )
## 170) Price < 109.5 16 7.481 Yes ( 0.06250 0.93750 ) *
## 171) Price > 109.5 24 32.600 No ( 0.58333 0.41667 )
## 342) Age < 49.5 13 16.050 Yes ( 0.30769 0.69231 ) *
## 343) Age > 49.5 11 6.702 No ( 0.90909 0.09091 ) *
## 43) Price > 122.5 77 55.540 No ( 0.88312 0.11688 )
## 86) CompPrice < 147.5 58 17.400 No ( 0.96552 0.03448 ) *
## 87) CompPrice > 147.5 19 25.010 No ( 0.63158 0.36842 )
## 174) Price < 147 12 16.300 Yes ( 0.41667 0.58333 )
## 348) CompPrice < 152.5 7 5.742 Yes ( 0.14286 0.85714 ) *
## 349) CompPrice > 152.5 5 5.004 No ( 0.80000 0.20000 ) *
## 175) Price > 147 7 0.000 No ( 1.00000 0.00000 ) *
## 11) Advertising > 13.5 45 61.830 Yes ( 0.44444 0.55556 )
## 22) Age < 54.5 25 25.020 Yes ( 0.20000 0.80000 )
## 44) CompPrice < 130.5 14 18.250 Yes ( 0.35714 0.64286 )
## 88) Income < 100 9 12.370 No ( 0.55556 0.44444 ) *
## 89) Income > 100 5 0.000 Yes ( 0.00000 1.00000 ) *
## 45) CompPrice > 130.5 11 0.000 Yes ( 0.00000 1.00000 ) *
## 23) Age > 54.5 20 22.490 No ( 0.75000 0.25000 )
## 46) CompPrice < 122.5 10 0.000 No ( 1.00000 0.00000 ) *
## 47) CompPrice > 122.5 10 13.860 No ( 0.50000 0.50000 )
## 94) Price < 125 5 0.000 Yes ( 0.00000 1.00000 ) *
## 95) Price > 125 5 0.000 No ( 1.00000 0.00000 ) *
## 3) ShelveLoc: Good 85 90.330 Yes ( 0.22353 0.77647 )
## 6) Price < 135 68 49.260 Yes ( 0.11765 0.88235 )
## 12) US: No 17 22.070 Yes ( 0.35294 0.64706 )
## 24) Price < 109 8 0.000 Yes ( 0.00000 1.00000 ) *
## 25) Price > 109 9 11.460 No ( 0.66667 0.33333 ) *
## 13) US: Yes 51 16.880 Yes ( 0.03922 0.96078 ) *
## 7) Price > 135 17 22.070 No ( 0.64706 0.35294 )
## 14) Income < 46 6 0.000 No ( 1.00000 0.00000 ) *
## 15) Income > 46 11 15.160 Yes ( 0.45455 0.54545 ) *
#Testing and training
# test and train
set.seed(2)
train<-sample(1:nrow(Carseats), 200)
carseats.test<-Carseats[-train]
dim(carseats.test)
## [1] 400 6
high.test<-High[-train]
tree.carseats<-tree(High~.-Sales, Carseats, subset=train)
tree.pred.all<-predict(tree.carseats, newdata=Carseats, type="class")
length(tree.pred.all)
## [1] 400
tree.pred<-tree.pred.all[-train]
# confusion matrix
cm<-table(tree.pred, high.test)
cm
## high.test
## tree.pred No Yes
## No 103 31
## Yes 14 52
# error
sum(diag(cm))/sum(cm)
## [1] 0.775
#Cross-validation for pruning
# prune
set.seed(3)
cv.carseats<-cv.tree(tree.carseats, FUN=prune.misclass)
names(cv.carseats)
## [1] "size" "dev" "k" "method"
# size is number of terminal nodes
# dev is the error
# k is the cost complexity parameter (alpha)
cv.carseats
## $size
## [1] 21 19 14 9 8 5 3 2 1
##
## $dev
## [1] 74 76 81 81 75 77 78 85 81
##
## $k
## [1] -Inf 0.0 1.0 1.4 2.0 3.0 4.0 9.0 18.0
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
# plot
par(mfrow=c(1,2))
plot(cv.carseats$size, cv.carseats$dev, type="b")
plot(cv.carseats$k, cv.carseats$dev, type="b")
# now prune based on desired complexity
par(mfrow=c(1,1))
prune.carseats<-prune.misclass(tree.carseats, best=9)
plot(prune.carseats)
text(prune.carseats, pretty=0)
# prediction and confusion matrix
tree.pred2<-predict(prune.carseats, newdata = Carseats, type="class")
test.pred<-tree.pred2[-train]
cm<-table(test.pred, high.test)
cm
## high.test
## test.pred No Yes
## No 97 25
## Yes 20 58
# error rate
sum(diag(cm))/sum(cm)
## [1] 0.775
## [1] 0.775
#What if we do less pruning?
# what if we do less pruning
prune.carseats<-prune.misclass(tree.carseats, best=15)
plot(prune.carseats)
text(prune.carseats, pretty=0)
# prediction and confusion matrix
tree.pred2<-predict(prune.carseats, newdata = Carseats, type="class")
test.pred<-tree.pred2[-train]
cm<-table(test.pred, high.test)
cm
## high.test
## test.pred No Yes
## No 102 30
## Yes 15 53
# error rate
sum(diag(cm))/sum(cm)
## [1] 0.775
detach(Carseats)
#Improvements on Trees: Bagging/Random Forests/Boosting
#Improvements on Trees
#Use the Boston data:
library(ISLR)
library(MASS)
data(Boston)
#The randomForest package can be used for bagging and boosting:
# package to be used for random forests and bagging
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
#Create the testing and training sets
set.seed(1)
train<-sample(1:dim(Boston)[1], floor(dim(Boston)[1]/2))
boston.test<-Boston[-train, "medv"]
#Fitting the BAG
# bagging --> m=p
# uses by default 500 trees
bag.boston<-randomForest(medv~., data=Boston, subset = train,
mtry=13, importance=TRUE)
#Test Errror
# test error
yhat.bag<-predict(bag.boston, newdata=Boston[-train,])
plot(yhat.bag, boston.test)
abline(0,1)
mean((yhat.bag-boston.test)^2)
## [1] 23.4579
#Decrease Number of Trees
# decrease number of trees
# ntree = 25
bag.boston2<-randomForest(medv~., data=Boston, subset = train,
mtry=13, ntree=25, importance=TRUE)
yhat.bag2<-predict(bag.boston2, newdata=Boston[-train,])
mean((yhat.bag2-boston.test)^2)
## [1] 22.79211
#Random Forest
# change number of predictors used in the random forest
# default is p/3
# mtry = 6
set.seed(1)
rf.boston<-randomForest(medv~., data=Boston, subset=train,
mtry=6, importance=TRUE)
yhat.rf<-predict(rf.boston, newdata=Boston[-train,])
mean((yhat.rf-boston.test)^2)
## [1] 19.62021
#Variable Importance
# variable importance
importance(rf.boston)
## %IncMSE IncNodePurity
## crim 16.697017 1076.08786
## zn 3.625784 88.35342
## indus 4.968621 609.53356
## chas 1.061432 52.21793
## nox 13.518179 709.87339
## rm 32.343305 7857.65451
## age 13.272498 612.21424
## dis 9.032477 714.94674
## rad 2.878434 95.80598
## tax 9.118801 364.92479
## ptratio 8.467062 823.93341
## black 7.579482 275.62272
## lstat 27.129817 6027.63740
varImpPlot(rf.boston)
#Boosting
#You will need a new package:
# BOOSTING
library(gbm)
## Loaded gbm 2.1.5
#Fitting a Boosted Tree
set.seed(1)
boost.boston<-gbm(medv~., data=Boston[train, ], distribution="gaussian",
n.trees=5000, interaction.depth = 4)
#Variable Importance
# relative influence plot
summary(boost.boston)
## var rel.inf
## rm rm 43.9919329
## lstat lstat 33.1216941
## crim crim 4.2604167
## dis dis 4.0111090
## nox nox 3.4353017
## black black 2.8267554
## age age 2.6113938
## ptratio ptratio 2.5403035
## tax tax 1.4565654
## indus indus 0.8008740
## rad rad 0.6546400
## zn zn 0.1446149
## chas chas 0.1443986
#Relative Marginal Plots
# relative dependence plots
# marginal effect of the selected var
par(mfrow=c(1,2))
plot(boost.boston, i="rm")
plot(boost.boston, i="lstat")
#Test Error with Difference Lambda
# test error of boost
# (lambda) shrinkage = 0.001
yhat.boost<-predict(boost.boston, newdata=Boston[-train,],
n.trees = 5000)
mean((yhat.boost-boston.test)^2)
## [1] 18.84709
# change the learning parameter
# (lambda) shrinkage = 0.1
boost.boston2<-gbm(medv~., data=Boston[train, ], distribution="gaussian",
n.trees=5000, interaction.depth = 4,
shrinkage=0.1, verbose = F)
yhat.boost2<-predict(boost.boston2, newdata=Boston[-train,],
n.trees = 5000)
mean((yhat.boost2-boston.test)^2)
## [1] 18.18255