Start by running all of the code in the first OkCupid presentation…
Now, instead of using the lm() function to create linear models, we’ll use rpart() to create decision trees. rpart stands for recursive partitioning. By partitioning, we mean that we are splitting our data into subsets (or branches). We then compute the mean value of the predicted variable (in our case, age) across all of the observations in each branch. Our prediciton for all observations on that branch is simply the mean value of age for that branch. By “recursive”“, we mean that we can keep doing this – splitting each branch into smaller branches and each of those smaller branches into yet smaller branches and so on. At each opporunity rpart() aims to choose the partition that will minimize the ultimate prediction error. It will also stop branching when additional branches are adding little to the in-sample prediction accuracy.
Let’s load the necessary packages and look at a few examples. Here is a decision tree for predicting age from job and offspring:
library(rpart); library(rpart.plot)
fit <- rpart(age ~ job+offspring,data=train)
prp(fit, type=1, fallen.leaves=TRUE, extra=1, cex=0.7)
At each junction, the left branch is the branch for which the statement at the junction is true.
We can also make a decision tree based on a continuous variable. Here is a decision tree based on income:
fit <- rpart(age ~ income,data=train)
prp(fit, type=1, fallen.leaves=TRUE, extra=1, cex=0.7)
In what situations would a decision tree make more sense than a linear model? When would a linear model make more sense?
Next, let’s make a tree based on job, education level and offspring and use the printcp() function to look at how the relative error depends on the number of nodes (junctions). Note that the “root node error” is essentially the same as the variance in age in the training set. The relative error of each model (each with a different number of nodes) is its mean square error divided by the mean square error of a model with zero nodes.
fit <- rpart(age ~ job+ed.level+offspring,data=train)
prp(fit, type=1, fallen.leaves=TRUE, extra=1, cex=0.7)
printcp(fit)
##
## Regression tree:
## rpart(formula = age ~ job + ed.level + offspring, data = train)
##
## Variables actually used in tree construction:
## [1] ed.level job offspring
##
## Root node error: 2706156/29973 = 90.286
##
## n= 29973
##
## CP nsplit rel error xerror xstd
## 1 0.174941 0 1.00000 1.00008 0.0111162
## 2 0.041692 1 0.82506 0.82545 0.0096478
## 3 0.020932 2 0.78337 0.78405 0.0094590
## 4 0.010000 3 0.76243 0.76316 0.0093473
var(train$age)
## [1] 90.28948
The leftmost column, “CP”, shows the minimum value of the complexity parameter that would yield this model (rather than one more complex). Any potential branching which does not improve the model fit by more than the cp value is not made. In short, having a lower cp value will result in a more complex model.
Why might we want to limit our model this way? Why not keep branching so long as our in sample prediction error keeps going down?
Let’s look at a few models we could make using job, ed.level and offspring using different cp values. First, let’s try a cp value of 0.015 - how many nodes should this model have?
fit <- rpart(age ~ job+ed.level+offspring,cp=0.03,data=train)
prp(fit, type=1, fallen.leaves=TRUE, extra=1, cex=0.7)
Now, let’s try a cp value of 0.003.
fit <- rpart(age ~ job+ed.level+offspring,cp=0.003,data=train)
prp(fit, type=1, fallen.leaves=TRUE, extra=1, cex=0.5)
printcp(fit)
##
## Regression tree:
## rpart(formula = age ~ job + ed.level + offspring, data = train,
## cp = 0.003)
##
## Variables actually used in tree construction:
## [1] ed.level job offspring
##
## Root node error: 2706156/29973 = 90.286
##
## n= 29973
##
## CP nsplit rel error xerror xstd
## 1 0.1749410 0 1.00000 1.00003 0.0111156
## 2 0.0416922 1 0.82506 0.82641 0.0096575
## 3 0.0209321 2 0.78337 0.78503 0.0094690
## 4 0.0078741 3 0.76243 0.76404 0.0093541
## 5 0.0059809 4 0.75456 0.75727 0.0092494
## 6 0.0057475 5 0.74858 0.75223 0.0092242
## 7 0.0043651 6 0.74283 0.74547 0.0091576
## 8 0.0038052 7 0.73847 0.74371 0.0091453
## 9 0.0030000 8 0.73466 0.73850 0.0090811