knitr::opts_chunk$set(echo = TRUE,
fig.align = "center")
# Load them here:
pacman::p_load(rpart, rpart.plot, tidyverse, caret)
# Changing the default theme
theme_set(theme_bw())
Can we determine if a mushroom is safe to eat based on the other features?
The mushroom.csv file has 23 features on 8124 different mushrooms. The feature we want to classify is type, if the mushroom is edible or poisonous.
# There are two datasets below
# The first contains just a few variables
# The second contains all 23 variables about the 8000 mushrooms
# We'll look at the smaller file with just the 4 features:
# mushroom <- read.csv("mushroom.csv", stringsAsFactors = T)
# Remove the # at the beginning of the line below to use all 23 features:
mushroom <-
read.csv("mushroom_full.csv",
stringsAsFactors = T) |>
rename(edible = type)
skimr::skim(mushroom)
Name | mushroom |
Number of rows | 8124 |
Number of columns | 23 |
_______________________ | |
Column type frequency: | |
factor | 23 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
edible | 0 | 1 | FALSE | 2 | edi: 4208, poi: 3916 |
cap_shape | 0 | 1 | FALSE | 6 | con: 3656, fla: 3152, kno: 828, bel: 452 |
cap_surface | 0 | 1 | FALSE | 4 | sca: 3244, smo: 2556, fib: 2320, gro: 4 |
cap_color | 0 | 1 | FALSE | 10 | bro: 2284, gra: 1840, red: 1500, yel: 1072 |
bruises | 0 | 1 | FALSE | 2 | no: 4748, yes: 3376 |
odor | 0 | 1 | FALSE | 9 | non: 3528, fou: 2160, fis: 576, spi: 576 |
gill_attachment | 0 | 1 | FALSE | 2 | fre: 7914, att: 210 |
gill_spacing | 0 | 1 | FALSE | 2 | clo: 6812, cro: 1312 |
gill_size | 0 | 1 | FALSE | 2 | bro: 5612, nar: 2512 |
gill_color | 0 | 1 | FALSE | 12 | buf: 1728, pin: 1492, whi: 1202, bro: 1048 |
stalk_shape | 0 | 1 | FALSE | 2 | tap: 4608, enl: 3516 |
stalk_root | 0 | 1 | FALSE | 5 | bul: 3776, mis: 2480, equ: 1120, clu: 556 |
stalk_surface_above_ring | 0 | 1 | FALSE | 4 | smo: 5176, sil: 2372, fib: 552, sca: 24 |
stalk_surface_below_ring | 0 | 1 | FALSE | 4 | smo: 4936, sil: 2304, fib: 600, sca: 284 |
stalk_color_above_ring | 0 | 1 | FALSE | 9 | whi: 4464, pin: 1872, gra: 576, bro: 448 |
stalk_color_below_ring | 0 | 1 | FALSE | 9 | whi: 4384, pin: 1872, gra: 576, bro: 512 |
veil_type | 0 | 1 | FALSE | 1 | par: 8124 |
veil_color | 0 | 1 | FALSE | 4 | whi: 7924, bro: 96, ora: 96, yel: 8 |
ring_number | 0 | 1 | FALSE | 3 | one: 7488, two: 600, non: 36 |
ring_type | 0 | 1 | FALSE | 5 | pen: 3968, eva: 2776, lar: 1296, fla: 48 |
spore_print_color | 0 | 1 | FALSE | 9 | whi: 2388, bro: 1968, bla: 1872, cho: 1632 |
population | 0 | 1 | FALSE | 6 | sev: 4040, sol: 1712, sca: 1248, num: 400 |
habitat | 0 | 1 | FALSE | 7 | woo: 3148, gra: 2148, pat: 1144, lea: 832 |
Since all the features are factors, we should check to see how many groups each feature has:
skimr::skim(mushroom) |>
ggplot(
mapping = aes(
x = reorder(skim_variable, factor.n_unique),
y = factor.n_unique
)
) +
geom_col(
color = "black",
fill = "steelblue"
) +
theme(
axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1)
) +
labs(
x = "Mushroom Feature",
y = "Number of groups"
) +
scale_y_continuous(
breaks = seq(0, 12, by = 2),
expand = c(0, 0, 0.05, 0)
)
When using the full data set, we should remove veil_type from the data since there is only 1 group - Not a variable
mushroom <-
mushroom %>%
select(-veil_type)
Build the classificiation tree using rpart()
formula = outcome ~ predictors
data = data set
method = "class"
parms = list(split = information)
The last 3 arguments are used to have rpart grow the tree as “fully” as it can
minsplit = 0
minbucket = 0
cp = -1
RNGversion("4.1.0")
set.seed(5230)
edible_tree <-
rpart(
formula = edible ~ .,
data = mushroom,
method = "class",
parms = list(split = "information"),
minsplit = 0,
minbucket = 0,
cp = -1
)
# display the cp table
edible_tree |>
pluck("cptable") |>
data.frame()
## CP nsplit rel.error xerror xstd
## 1 0.9693564862 0 1.000000000 1.000000000 0.0115008931
## 2 0.0183861083 1 0.030643514 0.030643514 0.0027766205
## 3 0.0051072523 2 0.012257406 0.012257406 0.0017639698
## 4 0.0010214505 4 0.002042901 0.002042901 0.0007219188
## 5 0.0005107252 5 0.001021450 0.001021450 0.0005105995
## 6 -1.0000000000 7 0.000000000 0.001021450 0.0005105995
Since the algorithm is set to fully grow the tree, we need to prune it to prevent overfitting. We can use printcp() to see at which point the tree should be pruned.
The xerror is the error of the tree using cross validation (x for cross error)
edible_tree |>
pluck("cptable") |>
data.frame()
## CP nsplit rel.error xerror xstd
## 1 0.9693564862 0 1.000000000 1.000000000 0.0115008931
## 2 0.0183861083 1 0.030643514 0.030643514 0.0027766205
## 3 0.0051072523 2 0.012257406 0.012257406 0.0017639698
## 4 0.0010214505 4 0.002042901 0.002042901 0.0007219188
## 5 0.0005107252 5 0.001021450 0.001021450 0.0005105995
## 6 -1.0000000000 7 0.000000000 0.001021450 0.0005105995
We want the smallest tree (nsplit) with xerror < lowest xerror + xstd. Then we set the cp argument to be slightly larger than the CP value for the corresponding tree.
# Finding the xerror cutoff:
edible_tree |>
pluck("cptable") |>
data.frame() |>
slice_min(xerror) |>
mutate(cut_off = xerror + xstd) |>
slice(1) |>
select(cut_off) |>
as.numeric() ->
xcutoff
# Now we find the rows with an xerror below the xcutoff
edible_tree$cptable |>
data.frame() |>
filter(xerror < xcutoff)
## CP nsplit rel.error xerror xstd
## 5 0.0005107252 5 0.00102145 0.00102145 0.0005105995
## 6 -1.0000000000 7 0.00000000 0.00102145 0.0005105995
Instead of having to refit the entire tree, we can use the prune() function with 2 arguments:
edible_pruned <-
prune(
edible_tree,
cp = 0.00052
)
# Now let's print the pruned tree:
rpart.plot(
x = edible_pruned,
type = 5,
extra = 101
)
If you want to use resubstitution (not recommended) to estimate the accuracy of the model, you can just calculate a confusion matrix:
confusionMatrix(
# Need to use the predict function to get the predicted class from the tree
data = predict(edible_pruned, type = "class"), # type = class is to predict
# Need to give the actual group
reference = mushroom$edible
)
## Confusion Matrix and Statistics
##
## Reference
## Prediction edible poisonous
## edible 4208 4
## poisonous 0 3912
##
## Accuracy : 0.9995
## 95% CI : (0.9987, 0.9999)
## No Information Rate : 0.518
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.999
##
## Mcnemar's Test P-Value : 0.1336
##
## Sensitivity : 1.0000
## Specificity : 0.9990
## Pos Pred Value : 0.9991
## Neg Pred Value : 1.0000
## Prevalence : 0.5180
## Detection Rate : 0.5180
## Detection Prevalence : 0.5185
## Balanced Accuracy : 0.9995
##
## 'Positive' Class : edible
##
Using resubstitution, we estimate the accuracy to be 99.95%! Compared to the no information rate of 51.8%, that is much, much higher!
But using resubstitution, we get much higher estimates of what the error would be compared to using a new data set to test the data. So what can we do to get an unbiased estimate of the true error?
The xerror column gives us the relative error rate using k-fold partitioning cross-validation. It is relative to the no information error rate:
\[\textrm{xerror} = \frac{\textrm{CV error}}{1 - \textrm{NIR}}\]
So if we want to estimate using k-fold cross-validation what the error rate is, we can find it by:
\[\textrm{CV error} = (1 - \textrm{NIR})\times(\text{xerror})\]
We find the xerror from the CP table:
edible_pruned |>
pluck("cptable") |>
data.frame() |>
# Getting the last row
slice_tail(n = 1)
## CP nsplit rel.error xerror xstd
## 5 0.00052 5 0.00102145 0.00102145 0.0005105995
The xerror is 0.001. The no information error is = 1 - 0.518 = 0.482. The estimated error rate is:
\[\textrm{CV error} = 0.482 \times (0.001) = 0.00048\]
Which is basically the same as our error rate using resubstitution of 0.0005.
Let’s compare it to the full tree:
edible_tree |>
pluck("cptable") |>
data.frame() |>
# Getting the last row
slice_tail(n = 1)
## CP nsplit rel.error xerror xstd
## 6 -1 7 0 0.00102145 0.0005105995
Since the full tree and the pruned tree have the same xerror rate, we end up with the same estimated error rate using resubstitution.
In general, the increase in complexity will often lead to higher estimated error rates!