Illustration of decision trees Use the tree package This is an example exercise taken from the book An Introduction to Statistical Learning with Applications in R by Gareth James, Deniela Witten, Trever Hastie.

Fitting a classification tree

library(tree)
## Warning: package 'tree' was built under R version 4.0.4
library(ISLR)

Have a look at the data.

data("Carseats")
head(Carseats)
##   Sales CompPrice Income Advertising Population Price ShelveLoc Age Education
## 1  9.50       138     73          11        276   120       Bad  42        17
## 2 11.22       111     48          16        260    83      Good  65        10
## 3 10.06       113     35          10        269    80    Medium  59        12
## 4  7.40       117    100           4        466    97    Medium  55        14
## 5  4.15       141     64           3        340   128       Bad  38        13
## 6 10.81       124    113          13        501    72       Bad  78        16
##   Urban  US
## 1   Yes Yes
## 2   Yes Yes
## 3   Yes Yes
## 4   Yes Yes
## 5   Yes  No
## 6    No Yes

See the data structure

str(Carseats)
## 'data.frame':    400 obs. of  11 variables:
##  $ Sales      : num  9.5 11.22 10.06 7.4 4.15 ...
##  $ CompPrice  : num  138 111 113 117 141 124 115 136 132 132 ...
##  $ Income     : num  73 48 35 100 64 113 105 81 110 113 ...
##  $ Advertising: num  11 16 10 4 3 13 0 15 0 0 ...
##  $ Population : num  276 260 269 466 340 501 45 425 108 131 ...
##  $ Price      : num  120 83 80 97 128 72 108 120 124 124 ...
##  $ ShelveLoc  : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
##  $ Age        : num  42 65 59 55 38 78 71 67 76 76 ...
##  $ Education  : num  17 10 12 14 13 16 15 10 10 17 ...
##  $ Urban      : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
##  $ US         : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...

Create a binary response, since sales is continuous. I chose the mean (7.5) as the separator.

High = ifelse(Carseats$Sales <= 7.5, 'No', 'Yes')
High = as.factor(High)
Carseats = data.frame(Carseats, High)
str(Carseats)
## 'data.frame':    400 obs. of  12 variables:
##  $ Sales      : num  9.5 11.22 10.06 7.4 4.15 ...
##  $ CompPrice  : num  138 111 113 117 141 124 115 136 132 132 ...
##  $ Income     : num  73 48 35 100 64 113 105 81 110 113 ...
##  $ Advertising: num  11 16 10 4 3 13 0 15 0 0 ...
##  $ Population : num  276 260 269 466 340 501 45 425 108 131 ...
##  $ Price      : num  120 83 80 97 128 72 108 120 124 124 ...
##  $ ShelveLoc  : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
##  $ Age        : num  42 65 59 55 38 78 71 67 76 76 ...
##  $ Education  : num  17 10 12 14 13 16 15 10 10 17 ...
##  $ Urban      : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
##  $ US         : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
##  $ High       : Factor w/ 2 levels "No","Yes": 2 2 2 1 1 2 1 2 1 1 ...

Grow a tree. Because the response variable is now High instead of Sales.

tree.carseats = tree(High ~ . -Sales, Carseats)

When growing a tree, in each split, we consider all predictors X1, …, Xp, and all possible values of the cutpoint for each of the predictors, and then choose the predictor and cutpoint such that the resulting tree has the lowest regression sum of square.

See the default output.

tree.carseats
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##   1) root 400 554.500 No ( 0.50500 0.49500 )  
##     2) ShelveLoc: Bad,Medium 315 424.000 No ( 0.60000 0.40000 )  
##       4) Price < 126.5 216 299.000 Yes ( 0.47685 0.52315 )  
##         8) Advertising < 7.5 128 169.400 No ( 0.62500 0.37500 )  
##          16) Price < 100.5 49  66.270 Yes ( 0.40816 0.59184 )  
##            32) Price < 71 7   0.000 Yes ( 0.00000 1.00000 ) *
##            33) Price > 71 42  58.130 Yes ( 0.47619 0.52381 )  
##              66) CompPrice < 109.5 14  14.550 No ( 0.78571 0.21429 ) *
##              67) CompPrice > 109.5 28  35.160 Yes ( 0.32143 0.67857 ) *
##          17) Price > 100.5 79  87.160 No ( 0.75949 0.24051 )  
##            34) CompPrice < 127 44  21.900 No ( 0.93182 0.06818 )  
##              68) Age < 50 13  14.050 No ( 0.76923 0.23077 ) *
##              69) Age > 50 31   0.000 No ( 1.00000 0.00000 ) *
##            35) CompPrice > 127 35  48.260 No ( 0.54286 0.45714 )  
##              70) Population < 58 6   0.000 No ( 1.00000 0.00000 ) *
##              71) Population > 58 29  39.890 Yes ( 0.44828 0.55172 )  
##               142) CompPrice < 140.5 22  29.770 No ( 0.59091 0.40909 ) *
##               143) CompPrice > 140.5 7   0.000 Yes ( 0.00000 1.00000 ) *
##         9) Advertising > 7.5 88 101.100 Yes ( 0.26136 0.73864 )  
##          18) CompPrice < 123.5 51  69.740 Yes ( 0.43137 0.56863 )  
##            36) Income < 57 13   7.051 No ( 0.92308 0.07692 ) *
##            37) Income > 57 38  43.800 Yes ( 0.26316 0.73684 )  
##              74) Price < 94.5 15   0.000 Yes ( 0.00000 1.00000 ) *
##              75) Price > 94.5 23  31.490 Yes ( 0.43478 0.56522 )  
##               150) Age < 68 17  18.550 Yes ( 0.23529 0.76471 )  
##                 300) CompPrice < 111 8  11.090 Yes ( 0.50000 0.50000 ) *
##                 301) CompPrice > 111 9   0.000 Yes ( 0.00000 1.00000 ) *
##               151) Age > 68 6   0.000 No ( 1.00000 0.00000 ) *
##          19) CompPrice > 123.5 37   9.195 Yes ( 0.02703 0.97297 ) *
##       5) Price > 126.5 99  77.000 No ( 0.86869 0.13131 )  
##        10) CompPrice < 142 73  31.010 No ( 0.94521 0.05479 )  
##          20) Advertising < 14.5 62  10.240 No ( 0.98387 0.01613 ) *
##          21) Advertising > 14.5 11  12.890 No ( 0.72727 0.27273 )  
##            42) Income < 64 5   6.730 Yes ( 0.40000 0.60000 ) *
##            43) Income > 64 6   0.000 No ( 1.00000 0.00000 ) *
##        11) CompPrice > 142 26  33.540 No ( 0.65385 0.34615 )  
##          22) Price < 145 11  14.420 Yes ( 0.36364 0.63636 ) *
##          23) Price > 145 15  11.780 No ( 0.86667 0.13333 ) *
##     3) ShelveLoc: Good 85  72.720 Yes ( 0.15294 0.84706 )  
##       6) Price < 132.5 64  17.800 Yes ( 0.03125 0.96875 ) *
##       7) Price > 132.5 21  29.060 No ( 0.52381 0.47619 )  
##        14) Income < 46 7   0.000 No ( 1.00000 0.00000 ) *
##        15) Income > 46 14  16.750 Yes ( 0.28571 0.71429 )  
##          30) Advertising < 3 8  11.090 Yes ( 0.50000 0.50000 ) *
##          31) Advertising > 3 6   0.000 Yes ( 0.00000 1.00000 ) *

See a summary of results.

summary(tree.carseats)
## 
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc"   "Price"       "Advertising" "CompPrice"   "Age"        
## [6] "Population"  "Income"     
## Number of terminal nodes:  23 
## Residual mean deviance:  0.5117 = 192.9 / 377 
## Misclassification error rate: 0.1125 = 45 / 400

Plot the tree.

plot(tree.carseats)
text(tree.carseats, pretty = 0, cex = 0.7)

Here the most important variable are Shelveloc and Price because they were selected by the tree in the first two splits.

Estimate test error rate using validation set approach.

Create training and test sets

set.seed(2)
train = sample(1:nrow(Carseats), 200)
Carseats.test = Carseats[-train, ]
High.test = High[-train]

Grow a tree using the training set.

tree.carseats = tree(High ~ . -Sales, Carseats, subset = train)

Get predictions on the test set.

tree.pred = predict(tree.carseats, Carseats.test, type = 'class')

Compute the confusion matrix.

table(tree.pred, High.test)
##          High.test
## tree.pred No Yes
##       No  81  27
##       Yes 20  72

From the table, the misclassification rate is 23.5%

Perform cost complexity pruning by cross-validation (CV), using misclassification rate.

Next, we consider whether prunning the tree migh lead to improved results. The function cv.tree() performs cross-validation in order to determine the optimal level of tree complexity. We use the argument FUN = prune.misclass in order to indicate that we want the classification error rate to guide the cross-validation and pruning process, rather than the default deviance for the cv.tree() function

set.seed(3)
cv.carseats = cv.tree(tree.carseats, FUN = prune.misclass)

Note: k = alpha (pruning), dev = cross-validation error rate, size = size of tree.

Look at what is stored in the result object.

names(cv.carseats)
## [1] "size"   "dev"    "k"      "method"

In the result, k = alpha, dev = cross-validation error rate, size = size of tree

Plot the estimated test error rate.

par(mfrow = c(1,2))
plot(cv.carseats$size, cv.carseats$dev, type = 'b')
plot(cv.carseats$k, cv.carseats$dev, type = 'b')

Note that when k is small, size is large, and vice versa.

Get the best size.

cv.carseats$size[which.min(cv.carseats$dev)]
## [1] 14

The best size is 14, so we prune the tree to obtain the 14-node tree.

Get the pruned tree of the best size.

prune.carseats = prune.misclass(tree.carseats, best = 14)

Plot the pruned tree. Nine leaves.

plot(prune.carseats)
text(prune.carseats, pretty = 0)

Get predictions on the test set.

tree.pred = predict(prune.carseats, Carseats.test, type = 'class')

Get the confusion matrix.

table(tree.pred, High.test)
##          High.test
## tree.pred No Yes
##       No  84  28
##       Yes 17  71

The missclassification rate is (17+28)/(84+71+17+28) = 22.5%

Compute the missclassification rate of a larger pruned tree for size 15.

prune.carseats = prune.misclass(tree.carseats, best = 15)
tree.pred = predict(prune.carseats, Carseats.test, type = 'class')
table(tree.pred, High.test)
##          High.test
## tree.pred No Yes
##       No  88  32
##       Yes 13  67

You can see the accuracy decreases on the testing data, as we use a larger size tree (alpha = k smaller)

Fitting a regression tree.

library(MASS)
data(Boston)

We will use the housing data from Boston.

str(Boston)
## 'data.frame':    506 obs. of  14 variables:
##  $ crim   : num  0.00632 0.02731 0.02729 0.03237 0.06905 ...
##  $ zn     : num  18 0 0 0 0 0 12.5 12.5 12.5 12.5 ...
##  $ indus  : num  2.31 7.07 7.07 2.18 2.18 2.18 7.87 7.87 7.87 7.87 ...
##  $ chas   : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ nox    : num  0.538 0.469 0.469 0.458 0.458 0.458 0.524 0.524 0.524 0.524 ...
##  $ rm     : num  6.58 6.42 7.18 7 7.15 ...
##  $ age    : num  65.2 78.9 61.1 45.8 54.2 58.7 66.6 96.1 100 85.9 ...
##  $ dis    : num  4.09 4.97 4.97 6.06 6.06 ...
##  $ rad    : int  1 2 2 3 3 3 5 5 5 5 ...
##  $ tax    : num  296 242 242 222 222 222 311 311 311 311 ...
##  $ ptratio: num  15.3 17.8 17.8 18.7 18.7 18.7 15.2 15.2 15.2 15.2 ...
##  $ black  : num  397 397 393 395 397 ...
##  $ lstat  : num  4.98 9.14 4.03 2.94 5.33 ...
##  $ medv   : num  24 21.6 34.7 33.4 36.2 28.7 22.9 27.1 16.5 18.9 ...

Create training and test sets. 50/50 split.

set.seed(1)
train = sample(1:nrow(Boston), nrow(Boston)/2)

Grow a tree using the training set. Media value (medv) is the response.

tree.Boston = tree(medv ~ ., Boston, subset = train)
summary(tree.Boston)
## 
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "rm"    "lstat" "crim"  "age"  
## Number of terminal nodes:  7 
## Residual mean deviance:  10.38 = 2555 / 246 
## Distribution of residuals:
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## -10.1800  -1.7770  -0.1775   0.0000   1.9230  16.5800

Plot the tree

plot(tree.Boston)
text(tree.Boston, pretty = 0)

Perform cost complexity pruning by CV

cv.boston = cv.tree(tree.Boston)
cv.boston
## $size
## [1] 7 6 5 4 3 2 1
## 
## $dev
## [1]  4380.849  4544.815  5601.055  6171.917  6919.608 10419.472 19630.870
## 
## $k
## [1]       -Inf   203.9641   637.2707   796.1207  1106.4931  3424.7810 10724.5951
## 
## $method
## [1] "deviance"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"
which.min(cv.boston$size)
## [1] 7

Plot the estimated test error rate.

plot(cv.boston$size, cv.boston$dev, type = 'b')

Note: Best size = 8 (i.e., no pruning)

If needed, pruning can be performed by specifying the “best” argument

prune.boston = prune.tree(tree.Boston, best = 5)
plot(prune.boston)
text(prune.boston, pretty = 0)

Get predictions on the test data.

yhat = predict(tree.Boston, newdata = Boston[-train, ])
boston.test = Boston[-train, 'medv']

Plot the observed values against the predicted values.

plot(yhat, boston.test)
abline(0,1, col = 'red')

Compute the test error rate.

mean((yhat - boston.test)^2)
## [1] 35.28688