The tree library is used to construct classification and regression trees. Since Sales is a continuous variable, we have to recode and store it as a binary variable. We use the ifelse() function to create a variable, called High, which takes on a value of Yes if the Sales variable exceeds 8 - otherwise it takes on a value of No.
library(tree)
library(ISLR)
attach(Carseats)
head(Carseats)
## Sales CompPrice Income Advertising Population Price ShelveLoc Age Education
## 1 9.50 138 73 11 276 120 Bad 42 17
## 2 11.22 111 48 16 260 83 Good 65 10
## 3 10.06 113 35 10 269 80 Medium 59 12
## 4 7.40 117 100 4 466 97 Medium 55 14
## 5 4.15 141 64 3 340 128 Bad 38 13
## 6 10.81 124 113 13 501 72 Bad 78 16
## Urban US
## 1 Yes Yes
## 2 Yes Yes
## 3 Yes Yes
## 4 Yes Yes
## 5 Yes No
## 6 No Yes
High <- ifelse(Sales <= 8, "No", "Yes")
head(High)
## [1] "Yes" "Yes" "Yes" "No" "No" "Yes"
To integrate this new variable, we use the data.frame() function to merge High with the rest of the Carseats data.
Carseats <- data.frame(Carseats, High)
head(Carseats)
## Sales CompPrice Income Advertising Population Price ShelveLoc Age Education
## 1 9.50 138 73 11 276 120 Bad 42 17
## 2 11.22 111 48 16 260 83 Good 65 10
## 3 10.06 113 35 10 269 80 Medium 59 12
## 4 7.40 117 100 4 466 97 Medium 55 14
## 5 4.15 141 64 3 340 128 Bad 38 13
## 6 10.81 124 113 13 501 72 Bad 78 16
## Urban US High
## 1 Yes Yes Yes
## 2 Yes Yes Yes
## 3 Yes Yes Yes
## 4 Yes Yes No
## 5 Yes No No
## 6 No Yes Yes
We now use the tree()function to fit a classification tree in order to predict High using all variables but Sales. The syntax of the tree() function is quite similar to that of the lm() function.
tree.carseats <- tree(High~.-Sales, Carseats)
summary(tree.carseats)
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
## [6] "Advertising" "Age" "US"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400The summary() function lists the variables that are used as internal nodes in the tree, the number of terminal (ending) nodes and the training error rate. In this case, the training error rate is 9%. The residual mean deviance reported is simply the deviance divided by n-|T0| (T0 refers to the number of terminal nodes), which in this case is 400-27 = 373.
We use the plot() function to display the tree structure and the text() function to display the nodes labels. The argument pretty = 0 instructs R to include the category names for any qualitative predictors, rather than simply displaying a letter for each category.
par(bg = 'black')
plot(tree.carseats, type = 'uniform', col = 'aquamarine')
text(tree.carseats, pretty = 0, cex = .42, col = 'white')The most important indicator of Sales appears to be shelving location, and the first branch differentiates Good locations (to the right) from Bad and Medium locations.
If we just type the name of the tree object, R prints output corresponding to each branch of the tree. R displays the split criterion, the number of observations in that branch, the deviance and the overall prediction for the branch(Yes or No). Branches that lead to terminal nodes are indicated using asterisks.
In this step, we split the observations into a training set and a test set, build the tree using the training set and evaluate its performance on the test data. The predict() function can be used for this purpose. The argument type = "class" instructs R to return the actual class prediction. This approach leads to correct predictions for 77% of the test data set.
set.seed(2)
train <- sample(1:nrow(Carseats), 200)
Carseats.test <- Carseats[-train,]
High.test <- High[-train]
tree.carseats <- tree(High~.-Sales, Carseats, subset = train)
tree.pred <- predict(tree.carseats, Carseats.test, type = "class")
table(tree.pred, High.test)
## High.test
## tree.pred No Yes
## No 104 33
## Yes 13 50
(104+50)/200
## [1] 0.77
cv.tree() determines the optimal level of tree complexity; cost complexity pruning is used in order to select a sequence of trees for consideration. We use the argument FUN = prune.misclass to indicate that we want the classification error rate to guide the cross-validation and pruning process, rather than the default for the cv.tree() function, which is deviance.
cv.carseats <- cv.tree(tree.carseats, FUN = prune.misclass)
names(cv.carseats)
## [1] "size" "dev" "k" "method"
cv.carseats
## $size
## [1] 21 19 14 9 8 5 3 2 1
##
## $dev
## [1] 71 70 66 66 70 74 76 87 88
##
## $k
## [1] -Inf 0.0 1.0 1.4 2.0 3.0 4.0 9.0 18.0
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"The cv.tree() function reports the number of terminal nodes of each tree considered (size) as well as the corresponding error rate and the value of the cost-complexity parameter used.
dev corresponds to the cross-validation error rate in this instance. The tree with nine terminal nodes results in the lowest cross-validation error rate, with 66 cross-validation errors.
Plot the error rate as a function of size
par(bg = 'black')
plot(cv.carseats$size, cv.carseats$dev, type = 'b', col = 'white', axes = FALSE, xlab = "", ylab = "")
box(col="aquamarine")
axis(1, col = "aquamarine", col.ticks = "aquamarine", col.axis = "white", cex.axis = 1)
axis(2, col = "aquamarine", col.ticks = "aquamarine", col.axis = "white", cex.axis = 1)
mtext("Nodes", side = 1, line = 3, col = "aquamarine", cex = 1)
mtext("Deviance", side = 2, line = 3, col = "aquamarine", cex = 1)
We now apply the prune.misclass() function in order to prune the tree to obtain its nine-node variant.
par(bg = 'black')
prune.carseats <- prune.misclass(tree.carseats, best = 9)
plot(prune.carseats, type = 'uniform', col = 'aquamarine')
text(prune.carseats, pretty = 0, cex = .75, col = 'white')tree.pred<- predict(prune.carseats, Carseats.test, type = "class")
table(tree.pred, High.test)
## High.test
## tree.pred No Yes
## No 97 25
## Yes 20 58
(97+58)/200
## [1] 0.775We then apply the predict() function - now 77.5% of the test observations are correct. The pruning process improved the classification accuracy. If we change the value of best to another number, however, we obtain a pruned tree with an inferior classification accuracy.
prune.carseats <- prune.misclass(tree.carseats, best = 6)
par(bg = 'black')
plot(prune.carseats, type = 'uniform', col = 'aquamarine')
text(prune.carseats, pretty = 0, col = 'white')We fit a regression tree to the Boston data set.
library(MASS)
set.seed(1)
train <- sample(1:nrow(Boston), nrow(Boston)/2)
tree.boston <- tree(medv~., Boston, subset=train)
summary(tree.boston)
##
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "rm" "lstat" "crim" "age"
## Number of terminal nodes: 7
## Residual mean deviance: 10.38 = 2555 / 246
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -10.1800 -1.7770 -0.1775 0.0000 1.9230 16.5800The output of summary() indicates that only four of the variables have been used in constructing the tree. In the context of a regression tree, the deviance is simply the sum of squared errors for the tree.
par(bg = 'black')
plot(tree.boston, type = 'uniform', col = 'aquamarine')
text(tree.boston, pretty = 0, col = 'white')The tree predicts a median house price of approximately $45,400 for larger homes (rm > 7.553).
Now we use the cv.tree() function to see whether pruning the tree will improve performance.
cv.boston <- cv.tree(tree.boston)
par(bg = 'black')
plot(cv.boston$size, cv.boston$dev, type= 'b', col = 'white', axes = FALSE)
box(col="aquamarine")
axis(1, col = "aquamarine", col.ticks = "aquamarine", col.axis = "white", cex.axis = 1)
axis(2, col = "aquamarine", col.ticks = "aquamarine", col.axis = "white", cex.axis = 1)
mtext("Nodes", side = 1, line = 3, col = "aquamarine", cex = 1)
mtext("Deviance", side = 2, line = 3, col = "aquamarine", cex = 1) In this case, the most complex tree is selected by cross-validation. However, if we wish to prune the tree, we could do so as follows, using the
prune.tree() function.
prune.boston <- prune.tree(tree.boston, best = 5)
par(bg = 'black')
plot(prune.boston, type = 'uniform', col = 'aquamarine')
text(prune.boston, pretty = 0, col = 'white')In keeping with the cross-validation results, we use the unpruned tree to make predictions on the test set.
yhat <- predict(tree.boston, newdata = Boston[-train,])
boston.test <- Boston[-train, "medv"]
par(bg = 'black')
plot(yhat, boston.test, col = 'white', axes = FALSE)
box(col = "aquamarine")
axis(1, col = "aquamarine", col.ticks = "aquamarine", col.axis = "white", cex.axis = 1)
axis(2, col = "aquamarine", col.ticks = "aquamarine", col.axis = "white", cex.axis = 1)
mtext("Y^", side = 1, line = 3, col = "aquamarine", cex = 1)
mtext("Test", side = 2, line = 3, col = "aquamarine", cex = 1)
abline(0, 1, col = 'white')The test set MSE associated with the regression tree is 35.29. The square root of the MSE is therefore around 5.94, indicating that this model leads to test predictions that are within $5940 of the true median home value for the suburb.