8.3.1 Fitting Classification Trees
The tree library is used to construct classification and regression trees.
library(tree)We first use classification tree to analyze the Carseats data set.In these data, Sales is a continuous variable, and so we begin by encoding it as a binary variable.
We use the ifelse()to create a variable, called High, which takes on a value of Yes if the sales exceed 8, and takes on a value of No otherwise.
ifelse(条件语句,条件为真执行,条件为假执行)
library(ISLR)
attach(Carseats)
High = ifelse(Sales<=8,"No","Yes")
Carseats = data.frame(Carseats,High) # merge High with the rest of the dataWe now use the tree() function to fit a classification tree in order to predict High using all variables but Sales.
Carseats$High = as.factor(Carseats$High)
tree.carseats = tree(High~.-Sales,Carseats)
summary(tree.carseats)##
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
## [6] "Advertising" "Age" "US"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400
We see that the training error rate is 9%. For classification trees, the deviance is given by \[-2\sum_{m=1}\sum_{k=1}n_{mk}log{\hat{p_{mk}}}\] where \(n_{mk}\) is the number of observations in the mth terminal node that belong to the kth class.
A small deviance indicates a tree that provides a good fit to the (training) data. The residual mean deviance reported is simply the deviance divided by n - |\(T_{0}\)|, which in this case is 400-27=373.
One of the most attractive properties of trees is that they can be graphically displayed.
We use:
- The
plot()function to display the tree structure. - The
text()function to display the node labels. - The
pretty=0instructsRto include the category names for any qualitative predictors, rather than simply displaying a letter for each category.
plot(tree.carseats)
text(tree.carseats,pretty=0)The most important indicator of Sales appears to be shelving location, since the first branch differentiates Good location from Bad and Mediumlocations.
In order to properly evaluate the performance of a classification tree on these data, we must estimate the test error rather than simply computing the training error.
We split the observations into a training set and a test set, build the tree using the training set, and evaluate its performance on the test data. The predict function can be used for this purpose.
set.seed(1)
train = sample(1:nrow(Carseats),200)
Carseats.test = Carseats[-train,]
High.test = High[-train]
tree.carseats = tree(High~.-Sales,Carseats,subset=train)
tree.pred = predict(tree.carseats,Carseats.test,type='class')
table(tree.pred,High.test)## High.test
## tree.pred No Yes
## No 84 37
## Yes 35 44
(84+44)/200## [1] 0.64
type='class:the actual class prediction. This approach leads to correct predictions for around 64% of the locations in the test data set.
Next, we consider whether pruning the tree might lead to improved results. The function cv.tree() performs cross-validation in order to determine the optimal level of tree complexity; cost complexity pruning is used in order to select a sequence of trees for consideration.
We use FUN=prune.misclass in order to indicate that we want the classification error rate to guide the cross-validation and pruning process, rather than the default for the cv.tree() function, which is deviance.
set.seed(1)
cv.carseats = cv.tree(tree.carseats,FUN=prune.misclass)
names(cv.carseats)## [1] "size" "dev" "k" "method"
cv.carseats## $size
## [1] 20 18 10 8 6 4 2 1
##
## $dev
## [1] 66 66 59 56 53 58 75 85
##
## $k
## [1] -Inf 0.0 0.5 1.5 2.0 4.0 12.0 19.0
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
Note that, despite the name, dev corresponds to the cross-validation error rate in this instance. The tree with 6 terminal nodes results in the lowest cross-validation error rate, with 53 cross-validation errors. We plot the error rate as a function of both size and k.
par(mfrow=c(1,2))
plot(cv.carseats$size,cv.carseats$dev,type='b')
plot(cv.carseats$k,cv.carseats$dev,type='b')We now apply the prune.misclass() function in order to prune the tree to obtain the 6-node tree.
prune.carseats = prune.misclass(tree.carseats,best=6)
plot(prune.carseats)
text(prune.carseats,pretty=0)Once again, we apply the predict()function.
tree.pred = predict(prune.carseats,Carseats.test,type='class')
table(tree.pred,High.test)## High.test
## tree.pred No Yes
## No 86 32
## Yes 33 49
(86+49)/200## [1] 0.675
Now 67.5% of the test observations are correctly classified, so not only has the pruning process produced a more interpretable tree, but it has also improved the classification accuracy.
If we increase the value of best, we obtain a larger pruned tree with lower classification accuracy:
prune.carseats = prune.misclass(tree.carseats,best=15)
plot(prune.carseats)
text(prune.carseats,pretty=0)tree.pred = predict(prune.carseats,Carseats.test,type='class')
table(tree.pred,High.test)## High.test
## tree.pred No Yes
## No 84 37
## Yes 35 44
(84+44)/200## [1] 0.64
8.3.2 Fitting Regression Trees
Here we fit a regression tree to the Boston data set. First, we create a training set, and fit the tree to the training data.
library(MASS)
set.seed(1)
train = sample(1:nrow(Boston),nrow(Boston)/2)
tree.boston = tree(medv~.,Boston,subset=train)
summary(tree.boston)##
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "rm" "lstat" "crim" "age"
## Number of terminal nodes: 7
## Residual mean deviance: 10.38 = 2555 / 246
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -10.1800 -1.7770 -0.1775 0.0000 1.9230 16.5800
In the context of regression tree, the deviance is simply the sum of squared errors for the tree.
We now plot the tree.
plot(tree.boston)
text(tree.boston,pretty=0)The variable rm measures the average number of rooms per dwelling. The tree indicates that larger values of rm correspond to more expensive houses. The tree predicts a median house price of $45.38 for larger homes(rm>7.553)
Now we use the cv.tree()function to see whether pruning the tree will improve performance.
cv.boston = cv.tree(tree.boston)
plot(cv.boston$size,cv.boston$dev,type='b')In this case, the most complex tree is selected by cross-validation. However, if we wish to prune the tree, we could do so as follows, using the prune.tree()function:
prune.boston = prune.tree(tree.boston,best=5)
plot(prune.boston)
text(prune.boston,pretty=0)In keeping with the cross-validation results, we use the unpruned tree to make predictions on the test set.
yhat = predict(tree.boston,newdata=Boston[-train,])
boston.test = Boston[-train,'medv']
plot(yhat,boston.test)
abline(0,1)mean((yhat-boston.test)^2)## [1] 35.28688
In other words, the test MSE associated with the regression tree is 35.28688. The square root of the MSE is therefore around 5.94, indicating that this model leads to test predictions that are within around $5.94 of the true median home value for the suburb.
8.3.3 Bagging and Random Forests
Here we apply bagging and random forests to the Boston data, using the randomForest packages in R. Recall that bagging is simply a special case of a random forest with m=p. Therefore, the randomForest()function can be used to perform both random forests and bagging.
Here we perform bagging as follows:
library(randomForest)## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
library(MASS)
set.seed(1)
train = sample(1:nrow(Boston),nrow(Boston)/2)
boston.test = Boston[-train,'medv']
bag.boston = randomForest(medv~.,data=Boston,subset=train,mtry=13,importance=TRUE)
bag.boston##
## Call:
## randomForest(formula = medv ~ ., data = Boston, mtry = 13, importance = TRUE, subset = train)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 13
##
## Mean of squared residuals: 11.33119
## % Var explained: 85.26
mtry=13 indicates that all 13 predictors should be considered for each split of the tree—in other words, that bagging should be done.
Here perform on the test set:
yhat.bag = predict(bag.boston,newdata=Boston[-train,])
plot(yhat.bag,boston.test)
abline(0,1)mean((yhat.bag-boston.test)^2)## [1] 23.4579
The test set MSE associated with the bagged regression tree is 23.4579, smaller than that obtained using an optimally-pruned single tree.
We could change the number of trees grown by randomForest()using the ntree argument:
bag.boston = randomForest(medv~.,data=Boston,subset=train,mtry=13,ntree=25)
yhat.bag = predict(bag.boston,newdata=Boston[-train,])
mean((yhat.bag-boston.test)^2)## [1] 22.99145
Growing a random forest proceeds in exactly the same way, except that we use a smaller value of the mtry argument. By default, uses p/3 variables when building a random forest of regression tree, and \(\sqrt{p}\) variables when building a random forest of classification trees.
Here we use mtry=6:
set.seed(1)
rf.boston = randomForest(medv~.,data=Boston,subset=train,mtry=6,importance=TRUE)
yhat.rf = predict(rf.boston,newdata=Boston[-train,])
mean((yhat.rf-boston.test)^2)## [1] 19.62021
The test set MSE is 19.62021; this indicates that random forests yielded an improvement over bagging in this case.
Using the importance()function, we can view the importance of each variable.
importance(rf.boston)## %IncMSE IncNodePurity
## crim 16.697017 1076.08786
## zn 3.625784 88.35342
## indus 4.968621 609.53356
## chas 1.061432 52.21793
## nox 13.518179 709.87339
## rm 32.343305 7857.65451
## age 13.272498 612.21424
## dis 9.032477 714.94674
## rad 2.878434 95.80598
## tax 9.118801 364.92479
## ptratio 8.467062 823.93341
## black 7.579482 275.62272
## lstat 27.129817 6027.63740
- The former is based upon the mean decrease of accuracy in predictions on the out of bag samples when a given variable is excluded from the model.
- The latter is a measure of the total decrease in node impurity that results from splits over that variable, averaged over all trees.
In the case of regression trees, the node impurity is measured by the training RSS, and for classification trees by the deviance. Plots of these importance measures can be produced using the varImpPlot()function.
varImpPlot(rf.boston)The results indicate that across all of the trees considered in the random forest, the wealth level of the community(lstat) and the house size(rm) are by far the two most important variables.
8.3.4 Boosting
Here we use the gbm()package to fit boosted regression trees to the Boston data set.
We run gbm() with the option:
- regression problem:
distribution='gaussian' - binary classification problem:
distribution=bernoulli' n.trees=5000indicates that we want 5000 trees- ’interaction.depth=4`limits the depth of each tree
library(gbm)## Loaded gbm 2.1.8
set.seed(1)
boost.boston = gbm(medv~.,data=Boston[train,],distribution='gaussian',n.trees=5000,
interaction.depth=4)
summary(boost.boston)## var rel.inf
## rm rm 43.9919329
## lstat lstat 33.1216941
## crim crim 4.2604167
## dis dis 4.0111090
## nox nox 3.4353017
## black black 2.8267554
## age age 2.6113938
## ptratio ptratio 2.5403035
## tax tax 1.4565654
## indus indus 0.8008740
## rad rad 0.6546400
## zn zn 0.1446149
## chas chas 0.1443986
We see that lstat and rm are by far the most important variables. We can also produce partial dependence plots for these two variables. These plots illustrate the marginal effect of the selected variables on the response after integrating out the other variables.
In this case, as we might expect, median house prices are increasing with rm and decreasing with lstat.
par(mfrow=c(1,2))
plot(boost.boston,i='rm')plot(boost.boston,i='lstat')We now use the boosted model to predict medv on the test set:
yhat.boost = predict(boost.boston,newdata=Boston[-train,],
n.trees=5000)
mean((yhat.boost-boston.test)^2)## [1] 18.84709
The test MSE obtained is 18.84709; similar to the test MSE for random forests and superior to that for bagging.
If we want to, we can perform boosting with a different value of the shrinkage parameter \(\lambda\). The default value is 0.001, but this is easily modified. Here we take \(\lambda\)=0.2:
boost.boston = gbm(medv~.,data=Boston[train,],distribution='gaussian',n.trees=5000,
interaction.depth=4,shrinkage=0.2,verbose=F)
yhat.boost = predict(boost.boston,newdata=Boston[-train,],n.trees=5000)
mean((yhat.boost-boston.test)^2)## [1] 18.33455
In this case, using \(\lambda\)=0.2 leads to a slightly lower test MSE than \(\lambda\)=0.001.