In this tutorial, we will construct a regression tree to predict the salaries of baseball players based on multiple features.
Tree-based methods, while simple and useful for interpretation, are typically not as competitive with the best supervised learning approaches such as polynomial regression, local regression, or general additive models (GAM). However, tree-based methods such as bagging, random forests, and boosting may make up for this shortfall. By combining a large number of trees instead of one, the model can result in dramatic improvements in terms of prediction accuracy. However, this improvement in accuracy may be at the expense of some loss in intepretation. Regression trees can be used for both classification and regression. We will obviously be focusing on the latter.
In a regression tree, the tree arranges or segments observations into regions of a predictor space. For example, using the “Hitters” data set, which contains various statistics on baseball players, a tree might look something like the following
library(rpart)
library(rpart.plot)
hitters <- read.csv('/Users/czar.yobero/R/Hitters.csv')
reg.tree <- rpart(Salary ~ Years + Hits, data = hitters)
rpart.plot(reg.tree, type = 4)
We can interperet the above regression tree as follows: the most important factor in determining a player’s salary is years (i.e. experience). Players who have been in the league for at least 4.5 years tend to make more than players who have only been in the league for less than that time. Hits also plays a factor in determining a player’s salary (after all, we did include it in our equation). If you’re a baseball fan, as you might expect, players who have more hits tend to make more than players who have less hits. We can check the importance each variable plays into the dependent variable using the following code
reg.tree$variable.importance
## Years Hits
## 15696202 14259635
It is evident that Years plays a relatively more important role than Hits in determining a player’s salary (at least based on our simple example).
In any decision tree, observations are often grouped into regions \(R_{1}, R_{2}, \ldots R_{n}\). These regions are commonly referred to as terminal nodes or leaves of the tree. Decision trees are generalliy plotted upside down, in the sense that the terminal nodes are at the bottom of the tree. The points along the decision tree where teh predictor space is split are referred to as internal nodes.
We are now ready to begin to build a regression tree. Generally speaking, there are two main steps involved in constructing such a tree.
The question now is, “How do we go about constructing the regions \(R_{1}, R_{2}, \ldots R_{j}\)?” The answer is that we choose to divide the predictor space into high-dimensional rectangles (or boxes) for simplicity and for ease of interpretation of the resulting predictive model. The goal to find boxes \(R_{1}, \ldots R_{j}\) that minimizes the residual sum of squares \(\text{RSS}\) is given by
\[ \begin{aligned} \sum_{j=1}^{J} \sum_{i \in R_{j}} (y_{i} - \hat{y}_{R_{j}})^{2} \end{aligned} \]
where \(\hat{y}_{R_{j}}\) is the mean response ofr the etraining observations within the \(j\)th box. Unfortunately, it is computationallyi difficult, nay infeasible, to consider every possible partition of the feature space into \(J\) boxes. As a result, we take a top-down, greedy approach that is known as recursive binary splitting.
To perform recursive binary splitting, we first select the predictor \(X_{j}\) and the cutpoint \(s\) such that splitting the predictor space into the regions \(\{X|X_{j} < s\}\) and \(\{X|X_{j} \geq s\}\) leadds to the greatest possible reduction in RSS. The notation \(\{X|X_{j} < s\}\) denotes the region of predictor space in which \(X_{j}\) takes on a value less than \(s\). In other words, we consider all predictors \(X_{1}, X_{2}, \ldots X_{p}\) and all possible values of the cutpoint \(s\) for each of the predictors, and thehn chhoose theh predictor and cutpoint such that the resulting tree has the lowest RSS. In more detail, for any \(j\) and \(s\) we define the pair of half-plains
\[ \begin{aligned} R_{1}(j,s) = \{X | X_{j} < s \} \qquad \text{and} \qquad R_{2}(j, s) = \{X | X \geq s \} \end{aligned} \]
where \(\hat{y}_{R_{1}}\) is thhe mean response for the training observations in \(R_{1}(j,s)\), and \(\hat{y}_{R_{2}}\) is the mean response for the training observations in \(R_{2}(j, s)\).
Next, we repeat the process and look for the best predictor and best cutpoint to split the ddata further so as to minimize the RSS withhin each of the resulting regions. However, this time around, rather than splitting thhe entire predictor space, we split one of the two previously identified regions. We now have three regions. Again, we are looking to split one of these regions further so as to minimize RSS. This process continues until a stopping criterion is reached.
Although the process above may produce goo predictors on the training set, it is likely to overfit the data, leaing to poor test set performance. One solution to this problem is to produce a smaller tree with fewer splits (i.e. fewer \(R_{j}\)’s). Such a tree might lead to lower variance and better intepretation at the expense of a little bias.
An alternative process is to the one above is to build the tree only so long as the decrease in RSS due to each split exceeds some high threshohld. This strategy will result in smaller trees, but is too short-sightehd since a seemingly worthless split early on in the tree might be followed by a very good split. Therefore, a superior strategy is to grow a very large tree \(T_{0}\) and then prune it back in order to obtain a subtree. But just hohw do we determine the best way to prune the tree? The answer is that we need a way to select a small set of subtrees for consideration. This is done through cost complexity pruning, or weakest link pruning.
Rather than considering every possible subtree, we consider a sequence of trees indexed by a non-negative tuning parameter \(\alpha\). Thus, the algorithm for building such a tree is as follows
Use recursive binary splitting to grow a large tree on the training data and stopping only when each terminal node has fewer than some minimum number of observations.
Apply cost complexity pruning to the large tree in order to obtain a sequence of best subtrees, as a function of \(\alpha\).
Return the subtree from step 2 that corresponds to the chosen value of \(\alpha\).
For each value of \(\alpha\) there corresponds a subtree \(T \subset T_{0}\) such that
\[ \begin{aligned} \sum_{m=1}^{|T|} \sum_{i: x \in R_{m}} (y_{i} - \hat{y}_{R_{m}})^{2} + \alpha |T| \end{aligned} \]
is as small as possible. \(|T|\) denotes the bnumber of terminal nodes of the tree \(T\), \(R_{m}\) denotes the subset of predictor space corresponding to the \(m\)thh terminal node, and \(\hat{y}_{R_{m}}\) is the predicted response associated with \(R_{m}\) (i.e. the mean of the training observations in \(R_{m}\)).
Let us now extend our baseball salaries example to build a regression tree.
library(MASS)
library(rpart)
set.seed(1984)
train <- sample(1:nrow(hitters), nrow(hitters)/2)
tree.baseball <- rpart(Salary ~ Hits + HmRun + Runs + RBI + Walks + Years + Errors, subset = train, data = hitters)
summary(tree.baseball)
## Call:
## rpart(formula = Salary ~ Hits + HmRun + Runs + RBI + Walks +
## Years + Errors, data = hitters, subset = train)
## n=132 (29 observations deleted due to missingness)
##
## CP nsplit rel error xerror xstd
## 1 0.24729229 0 1.0000000 1.0075685 0.1937932
## 2 0.17431727 1 0.7527077 0.8197971 0.1970009
## 3 0.08513476 2 0.5783904 0.8668485 0.1878783
## 4 0.03302001 3 0.4932557 0.8095991 0.1849311
## 5 0.01875615 4 0.4602357 0.7640463 0.1658833
## 6 0.01540453 5 0.4414795 0.7682444 0.1656465
## 7 0.01198533 6 0.4260750 0.7600164 0.1658597
## 8 0.01000000 7 0.4140897 0.7586088 0.1694030
##
## Variable importance
## Hits Years RBI Runs HmRun Walks Errors
## 22 22 17 16 11 10 1
##
## Node number 1: 132 observations, complexity param=0.2472923
## mean=557.9758, MSE=245671.6
## left son=2 (50 obs) right son=3 (82 obs)
## Primary splits:
## Years < 4.5 to the left, improve=0.2472923, (0 missing)
## Walks < 66.5 to the left, improve=0.2091453, (0 missing)
## Runs < 60 to the left, improve=0.1663612, (0 missing)
## Hits < 143 to the left, improve=0.1621618, (0 missing)
## RBI < 72.5 to the left, improve=0.1408579, (0 missing)
## Surrogate splits:
## Hits < 60.5 to the left, agree=0.682, adj=0.16, (0 split)
## RBI < 16 to the left, agree=0.682, adj=0.16, (0 split)
## HmRun < 2.5 to the left, agree=0.674, adj=0.14, (0 split)
## Walks < 19 to the left, agree=0.674, adj=0.14, (0 split)
## Runs < 29.5 to the left, agree=0.667, adj=0.12, (0 split)
##
## Node number 2: 50 observations, complexity param=0.01875615
## mean=242.3267, MSE=112784.3
## left son=4 (42 obs) right son=5 (8 obs)
## Primary splits:
## Walks < 12.5 to the right, improve=0.10785840, (0 missing)
## Hits < 44.5 to the right, improve=0.10589680, (0 missing)
## HmRun < 1.5 to the right, improve=0.10210490, (0 missing)
## RBI < 14.5 to the right, improve=0.10016790, (0 missing)
## Runs < 23.5 to the right, improve=0.07569985, (0 missing)
## Surrogate splits:
## RBI < 20.5 to the right, agree=0.92, adj=0.500, (0 split)
## Hits < 32 to the right, agree=0.90, adj=0.375, (0 split)
## Runs < 10.5 to the right, agree=0.90, adj=0.375, (0 split)
##
## Node number 3: 82 observations, complexity param=0.1743173
## mean=750.4448, MSE=228903.4
## left son=6 (58 obs) right son=7 (24 obs)
## Primary splits:
## Hits < 142.5 to the left, improve=0.3011641, (0 missing)
## Runs < 60 to the left, improve=0.2595740, (0 missing)
## RBI < 81 to the left, improve=0.2506971, (0 missing)
## Walks < 43.5 to the left, improve=0.2365055, (0 missing)
## HmRun < 16.5 to the left, improve=0.1498388, (0 missing)
## Surrogate splits:
## Runs < 70 to the left, agree=0.890, adj=0.625, (0 split)
## RBI < 67.5 to the left, agree=0.829, adj=0.417, (0 split)
## Walks < 50.5 to the left, agree=0.793, adj=0.292, (0 split)
## HmRun < 27 to the left, agree=0.780, adj=0.250, (0 split)
##
## Node number 4: 42 observations, complexity param=0.01198533
## mean=194.1905, MSE=34714.64
## left son=8 (31 obs) right son=9 (11 obs)
## Primary splits:
## Years < 3.5 to the left, improve=0.2665734, (0 missing)
## HmRun < 24.5 to the left, improve=0.1625136, (0 missing)
## RBI < 89.5 to the left, improve=0.1597607, (0 missing)
## Runs < 74.5 to the left, improve=0.1528585, (0 missing)
## Hits < 112.5 to the left, improve=0.1474041, (0 missing)
## Surrogate splits:
## Hits < 159.5 to the left, agree=0.762, adj=0.091, (0 split)
## Walks < 64.5 to the left, agree=0.762, adj=0.091, (0 split)
## Errors < 21 to the left, agree=0.762, adj=0.091, (0 split)
##
## Node number 5: 8 observations
## mean=495.0416, MSE=446620.6
##
## Node number 6: 58 observations, complexity param=0.03302001
## mean=581.5489, MSE=84082.62
## left son=12 (37 obs) right son=13 (21 obs)
## Primary splits:
## Hits < 104.5 to the left, improve=0.2195694, (0 missing)
## Runs < 33.5 to the left, improve=0.1914736, (0 missing)
## Walks < 20.5 to the left, improve=0.1371105, (0 missing)
## Years < 6.5 to the left, improve=0.1191623, (0 missing)
## Errors < 4.5 to the left, improve=0.1157724, (0 missing)
## Surrogate splits:
## Runs < 49.5 to the left, agree=0.931, adj=0.810, (0 split)
## RBI < 54 to the left, agree=0.776, adj=0.381, (0 split)
## Errors < 7.5 to the left, agree=0.759, adj=0.333, (0 split)
## HmRun < 17 to the left, agree=0.741, adj=0.286, (0 split)
## Walks < 66.5 to the left, agree=0.690, adj=0.143, (0 split)
##
## Node number 7: 24 observations, complexity param=0.08513476
## mean=1158.61, MSE=343350.4
## left son=14 (14 obs) right son=15 (10 obs)
## Primary splits:
## RBI < 81 to the left, improve=0.3350325, (0 missing)
## Walks < 70.5 to the left, improve=0.1915065, (0 missing)
## HmRun < 15.5 to the left, improve=0.1835997, (0 missing)
## Years < 7.5 to the left, improve=0.1288186, (0 missing)
## Errors < 5.5 to the left, improve=0.1015978, (0 missing)
## Surrogate splits:
## HmRun < 15.5 to the left, agree=0.833, adj=0.6, (0 split)
## Hits < 185 to the left, agree=0.708, adj=0.3, (0 split)
## Runs < 84 to the left, agree=0.708, adj=0.3, (0 split)
## Walks < 60.5 to the left, agree=0.625, adj=0.1, (0 split)
## Years < 12 to the left, agree=0.625, adj=0.1, (0 split)
##
## Node number 8: 31 observations
## mean=136.8871, MSE=3426.479
##
## Node number 9: 11 observations
## mean=355.6818, MSE=87556.92
##
## Node number 12: 37 observations, complexity param=0.01540453
## mean=479.1847, MSE=66810.12
## left son=24 (18 obs) right son=25 (19 obs)
## Primary splits:
## Years < 8.5 to the left, improve=0.20208470, (0 missing)
## Walks < 42.5 to the left, improve=0.17910550, (0 missing)
## Runs < 33.5 to the left, improve=0.13689250, (0 missing)
## Errors < 2.5 to the left, improve=0.08896843, (0 missing)
## HmRun < 8.5 to the left, improve=0.06583124, (0 missing)
## Surrogate splits:
## Hits < 79.5 to the right, agree=0.622, adj=0.222, (0 split)
## Errors < 4.5 to the left, agree=0.622, adj=0.222, (0 split)
## HmRun < 3.5 to the left, agree=0.595, adj=0.167, (0 split)
## Runs < 32.5 to the right, agree=0.595, adj=0.167, (0 split)
## Walks < 21.5 to the left, agree=0.595, adj=0.167, (0 split)
##
## Node number 13: 21 observations
## mean=761.9047, MSE=63524.92
##
## Node number 14: 14 observations
## mean=871.9626, MSE=171947.7
##
## Node number 15: 10 observations
## mean=1559.917, MSE=307233.8
##
## Node number 24: 18 observations
## mean=359.8056, MSE=25916.61
##
## Node number 25: 19 observations
## mean=592.2807, MSE=79259.33
Plot our tree.
library(rpart.plot)
rpart.plot(tree.baseball)
According to our tree, baseball players with more experience, hihts and RBIs (runs batted in) tend to have a higher salary than their peers.
The most important factors in determining salary (according to thhe features used in our model) are
tree.baseball$variable.importance
## Hits Years RBI Runs HmRun Walks Errors
## 9209438.4 9183650.9 7111305.5 6501789.9 4581610.2 3927343.5 503275.6