knitr::opts_chunk$set(echo = TRUE,
fig.align = "center")
pacman::p_load(rpart, rpart.plot, tidyverse, caret)
theme_set(theme_bw())
# We will use the data cats, which is in the package, MASS
cats <-
MASS::cats |>
rename(heart_wt = Hwt,
body_wt = Bwt)
head(cats)
## Sex body_wt heart_wt
## 1 F 2.0 7.0
## 2 F 2.0 7.4
## 3 F 2.0 9.5
## 4 F 2.1 7.2
## 5 F 2.1 7.3
## 6 F 2.1 7.6
We will use Body Weight and Heart Weight to try to classify cats as male or female. Here is a plot of the data:
# make scatterplot of body weight by heart weight, coloring points by sex.
ggplot(
data = cats,
mapping = aes(
x = body_wt,
y = heart_wt,
color = Sex
)
) +
geom_jitter() +
labs(
x = "Body Weight (kg)",
y = "Heart Weight (g)",
color = NULL
) +
scale_color_discrete(
type = c("F" = "tomato", "M" = "steelblue"),
labels = c("F" = "Female", "M" = "Male")) +
theme(legend.position = "top")
We’ll use the rpart()
function to create the tree. Its
arguments are:
formula = {response} ~ {explanatory variables}
data =
method =
“class” for classification, “anova” for
regression
parms = list(split = "information")
tells it to use
entropy for determining splits
To fully grow the tree, we need to add a couple of additional arguments:
minsplit = 0
minbucket = 0
cp = -1
Create the full decision tree that will be used to find where to prune the tree
RNGversion("4.1.0")
set.seed(5230)
# Build the full decision tree here
cat_tree <-
rpart(
formula = Sex ~ body_wt + heart_wt,
data = cats,
method = "class",
parms = list(split = "information"),
minsplit = 0,
minbucket = 0,
cp = -1
)
Plot the full decision tree with rpart.plot()
with the
three arguments:
x =
tree created with rplot()
The next two arguments you can choose different numbers to see how the tree changes:
type = 5
extra = 101
rpart.plot(x = cat_tree,
type = 5,
extra = 101)
We can look at the complexity parameter (cp) table to find where to prune the table.
# Use the cptable to find the best value of cp to use the prune the full tree
cat_tree$cptable
## CP nsplit rel error xerror xstd
## 1 0.191489362 0 1.00000000 1.0000000 0.1197170
## 2 0.042553191 2 0.61702128 0.7234043 0.1084318
## 3 0.028368794 3 0.57446809 0.8297872 0.1134613
## 4 0.021276596 6 0.48936170 0.8085106 0.1125293
## 5 0.017021277 13 0.31914894 0.8510638 0.1143583
## 6 0.014184397 18 0.23404255 0.8085106 0.1125293
## 7 0.010638298 21 0.19148936 0.8085106 0.1125293
## 8 0.007092199 27 0.12765957 0.8723404 0.1152209
## 9 0.000000000 33 0.08510638 0.8936170 0.1160501
## 10 -1.000000000 39 0.08510638 0.8936170 0.1160501
What we want to find is the first row in the table with:
xerror < min(xerror) + xstd
where xstd is the xstd of the minimum xerror
cat_tree$cptable |>
# Need to convert it to a data frame to use any dplyr verbs
data.frame() |>
slice_min(xerror) |>
mutate(cut_off = xerror + xstd)
## CP nsplit rel.error xerror xstd cut_off
## 2 0.04255319 2 0.6170213 0.7234043 0.1084318 0.8318361
# Plotting the CP table
plotcp(cat_tree)
Using our cptable, we need the simplest tree with an xerror below 0.660 + 0.105, which is row 2 that has a cp of 0.043
Once we find the cp value of the tree we want to use, we can use the
prune()
function to, well, prune the tree.
It has 2 arguments:
tree =
the tree we want to prune
cp =
a cp value slightly higher than the tree we
want to use.
Use prune()
and plot it:
# Prune the tree
cats_pruned <-
prune(
tree = cat_tree,
cp = 0.043
)
# Then plot it:
rpart.plot(
x = cats_pruned,
type = 5,
extra = 101
)