Illustration of decision trees Use the tree package This is an example exercise taken from the book An Introduction to Statistical Learning with Applications in R by Gareth James, Deniela Witten, Trever Hastie.
library(tree)
## Warning: package 'tree' was built under R version 4.0.4
library(ISLR)
data("Carseats")
head(Carseats)
## Sales CompPrice Income Advertising Population Price ShelveLoc Age Education
## 1 9.50 138 73 11 276 120 Bad 42 17
## 2 11.22 111 48 16 260 83 Good 65 10
## 3 10.06 113 35 10 269 80 Medium 59 12
## 4 7.40 117 100 4 466 97 Medium 55 14
## 5 4.15 141 64 3 340 128 Bad 38 13
## 6 10.81 124 113 13 501 72 Bad 78 16
## Urban US
## 1 Yes Yes
## 2 Yes Yes
## 3 Yes Yes
## 4 Yes Yes
## 5 Yes No
## 6 No Yes
str(Carseats)
## 'data.frame': 400 obs. of 11 variables:
## $ Sales : num 9.5 11.22 10.06 7.4 4.15 ...
## $ CompPrice : num 138 111 113 117 141 124 115 136 132 132 ...
## $ Income : num 73 48 35 100 64 113 105 81 110 113 ...
## $ Advertising: num 11 16 10 4 3 13 0 15 0 0 ...
## $ Population : num 276 260 269 466 340 501 45 425 108 131 ...
## $ Price : num 120 83 80 97 128 72 108 120 124 124 ...
## $ ShelveLoc : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
## $ Age : num 42 65 59 55 38 78 71 67 76 76 ...
## $ Education : num 17 10 12 14 13 16 15 10 10 17 ...
## $ Urban : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
## $ US : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
High = ifelse(Carseats$Sales <= 7.5, 'No', 'Yes')
High = as.factor(High)
Carseats = data.frame(Carseats, High)
str(Carseats)
## 'data.frame': 400 obs. of 12 variables:
## $ Sales : num 9.5 11.22 10.06 7.4 4.15 ...
## $ CompPrice : num 138 111 113 117 141 124 115 136 132 132 ...
## $ Income : num 73 48 35 100 64 113 105 81 110 113 ...
## $ Advertising: num 11 16 10 4 3 13 0 15 0 0 ...
## $ Population : num 276 260 269 466 340 501 45 425 108 131 ...
## $ Price : num 120 83 80 97 128 72 108 120 124 124 ...
## $ ShelveLoc : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
## $ Age : num 42 65 59 55 38 78 71 67 76 76 ...
## $ Education : num 17 10 12 14 13 16 15 10 10 17 ...
## $ Urban : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
## $ US : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
## $ High : Factor w/ 2 levels "No","Yes": 2 2 2 1 1 2 1 2 1 1 ...
High instead of Sales.tree.carseats = tree(High ~ . -Sales, Carseats)
When growing a tree, in each split, we consider all predictors X1, …, Xp, and all possible values of the cutpoint for each of the predictors, and then choose the predictor and cutpoint such that the resulting tree has the lowest regression sum of square.
tree.carseats
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 400 554.500 No ( 0.50500 0.49500 )
## 2) ShelveLoc: Bad,Medium 315 424.000 No ( 0.60000 0.40000 )
## 4) Price < 126.5 216 299.000 Yes ( 0.47685 0.52315 )
## 8) Advertising < 7.5 128 169.400 No ( 0.62500 0.37500 )
## 16) Price < 100.5 49 66.270 Yes ( 0.40816 0.59184 )
## 32) Price < 71 7 0.000 Yes ( 0.00000 1.00000 ) *
## 33) Price > 71 42 58.130 Yes ( 0.47619 0.52381 )
## 66) CompPrice < 109.5 14 14.550 No ( 0.78571 0.21429 ) *
## 67) CompPrice > 109.5 28 35.160 Yes ( 0.32143 0.67857 ) *
## 17) Price > 100.5 79 87.160 No ( 0.75949 0.24051 )
## 34) CompPrice < 127 44 21.900 No ( 0.93182 0.06818 )
## 68) Age < 50 13 14.050 No ( 0.76923 0.23077 ) *
## 69) Age > 50 31 0.000 No ( 1.00000 0.00000 ) *
## 35) CompPrice > 127 35 48.260 No ( 0.54286 0.45714 )
## 70) Population < 58 6 0.000 No ( 1.00000 0.00000 ) *
## 71) Population > 58 29 39.890 Yes ( 0.44828 0.55172 )
## 142) CompPrice < 140.5 22 29.770 No ( 0.59091 0.40909 ) *
## 143) CompPrice > 140.5 7 0.000 Yes ( 0.00000 1.00000 ) *
## 9) Advertising > 7.5 88 101.100 Yes ( 0.26136 0.73864 )
## 18) CompPrice < 123.5 51 69.740 Yes ( 0.43137 0.56863 )
## 36) Income < 57 13 7.051 No ( 0.92308 0.07692 ) *
## 37) Income > 57 38 43.800 Yes ( 0.26316 0.73684 )
## 74) Price < 94.5 15 0.000 Yes ( 0.00000 1.00000 ) *
## 75) Price > 94.5 23 31.490 Yes ( 0.43478 0.56522 )
## 150) Age < 68 17 18.550 Yes ( 0.23529 0.76471 )
## 300) CompPrice < 111 8 11.090 Yes ( 0.50000 0.50000 ) *
## 301) CompPrice > 111 9 0.000 Yes ( 0.00000 1.00000 ) *
## 151) Age > 68 6 0.000 No ( 1.00000 0.00000 ) *
## 19) CompPrice > 123.5 37 9.195 Yes ( 0.02703 0.97297 ) *
## 5) Price > 126.5 99 77.000 No ( 0.86869 0.13131 )
## 10) CompPrice < 142 73 31.010 No ( 0.94521 0.05479 )
## 20) Advertising < 14.5 62 10.240 No ( 0.98387 0.01613 ) *
## 21) Advertising > 14.5 11 12.890 No ( 0.72727 0.27273 )
## 42) Income < 64 5 6.730 Yes ( 0.40000 0.60000 ) *
## 43) Income > 64 6 0.000 No ( 1.00000 0.00000 ) *
## 11) CompPrice > 142 26 33.540 No ( 0.65385 0.34615 )
## 22) Price < 145 11 14.420 Yes ( 0.36364 0.63636 ) *
## 23) Price > 145 15 11.780 No ( 0.86667 0.13333 ) *
## 3) ShelveLoc: Good 85 72.720 Yes ( 0.15294 0.84706 )
## 6) Price < 132.5 64 17.800 Yes ( 0.03125 0.96875 ) *
## 7) Price > 132.5 21 29.060 No ( 0.52381 0.47619 )
## 14) Income < 46 7 0.000 No ( 1.00000 0.00000 ) *
## 15) Income > 46 14 16.750 Yes ( 0.28571 0.71429 )
## 30) Advertising < 3 8 11.090 Yes ( 0.50000 0.50000 ) *
## 31) Advertising > 3 6 0.000 Yes ( 0.00000 1.00000 ) *
summary(tree.carseats)
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Advertising" "CompPrice" "Age"
## [6] "Population" "Income"
## Number of terminal nodes: 23
## Residual mean deviance: 0.5117 = 192.9 / 377
## Misclassification error rate: 0.1125 = 45 / 400
plot(tree.carseats)
text(tree.carseats, pretty = 0, cex = 0.7)
Here the most important variable are
Shelveloc and Price because they were selected by the tree in the first two splits.
set.seed(2)
train = sample(1:nrow(Carseats), 200)
Carseats.test = Carseats[-train, ]
High.test = High[-train]
tree.carseats = tree(High ~ . -Sales, Carseats, subset = train)
tree.pred = predict(tree.carseats, Carseats.test, type = 'class')
table(tree.pred, High.test)
## High.test
## tree.pred No Yes
## No 81 27
## Yes 20 72
From the table, the misclassification rate is 23.5%
Next, we consider whether prunning the tree migh lead to improved results. The function cv.tree() performs cross-validation in order to determine the optimal level of tree complexity. We use the argument FUN = prune.misclass in order to indicate that we want the classification error rate to guide the cross-validation and pruning process, rather than the default deviance for the cv.tree() function
set.seed(3)
cv.carseats = cv.tree(tree.carseats, FUN = prune.misclass)
Note: k = alpha (pruning), dev = cross-validation error rate, size = size of tree.
names(cv.carseats)
## [1] "size" "dev" "k" "method"
In the result, k = alpha, dev = cross-validation error rate, size = size of tree
par(mfrow = c(1,2))
plot(cv.carseats$size, cv.carseats$dev, type = 'b')
plot(cv.carseats$k, cv.carseats$dev, type = 'b')
Note that when k is small, size is large, and vice versa.
cv.carseats$size[which.min(cv.carseats$dev)]
## [1] 14
The best size is 14, so we prune the tree to obtain the 14-node tree.
prune.carseats = prune.misclass(tree.carseats, best = 14)
plot(prune.carseats)
text(prune.carseats, pretty = 0)
tree.pred = predict(prune.carseats, Carseats.test, type = 'class')
table(tree.pred, High.test)
## High.test
## tree.pred No Yes
## No 84 28
## Yes 17 71
The missclassification rate is (17+28)/(84+71+17+28) = 22.5%
prune.carseats = prune.misclass(tree.carseats, best = 15)
tree.pred = predict(prune.carseats, Carseats.test, type = 'class')
table(tree.pred, High.test)
## High.test
## tree.pred No Yes
## No 88 32
## Yes 13 67
You can see the accuracy decreases on the testing data, as we use a larger size tree (alpha = k smaller)
library(MASS)
data(Boston)
str(Boston)
## 'data.frame': 506 obs. of 14 variables:
## $ crim : num 0.00632 0.02731 0.02729 0.03237 0.06905 ...
## $ zn : num 18 0 0 0 0 0 12.5 12.5 12.5 12.5 ...
## $ indus : num 2.31 7.07 7.07 2.18 2.18 2.18 7.87 7.87 7.87 7.87 ...
## $ chas : int 0 0 0 0 0 0 0 0 0 0 ...
## $ nox : num 0.538 0.469 0.469 0.458 0.458 0.458 0.524 0.524 0.524 0.524 ...
## $ rm : num 6.58 6.42 7.18 7 7.15 ...
## $ age : num 65.2 78.9 61.1 45.8 54.2 58.7 66.6 96.1 100 85.9 ...
## $ dis : num 4.09 4.97 4.97 6.06 6.06 ...
## $ rad : int 1 2 2 3 3 3 5 5 5 5 ...
## $ tax : num 296 242 242 222 222 222 311 311 311 311 ...
## $ ptratio: num 15.3 17.8 17.8 18.7 18.7 18.7 15.2 15.2 15.2 15.2 ...
## $ black : num 397 397 393 395 397 ...
## $ lstat : num 4.98 9.14 4.03 2.94 5.33 ...
## $ medv : num 24 21.6 34.7 33.4 36.2 28.7 22.9 27.1 16.5 18.9 ...
set.seed(1)
train = sample(1:nrow(Boston), nrow(Boston)/2)
tree.Boston = tree(medv ~ ., Boston, subset = train)
summary(tree.Boston)
##
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "rm" "lstat" "crim" "age"
## Number of terminal nodes: 7
## Residual mean deviance: 10.38 = 2555 / 246
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -10.1800 -1.7770 -0.1775 0.0000 1.9230 16.5800
plot(tree.Boston)
text(tree.Boston, pretty = 0)
cv.boston = cv.tree(tree.Boston)
cv.boston
## $size
## [1] 7 6 5 4 3 2 1
##
## $dev
## [1] 4380.849 4544.815 5601.055 6171.917 6919.608 10419.472 19630.870
##
## $k
## [1] -Inf 203.9641 637.2707 796.1207 1106.4931 3424.7810 10724.5951
##
## $method
## [1] "deviance"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
which.min(cv.boston$size)
## [1] 7
plot(cv.boston$size, cv.boston$dev, type = 'b')
Note: Best size = 8 (i.e., no pruning)
If needed, pruning can be performed by specifying the “best” argument
prune.boston = prune.tree(tree.Boston, best = 5)
plot(prune.boston)
text(prune.boston, pretty = 0)
yhat = predict(tree.Boston, newdata = Boston[-train, ])
boston.test = Boston[-train, 'medv']
plot(yhat, boston.test)
abline(0,1, col = 'red')
mean((yhat - boston.test)^2)
## [1] 35.28688