Visit my website for more like this!

Data Sources:

Heavily borrowed from:

require(knitr)
## Loading required package: knitr

Overview

This notebook is about tree-based methods for regression and classification. They involve stratifying or segmenting the predictor space into a number of more simple regions. To make a prediction for any given observation, we typically use the mean of the mode of the training observations in these regions to which is belongs. These splitting rules used to segment the predictor space can be summarized in a tree, hence they are usually called decision tree methods.

Tree methods are simple and useful for interpretation, however, they typically are not competitive with the best supervised learning methods in terms of prediction accuracy. Hence, in subsequent notebooks we also introduce bagging, random forests, and boosting. Each of these examples involve producing multiples trees, which are then combined to yield a single consensus prediction. We see that combining a large number of trees can result in dramatic improvements in prediction accuracy at the expensive of a loss in interpretation.

Decision trees can be applied to both regression and classification problems. We will first consider regression.

Decision Tree Basics: Regression

We begin with a simple example:

We use the Hitters data from the ISLR library to predict a baseball player’s Salary based on the number of Years he has played in the major league, and the number of Hits he made in the previous year.

The result would be a series of splitting rules. The first split would segment the data into Years < 4.5 on the left branch, and the remainder to the right. The predicted salary for these players is given by the mean response value from the players in either branch. Players with Years >= 4.5 are assigned to the right branch, and then further subdivided by Hits. Players with Years >= 4.5 and Hits < 118 fall into the third region, and players with Years >= 4.5 and Hits >= 118 fall in the fourth region, each with their own predicted probabilities. The end points of the trees are called nodes, or leaves. We might interpret such a response as Years are the most important factor in determining Salary, and players with less experience have lower salaries. If the player is less experienced, the number of Hits last year plays a roll in his Salary. If we code this model, we see that the relationship ends up being slightly more complicated.

library(tree)
library(ISLR)
attach(Hitters)
# Remove NA data
Hitters<- na.omit(Hitters)
# log transform Salary to make it a bit more normally distributed
hist(Hitters$Salary)

plot of chunk unnamed-chunk-1

Hitters$Salary <- log(Hitters$Salary)
hist(Hitters$Salary)

plot of chunk unnamed-chunk-1

tree.fit <- tree(Salary~Hits+Years, data=Hitters)
summary(tree.fit)
## 
## Regression tree:
## tree(formula = Salary ~ Hits + Years, data = Hitters)
## Number of terminal nodes:  8 
## Residual mean deviance:  0.271 = 69.1 / 255 
## Distribution of residuals:
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
## -2.2400 -0.2980 -0.0365  0.0000  0.3230  2.1500
plot(tree.fit)

plot of chunk unnamed-chunk-1

Now we discuss prediction via stratification of feature space, to build a regression tree. In general, there are two steps.

  1. Find the variable / split that best separates the response variable, which yields the lowest RSS.

  2. Divide the data into two leaves on the first identified node.

  3. Within each leaf, find the best variable/split that separates the outcomes.

  4. Continue until the groups are too small or sufficiently ‘pure’.

The goal being to find the number of regions the minimize RSS. However, it computationally unfeasible to consider every possible partition into J regions. For this reason we take a top-down, greedy approach. It is top-down because we start at a point where all the observation belongs to a single region. It is greedy because at each step of the tree-building process, the best split is chosen at that particular step, rather than looking ahead to see a split that will lead to a better tree in some future step.

Once all the regions have been created, we predict the response for a given test observation using the mean of the training observations in each region.

Tree Pruning

While the model above can produce good prediction on training data, basic tree methods are likely to over fit the data, leading to poor test performance. This is because the resulting trees tend to be too complex. A smaller tree with fewer splits often leads to lower variance, easier interpretation and lower test errors, at the cost of a little bias. One possible way to achieve this is to build a tree only so long as the decrease in RSS due to each split exceeds some (high) threshold. While this will certainly reduce tree size, it is too short sighted. This is because a seemingly worthless split early on in a tree can be followed by a very good split later.

Therefore, a better strategy is to grow a large tree, then prune it back to obtain a better sub tree. Intuitively, our goal is to select a sub tree that leads to the lowest test error rate. To do this, we would normally use cross validation. However it is too cumbersome since there is an extremely large number of possible sub trees.

Cost complexity pruning - also know as weakest link pruning gives us a way to remedy this problem. Rather than considering every possible sub tree, we consider a sequence of trees indexed by a non negative tuning parameter alpha.

__Revised steps to building a regression tree_

  1. Use recursive binary splitting to grow a large tree based on training data, stopping only when each terminal node has fewer than some minimum number of observations.

  2. Apply cost complexity pruning to the large tree in order to obtain a sequence of best sub trees as a function of alpha.

  3. Use k-fold cross validation to choose alpha.

  4. Return the sub tree from step 2 that corresponds to the chosen value of alpha.

library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
split <- createDataPartition(y=Hitters$Salary, p=0.5, list=FALSE)

train <- Hitters[split,]
test <- Hitters[-split,]

#Create tree model
trees <- tree(Salary~., train)
plot(trees)
text(trees, pretty=0)

plot of chunk unnamed-chunk-2

#Cross validate to see whether pruning the tree will improve performance
cv.trees <- cv.tree(trees)
plot(cv.trees)

plot of chunk unnamed-chunk-3

It seems like the 7th sized trees result in the lowest deviance. We can then prune the tree. However, this doesn’t really prune the model, therefore we can select a smaller size where the improvement in deviance plateaus. This would be around the 4rd split.

prune.trees <- prune.tree(trees, best=4)
plot(prune.trees)
text(prune.trees, pretty=0)

plot of chunk unnamed-chunk-4

Use the pruned tree to make predictions on the test set.

yhat <- predict(prune.trees, test)
plot(yhat, test$Salary)
abline(0,1)

plot of chunk unnamed-chunk-5

mean((yhat - test$Salary)^2)
## [1] 0.3531

Classification Trees

Classification trees are very similar to regression trees, except that it is used to predict a qualitative response rather than a quantitative one. For a regression tree, the predicted response for an observation is given by the mean response of the training observations in that branch. In contract, for classification trees, we predict that each observation belongs to the most commonly occuring class of training observation in the region in belongs. When interpreting the results of a classification tree, we are often interested in not only the predictions for each node, but also the class proportions in the region.

To grow a classification tree, we use the same recursive binary splitting, but now RSS cannot be used as a splitting criterion. The alternative is to use the classification error rate. While it is intuitive, it turns out that this method is not sensitive enough for tree-growing.

In practise two other methods are preferable, though they are quite similar numerically:

__Gini index_ is a measure of the total variance across K classes.

__Cross-entropy_ will take on a value near zero if the proportion of training observations in the given category are all near zero or one.

These two methods are preferred when pruning the tree, but the regular classification error rate is preferable if the prediction accuracy of the final pruned model is the goal.

To demonstrate this we will use the Heart dataset. These data contain a binary outcome variable AHD for 303 patients who presented with chest pain. The outcomes are coded as Yes or No for presence of heart disease.

Heart <-read.csv('http://www-bcf.usc.edu/~gareth/ISL/Heart.csv')
kable(head(Heart))
X Age Sex ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope Ca Thal AHD
1 63 1 typical 145 233 1 2 150 0 2.3 3 0 fixed No
2 67 1 asymptomatic 160 286 0 2 108 1 1.5 2 3 normal Yes
3 67 1 asymptomatic 120 229 0 2 129 1 2.6 2 2 reversable Yes
4 37 1 nonanginal 130 250 0 0 187 0 3.5 3 0 normal No
5 41 0 nontypical 130 204 0 2 172 0 1.4 1 0 normal No
6 56 1 nontypical 120 236 0 0 178 0 0.8 1 0 normal No
dim(Heart)

[1] 303 15

split <- createDataPartition(y=Heart$AHD, p = 0.5, list=FALSE)
train <- Heart[split,]
test <- Heart[-split,]

trees <- tree(AHD ~., train)
plot(trees)

plot of chunk unnamed-chunk-7

So far this is a pretty complex tree. Let’s identify if we can improved the fit with a pruned version via cross validation using a miss classification scoring method.

cv.trees <- cv.tree(trees, FUN=prune.misclass)
plot(cv.trees)

plot of chunk unnamed-chunk-8

cv.trees
## $size
## [1] 16  9  5  3  2  1
## 
## $dev
## [1] 44 45 42 41 41 81
## 
## $k
## [1] -Inf  0.0  1.0  2.5  5.0 37.0
## 
## $method
## [1] "misclass"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

It looks like a 4 split tree has the lowest deviance. Let’s see what this tree looks like. Again we use prune.misclass for classification settings.

prune.trees <- prune.misclass(trees, best=4)
plot(prune.trees)
text(prune.trees, pretty=0)

plot of chunk unnamed-chunk-9

tree.pred <- predict(prune.trees, test, type='class')
confusionMatrix(tree.pred, test$AHD)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction No Yes
##        No  72  24
##        Yes 10  45
##                                       
##                Accuracy : 0.775       
##                  95% CI : (0.7, 0.839)
##     No Information Rate : 0.543       
##     P-Value [Acc > NIR] : 2.86e-09    
##                                       
##                   Kappa : 0.539       
##  Mcnemar's Test P-Value : 0.0258      
##                                       
##             Sensitivity : 0.878       
##             Specificity : 0.652       
##          Pos Pred Value : 0.750       
##          Neg Pred Value : 0.818       
##              Prevalence : 0.543       
##          Detection Rate : 0.477       
##    Detection Prevalence : 0.636       
##       Balanced Accuracy : 0.765       
##                                       
##        'Positive' Class : No          
## 

Here we obtain about 76% accuracy. Pretty sweet.

Something to note here. In the unpruned tree there is actually a split that yields the same predicted value (yes and yes). So why is the split done at all? The split leads to increased node purity, which will likely result in better predictions when using test data.

Trees vs. Linear Models

The best models always depend on the problem at hand. If the relationship can be approximated by a linear model, then linear regression will likely dominate. If instead we have a complex, highly non-linear relationship between the features and the response, a decision tree may outperform the classical approaches. Also, sometimes improved interpretation can be chosen above simply test error rate.

Pros / Cons of Trees

Advantages:

Disadvantages:

Extra Examples

library(rpart.plot) # use prp() to make cleaner plot with caret
## Loading required package: rpart
library(ISLR)
attach(Carseats)
kable(head(Carseats))
Sales CompPrice Income Advertising Population Price ShelveLoc Age Education Urban US
9.50 138 73 11 276 120 Bad 42 17 Yes Yes
11.22 111 48 16 260 83 Good 65 10 Yes Yes
10.06 113 35 10 269 80 Medium 59 12 Yes Yes
7.40 117 100 4 466 97 Medium 55 14 Yes Yes
4.15 141 64 3 340 128 Bad 38 13 Yes No
10.81 124 113 13 501 72 Bad 78 16 No Yes
# Change Sales to a qualitative variable by splitting it on the median.
Carseats$Sales <- ifelse(Sales <= median(Sales), 'Low', 'High')
Carseats$Sales <- factor(Carseats$Sales)

Carseats<-na.omit(Carseats)
#Split data into train / validation
set.seed(111)
split <- createDataPartition(y=Carseats$Sales, p=0.6, list=FALSE)
train <- Carseats[split,]
test <- Carseats[-split,]

sales.tree <- tree(Sales ~., data=train)
summary(sales.tree)

Classification tree: tree(formula = Sales ~ ., data = train) Variables actually used in tree construction: [1] “Price” “CompPrice” “Age” “Income” “ShelveLoc”
[6] “Advertising” Number of terminal nodes: 19 Residual mean deviance: 0.414 = 92 / 222 Misclassification error rate: 0.0996 = 24 / 241

Here we see the training error is about 9%. We use plot() to display the tree structure and text() to display the node labels.

plot(sales.tree)
text(sales.tree, pretty=0)

plot of chunk unnamed-chunk-11

Let’s see how the full tree handles the test data.

sales.pred <- predict(sales.tree, test, type='class')
confusionMatrix(sales.pred, test$Sales)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction High Low
##       High   56  12
##       Low    23  68
##                                         
##                Accuracy : 0.78          
##                  95% CI : (0.707, 0.842)
##     No Information Rate : 0.503         
##     P-Value [Acc > NIR] : 6.28e-13      
##                                         
##                   Kappa : 0.559         
##  Mcnemar's Test P-Value : 0.091         
##                                         
##             Sensitivity : 0.709         
##             Specificity : 0.850         
##          Pos Pred Value : 0.824         
##          Neg Pred Value : 0.747         
##              Prevalence : 0.497         
##          Detection Rate : 0.352         
##    Detection Prevalence : 0.428         
##       Balanced Accuracy : 0.779         
##                                         
##        'Positive' Class : High          
## 

A test error rate of ~74% is pretty good! But we could potentially improve it with cross validation.

set.seed(12)
cv.sales.tree <- cv.tree(sales.tree, FUN=prune.misclass)
plot(cv.sales.tree)

plot of chunk unnamed-chunk-13

Here we see that the the lowest / simplest misclassification error is for a 4 leaf model. We can now prune the tree to a 4 leaf model.

prune.sales.tree <- prune.misclass(sales.tree, best=4)
prune.pred <- predict(prune.sales.tree, test, type='class')

plot(prune.sales.tree)
text(prune.sales.tree, pretty=0)

plot of chunk unnamed-chunk-14

confusionMatrix(prune.pred, test$Sales)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction High Low
##       High   52  20
##       Low    27  60
##                                         
##                Accuracy : 0.704         
##                  95% CI : (0.627, 0.774)
##     No Information Rate : 0.503         
##     P-Value [Acc > NIR] : 2.02e-07      
##                                         
##                   Kappa : 0.408         
##  Mcnemar's Test P-Value : 0.381         
##                                         
##             Sensitivity : 0.658         
##             Specificity : 0.750         
##          Pos Pred Value : 0.722         
##          Neg Pred Value : 0.690         
##              Prevalence : 0.497         
##          Detection Rate : 0.327         
##    Detection Prevalence : 0.453         
##       Balanced Accuracy : 0.704         
##                                         
##        'Positive' Class : High          
## 

This doesnt really improve our classification, but we greatly simplified our model.

# Specify cross validation tuning
fitControl <- trainControl(method = "cv",
                            number = 10,
                           classProbs=TRUE, 
                           summaryFunction=twoClassSummary)

set.seed(123123)
sales.tree <- train(Sales ~., train,
                   trControl=fitControl,
                    metric='ROC',
                    method='rpart')
## Loading required package: pROC
## Type 'citation("pROC")' for a citation.
## 
## Attaching package: 'pROC'
## 
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
sales.tree
## CART 
## 
## 241 samples
##  10 predictors
##   2 classes: 'High', 'Low' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## 
## Summary of sample sizes: 217, 217, 216, 217, 217, 217, ... 
## 
## Resampling results across tuning parameters:
## 
##   cp    ROC  Sens  Spec  ROC SD  Sens SD  Spec SD
##   0.06  0.7  0.7   0.7   0.1     0.2      0.1    
##   0.1   0.6  0.7   0.6   0.2     0.2      0.2    
##   0.4   0.5  0.3   0.8   0.09    0.3      0.3    
## 
## ROC was used to select the optimal model using  the largest value.
## The final value used for the model was cp = 0.06.
prp(sales.tree$finalModel, extra=2)

plot of chunk unnamed-chunk-15

sales.tree.pred <- predict(sales.tree, test)
confusionMatrix(sales.tree.pred, test$Sales)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction High Low
##       High   56  21
##       Low    23  59
##                                         
##                Accuracy : 0.723         
##                  95% CI : (0.647, 0.791)
##     No Information Rate : 0.503         
##     P-Value [Acc > NIR] : 1.3e-08       
##                                         
##                   Kappa : 0.446         
##  Mcnemar's Test P-Value : 0.88          
##                                         
##             Sensitivity : 0.709         
##             Specificity : 0.738         
##          Pos Pred Value : 0.727         
##          Neg Pred Value : 0.720         
##              Prevalence : 0.497         
##          Detection Rate : 0.352         
##    Detection Prevalence : 0.484         
##       Balanced Accuracy : 0.723         
##                                         
##        'Positive' Class : High          
## 

Caret opts for an even simpler tree, with a small reduction in prediction accuracy.