Set up

knitr::opts_chunk$set(echo = TRUE,
                      fig.align = "center")
pacman::p_load(rpart, rpart.plot, tidyverse, caret)

theme_set(theme_bw())

#  We will use the data cats, which is in the package, MASS 
cats <- 
  MASS::cats |> 
  rename(heart_wt = Hwt,
         body_wt = Bwt)
head(cats)
##   Sex body_wt heart_wt
## 1   F     2.0      7.0
## 2   F     2.0      7.4
## 3   F     2.0      9.5
## 4   F     2.1      7.2
## 5   F     2.1      7.3
## 6   F     2.1      7.6

Using Tree Method for Classification

Plot of Cat Data

We will use Body Weight and Heart Weight to try to classify cats as male or female. Here is a plot of the data:

#  make scatterplot of body weight by heart weight, coloring points by sex.
ggplot(
  data = cats,
  mapping = aes(
    x = body_wt, 
    y = heart_wt,
    color = Sex
  )
) + 
  geom_jitter() + 
  labs(
    x = "Body Weight (kg)",
    y = "Heart Weight (g)",
    color = NULL
  ) + 
  
  scale_color_discrete(
    type = c("F" = "tomato",    "M" = "steelblue"),
    labels = c("F" = "Female",  "M" = "Male")) + 
  
  theme(legend.position = "top")

Grow the full decision tree

We’ll use the rpart() function to create the tree. Its arguments are:

  1. formula = {response} ~ {explanatory variables}

  2. data =

  3. method = “class” for classification, “anova” for regression

  4. parms = list(split = "information") tells it to use entropy for determining splits

To fully grow the tree, we need to add a couple of additional arguments:

  1. minsplit = 0

  2. minbucket = 0

  3. cp = -1

Create the full decision tree that will be used to find where to prune the tree

RNGversion("4.1.0")
set.seed(5230)

# Build the full decision tree here
cat_tree <- 
  rpart(
    formula = Sex ~ body_wt + heart_wt,
    data = cats,
    method = "class",
    parms = list(split = "information"),
    minsplit = 0, 
    minbucket = 0,
    cp = -1
  )  

Plot the full decision tree with rpart.plot() with the three arguments:

  1. x = tree created with rplot()

The next two arguments you can choose different numbers to see how the tree changes:

  1. type = 5

  2. extra = 101

rpart.plot(x = cat_tree,
           type = 5,
           extra = 101)

Pruning the full decision tree

We can look at the complexity parameter (cp) table to find where to prune the table.

# Use the cptable to find the best value of cp to use the prune the full tree
cat_tree$cptable
##              CP nsplit  rel error    xerror      xstd
## 1   0.191489362      0 1.00000000 1.0000000 0.1197170
## 2   0.042553191      2 0.61702128 0.7234043 0.1084318
## 3   0.028368794      3 0.57446809 0.8297872 0.1134613
## 4   0.021276596      6 0.48936170 0.8085106 0.1125293
## 5   0.017021277     13 0.31914894 0.8510638 0.1143583
## 6   0.014184397     18 0.23404255 0.8085106 0.1125293
## 7   0.010638298     21 0.19148936 0.8085106 0.1125293
## 8   0.007092199     27 0.12765957 0.8723404 0.1152209
## 9   0.000000000     33 0.08510638 0.8936170 0.1160501
## 10 -1.000000000     39 0.08510638 0.8936170 0.1160501

What we want to find is the first row in the table with:

xerror < min(xerror) + xstd

where xstd is the xstd of the minimum xerror

cat_tree$cptable |> 
  # Need to convert it to a data frame to use any dplyr verbs
  data.frame() |> 
  
  slice_min(xerror) |> 
  
  mutate(cut_off = xerror + xstd)
##           CP nsplit rel.error    xerror      xstd   cut_off
## 2 0.04255319      2 0.6170213 0.7234043 0.1084318 0.8318361
# Plotting the CP table
plotcp(cat_tree)

Using our cptable, we need the simplest tree with an xerror below 0.660 + 0.105, which is row 2 that has a cp of 0.043

Once we find the cp value of the tree we want to use, we can use the prune() function to, well, prune the tree.

It has 2 arguments:

  1. tree = the tree we want to prune

  2. cp = a cp value slightly higher than the tree we want to use.

Use prune() and plot it:

# Prune the tree
cats_pruned <- 
  prune(
    tree = cat_tree,
    cp = 0.043
  )

# Then plot it:
rpart.plot(
  x = cats_pruned,
  type = 5,
  extra = 101
)