Today we are going to work on learning about classification trees. The goal is to get you comfortable with the methodology and the coding so you can apply it to your own data set.
Today, we are going to work with a very classic data set involving classifying irises into different subspecies. To load the data, run the following:
data("iris")
The variable Species
in the data set is our response variable. We also have 4 numeric predictors about the size of the flower and plant.
Our job is to try to use these predictors to classify flowers into three groups: setosa, versicolor, virginica. There are a lot of possible ways such clustering could be done, but we are going to choose classification trees.
Before we get started, let's divide our data into test and training sets. We use training sets to train, or fit, our models. We use this for EDA, model comparison, etc. The rest of the data (the test set) is used to make predictions and evaluate our model.
set.seed(100)
train <- sample(1:nrow(iris), 100)
train <-iris[train,]
test <-irst[-train,]
In classification trees, we work to create groups such that observations that are grouped together have similar values. If we can do that, we can use the values of the response variables in the cluster to help predict what the value of Y would be for any new observations that might fall into the cluster. For instance, suppose everyone in the cluster has a particular type of chest pain, and everyone in that cluster also has heart disease. Then, for a new person coming into our model, if they have that type of chest pain, we can predict they are likely to have heart disease!
The trick is how do define the clusters. We do this using the explanatory variables (features) that we already have in our data. This means that we try to use our X variables to create clusters of data points.
However...in the first step of the process, we don't use X at all! Instead we initialize a classification tree by assigning all observations (people) to one giant cluster. In that cluster, how many people have are in each of three categories? To estimate the probabilities, we compute.
$$\hat{\pi_{j}} = \frac{1}{n}\sum_{i=1}^{n} I(y_i =j).$$Take a look at the probability you computed. With classification trees, this probability is very important, because it will be used to make predictions for a new data point that comes in to the data set. If the probability of one category is highest, every new flower that comes into the data set right now will be predicted to be that category.
This is not very helpful. We don't want to make the same prediction for everyone - we want to use our predictors! To do this, we need to start to grow branches on our tree.
Now we need to determine how to move our data points out of this one big cluster into two smaller clusters. These clusters will be defined using exactly one of our predictors. Before we get into the math, let's see what that means. You are going to need to load a few packages, including caret
, rpart.part
, rpart
, and rattle
. Once you have loaded these packages, run library
on each. Then, run the following code to grow our classification tree:
library(caret)
set.seed(150)
iris.tree = train(Species ~ ., data=train, method="rpart", trControl = trainControl(method = "cv"))
fancyRpartPlot(iris.tree$finalModel)
The code you have just run will produce a classification tree! There is a lot to see here, but let's start off looking at the very first split from one cluster into two.
If you look at the two clusters you have produced, you will see three rows of information. The first row says "setosa" or "versicolor" or"virginica"; this is the most popular value of Y in a specific cluster. If values are tied, the value that is shown is the order of the categorical variable levels. To figure out this order, run
levels(iris$Species)
The values in the second row of a cluster tell you the proportion of the data points in the cluster that are assigned to that value of Y. Again, these are ordered according to the levels of the categorical variable.
The third row in the cluster tells you the percent of the total data that have landed in this cluster!
All of this information is followed up with a color coding that indicates what prediction you would use for a new observation that fell into a cluster.
And that's it! This is how we interpret and use our tree.
Once the three is grown, we usually compute a classification error rate to see how many of our original (training) data points we would incorrectly predict using our tree.
Ideally, we like CER values that are as low as possible if we are planning to use a tree for prediction.
If we want to make predictions for our test data set, we can use the predict
function in R.
iris.pred = predict(iris.tree, newdata = test.set)
Okay, so now we know where we are going. How did we get there? How did we actually grow and create this tree?
The process used is called recursive binary splitting. We start off with all the data in one big cluster, and the split, one split at a time, to grow the tree. There are standards for when and how we split, and when and how we stop. Let's see a few.
There are a few metrics used to grow classification trees, but the most common is the Gini Index. Suppose we are fitting a classification tree where the outcome variable has K levels. K must be at least two, but it can be more than two. Then the Gini Index is
$$G = \sum_{k=1}^{K} \hat{p}_{k} (1- \hat{p}_{k}).$$Here, $$\hat{p}_{k}$$ is the proportion of individuals that are assigned to class k.
This Gini Index that you have just computed is the Gini Index of the root node (the single giant cluster that you started with). When we create our first split, our goal is to try and make the Gini Index as small as we can. This is because the Gini Index is a measure of 'purity"; it goes to 0 if all observations in a leaf have the same Y value, and increases with more variety. Our goal is to build a good tool for predictions, so we are happiest when one value of Y dominates a cluster.
Okay, great. But...how did we know to split on Petal Length to create our two new clusters? And how did we know that splitting on Petal Length < 2.5 was what we should do? We didn't...we actually tried everything! R choose the first potential X variable, and created splits on all possible values of that X (well, values like .1, .2, .3...not everything, but a lot of choices!). We then computed the Gini Index for each split. R then saved the value of that X variable that had the lowest Gini index.
With that done, R moves on to the next X variable, and does that whole process again. At the end, there are four possible splits: The one that minimizes the Gini Index if we split on Petal Length, on Petal Width, on Sepal Length and on Sepal Width. We then choose the one that is the smallest, and split the data.
This same process continues to create every split on the tree.