Recursive partitioning is a fundamental tool in data mining. It helps us explore the stucture of a set of data, while developing easy to visualize decision rules for predicting a categorical (classification tree).
Classification (as described by Brieman, Freidman, Olshen, and Stone) can be generated through the rpart package. Detailed information on rpart is available in An Introduction to Recursive Partitioning Using the RPART Routines.
library(ISLR)
package 㤼㸱ISLR㤼㸲 was built under R version 3.6.3
library(rpart)
library(caret)
Loading required package: lattice
Loading required package: ggplot2
Registered S3 method overwritten by 'dplyr':
method from
print.rowwise_df
Registered S3 method overwritten by 'data.table':
method from
print.data.table
We first use classification trees to analyze the Carseats
data set. In these data, Sales
is a continuous variable, and so we begin by recoding it as a binary variable. We use the ifelse()
function to create a variable, called High
, which takes on a value of Yes
if the Sales
variable exceeds 8, and takes on a value of No
otherwise.
High=as.factor(ifelse(Carseats$Sales <=8,"No","Yes "))
Finally, we use the data.frame()
function to merge High
with the rest of the Carseats
data. We also remove the Sales
variable from the original Carseats
to avoid perfect collinearity in our model. Any Sales
higher than 8 corresond to High
values of ‘Yes’.
Carseats=data.frame(Carseats[,-1],High)
To grow a tree, use rpart(formula, data=, method=,control=)
where formula
is in the format outcome ~ predictor1+predictor2+predictor3+ect., data=
specifies the data frame, method="class"
for a classification tree ("anova"
for a regression tree), and control=
allows for optional parameters for controlling tree growth.
##fit the tree
tree.carseats=rpart(High∼., data=Carseats, method="class", control=rpart.control(minsplit=15, cp=0.01))
For example, control=rpart.control(minsplit=15, cp=0.01)
requires that the minimum number of observations in a node be 15 before attempting a split and that a split must decrease the overall lack of fit by a factor of 0.01 (cost complexity factor) before being attempted.
The summary()
function lists the variables that are used as internal nodes in the tree and the the number of terminal nodes.
summary(tree.carseats)
Call:
rpart(formula = High ~ ., data = Carseats, method = "class",
control = rpart.control(minsplit = 15, cp = 0.01))
n= 400
CP nsplit rel error xerror xstd
1 0.28658537 0 1.0000000 1.0000000 0.05997967
2 0.10975610 1 0.7134146 0.7134146 0.05547692
3 0.04573171 2 0.6036585 0.6036585 0.05262923
4 0.03658537 4 0.5121951 0.5731707 0.05170717
5 0.02743902 5 0.4756098 0.5731707 0.05170717
6 0.02439024 8 0.3902439 0.5670732 0.05151537
7 0.01219512 9 0.3658537 0.5548780 0.05112415
8 0.01000000 12 0.3292683 0.5609756 0.05132104
Variable importance
Price ShelveLoc CompPrice Age Advertising Income Population Education
30 27 11 11 10 7 3 1
Node number 1: 400 observations, complexity param=0.2865854
predicted class=No expected loss=0.41 P(node) =1
class counts: 236 164
probabilities: 0.590 0.410
left son=2 (315 obs) right son=3 (85 obs)
Primary splits:
ShelveLoc splits as LRL, improve=28.991900, (0 missing)
Price < 92.5 to the right, improve=19.463880, (0 missing)
Advertising < 6.5 to the left, improve=17.277980, (0 missing)
Age < 61.5 to the right, improve= 9.264442, (0 missing)
Income < 60.5 to the left, improve= 7.249032, (0 missing)
Node number 2: 315 observations, complexity param=0.1097561
predicted class=No expected loss=0.3111111 P(node) =0.7875
class counts: 217 98
probabilities: 0.689 0.311
left son=4 (269 obs) right son=5 (46 obs)
Primary splits:
Price < 92.5 to the right, improve=15.930580, (0 missing)
Advertising < 7.5 to the left, improve=11.432570, (0 missing)
ShelveLoc splits as L-R, improve= 7.543912, (0 missing)
Age < 50.5 to the right, improve= 6.369905, (0 missing)
Income < 60.5 to the left, improve= 5.984509, (0 missing)
Surrogate splits:
CompPrice < 95.5 to the right, agree=0.873, adj=0.13, (0 split)
Node number 3: 85 observations, complexity param=0.03658537
predicted class=Yes expected loss=0.2235294 P(node) =0.2125
class counts: 19 66
probabilities: 0.224 0.776
left son=6 (12 obs) right son=7 (73 obs)
Primary splits:
Price < 142.5 to the right, improve=7.745608, (0 missing)
US splits as LR, improve=5.112440, (0 missing)
Income < 35 to the left, improve=4.529433, (0 missing)
Advertising < 6 to the left, improve=3.739996, (0 missing)
Education < 15.5 to the left, improve=2.565856, (0 missing)
Surrogate splits:
CompPrice < 154.5 to the right, agree=0.882, adj=0.167, (0 split)
Node number 4: 269 observations, complexity param=0.04573171
predicted class=No expected loss=0.2453532 P(node) =0.6725
class counts: 203 66
probabilities: 0.755 0.245
left son=8 (224 obs) right son=9 (45 obs)
Primary splits:
Advertising < 13.5 to the left, improve=10.400090, (0 missing)
Age < 49.5 to the right, improve= 8.083998, (0 missing)
ShelveLoc splits as L-R, improve= 7.023150, (0 missing)
CompPrice < 124.5 to the left, improve= 6.749986, (0 missing)
Price < 126.5 to the right, improve= 5.646063, (0 missing)
Node number 5: 46 observations, complexity param=0.02439024
predicted class=Yes expected loss=0.3043478 P(node) =0.115
class counts: 14 32
probabilities: 0.304 0.696
left son=10 (10 obs) right son=11 (36 obs)
Primary splits:
Income < 57 to the left, improve=4.000483, (0 missing)
ShelveLoc splits as L-R, improve=3.189762, (0 missing)
Advertising < 9.5 to the left, improve=1.388592, (0 missing)
Price < 80.5 to the right, improve=1.388592, (0 missing)
Age < 64.5 to the right, improve=1.172885, (0 missing)
Node number 6: 12 observations
predicted class=No expected loss=0.25 P(node) =0.03
class counts: 9 3
probabilities: 0.750 0.250
Node number 7: 73 observations
predicted class=Yes expected loss=0.1369863 P(node) =0.1825
class counts: 10 63
probabilities: 0.137 0.863
Node number 8: 224 observations, complexity param=0.02743902
predicted class=No expected loss=0.1830357 P(node) =0.56
class counts: 183 41
probabilities: 0.817 0.183
left son=16 (96 obs) right son=17 (128 obs)
Primary splits:
CompPrice < 124.5 to the left, improve=4.881696, (0 missing)
Age < 49.5 to the right, improve=3.960418, (0 missing)
ShelveLoc splits as L-R, improve=3.654633, (0 missing)
Price < 126.5 to the right, improve=3.234428, (0 missing)
Advertising < 6.5 to the left, improve=2.371276, (0 missing)
Surrogate splits:
Price < 115.5 to the left, agree=0.741, adj=0.396, (0 split)
Age < 50.5 to the right, agree=0.634, adj=0.146, (0 split)
Population < 405 to the right, agree=0.629, adj=0.135, (0 split)
Education < 11.5 to the left, agree=0.585, adj=0.031, (0 split)
Income < 22.5 to the left, agree=0.580, adj=0.021, (0 split)
Node number 9: 45 observations, complexity param=0.04573171
predicted class=Yes expected loss=0.4444444 P(node) =0.1125
class counts: 20 25
probabilities: 0.444 0.556
left son=18 (20 obs) right son=19 (25 obs)
Primary splits:
Age < 54.5 to the right, improve=6.722222, (0 missing)
CompPrice < 121.5 to the left, improve=4.629630, (0 missing)
ShelveLoc splits as L-R, improve=3.250794, (0 missing)
Income < 99.5 to the left, improve=3.050794, (0 missing)
Price < 127 to the right, improve=2.933429, (0 missing)
Surrogate splits:
Population < 363.5 to the left, agree=0.667, adj=0.25, (0 split)
Income < 39 to the left, agree=0.644, adj=0.20, (0 split)
Advertising < 17.5 to the left, agree=0.644, adj=0.20, (0 split)
CompPrice < 106.5 to the left, agree=0.622, adj=0.15, (0 split)
Price < 135.5 to the right, agree=0.622, adj=0.15, (0 split)
Node number 10: 10 observations
predicted class=No expected loss=0.3 P(node) =0.025
class counts: 7 3
probabilities: 0.700 0.300
Node number 11: 36 observations
predicted class=Yes expected loss=0.1944444 P(node) =0.09
class counts: 7 29
probabilities: 0.194 0.806
Node number 16: 96 observations
predicted class=No expected loss=0.0625 P(node) =0.24
class counts: 90 6
probabilities: 0.938 0.062
Node number 17: 128 observations, complexity param=0.02743902
predicted class=No expected loss=0.2734375 P(node) =0.32
class counts: 93 35
probabilities: 0.727 0.273
left son=34 (107 obs) right son=35 (21 obs)
Primary splits:
Price < 109.5 to the right, improve=9.764582, (0 missing)
ShelveLoc splits as L-R, improve=6.320022, (0 missing)
Age < 49.5 to the right, improve=2.575061, (0 missing)
Income < 108.5 to the right, improve=1.799546, (0 missing)
CompPrice < 143.5 to the left, improve=1.741982, (0 missing)
Node number 18: 20 observations
predicted class=No expected loss=0.25 P(node) =0.05
class counts: 15 5
probabilities: 0.750 0.250
Node number 19: 25 observations
predicted class=Yes expected loss=0.2 P(node) =0.0625
class counts: 5 20
probabilities: 0.200 0.800
Node number 34: 107 observations, complexity param=0.01219512
predicted class=No expected loss=0.1869159 P(node) =0.2675
class counts: 87 20
probabilities: 0.813 0.187
left son=68 (65 obs) right son=69 (42 obs)
Primary splits:
Price < 126.5 to the right, improve=2.9643900, (0 missing)
CompPrice < 147.5 to the left, improve=2.2337090, (0 missing)
ShelveLoc splits as L-R, improve=2.2125310, (0 missing)
Age < 49.5 to the right, improve=2.1458210, (0 missing)
Income < 60.5 to the left, improve=0.8025853, (0 missing)
Surrogate splits:
CompPrice < 129.5 to the right, agree=0.664, adj=0.143, (0 split)
Advertising < 3.5 to the right, agree=0.664, adj=0.143, (0 split)
Population < 53.5 to the right, agree=0.645, adj=0.095, (0 split)
Age < 77.5 to the left, agree=0.636, adj=0.071, (0 split)
US splits as RL, agree=0.626, adj=0.048, (0 split)
Node number 35: 21 observations, complexity param=0.02743902
predicted class=Yes expected loss=0.2857143 P(node) =0.0525
class counts: 6 15
probabilities: 0.286 0.714
left son=70 (5 obs) right son=71 (16 obs)
Primary splits:
ShelveLoc splits as L-R, improve=6.6964290, (0 missing)
CompPrice < 129.5 to the left, improve=2.4380950, (0 missing)
Income < 50.5 to the right, improve=1.7142860, (0 missing)
Advertising < 9 to the left, improve=1.7142860, (0 missing)
US splits as LR, improve=0.7936508, (0 missing)
Surrogate splits:
Income < 109 to the right, agree=0.857, adj=0.4, (0 split)
CompPrice < 126.5 to the left, agree=0.810, adj=0.2, (0 split)
Population < 395.5 to the right, agree=0.810, adj=0.2, (0 split)
Node number 68: 65 observations
predicted class=No expected loss=0.09230769 P(node) =0.1625
class counts: 59 6
probabilities: 0.908 0.092
Node number 69: 42 observations, complexity param=0.01219512
predicted class=No expected loss=0.3333333 P(node) =0.105
class counts: 28 14
probabilities: 0.667 0.333
left son=138 (22 obs) right son=139 (20 obs)
Primary splits:
Age < 49.5 to the right, improve=5.4303030, (0 missing)
CompPrice < 137.5 to the left, improve=2.1000000, (0 missing)
Advertising < 5.5 to the left, improve=1.8666670, (0 missing)
ShelveLoc splits as L-R, improve=1.4291670, (0 missing)
Population < 382 to the right, improve=0.8578431, (0 missing)
Surrogate splits:
Income < 46.5 to the left, agree=0.595, adj=0.15, (0 split)
Education < 12.5 to the left, agree=0.595, adj=0.15, (0 split)
CompPrice < 131.5 to the right, agree=0.571, adj=0.10, (0 split)
Advertising < 5.5 to the left, agree=0.571, adj=0.10, (0 split)
Population < 221.5 to the left, agree=0.571, adj=0.10, (0 split)
Node number 70: 5 observations
predicted class=No expected loss=0 P(node) =0.0125
class counts: 5 0
probabilities: 1.000 0.000
Node number 71: 16 observations
predicted class=Yes expected loss=0.0625 P(node) =0.04
class counts: 1 15
probabilities: 0.062 0.938
Node number 138: 22 observations
predicted class=No expected loss=0.09090909 P(node) =0.055
class counts: 20 2
probabilities: 0.909 0.091
Node number 139: 20 observations, complexity param=0.01219512
predicted class=Yes expected loss=0.4 P(node) =0.05
class counts: 8 12
probabilities: 0.400 0.600
left son=278 (14 obs) right son=279 (6 obs)
Primary splits:
CompPrice < 137 to the left, improve=2.7428570, (0 missing)
Population < 315 to the right, improve=1.2190480, (0 missing)
Advertising < 5 to the left, improve=0.9333333, (0 missing)
Age < 33.5 to the right, improve=0.9333333, (0 missing)
Urban splits as LR, improve=0.6329670, (0 missing)
Surrogate splits:
Advertising < 9 to the left, agree=0.80, adj=0.333, (0 split)
Age < 26.5 to the right, agree=0.80, adj=0.333, (0 split)
Education < 16.5 to the left, agree=0.75, adj=0.167, (0 split)
Node number 278: 14 observations
predicted class=No expected loss=0.4285714 P(node) =0.035
class counts: 8 6
probabilities: 0.571 0.429
Node number 279: 6 observations
predicted class=Yes expected loss=0 P(node) =0.015
class counts: 0 6
probabilities: 0.000 1.000
One of the most attractive properties of trees is that they can be graphically displayed. We use the fancyRpartPlot
function in the rattle
library to display the tree structure, fancyRpartPlot
plots a fancy RPart decision tree using the pretty rpart plotter. You can read more about fancyRpartPlot
here.
library(rattle)
package 㤼㸱rattle㤼㸲 was built under R version 3.6.2Rattle: A free graphical interface for data science with R.
Version 5.3.0 Copyright (c) 2006-2018 Togaware Pty Ltd.
Type 'rattle()' to shake, rattle, and roll your data.
fancyRpartPlot(tree.carseats)
The most important indicator of Sales
appears to be shelving location, since the first branch differentiates Good
locations from Bad
and Medium
locations.
If we just type the name of the tree object, R prints output corresponding to each branch of the tree. R displays the split criterion (e.g. Price>=92.5), the number of observations in that branch, the deviance, the overall prediction for the branch (Yes or No), and the fraction of observations in that branch that take on values of Yes and No. Branches that lead to terminal nodes are indicated using asterisks.
tree.carseats
n= 400
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 400 164 No (0.59000000 0.41000000)
2) ShelveLoc=Bad,Medium 315 98 No (0.68888889 0.31111111)
4) Price>=92.5 269 66 No (0.75464684 0.24535316)
8) Advertising< 13.5 224 41 No (0.81696429 0.18303571)
16) CompPrice< 124.5 96 6 No (0.93750000 0.06250000) *
17) CompPrice>=124.5 128 35 No (0.72656250 0.27343750)
34) Price>=109.5 107 20 No (0.81308411 0.18691589)
68) Price>=126.5 65 6 No (0.90769231 0.09230769) *
69) Price< 126.5 42 14 No (0.66666667 0.33333333)
138) Age>=49.5 22 2 No (0.90909091 0.09090909) *
139) Age< 49.5 20 8 Yes (0.40000000 0.60000000)
278) CompPrice< 137 14 6 No (0.57142857 0.42857143) *
279) CompPrice>=137 6 0 Yes (0.00000000 1.00000000) *
35) Price< 109.5 21 6 Yes (0.28571429 0.71428571)
70) ShelveLoc=Bad 5 0 No (1.00000000 0.00000000) *
71) ShelveLoc=Medium 16 1 Yes (0.06250000 0.93750000) *
9) Advertising>=13.5 45 20 Yes (0.44444444 0.55555556)
18) Age>=54.5 20 5 No (0.75000000 0.25000000) *
19) Age< 54.5 25 5 Yes (0.20000000 0.80000000) *
5) Price< 92.5 46 14 Yes (0.30434783 0.69565217)
10) Income< 57 10 3 No (0.70000000 0.30000000) *
11) Income>=57 36 7 Yes (0.19444444 0.80555556) *
3) ShelveLoc=Good 85 19 Yes (0.22352941 0.77647059)
6) Price>=142.5 12 3 No (0.75000000 0.25000000) *
7) Price< 142.5 73 10 Yes (0.13698630 0.86301370) *
In order to properly evaluate the performance of a classification tree on these data, we must estimate the test error. Thankfully the rpart
package gives us some tools to help. We can’t split the observations into a training set and a test set, since we have so few observations but we can use some built in functions within rpart
to examine the cross-validation error.
The rpart
package’s plotcp
function plots the Complexity Parameter Table for an rpart
tree fit on the training dataset. You don’t need to supply any additional validation datasets when using the plotcp
function.
To validate the model we use the printcp
and plotcp
functions. CP
stands for Complexity Parameter of the tree. This function provides the optimal prunings based on the cp
value.
We prune the tree to avoid any overfitting of the data. The convention is to have a small tree and the one with least cross validated error given by printcp()
function i.e. ‘xerror’.
printcp(tree.carseats)
Classification tree:
rpart(formula = High ~ ., data = Carseats, method = "class",
control = rpart.control(minsplit = 15, cp = 0.01))
Variables actually used in tree construction:
[1] Advertising Age CompPrice Income Price ShelveLoc
Root node error: 164/400 = 0.41
n= 400
CP nsplit rel error xerror xstd
1 0.286585 0 1.00000 1.00000 0.059980
2 0.109756 1 0.71341 0.71341 0.055477
3 0.045732 2 0.60366 0.60366 0.052629
4 0.036585 4 0.51220 0.57317 0.051707
5 0.027439 5 0.47561 0.57317 0.051707
6 0.024390 8 0.39024 0.56707 0.051515
7 0.012195 9 0.36585 0.55488 0.051124
8 0.010000 12 0.32927 0.56098 0.051321
plotcp(tree.carseats)
Plotcp()
provides a graphical representation to the cross validated error summary. The cp
values are plotted against the geometric mean to depict the deviation until the minimum value is reached.
From the above mentioned list of cp values, we can select the one having the least cross-validated error and use it to prune the tree. In this case, i think the tree with 7 nodes is best at a cp value of 0.0121951 since it’s cross-validated error is 0.65854.
To select this, you can make use of this function that returns the optimal cp value associated with the minimum error.
tree.carseats$cptable[which.min(tree.carseats$cptable[,"xerror"]),"CP"]
[1] 0.01219512
Next, we consider whether pruning the tree might lead to improved results. The function prune.rpart
determines a nested sequence of subtrees of the supplied rpart
object by recursively snipping off the least important splits, based on the complexity parameter (cp
).
carseats.prune=prune(tree.carseats,cp=tree.carseats$cptable[which.min(tree.carseats$cptable[,"xerror"]),"CP"])
fancyRpartPlot(carseats.prune, uniform=TRUE, main="Pruned Classification Tree")
###Fitting Regression Trees Here we fit a regression tree to the Boston data set in caret. Again, we only have 506 so it’ll be important to estimate the test error with cross validation.
library(MASS)
set.seed(1)
# define training control
train_control <- trainControl(method="repeatedcv", number=10, repeats=3)
##fit the model
tree.boston=train(medv∼.,data=Boston, trControl=train_control,method='rpart')
There were missing values in resampled performance measures.
tree.boston
CART
506 samples
13 predictor
No pre-processing
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 458, 455, 455, 455, 456, 455, ...
Resampling results across tuning parameters:
cp RMSE Rsquared MAE
0.07165784 5.759065 0.6064638 4.141770
0.17117244 6.638677 0.4720946 4.905092
0.45274420 8.278053 0.3253506 6.029187
RMSE was used to select the optimal model using the smallest value.
The final value used for the model was cp = 0.07165784.
tree.boston$finalModel
n= 506
node), split, n, deviance, yval
* denotes terminal node
1) root 506 42716.300 22.53281
2) rm< 6.941 430 17317.320 19.93372
4) lstat>=14.4 175 3373.251 14.95600 *
5) lstat< 14.4 255 6632.217 23.34980 *
3) rm>=6.941 76 6059.419 37.23816 *
Notice that the output indicates that the final value used for the model was cp = 0.07165784. The caret
package implements the rpart
method with cp
as the tuning parameter. caret
by default will prune your tree based on a default run it makes on a default parameter grid (even if you don’t supply any tuneGrid
and trControl
while training your model.
Also notice that the output of the finalModel
object indicates that only two of the variables have been used in constructing the tree. In the context of a regression tree, the deviance is simply the sum of squared errors for the tree. We now plot the tree. We now plot the tree. We need to use the rpart.plot
function in the rpart.plot
library since the output of caret
’s train
function doesn’t work with fancyRpartPlot()
.
library(rpart.plot)
package 㤼㸱rpart.plot㤼㸲 was built under R version 3.6.2
rpart.plot(tree.boston$finalModel)
The variable lstat measures the percentage of individuals with lower socioeconomic status. The tree indicates that lower values of lstat correspond to more expensive houses. The tree predicts a median house price of $37,200 for larger homes in suburbs in which residents have high socioeconomic status (rm>=6.941 and lstat<14.4).
Here we apply bagging and random forests to the Boston data, using caret in R. The exact results obtained in this section may depend on the version of R and the version of the randomForest package installed on your computer. We’ll use the caret
workflow, which invokes the randomforest()
function from the randomForest package, to automatically select the optimal number (mtry) of predictor variables randomly sampled as candidates at each split, and fit the final best random forest model that explains the best our data.
Here, even though I don’t want to, i’ll split into training and test sets, just cause I need this to run in a reasonable amount of time. In practice these models take a while to
##Split into training and test
inTrain=createDataPartition(Boston$medv,p=0.5,list=FALSE)
train=Boston[inTrain,]
##fit random forest
boston.rf=train(medv~.,data=train,method='rf',trControl = trainControl("cv", number = 10),importance = TRUE)
##best tuning parameter
boston.rf$bestTune
##final model
boston.rf$finalModel
Call:
randomForest(x = x, y = y, mtry = param$mtry, importance = TRUE)
Type of random forest: regression
Number of trees: 500
No. of variables tried at each split: 7
Mean of squared residuals: 10.02197
% Var explained: 86.02
By default, 500 trees are trained. The optimal number of variables sampled at each split is 7.
Recall that bagging is simply a special case of a random forest with m = p. Therefore, this function can be used to perform both random forests and bagging.
Growing a bagged random forest proceeds in exactly the same way, except that we update the mtry argument to be 14 to tell R to use every variable in the data set. By default, caret uses \(p/3\) variables when building a random forest of regression trees, and \(\sqrt(p)\) variables when building a random forest of classification trees.
Using the importance=TRUE
argument function allows us to view the importance of each variable.
varImp(boston.rf)
rf variable importance
The measures of variable importance is based upon the mean decrease of RSS on the out of bag samples when a given variable is excluded from the model. Plots of these importance measures can be produced using the plot()
function.
plot(varImp(boston.rf))
The results indicate that across all of the trees considered in the random forest, the wealth level of the community (lstat
) and the house size (rm
) are by far the two most important variables.
Here we use the gbm
method in the train
function, to fit boosted regression trees to the Boston data set. We run the gbm
with the option distribution="gaussian"
since this is a regression problem; if it were a binary classification problem, we would use distribution="bernoulli"
. The argument n.trees=5000
indicates that we want 5000 trees, and the option interaction.depth=4
limits the depth of each tree.
# Using caret with the default grid to optimize tune parameters automatically
# GBM Tuning parameters:
# n.trees (# Boosting Iterations)
# interaction.depth (Max Tree Depth)
# shrinkage (Shrinkage)
# n.minobsinnode (Min. Terminal Node Size)
metric <- "RMSE"
trainControl <- trainControl(method="cv", number=10)
gbm.boston<-train(medv~.,data=train,distribution='gaussian',method='gbm',trControl=trainControl,verbose=FALSE, metric=metric,bag.fraction=0.75)
print(gbm.boston)
Stochastic Gradient Boosting
254 samples
13 predictor
No pre-processing
Resampling: Cross-Validated (10 fold)
Summary of sample sizes: 227, 229, 230, 229, 227, 229, ...
Resampling results across tuning parameters:
interaction.depth n.trees RMSE Rsquared MAE
1 50 3.590602 0.8439661 2.650186
1 100 3.343940 0.8598463 2.469341
1 150 3.252807 0.8684595 2.382491
2 50 3.252500 0.8661918 2.398781
2 100 3.142982 0.8761107 2.260648
2 150 3.073783 0.8812904 2.198436
3 50 3.181965 0.8759233 2.322446
3 100 3.103228 0.8839721 2.218162
3 150 3.075176 0.8870174 2.166229
Tuning parameter 'shrinkage' was held constant at a value of 0.1
Tuning
parameter 'n.minobsinnode' was held constant at a value of 10
RMSE was used to select the optimal model using the smallest value.
The final values used for the model were n.trees = 150, interaction.depth = 2, shrinkage =
0.1 and n.minobsinnode = 10.
The summary()
function produces a relative influence plot and also outputs the relative influence statistics.
summary(gbm.boston)
We see that lstat and rm are by far the most important variables. We can also produce partial dependence plots for these two variables. These plots partial dependence plot illustrate the marginal effect of the selected variables on the response after integrating out the other variables. In this case, as we might expect, median house prices are increasing with rm and decreasing with lstat.
par(mfrow=c(1,2))
plot(gbm.boston$finalModel ,i="rm")
plot(gbm.boston$finalModel ,i="lstat")
We now use the boosted model to predict medv on the test set:
yhat.boost=predict(gbm.boston,newdata=Boston[-inTrain,],n.trees=5000)
mean((yhat.boost - Boston[-inTrain,]$medv)^2)
[1] 23.40865
The test MSE obtained is 23.7433784; similar to the test MSE for random forests and superior to that for bagging. If we want to, we can perform boosting with a different value of the shrinkage parameter \(\lambda\). The default value is 0.001, but this is easily modified. Here we take \(\lambda\) = 0.2.
boost.boston=train(medv∼.,data=train, distribution="gaussian ",n.trees =5000, interaction.depth =4, shrinkage=0.2,verbose=F)
yhat.boost=predict(boost.boston,newdata=Boston[-inTrain,],n.trees=5000)
mean((yhat.boost - Boston[-inTrain,]$medv)^2)
[1] 23.74338
In this case, using \(\lambda\) = 0.2 leads to a slightly higher test MSE than \(\lambda\) = 0.001.