Chapter 5: Classification using Decision Trees

knitr::opts_chunk$set(echo = TRUE,
                      fig.align = "center")

#  Load them here:
pacman::p_load(rpart, rpart.plot, tidyverse, caret)

# Changing the default theme
theme_set(theme_bw())

Decision Trees

Example: To Eat or not to Eat?

Can we determine if a mushroom is safe to eat based on the other features?

Step 1: Mushroom Data

The mushroom.csv file has 23 features on 8124 different mushrooms. The feature we want to classify is type, if the mushroom is edible or poisonous.

Step 2: Exploring and preparing the data

# There are two datasets below
# The first contains just a few variables 
# The second contains all 23 variables about the 8000 mushrooms

# We'll look at the smaller file with just the 4 features:
# mushroom <- read.csv("mushroom.csv", stringsAsFactors = T)

# Remove the # at the beginning of the line below to use all 23 features:
mushroom <- 
   read.csv("mushroom_full.csv", 
            stringsAsFactors = T) |> 
   rename(edible = type)

skimr::skim(mushroom) 
Data summary
Name mushroom
Number of rows 8124
Number of columns 23
_______________________
Column type frequency:
factor 23
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
edible 0 1 FALSE 2 edi: 4208, poi: 3916
cap_shape 0 1 FALSE 6 con: 3656, fla: 3152, kno: 828, bel: 452
cap_surface 0 1 FALSE 4 sca: 3244, smo: 2556, fib: 2320, gro: 4
cap_color 0 1 FALSE 10 bro: 2284, gra: 1840, red: 1500, yel: 1072
bruises 0 1 FALSE 2 no: 4748, yes: 3376
odor 0 1 FALSE 9 non: 3528, fou: 2160, fis: 576, spi: 576
gill_attachment 0 1 FALSE 2 fre: 7914, att: 210
gill_spacing 0 1 FALSE 2 clo: 6812, cro: 1312
gill_size 0 1 FALSE 2 bro: 5612, nar: 2512
gill_color 0 1 FALSE 12 buf: 1728, pin: 1492, whi: 1202, bro: 1048
stalk_shape 0 1 FALSE 2 tap: 4608, enl: 3516
stalk_root 0 1 FALSE 5 bul: 3776, mis: 2480, equ: 1120, clu: 556
stalk_surface_above_ring 0 1 FALSE 4 smo: 5176, sil: 2372, fib: 552, sca: 24
stalk_surface_below_ring 0 1 FALSE 4 smo: 4936, sil: 2304, fib: 600, sca: 284
stalk_color_above_ring 0 1 FALSE 9 whi: 4464, pin: 1872, gra: 576, bro: 448
stalk_color_below_ring 0 1 FALSE 9 whi: 4384, pin: 1872, gra: 576, bro: 512
veil_type 0 1 FALSE 1 par: 8124
veil_color 0 1 FALSE 4 whi: 7924, bro: 96, ora: 96, yel: 8
ring_number 0 1 FALSE 3 one: 7488, two: 600, non: 36
ring_type 0 1 FALSE 5 pen: 3968, eva: 2776, lar: 1296, fla: 48
spore_print_color 0 1 FALSE 9 whi: 2388, bro: 1968, bla: 1872, cho: 1632
population 0 1 FALSE 6 sev: 4040, sol: 1712, sca: 1248, num: 400
habitat 0 1 FALSE 7 woo: 3148, gra: 2148, pat: 1144, lea: 832

Since all the features are factors, we should check to see how many groups each feature has:

skimr::skim(mushroom) |>  
  ggplot(
    mapping = aes(
      x = reorder(skim_variable, factor.n_unique),
      y = factor.n_unique
    )
  ) +
  
  geom_col(
    color = "black", 
    fill = "steelblue"
  ) + 
  
  theme(
    axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1)
  ) + 
  
  labs(
    x = "Mushroom Feature",
    y = "Number of groups"
  ) + 
  
  scale_y_continuous(
    breaks = seq(0, 12, by = 2),
    expand = c(0, 0, 0.05, 0)
  )

When using the full data set, we should remove veil_type from the data since there is only 1 group - Not a variable

mushroom <- 
  mushroom %>% 
  select(-veil_type)

Forming the full tree

Build the classificiation tree using rpart()

The last 3 arguments are used to have rpart grow the tree as “fully” as it can

  1. minsplit = 0

  2. minbucket = 0

  3. cp = -1

RNGversion("4.1.0")
set.seed(5230)
edible_tree <- 
  rpart(
    formula = edible ~ .,
    data = mushroom,
    method = "class",
    parms = list(split = "information"),
    minsplit = 0, 
    minbucket = 0,
    cp = -1
  )  

# display the cp table
edible_tree |> 
  pluck("cptable") |> 
  data.frame()
##              CP nsplit   rel.error      xerror         xstd
## 1  0.9693564862      0 1.000000000 1.000000000 0.0115008931
## 2  0.0183861083      1 0.030643514 0.030643514 0.0027766205
## 3  0.0051072523      2 0.012257406 0.012257406 0.0017639698
## 4  0.0010214505      4 0.002042901 0.002042901 0.0007219188
## 5  0.0005107252      5 0.001021450 0.001021450 0.0005105995
## 6 -1.0000000000      7 0.000000000 0.001021450 0.0005105995

Step 4: Pruning the tree

Since the algorithm is set to fully grow the tree, we need to prune it to prevent overfitting. We can use printcp() to see at which point the tree should be pruned.

The xerror is the error of the tree using cross validation (x for cross error)

edible_tree |> 
  pluck("cptable") |> 
  data.frame()
##              CP nsplit   rel.error      xerror         xstd
## 1  0.9693564862      0 1.000000000 1.000000000 0.0115008931
## 2  0.0183861083      1 0.030643514 0.030643514 0.0027766205
## 3  0.0051072523      2 0.012257406 0.012257406 0.0017639698
## 4  0.0010214505      4 0.002042901 0.002042901 0.0007219188
## 5  0.0005107252      5 0.001021450 0.001021450 0.0005105995
## 6 -1.0000000000      7 0.000000000 0.001021450 0.0005105995

We want the smallest tree (nsplit) with xerror < lowest xerror + xstd. Then we set the cp argument to be slightly larger than the CP value for the corresponding tree.

# Finding the xerror cutoff:
edible_tree |> 
  pluck("cptable") |> 
  data.frame() |> 
  slice_min(xerror) |> 
  mutate(cut_off = xerror + xstd) |> 
  slice(1) |> 
  select(cut_off) |> 
  as.numeric() ->
  xcutoff

# Now we find the rows with an xerror below the xcutoff
edible_tree$cptable |> 
  data.frame() |> 
  filter(xerror < xcutoff)
##              CP nsplit  rel.error     xerror         xstd
## 5  0.0005107252      5 0.00102145 0.00102145 0.0005105995
## 6 -1.0000000000      7 0.00000000 0.00102145 0.0005105995

Instead of having to refit the entire tree, we can use the prune() function with 2 arguments:

  • the fully grown tree
  • the cp value to cut the tree at
edible_pruned <- 
  prune(
    edible_tree,
    cp = 0.00052
  )

# Now let's print the pruned tree:
rpart.plot(
  x = edible_pruned,
  type = 5,
  extra = 101
)

Estimating the classification error

If you want to use resubstitution (not recommended) to estimate the accuracy of the model, you can just calculate a confusion matrix:

confusionMatrix(
  # Need to use the predict function to get the predicted class from the tree
  data = predict(edible_pruned, type = "class"), # type = class is to predict
  # Need to give the actual group
  reference = mushroom$edible
)
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  edible poisonous
##   edible      4208         4
##   poisonous      0      3912
##                                           
##                Accuracy : 0.9995          
##                  95% CI : (0.9987, 0.9999)
##     No Information Rate : 0.518           
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.999           
##                                           
##  Mcnemar's Test P-Value : 0.1336          
##                                           
##             Sensitivity : 1.0000          
##             Specificity : 0.9990          
##          Pos Pred Value : 0.9991          
##          Neg Pred Value : 1.0000          
##              Prevalence : 0.5180          
##          Detection Rate : 0.5180          
##    Detection Prevalence : 0.5185          
##       Balanced Accuracy : 0.9995          
##                                           
##        'Positive' Class : edible          
## 

Using resubstitution, we estimate the accuracy to be 99.95%! Compared to the no information rate of 51.8%, that is much, much higher!

But using resubstitution, we get much higher estimates of what the error would be compared to using a new data set to test the data. So what can we do to get an unbiased estimate of the true error?

The xerror column gives us the relative error rate using k-fold partitioning cross-validation. It is relative to the no information error rate:

\[\textrm{xerror} = \frac{\textrm{CV error}}{1 - \textrm{NIR}}\]

So if we want to estimate using k-fold cross-validation what the error rate is, we can find it by:

\[\textrm{CV error} = (1 - \textrm{NIR})\times(\text{xerror})\]

We find the xerror from the CP table:

edible_pruned |> 
  pluck("cptable") |> 
  data.frame() |> 
  # Getting the last row
  slice_tail(n = 1)
##        CP nsplit  rel.error     xerror         xstd
## 5 0.00052      5 0.00102145 0.00102145 0.0005105995

The xerror is 0.001. The no information error is = 1 - 0.518 = 0.482. The estimated error rate is:

\[\textrm{CV error} = 0.482 \times (0.001) = 0.00048\]

Which is basically the same as our error rate using resubstitution of 0.0005.

Let’s compare it to the full tree:

edible_tree |> 
  pluck("cptable") |> 
  data.frame() |> 
  # Getting the last row
  slice_tail(n = 1)
##   CP nsplit rel.error     xerror         xstd
## 6 -1      7         0 0.00102145 0.0005105995

Since the full tree and the pruned tree have the same xerror rate, we end up with the same estimated error rate using resubstitution.

In general, the increase in complexity will often lead to higher estimated error rates!