rm(list = ls())
library(rpart)
set.seed(161)
n <- 1272
Generate two random normal variables X1 and X2. When X1 is less than 2, then Y should be 1, when X1 is greater than or equal to 2 and X2 is less than -.5, then Y should be 1. Otherwise, Y should be 0.
x1 <- rnorm(n = n, mean = 3.5)
x2 <- rnorm(n = n, mean = 1.3)
x1.b <- ifelse(x1 < 2, 1, 0)
x2.b <- ifelse(x1 >= 2 & x2 < -.5, 1, 0)
But let’s add a little random variability to Y and then add some unrelated variables (X3, X4, X5).
y <- ifelse(x1.b == 1 | x2.b == 1, .9, .1) # the .9 and .1 are the probs of getting a 1
y <- rbinom(n = n, size = 1, prob = y)
x3 <- rnorm(n = n, mean = x1) # correlated with x1
x4 <- rnorm(n = n, mean = 2)
x5 <- rnorm(n = n, mean = .4)
Run the tree. Ugh. That is one ugly, overfit, and unuseful tree.
mod <- rpart(y ~ x1 + x2 + x3 + x4 + x5, method = "class", control = rpart.control(xval = 10, cp = 0, minsplit = 20))
plot(mod, margin =.1)
text(mod, cex = .8)
Can we prune our tree with cross-validation and get a slightly less complex tree?
printcp(mod)
Classification tree:
rpart(formula = y ~ x1 + x2 + x3 + x4 + x5, method = "class",
control = rpart.control(xval = 10, cp = 0, minsplit = 20))
Variables actually used in tree construction:
[1] x1 x2 x3 x4 x5
Root node error: 203/1272 = 0.15959
n= 1272
CP nsplit rel error xerror xstd
1 0.31527094 0 1.00000 1.00000 0.064342
2 0.15270936 1 0.68473 0.69458 0.055157
3 0.00394089 2 0.53202 0.56158 0.050184
4 0.00164204 7 0.51232 0.60099 0.051736
5 0.00089566 10 0.50739 0.63547 0.053037
6 0.00000000 21 0.49754 0.64039 0.053219
plotcp(mod, upper = "split")
Hmm, cross-validation says we should have just two splits. What does it look like?
fit2 <- prune(mod, cp = .021)
plot(fit2, margin =.5)
text(fit2)
Great. We recovered the truth (with cross-validation).