Tree-based methods for classification and regression involve stratifying or segmenting the predictor space into a number of simple regions. These types of approaches are known as decision tree methods and can be applied to both regression and classification problems. When we are predicting a quantitive response (i.e., estimating a numeric value), we use a regression tree. When we are predicting a qualitative response (i.e., classifying an observation), we use a classification tree. Decision trees are non-parametric methods to partition the data into smaller, more “pure” or homogeneous groups called nodes. A simple way to define purity in classification is by maximizing accuracy or equivalently by minimizing misclassification error. That is, the nodes of the split will contain a larger proportion of one class in each node. This is accomplished by stating each node as an if/then rule. Cases satisfying the if/then statement are placed in the node. For each split, two determinations are made: the predictor variable used for the split, called the splitting variable, and the set of values for the predictor variable (which are split between the left node and the right node), called the split point. For the regression problem, we want to minimize the mean square error (MSE).
When would you use a CART model rather than a standard linear regression model (LM), a generalized linear model (GLM), or a generalized additive model (GAM)? The recursive structure of CART models is ideal for uncovering complex dependencies among predictor variables. If a response variable depends strongly on a predictor variable in a nonlinear fashion, then a CART model will be better at detecting this relationship than the use of interaction terms in LMs, GLMs, or GAMs. When there is good reason to suspect non-additive interactions among variables, or there are far too many variables to consider, try a decision tree model. As an example, we will use the Carseats data set found in the ISLR package. In Part I we will fit a classification tree model and, Part II we will fit a regression tree model.
Load packages and look at the data.
# For decision tree model
library(rpart)
# For data visualization
library(rpart.plot)
# Contains the data
library(ISLR)
# Get the list of data sets contained in package
d <- data(package = "ISLR")
d$results[, "Item"]
## [1] "Auto" "Caravan" "Carseats" "College" "Credit" "Default"
## [7] "Hitters" "Khan" "NCI60" "OJ" "Portfolio" "Smarket"
## [13] "Wage" "Weekly"
data(Carseats)
# Get the variable names
names(Carseats)
## [1] "Sales" "CompPrice" "Income" "Advertising" "Population"
## [6] "Price" "ShelveLoc" "Age" "Education" "Urban"
## [11] "US"
dim(Carseats)
## [1] 400 11
We will use a classification tree to analyze the carseats data set, a simulated data set containing sales of child car seats at 400 different stores. There are 400 observations and 11 variables in the data set. We are interested in predicting Sales based on the other variables in the data set. However, since Sales is a continuous variable, we need to recode it as a binary variable. This new variable, High, takes on a value of Yes if the Sales variable exceeds 8, and No otherwise. Because High is a binary variable, this is a classification problem and requires the use of a classification tree.
# Creates a new binary variable, High.
High = ifelse(Carseats$Sales <=8, "No", "Yes")
# Add High to the data set.
Carseats=data.frame(Carseats,High)
# Remove the Sales variable from the data.
Carseats.H <- Carseats[,-1]
# Code High as a factor variable
Carseats.H$High = as.factor(Carseats$High)
class(Carseats.H$High)
## [1] "factor"
A good classifier is one for which the test error rate is smallest. To properly evaluate the performance of a classification tree, we must estimate the test error rather than simply computing the training error. We would then compare the test error rate for the competing models. Thus, we split the data into a training set and a test set, build the tree using the training set and evaluate its performance on the test data. Later, we will use the predict () function for this purpose. In the case of a classification tree, the argument type=“class” instructs R to return the actual class prediction.
set.seed(234)
train = sample(1:nrow(Carseats.H), 200)
Carseats.train=Carseats.H[train,]
Carseats.test=Carseats.H[-train,]
High.test=High[-train]
First, we build a classification tree using the training set to predict High using all variables except Sales (remember that High was derived from Sales). The tree model (see below) shows the variables that are actually used to construct the tree. The algorithm has determined that the other variables did not contribute to the predictive power of the model. From the model we can see that the most important indicator of Sales appears to be shelving location, since the first branch differentiates Good locations from Bad and Medium locations. The next important indicator is Price, since the second branch differentiates a price greater than equal to $92.5 versus a price less than $92.5.
# cp or complexity parameter determines how deep the tree will grow. Here it is assigned a small value which will allow a decesion on further pruning. That is, we want a cp value (with a more parsimonious tree) that minimizes the xerror (cross-validation error).
fit.tree = rpart(High ~ ., data=Carseats.train, method = "class", cp=0.008)
fit.tree
## n= 200
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 200 90 No (0.55000000 0.45000000)
## 2) ShelveLoc=Bad,Medium 154 52 No (0.66233766 0.33766234)
## 4) Price>=92.5 132 35 No (0.73484848 0.26515152)
## 8) Advertising< 13.5 106 19 No (0.82075472 0.17924528) *
## 9) Advertising>=13.5 26 10 Yes (0.38461538 0.61538462)
## 18) Age>=44 15 6 No (0.60000000 0.40000000) *
## 19) Age< 44 11 1 Yes (0.09090909 0.90909091) *
## 5) Price< 92.5 22 5 Yes (0.22727273 0.77272727) *
## 3) ShelveLoc=Good 46 8 Yes (0.17391304 0.82608696)
## 6) Price>=135 7 2 No (0.71428571 0.28571429) *
## 7) Price< 135 39 3 Yes (0.07692308 0.92307692) *
Note: The method = “argument” can be modified and selection depends on the type of response variable. It is class for categorical/factor, anova for numerical, poisson for count, and exp for survival data. cp = “value” is the assigned a numeric value that will determine how deep you want your tree to grow. The smaller the value (closer to 0), the larger the tree. The default value is 0.01, which will render a very pruned tree.
Decision Tree terminology and output items:
Root Node: Represents the entire sample population and is further divided into two or more homogenous groups.
split: The criterion used to divide a node into two or more sub-nodes.
n: The number of observations in a node.
loss: This is the total number of rows that will be misclassified if the predicted class for the node is applied to all rows.
yval: The overall prediction for the branch (Yes or No). In general, this is the mean response value for that subset.
yprob: The fraction of obervations in that branch that take on values of Yes and No.
Pruning: The process of removing sub-nodes from a decision node. Pruning is a fix for the problem of overfitting.
Complexity: The complexity parameter is used to establish a control level that determines whether a split contributes to a better model fit. Any split that increases the model fit by a factor greater than the defined complexity factor is attempted. The best cp value has the lowest cross-validtion error (xerror).
Next, take a look at a plot of the tree. The tree selected contains 4 variables with 5 splits. If you look at the plot and at the node descriptions, you will notice that splits have occurred on the variables ShelveLoc, Price, Advertising,and Age. Using the output table (above) and the plot (below), let’s interpret the tree model. Nodes 2 and 3 were formed by splitting node 1, the root node, on the predictor variable ShelveLoc. The split point is ShelveLoc=Bad,Medium; that is, node 2 consists of all rows with the value of ShelveLoc=Bad,Medium and node 3 consist of all rows with ShelveLoc=Good. The predicted class for node 2 is No, where No indicates sales less than equal to $8k. The expected loss is 52 - this is the total number of rows that will be misclassified if the predicted class for the node is applied to all rows. Specifically, out of the total 154 cases, 52 (34%) will be misclassified and 102 (66%) will be classified correctly. The predicted class for node 3 is Yes, where Yes indicates sales greater than $8k. The expected loss is 8 - that is, 8 rows will be misclassified if the predicted class for the node is applied to all rows. Again, out of the total 46 cases, 8 (17%) will be missclassified and 38 (83%) will be classified correctly.
# Visualizing the unpruned tree
rpart.plot(fit.tree)
# Checking the order of variable importance
fit.tree$variable.importance
## ShelveLoc Price Advertising Age CompPrice Income
## 16.8994918 14.5382371 8.5411152 3.2895105 1.7804529 1.4952320
## Education Population
## 0.5980928 0.2990464
We then use the decision tree model that was built using training data to predict the response variable, High, using the test data set. Here we use the predict () function for this purpose. It is important to note that while this model may produce good predictions on the training set, it is likely to overfit the data, leading to poor test set performance. That is, the resulting tree may be too complex. However, a smaller tree with fewer splits might lead to lower variance and better interpretation at the cost of a little bias. As a general rule, as we use more flexible methods, the variance will increase and the bias will decrease. Pruning reduces model flexibility. Later we will consider whether pruning this tree improves performance as measured by the test set error rate. A lower test error rate means better performance.
pred.tree = predict(fit.tree, Carseats.test, type = "class")
To determine how well the decision-tree model performs on the test set, we run a cross-tabulation of the predicted versus actual values. From this, we see that the misclassification error rate is 24%, calculated as (32+16)/200. This means that the model leads to correct predictions for 76% of the observations in the test data set; this is calculated as (110+42)/200.
table(pred.tree,High.test)
## High.test
## pred.tree No Yes
## No 110 32
## Yes 16 42
Next, we consider whether pruning the tree might lead to improved results. That is, we want to see if pruning will result in a lower percent misclassification error rate and, therefore a higher percent correct predictions. Pruning selects the cp (complexity parameter) value associated with a shorter tree that minimizes the cross-validated error rate (xerror). See the table below. We refer to the CP table to select the cp value that produces the lowest cross-validation error (xerror). Referring to the table, we can choose or explicitly request the lowest cp value. The lowest cp value is 0.0333333 and has a tree with 3 splits.
#plotcp(fit.tree)
printcp(fit.tree)
##
## Classification tree:
## rpart(formula = High ~ ., data = Carseats.train, method = "class",
## cp = 0.008)
##
## Variables actually used in tree construction:
## [1] Advertising Age Price ShelveLoc
##
## Root node error: 90/200 = 0.45
##
## n= 200
##
## CP nsplit rel error xerror xstd
## 1 0.333333 0 1.00000 1.00000 0.078174
## 2 0.133333 1 0.66667 0.66667 0.072008
## 3 0.066667 2 0.53333 0.60000 0.069761
## 4 0.033333 3 0.46667 0.54444 0.067582
## 5 0.008000 5 0.40000 0.66667 0.072008
# Explicitly request the lowest cp value
fit.tree$cptable[which.min(fit.tree$cptable[,"xerror"]),"CP"]
## [1] 0.03333333
Applying the pruned tree, the tree with the lowest cp value, to the model that was built using the training data set. The pruned tree model contains three variables with 3 splits.
bestcp <-fit.tree$cptable[which.min(fit.tree$cptable[,"xerror"]),"CP"]
pruned.tree <- prune(fit.tree, cp = bestcp)
rpart.plot(pruned.tree)
We now apply the pruned tree to the test data. Now only 74% of the test observations are correctly classified, thus the misclassification error rate is 26%. In contrast, our unpruned tree has a misclassification error rate of 24%. Therefore, the pruned model results in an increase in bias as measured by the test error rate. As a result, we would select the larger, unpruned tree since the test error rate is lower than the test error rate of the pruned tree. Remember that a good classifier is one for which the test error is smallest. This example shows that pruning is not always effective in reducing bias. It is important to note that we could choose the pruned tree if interpretability is more important than a lower bias.
# Alternate specification
pred.prune = predict(pruned.tree, Carseats.test, type="class")
table(pred.prune, High.test)
## High.test
## pred.prune No Yes
## No 96 22
## Yes 30 52
In this section, we will use a regression tree to analyze the carseats data set. As before, we are interested in predicting Sales based on the other variables in the data set. Because Sales is a continuous variable, this is a regression problem and requires the use of a regression tree. Like before, we split the data into a training set and a test set, build the tree using the training set and evaluate its performance on the test data.
# Remove the variable High variable from the data.
Carseats.S <- Carseats[,-12]
set.seed(234)
train = sample(1:nrow(Carseats.S), 200)
Carseats.train=Carseats.S[train,]
Carseats.test=Carseats.S[-train,]
# Build the regression tree on the training set
fit.tree = rpart(Sales ~ ., data=Carseats.train, method="anova", cp=0.008)
#summary(fit.tree)
fit.tree
## n= 200
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 200 1596.13500 7.720450
## 2) ShelveLoc=Bad,Medium 154 908.83350 6.902727
## 4) Price>=92.5 132 665.46350 6.458030
## 8) ShelveLoc=Bad 41 186.98880 4.909024
## 16) CompPrice< 137.5 31 93.80490 4.369677
## 32) Price>=102.5 24 46.87750 3.850417
## 64) Income>=78 10 15.49221 2.933000 *
## 65) Income< 78 14 16.95694 4.505714 *
## 33) Price< 102.5 7 18.26940 6.150000 *
## 17) CompPrice>=137.5 10 56.21109 6.581000 *
## 9) ShelveLoc=Medium 91 335.77520 7.155934
## 18) Advertising< 13.5 70 217.43040 6.699429
## 36) CompPrice< 141.5 55 129.23440 6.286727
## 72) Price>=124.5 20 26.05848 5.414000
## 144) Income< 85 12 7.11050 4.735000 *
## 145) Income>=85 8 5.11675 6.432500 *
## 73) Price< 124.5 35 79.23827 6.785429
## 146) CompPrice< 122 21 36.81240 6.230000 *
## 147) CompPrice>=122 14 26.22957 7.618571 *
## 37) CompPrice>=141.5 15 44.47989 8.212667 *
## 19) Advertising>=13.5 21 55.13098 8.677619
## 38) Age>=42 13 27.06840 7.970000 *
## 39) Age< 42 8 10.97535 9.827500 *
## 5) Price< 92.5 22 60.64398 9.570909
## 10) Price>=75 15 34.02537 8.978667 *
## 11) Price< 75 7 10.08320 10.840000 *
## 3) ShelveLoc=Good 46 239.58290 10.458040
## 6) Price>=135 7 16.43160 7.260000 *
## 7) Price< 135 39 138.70900 11.032050
## 14) Price>=97.5 26 91.14694 10.438460
## 28) Advertising< 11.5 18 52.84876 9.887222 *
## 29) Advertising>=11.5 8 20.52209 11.678750 *
## 15) Price< 97.5 13 20.07889 12.219230 *
Next, take a look at a plot of the tree. The regression tree contains 6 variables with 16 splits. Splits have occurred on the variables ShelveLoc, Price, CompPrice, Advertising, Income, and Age.
rpart.plot(fit.tree)
fit.tree$variable.importance
## ShelveLoc Price CompPrice Advertising Income Age
## 590.418012 389.383611 153.349660 100.448780 55.043771 31.917276
## Population Education US Urban
## 26.786291 19.240978 8.182210 2.135904
pred.tree = predict(fit.tree, Carseats.test)
# Calcualte the mean square error
mse <- mean((pred.tree - Carseats.test$Sales)^2)
mse
## [1] 4.537886
We prune or trim the tree by choosing the best CP value based on the lowest cross-validation error (xerror) from the cp table. From the CP table it is observed that 0.0396043 is the lowest xerror and renders a tree with 4 splits. Alternatively, instead of printing the table, we can explicitly ask for the best CP value. A plot of the pruned tree appears below.
# Finding the best CP value
printcp(fit.tree)
##
## Regression tree:
## rpart(formula = Sales ~ ., data = Carseats.train, method = "anova",
## cp = 0.008)
##
## Variables actually used in tree construction:
## [1] Advertising Age CompPrice Income Price ShelveLoc
##
## Root node error: 1596.1/200 = 7.9807
##
## n= 200
##
## CP nsplit rel error xerror xstd
## 1 0.2805017 0 1.00000 1.01274 0.086707
## 2 0.1144803 1 0.71950 0.73668 0.064211
## 3 0.0894032 2 0.60502 0.70898 0.068519
## 4 0.0529042 3 0.51561 0.63450 0.060518
## 5 0.0396043 4 0.46271 0.62745 0.062296
## 6 0.0273887 5 0.42311 0.65640 0.069145
## 7 0.0231639 6 0.39572 0.63721 0.067014
## 8 0.0179546 7 0.37255 0.62940 0.066112
## 9 0.0172186 8 0.35460 0.63026 0.067400
## 10 0.0149973 9 0.33738 0.62986 0.067925
## 11 0.0111370 10 0.32238 0.65073 0.068184
## 12 0.0107054 11 0.31125 0.64495 0.067733
## 13 0.0103597 12 0.30054 0.64443 0.067775
## 14 0.0101472 13 0.29018 0.64781 0.067436
## 15 0.0090396 14 0.28003 0.64474 0.067337
## 16 0.0086655 15 0.27099 0.63910 0.066231
## 17 0.0080000 16 0.26233 0.63740 0.066212
bestcp <- fit.tree$cptable[which.min(fit.tree$cptable[,"xerror"]),"CP"]
bestcp
## [1] 0.03960432
# Prune the tree with the best cp value (the lowest cross-validation error - xerror)
pruned.tree <- prune(fit.tree, cp = bestcp)
# Visualizing the pruned tree
rpart.plot(pruned.tree)
# Checking the order of variable importance
pruned.tree$variable.importance
## ShelveLoc Price CompPrice Population
## 590.418012 267.168273 16.611453 3.480476
A good estimator is one for which the test error is smallest. To determine how well the decision-tree model performs on the test set, we calucate the mean square error (MSE). The MSE is a performance metric for regression models. A model having a low MSE is the best.
# Use the test data to evaluate performance of pruned regression tree
pred.prune = predict(pruned.tree, Carseats.test)
# Calcualte the MSE for the pruned tree
mse <- mean((pred.prune - Carseats.test$Sales)^2)
mse
## [1] 5.180412
The MSE value for the pruned tree is more than the MSE value for the original tee. In this case, pruning the tree did not reduce the MSE value, therefore our first tree is the better tree. Again, this example shows that pruning is not always effective in reducing bias. Although the unpruned tree renders the smallest MSE, it is more complicated to interpret. If interpretability is more important, then we would choose the pruned model.
If there is high non-linearity & a complex relationship between the dependent and independent variables, then a tree model will outperform other classical linear regression models (such as, LMs, GLMs, or GAMs).
Trees are very easy to explain and even simpler to interpret than linear regression.
It is believed that decision trees more closely mirror human decision-making than do classical regression approaches.
Trees can be displayed graphically, and are easily interpreted even by a non-expert.
Trees can handle data of different types, including continuous, categorical, ordinal, and binary. Transformations of the data are not required. For example, trees can easily handle qualitative predictors without the need to create dummy variables.
Trees can be useful for detecting important variables, interactions, and identifying outliers. The larger the number of variables, the more valuable is the exploration using decision trees.
Records with missing values are omitted by default.
Trees generally do not have the same level of predictive accuracy as some of the other regression and classification approaches.
Tree models have a tendency to overfit, that is, the error is fitted along with the data, and thus lead to over-interpretation.
Trees can be very non-robust. That is, a small change in the data can cause a large change in the final estimated tree.
Note: The predictive performance of trees can be substantially improved by aggregating many trees using methods like bagging, random forests, and boosting.
Explanation of the Decision Tree Model https://webfocusinfocenter.informationbuilders.com/wfappent/TLs/TL_rstat/source/DecisionTree47.htm
Tree Based Algorithms: A Complete Tutorial from Scratch (in R & Python) https://www.analyticsvidhya.com/blog/2016/04/tree-based-algorithms-complete-tutorial-scratch-in-python/
An Introduction to Statistical Learning: with Applications in R by Gareth James, Daniela Witten, Trevor Hastie and Rob Tibshirani. 2017 edition.