Visit my website for more like this!
Heavily borrowed from:
Textbook: Introduction to statistical learning
Textbook: Elements of statistical learning
UCLA Example link
require(knitr)
## Loading required package: knitr
This notebook is about tree-based methods for regression and classification. They involve stratifying or segmenting the predictor space into a number of more simple regions. To make a prediction for any given observation, we typically use the mean of the mode of the training observations in these regions to which is belongs. These splitting rules used to segment the predictor space can be summarized in a tree, hence they are usually called decision tree methods.
Tree methods are simple and useful for interpretation, however, they typically are not competitive with the best supervised learning methods in terms of prediction accuracy. Hence, in subsequent notebooks we also introduce bagging, random forests, and boosting. Each of these examples involve producing multiples trees, which are then combined to yield a single consensus prediction. We see that combining a large number of trees can result in dramatic improvements in prediction accuracy at the expensive of a loss in interpretation.
Decision trees can be applied to both regression and classification problems. We will first consider regression.
We begin with a simple example:
We use the Hitters
data from the ISLR
library to predict a baseball player’s Salary
based on the number of Years
he has played in the major league, and the number of Hits
he made in the previous year.
The result would be a series of splitting rules. The first split would segment the data into Years < 4.5
on the left branch, and the remainder to the right. The predicted salary for these players is given by the mean response value from the players in either branch. Players with Years >= 4.5
are assigned to the right branch, and then further subdivided by Hits
. Players with Years >= 4.5
and Hits < 118
fall into the third region, and players with Years >= 4.5
and Hits >= 118
fall in the fourth region, each with their own predicted probabilities. The end points of the trees are called nodes, or leaves. We might interpret such a response as Years
are the most important factor in determining Salary
, and players with less experience have lower salaries. If the player is less experienced, the number of Hits
last year plays a roll in his Salary
. If we code this model, we see that the relationship ends up being slightly more complicated.
library(tree)
library(ISLR)
attach(Hitters)
# Remove NA data
Hitters<- na.omit(Hitters)
# log transform Salary to make it a bit more normally distributed
hist(Hitters$Salary)
Hitters$Salary <- log(Hitters$Salary)
hist(Hitters$Salary)
tree.fit <- tree(Salary~Hits+Years, data=Hitters)
summary(tree.fit)
##
## Regression tree:
## tree(formula = Salary ~ Hits + Years, data = Hitters)
## Number of terminal nodes: 8
## Residual mean deviance: 0.271 = 69.1 / 255
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -2.2400 -0.2980 -0.0365 0.0000 0.3230 2.1500
plot(tree.fit)
Now we discuss prediction via stratification of feature space, to build a regression tree. In general, there are two steps.
Find the variable / split that best separates the response variable, which yields the lowest RSS.
Divide the data into two leaves on the first identified node.
Within each leaf, find the best variable/split that separates the outcomes.
Continue until the groups are too small or sufficiently ‘pure’.
The goal being to find the number of regions the minimize RSS. However, it computationally unfeasible to consider every possible partition into J regions. For this reason we take a top-down, greedy approach. It is top-down because we start at a point where all the observation belongs to a single region. It is greedy because at each step of the tree-building process, the best split is chosen at that particular step, rather than looking ahead to see a split that will lead to a better tree in some future step.
Once all the regions have been created, we predict the response for a given test observation using the mean of the training observations in each region.
While the model above can produce good prediction on training data, basic tree methods are likely to over fit the data, leading to poor test performance. This is because the resulting trees tend to be too complex. A smaller tree with fewer splits often leads to lower variance, easier interpretation and lower test errors, at the cost of a little bias. One possible way to achieve this is to build a tree only so long as the decrease in RSS due to each split exceeds some (high) threshold. While this will certainly reduce tree size, it is too short sighted. This is because a seemingly worthless split early on in a tree can be followed by a very good split later.
Therefore, a better strategy is to grow a large tree, then prune it back to obtain a better sub tree. Intuitively, our goal is to select a sub tree that leads to the lowest test error rate. To do this, we would normally use cross validation. However it is too cumbersome since there is an extremely large number of possible sub trees.
Cost complexity pruning - also know as weakest link pruning gives us a way to remedy this problem. Rather than considering every possible sub tree, we consider a sequence of trees indexed by a non negative tuning parameter alpha
.
__Revised steps to building a regression tree_
Use recursive binary splitting to grow a large tree based on training data, stopping only when each terminal node has fewer than some minimum number of observations.
Apply cost complexity pruning to the large tree in order to obtain a sequence of best sub trees as a function of alpha
.
Use k-fold cross validation to choose alpha
.
Return the sub tree from step 2 that corresponds to the chosen value of alpha
.
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
split <- createDataPartition(y=Hitters$Salary, p=0.5, list=FALSE)
train <- Hitters[split,]
test <- Hitters[-split,]
#Create tree model
trees <- tree(Salary~., train)
plot(trees)
text(trees, pretty=0)
#Cross validate to see whether pruning the tree will improve performance
cv.trees <- cv.tree(trees)
plot(cv.trees)
It seems like the 7th sized trees result in the lowest deviance. We can then prune the tree. However, this doesn’t really prune the model, therefore we can select a smaller size where the improvement in deviance plateaus. This would be around the 4rd split.
prune.trees <- prune.tree(trees, best=4)
plot(prune.trees)
text(prune.trees, pretty=0)
Use the pruned tree to make predictions on the test set.
yhat <- predict(prune.trees, test)
plot(yhat, test$Salary)
abline(0,1)
mean((yhat - test$Salary)^2)
## [1] 0.3531
Classification trees are very similar to regression trees, except that it is used to predict a qualitative response rather than a quantitative one. For a regression tree, the predicted response for an observation is given by the mean response of the training observations in that branch. In contract, for classification trees, we predict that each observation belongs to the most commonly occuring class of training observation in the region in belongs. When interpreting the results of a classification tree, we are often interested in not only the predictions for each node, but also the class proportions in the region.
To grow a classification tree, we use the same recursive binary splitting, but now RSS cannot be used as a splitting criterion. The alternative is to use the classification error rate. While it is intuitive, it turns out that this method is not sensitive enough for tree-growing.
In practise two other methods are preferable, though they are quite similar numerically:
__Gini index_ is a measure of the total variance across K classes.
__Cross-entropy_ will take on a value near zero if the proportion of training observations in the given category are all near zero or one.
These two methods are preferred when pruning the tree, but the regular classification error rate is preferable if the prediction accuracy of the final pruned model is the goal.
To demonstrate this we will use the Heart
dataset. These data contain a binary outcome variable AHD
for 303 patients who presented with chest pain. The outcomes are coded as Yes
or No
for presence of heart disease.
Heart <-read.csv('http://www-bcf.usc.edu/~gareth/ISL/Heart.csv')
kable(head(Heart))
X | Age | Sex | ChestPain | RestBP | Chol | Fbs | RestECG | MaxHR | ExAng | Oldpeak | Slope | Ca | Thal | AHD |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 63 | 1 | typical | 145 | 233 | 1 | 2 | 150 | 0 | 2.3 | 3 | 0 | fixed | No |
2 | 67 | 1 | asymptomatic | 160 | 286 | 0 | 2 | 108 | 1 | 1.5 | 2 | 3 | normal | Yes |
3 | 67 | 1 | asymptomatic | 120 | 229 | 0 | 2 | 129 | 1 | 2.6 | 2 | 2 | reversable | Yes |
4 | 37 | 1 | nonanginal | 130 | 250 | 0 | 0 | 187 | 0 | 3.5 | 3 | 0 | normal | No |
5 | 41 | 0 | nontypical | 130 | 204 | 0 | 2 | 172 | 0 | 1.4 | 1 | 0 | normal | No |
6 | 56 | 1 | nontypical | 120 | 236 | 0 | 0 | 178 | 0 | 0.8 | 1 | 0 | normal | No |
dim(Heart)
[1] 303 15
split <- createDataPartition(y=Heart$AHD, p = 0.5, list=FALSE)
train <- Heart[split,]
test <- Heart[-split,]
trees <- tree(AHD ~., train)
plot(trees)
So far this is a pretty complex tree. Let’s identify if we can improved the fit with a pruned version via cross validation using a miss classification scoring method.
cv.trees <- cv.tree(trees, FUN=prune.misclass)
plot(cv.trees)
cv.trees
## $size
## [1] 16 9 5 3 2 1
##
## $dev
## [1] 44 45 42 41 41 81
##
## $k
## [1] -Inf 0.0 1.0 2.5 5.0 37.0
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
It looks like a 4 split tree has the lowest deviance. Let’s see what this tree looks like. Again we use prune.misclass
for classification settings.
prune.trees <- prune.misclass(trees, best=4)
plot(prune.trees)
text(prune.trees, pretty=0)
tree.pred <- predict(prune.trees, test, type='class')
confusionMatrix(tree.pred, test$AHD)
## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 72 24
## Yes 10 45
##
## Accuracy : 0.775
## 95% CI : (0.7, 0.839)
## No Information Rate : 0.543
## P-Value [Acc > NIR] : 2.86e-09
##
## Kappa : 0.539
## Mcnemar's Test P-Value : 0.0258
##
## Sensitivity : 0.878
## Specificity : 0.652
## Pos Pred Value : 0.750
## Neg Pred Value : 0.818
## Prevalence : 0.543
## Detection Rate : 0.477
## Detection Prevalence : 0.636
## Balanced Accuracy : 0.765
##
## 'Positive' Class : No
##
Here we obtain about 76% accuracy. Pretty sweet.
Something to note here. In the unpruned tree there is actually a split that yields the same predicted value (yes and yes). So why is the split done at all? The split leads to increased node purity, which will likely result in better predictions when using test data.
The best models always depend on the problem at hand. If the relationship can be approximated by a linear model, then linear regression will likely dominate. If instead we have a complex, highly non-linear relationship between the features and the response, a decision tree may outperform the classical approaches. Also, sometimes improved interpretation can be chosen above simply test error rate.
Advantages:
Trees are easy to explain, even more so than linear regression.
More closely mirrors human decision making.
Easily displayed graphically.
Can handle qualitative predictors without dummy variables.
Disadvantages:
library(rpart.plot) # use prp() to make cleaner plot with caret
## Loading required package: rpart
library(ISLR)
attach(Carseats)
kable(head(Carseats))
Sales | CompPrice | Income | Advertising | Population | Price | ShelveLoc | Age | Education | Urban | US |
---|---|---|---|---|---|---|---|---|---|---|
9.50 | 138 | 73 | 11 | 276 | 120 | Bad | 42 | 17 | Yes | Yes |
11.22 | 111 | 48 | 16 | 260 | 83 | Good | 65 | 10 | Yes | Yes |
10.06 | 113 | 35 | 10 | 269 | 80 | Medium | 59 | 12 | Yes | Yes |
7.40 | 117 | 100 | 4 | 466 | 97 | Medium | 55 | 14 | Yes | Yes |
4.15 | 141 | 64 | 3 | 340 | 128 | Bad | 38 | 13 | Yes | No |
10.81 | 124 | 113 | 13 | 501 | 72 | Bad | 78 | 16 | No | Yes |
# Change Sales to a qualitative variable by splitting it on the median.
Carseats$Sales <- ifelse(Sales <= median(Sales), 'Low', 'High')
Carseats$Sales <- factor(Carseats$Sales)
Carseats<-na.omit(Carseats)
#Split data into train / validation
set.seed(111)
split <- createDataPartition(y=Carseats$Sales, p=0.6, list=FALSE)
train <- Carseats[split,]
test <- Carseats[-split,]
sales.tree <- tree(Sales ~., data=train)
summary(sales.tree)
Classification tree: tree(formula = Sales ~ ., data = train) Variables actually used in tree construction: [1] “Price” “CompPrice” “Age” “Income” “ShelveLoc”
[6] “Advertising” Number of terminal nodes: 19 Residual mean deviance: 0.414 = 92 / 222 Misclassification error rate: 0.0996 = 24 / 241
Here we see the training error is about 9%. We use plot()
to display the tree structure and text()
to display the node labels.
plot(sales.tree)
text(sales.tree, pretty=0)
Let’s see how the full tree handles the test data.
sales.pred <- predict(sales.tree, test, type='class')
confusionMatrix(sales.pred, test$Sales)
## Confusion Matrix and Statistics
##
## Reference
## Prediction High Low
## High 56 12
## Low 23 68
##
## Accuracy : 0.78
## 95% CI : (0.707, 0.842)
## No Information Rate : 0.503
## P-Value [Acc > NIR] : 6.28e-13
##
## Kappa : 0.559
## Mcnemar's Test P-Value : 0.091
##
## Sensitivity : 0.709
## Specificity : 0.850
## Pos Pred Value : 0.824
## Neg Pred Value : 0.747
## Prevalence : 0.497
## Detection Rate : 0.352
## Detection Prevalence : 0.428
## Balanced Accuracy : 0.779
##
## 'Positive' Class : High
##
A test error rate of ~74% is pretty good! But we could potentially improve it with cross validation.
set.seed(12)
cv.sales.tree <- cv.tree(sales.tree, FUN=prune.misclass)
plot(cv.sales.tree)
Here we see that the the lowest / simplest misclassification error is for a 4 leaf model. We can now prune the tree to a 4 leaf model.
prune.sales.tree <- prune.misclass(sales.tree, best=4)
prune.pred <- predict(prune.sales.tree, test, type='class')
plot(prune.sales.tree)
text(prune.sales.tree, pretty=0)
confusionMatrix(prune.pred, test$Sales)
## Confusion Matrix and Statistics
##
## Reference
## Prediction High Low
## High 52 20
## Low 27 60
##
## Accuracy : 0.704
## 95% CI : (0.627, 0.774)
## No Information Rate : 0.503
## P-Value [Acc > NIR] : 2.02e-07
##
## Kappa : 0.408
## Mcnemar's Test P-Value : 0.381
##
## Sensitivity : 0.658
## Specificity : 0.750
## Pos Pred Value : 0.722
## Neg Pred Value : 0.690
## Prevalence : 0.497
## Detection Rate : 0.327
## Detection Prevalence : 0.453
## Balanced Accuracy : 0.704
##
## 'Positive' Class : High
##
This doesnt really improve our classification, but we greatly simplified our model.
# Specify cross validation tuning
fitControl <- trainControl(method = "cv",
number = 10,
classProbs=TRUE,
summaryFunction=twoClassSummary)
set.seed(123123)
sales.tree <- train(Sales ~., train,
trControl=fitControl,
metric='ROC',
method='rpart')
## Loading required package: pROC
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
##
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
sales.tree
## CART
##
## 241 samples
## 10 predictors
## 2 classes: 'High', 'Low'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
##
## Summary of sample sizes: 217, 217, 216, 217, 217, 217, ...
##
## Resampling results across tuning parameters:
##
## cp ROC Sens Spec ROC SD Sens SD Spec SD
## 0.06 0.7 0.7 0.7 0.1 0.2 0.1
## 0.1 0.6 0.7 0.6 0.2 0.2 0.2
## 0.4 0.5 0.3 0.8 0.09 0.3 0.3
##
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.06.
prp(sales.tree$finalModel, extra=2)
sales.tree.pred <- predict(sales.tree, test)
confusionMatrix(sales.tree.pred, test$Sales)
## Confusion Matrix and Statistics
##
## Reference
## Prediction High Low
## High 56 21
## Low 23 59
##
## Accuracy : 0.723
## 95% CI : (0.647, 0.791)
## No Information Rate : 0.503
## P-Value [Acc > NIR] : 1.3e-08
##
## Kappa : 0.446
## Mcnemar's Test P-Value : 0.88
##
## Sensitivity : 0.709
## Specificity : 0.738
## Pos Pred Value : 0.727
## Neg Pred Value : 0.720
## Prevalence : 0.497
## Detection Rate : 0.352
## Detection Prevalence : 0.484
## Balanced Accuracy : 0.723
##
## 'Positive' Class : High
##
Caret opts for an even simpler tree, with a small reduction in prediction accuracy.