Tree-Based Models
The following algorithm obtains two classification trees for the Iris dataset.
Functions rpart() and rpartXse() use the standard formula interface that most modeling functions in R use. This means specifying the abstract functional form of the model we are trying to obtain in the first argument and the available training data in the second.
# install.packages("devtools")
# library(devtools)
# install_github("ltorgo/DMwR2", ref="master")
library(DMwR2)
## Registered S3 method overwritten by 'quantmod':
## method from
## as.zoo.data.frame zoo
set.seed(1234)
data(iris)
ct1 <- rpartXse(Species ~ ., iris)
ct2 <- rpartXse(Species ~ ., iris, se=0)
The first tree is obtained with the default parameters of rpartXse(), which means using 1-SE post pruning. The second tree is obtained with a less “aggressive” pruning by specifying 0-SE pruning. This corresponds to selecting the lowest estimated error subtree of the originally overly large tree. Note the use of the function set.seed() as a way of ensuring you get the same trees.
We will get a graphical representation of the two objects(ct1 and ct2), which will give us a textual representation of the trees. Package rpart.plot - provides powerful graphical visualizations for rpart trees. Namely, function prp() can be used to plot the trees with many graphical variants, accessible through the large amount of parameters of this function. The tree diagram shows the two trees of the above example.
library(rpart.plot)
## Loading required package: rpart
par(mfrow=c(1,2))
prp(ct1, type=0, extra=101) # left tree
## Warning: Cannot retrieve the data used to build the model (so cannot determine roundint and is.binary for the variables).
## To silence this warning:
## Call prp with roundint=FALSE,
## or rebuild the rpart model with model=TRUE.
prp(ct2, type=0, extra=101) # right tree
## Warning: Cannot retrieve the data used to build the model (so cannot determine roundint and is.binary for the variables).
## To silence this warning:
## Call prp with roundint=FALSE,
## or rebuild the rpart model with model=TRUE.
Regression trees are obtained using the same procedure. The function, rpart() decides to obtain either a classification or a regression tree, depending on the type of the target variable you indicate in the formula. Trees obtained with this package can also be used to obtain predictions for a set of test cases. Package rpart uses the standard procedure for providing a predict() methodfor the objects produced by the rpart() function. The usage of this predict() function requires you to specify the model in the first argument and the test cases (a data frame) in the second. Notice the example coding below:
set.seed(1234)
rndSample <- sample(1:nrow(iris), 100)
tr <- iris[rndSample, ]
ts <- iris[- rndSample, ]
ct <- rpartXse(Species ~ ., tr, se=0.5)
psi <- predict(ct, ts)
head(psi)
## setosa versicolor virginica
## 1 1 0 0
## 2 1 0 0
## 5 1 0 0
## 6 1 0 0
## 8 1 0 0
## 9 1 0 0
ps2 <- predict(ct, ts, type="class")
head(ps2)
## 1 2 5 6 8 9
## setosa setosa setosa setosa setosa setosa
## Levels: setosa versicolor virginica
(cm <- table(ps2, ts$Species))
##
## ps2 setosa versicolor virginica
## setosa 18 0 0
## versicolor 0 18 3
## virginica 0 1 10
100*(1-sum(diag(cm))/sum(cm)) # the error rate
## [1] 8
If called without any further argument with a classification tree in the first argument, the predict() method of these rpart objects will output a matrix with the estimated probabilities of each class for each test case. With the extra argument type=“class” we get a vector with the predicted classes for all test cases (which are the classes with highest probability). To obtain predictions with regression trees, we use the same predict() function. However, in this case, the argument type=“class” does not make sense as the predictions are numbers and there are no class probabilities.