In this example, we’ll build a classification decision tree in order to analyze if a particular individual will commit an affair on their partner based on demographics and other data.
Start by loading the packages we will use in the analysis. The rpart package is the main workhorse for building and analzying decision trees. The caret package is used for splitting the data into training and test sets. The remaining packages will enable us to output a much more beautifully constructed decision tree.
library(rpart)
library(rpart.plot)
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
library(readr)
library(rattle)
## Loading required package: RGtk2
## Rattle: A free graphical interface for data mining with R.
## Version 3.5.0 Copyright (c) 2006-2015 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.
library(RColorBrewer)
Decision trees can be powerful tools for data analysis and prediction. They’re also often easier to interpret than linear regression models. This is especially handy for those who work with others, such as myself, whose main role is not data analysis. However, as trees grow, they become less interpretable, but can still offer greater predictive performance over regression (of course this is not always the case).
The tree starts at the top and finds the best data to split into nodes. It does this by recursive binary splitting using either the Gini index or cross-entropy measure. The Gini index is defined as:
\[G = \sum_{k=1}^K \hat{p}_{mk}(1 - \hat{p}_{mk})\]
and is also referred as a measure of node purity, i.e., a smaller value indicates a node contains observations primarily from a single class.
Cross-entropy is similar to the Gini index in that it will take a small value if the node is pure. It is defined as:
\[D = -\sum_{k=1}^K \hat{p}_{mk}log\hat{p}_{mk}\]
The Gini index and Cross-entropy measures dictate when a node split will occur in order to keep each node as pure as possible to reduce the total value of the Gini index or cross-entropy measures.
Start at the top, or root of the tree. 75% of the individuals will not have an affair with 25% committing an affair. If the individual’s rating of their marriage was equal to or above 2.5, you look left, otherwise you move right. Tp the left, we see only 21% of those who rated their marriage 2.5 or above will commit an affair, so the overall terminal node ends with the bucket not committing an affair. To the right of the root node, those who rated their marriage below 2.5 are split between cheaters and non-cheaters. Therefore, another node split occurs on the occupation variable. If an individual’s occupation is rated 2.5 or below, you move left. The terminal node there states only 37% of all individuals with an occupation rated 2.5 or below will commit an affair.
Start by loading the data into R. I used the readr package, but you could use the base read.csv or read.table. The Fair data is also available in the Ecdat package.
fair <- read_csv('C:/Users/aarschle1/Google Drive/Projects/Stats/decisiontrees/classification/fair.csv')
Check out the first few rows of the data using the head()
function
head(fair)
## sex age ym child religious education occupation rate nbaffairs
## 1 male 37 10.00 no 3 18 7 4 0
## 2 female 27 4.00 no 4 14 6 4 0
## 3 female 32 15.00 yes 1 12 1 4 0
## 4 male 57 15.00 yes 5 18 6 5 0
## 5 male 22 0.75 no 2 17 6 3 0
## 6 female 32 1.50 no 2 17 5 5 0
We can see the data is made up of several factors and continuous variables. The religious, occupation, and rate columns are according to a scale measure. In order to make these variables useful in the analysis and prediction, we will perform feature engineering to create new categories.
Before that, let’s plot a histogram on the number of affairs in the past year column (nbaffairs) to determine how the data is distributed.
hist(fair$nbaffairs, freq=FALSE, main="Density plot")
curve(dnorm(x, mean=mean(fair$nbaffairs), sd=sd(fair$nbaffairs)), add = TRUE, col='darkblue', lwd=2)
Wow, 80% of the individuals in the data did not have an affair within the past year! The remaining 20% had between 1 and 12 affairs in the past year. Since we are only interested in analyzing and predicting the probability of an individual having an affair in the future, we can change the number of affairs to a binary factor of Yes and No (or 1 and 0).
fair$past.affairs <- ifelse(fair$nbaffairs == 0, "No", "Yes")
We can also change the other binary classes in the data to factors using the as.factor
method. Text classes in data must be classified or factored in order for them to be used in analyses.
fair$sex.f <- as.factor(fair$sex)
fair$child.f <- as.factor(fair$child)
Let’s take a look at a histogram of the age variable to get a sense of the distribution of the individuals’ ages.
hist(fair$age, freq=FALSE, main="Density plot")
curve(dnorm(x, mean=mean(fair$age), sd=sd(fair$age)), add = TRUE, col='darkblue', lwd=2)
Judging by the histogram, we can see the ages of the individuals are much more normally distributed with the most centered around the age of 30. Since age in this case is a factor, we need to split the continuous age variable into categories. We can use something similar to survey questions which ask for an individual’s age within a particular range. To do this, let’s find all the unique values in the age column.
sort(unique(fair$age))
## [1] 17.5 22.0 27.0 32.0 37.0 42.0 47.0 52.0 57.0
The lowest age is 17.5 and the highest is 57. This can be cut into 5 year groups using the cut()
function. One could also use a 10 year range; however, a 5 year range in this case keeps individuals’ ages in similar groups.
fair$age.bins <- cut(fair$age, 7, labels=c('17-22','23-28','29-34','35-40','41-46','47-52','53-57'))
The education variable can also be considered a factor as it represents the individual’s years of education. We can classifying the years of education using the cut()
function as above. First, let’s see the unique values of the education variable to help us determine the appropriate classifications.
sort(unique(fair$education))
## [1] 9 12 14 16 17 18 20
The lowest years of education is 9 with a high of 20. We can create categories based on the number of years of education based on standard U.S. educational guidelines. 12 years of education typically means the individual completed high-school, with 16 being the average number of years it takes to complete a bachelor’s degree, and so on.
fair$edu.f <- cut(fair$education, 7, labels=c('No High-School Degree', 'High-School Completed', 'Some College', 'Bachelors Degree', 'Post-Baccalaureate', 'Masters Degree', 'Post-Doc'))
Let’s now take a look at the factors based on a rising scale. According to the data set information, occupation is coded according to Hollingshead classification. We can creae classifications of each occupation rating based on the names designated by Hollingshead.
fair$occ.f <- cut(fair$occupation, 7, labels=c('Menial Service', 'Unskilled', 'Semiskilled', 'Skilled', 'Clerical', 'Semiprofessional', 'Minor Professional'))
The other two scale variables are more straightforward. The rate variable is a scale from 1 - 5 representing the self-rating of the individual’s marriage, with 1 being very unhappy and 5 being very happy.
Religion is similar as it is also on a 1 - 5 scale with 1 being anti-religious to 5 being very religious.
fair$rating.f <- cut(fair$rate, 5, labels=c('Very Unhappy', 'Unhappy', 'Average', 'Happy', 'Very Happy'))
fair$religion.f <- cut(fair$religious, 5, labels=c('Not at all', 'Somewhat Against', 'Not sure', 'Somewhat Religious', 'Very Religious'))
Now that we’ve explored the data and created features from the raw data set, we can build our classification decision tree!
We split the data into training and test sets using the caret package. 70% of the data will go to training with the remaining 30% designated for testing the model.
inTrain <- createDataPartition(y = fair$past.affairs,
p = 0.7,
list = FALSE)
training <- fair[ inTrain,]
testing <- fair[-inTrain,]
Before building the tree, set a seed so we can replicate the results later. This will come in handy when we prune the tree.
set.seed(1)
Creating the tree is straightfoward with the rpart package which allows us to use a formula call to build the tree. The past.affairs variable that was set earlier is used as the y variable. The data to use is the training set and since we are predicting categories, or classes, the method is set to class
.
tree <- rpart(past.affairs ~ sex.f + child.f + age.bins + occ.f + rating.f + ym + religion.f + edu.f, data = training, method='class')
We can then examine the tree using the fancyRpartPlot
function, which is included in the rpart.plot
package.
fancyRpartPlot(tree)
Our tree is definitely a lot deeper than the previous example. Decisions were found for individuals’ ages, religious convictions, education and others. An individual’s rating of their marriage appears to be the most determining factor, as well as age. Somewhat expectedly, those who rated their marriage positively and considered themselves religious were among those least likely to have an affair, whereas more educated individuals with no children were some of the most likely.
We can use the printcp
function to find more information about the tree.
printcp(tree)
##
## Classification tree:
## rpart(formula = past.affairs ~ sex.f + child.f + age.bins + occ.f +
## rating.f + ym + religion.f + edu.f, data = training, method = "class")
##
## Variables actually used in tree construction:
## [1] age.bins child.f edu.f occ.f rating.f religion.f
##
## Root node error: 105/421 = 0.24941
##
## n= 421
##
## CP nsplit rel error xerror xstd
## 1 0.066667 0 1.00000 1.00000 0.084549
## 2 0.019048 2 0.86667 0.92381 0.082286
## 3 0.014286 4 0.82857 0.88571 0.081068
## 4 0.011905 6 0.80000 1.00952 0.084816
## 5 0.010000 10 0.75238 1.00952 0.084816
As we noticed when viewing the tree, education, marriage rating and religious conviction were the most important indicators of an individual’s propensity to cheat.
We can use variable.importance
in the tree object to see what variables were considered in the tree’s construction.
tree$variable.importance
## age.bins rating.f occ.f edu.f religion.f ym
## 13.612347 12.829227 8.603591 6.157185 4.381092 2.954118
## child.f sex.f
## 2.438692 1.189296
With the tree built, we can check its predictive power but using it on the test data set we created earlier.
tree.pred <- predict(tree, testing, type = "class")
We construct a quick table to show us the number of correctly predicted observations on the test data set.
table(tree.pred, testing$past.affairs)
##
## tree.pred No Yes
## No 121 37
## Yes 14 8
(124 + 11) / nrow(testing)
## [1] 0.75
The tree was able to accurately predict 75% of the individuals fidelity status.
We can prune the tree using a cost complexity prune of 0.05 (it’s generally best to choose a cp number between the top 2 as printed with the printcp
function) to see if the results are any better.
pruned <- prune(tree, cp = 0.05)
fancyRpartPlot(pruned)
The pruned tree is definitely easier to interpret, with only marriage rating and age in the decisions.
printcp(pruned)
##
## Classification tree:
## rpart(formula = past.affairs ~ sex.f + child.f + age.bins + occ.f +
## rating.f + ym + religion.f + edu.f, data = training, method = "class")
##
## Variables actually used in tree construction:
## [1] age.bins rating.f
##
## Root node error: 105/421 = 0.24941
##
## n= 421
##
## CP nsplit rel error xerror xstd
## 1 0.066667 0 1.00000 1.00000 0.084549
## 2 0.050000 2 0.86667 0.92381 0.082286
Does it improve prediction? We can predict again using the new pruned tree on the testing data set.
pruned.pred <- predict(pruned, testing, type = 'class')
Creating a new table to display the number of correctly classified individuals.
table(pruned.pred, testing$past.affairs)
##
## pruned.pred No Yes
## No 134 42
## Yes 1 3
The pruned tree correctly classifies about 77% of the observations. Not a terribly strong increase in performance; but by playing around with the cp values and setting tree controls with the rpart package we could likely improve this even further.
(130 + 9) / nrow(testing)
## [1] 0.7722222