Predicting class membership for the TidyTuesday Datasaurus Dozen
The Datasaurus Dozen dataset is a collection of 13 sets of x/y data that have very similar summary statistics but look very different when plotted. Our modeling goal is to predict which member of the ‘dozen’ each point belongs to
Let’s start by reading in the data from the datasauRus package
library(tidyverse)
## -- Attaching packages --------
## v ggplot2 3.3.2 v purrr 0.3.4
## v tibble 3.0.3 v dplyr 1.0.0
## v tidyr 1.1.0 v stringr 1.4.0
## v readr 1.3.1 v forcats 0.5.0
## -- Conflicts -----------------
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(datasauRus)
## Warning: package 'datasauRus' was built under R version 4.0.3
datasaurus_dozen
## # A tibble: 1,846 x 3
## dataset x y
## <chr> <dbl> <dbl>
## 1 dino 55.4 97.2
## 2 dino 51.5 96.0
## 3 dino 46.2 94.5
## 4 dino 42.8 91.4
## 5 dino 40.8 88.3
## 6 dino 38.7 84.9
## 7 dino 35.6 79.9
## 8 dino 33.1 77.6
## 9 dino 29.0 74.5
## 10 dino 26.2 71.4
## # ... with 1,836 more rows
theme_set(theme_light())
These datasets are very different from each other!
datasaurus_dozen %>%
ggplot(aes(x, y , color = dataset)) +
geom_point(show.legend = F) +
facet_wrap(~dataset, ncol = 5)
But their summary statistics are so similar:
datasaurus_dozen %>%
group_by(dataset) %>%
summarise(across(c(x, y), list(mean = mean, sd = sd)),
x_y_cor = cor(x, y))
## `summarise()` ungrouping output (override with `.groups` argument)
## # A tibble: 13 x 6
## dataset x_mean x_sd y_mean y_sd x_y_cor
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 away 54.3 16.8 47.8 26.9 -0.0641
## 2 bullseye 54.3 16.8 47.8 26.9 -0.0686
## 3 circle 54.3 16.8 47.8 26.9 -0.0683
## 4 dino 54.3 16.8 47.8 26.9 -0.0645
## 5 dots 54.3 16.8 47.8 26.9 -0.0603
## 6 h_lines 54.3 16.8 47.8 26.9 -0.0617
## 7 high_lines 54.3 16.8 47.8 26.9 -0.0685
## 8 slant_down 54.3 16.8 47.8 26.9 -0.0690
## 9 slant_up 54.3 16.8 47.8 26.9 -0.0686
## 10 star 54.3 16.8 47.8 26.9 -0.0630
## 11 v_lines 54.3 16.8 47.8 26.9 -0.0694
## 12 wide_lines 54.3 16.8 47.8 26.9 -0.0666
## 13 x_shape 54.3 16.8 47.8 26.9 -0.0656
Let’s explore whether we can use modeling to predict which dataset a point belongs to. This is not a large dataset compared to the number of classes(13), so this isn’t a tutorial that shows best practices for a predictive modeling workflow overall, but it does demonstrate how to evaluate a multiclass model, as well as a bit about how random forest models work
datasaurus_dozen%>%
count(dataset)
## # A tibble: 13 x 2
## dataset n
## <chr> <int>
## 1 away 142
## 2 bullseye 142
## 3 circle 142
## 4 dino 142
## 5 dots 142
## 6 h_lines 142
## 7 high_lines 142
## 8 slant_down 142
## 9 slant_up 142
## 10 star 142
## 11 v_lines 142
## 12 wide_lines 142
## 13 x_shape 142
Let’s start out by creating resamples of the Datasaurus Dozen. Notice that we aren’t splitting into testing and training sets, so we won’t have an unbiased estimate of performance on new data. Instead, we will use these resamples to understand the dataset and multiclass model better.
library(tidymodels)
## -- Attaching packages --------
## v broom 0.7.0 v recipes 0.1.13
## v dials 0.0.8 v rsample 0.0.7
## v infer 0.5.3 v tune 0.1.1
## v modeldata 0.0.2 v workflows 0.1.2
## v parsnip 0.1.2 v yardstick 0.0.7
## -- Conflicts -----------------
## x scales::discard() masks purrr::discard()
## x dplyr::filter() masks stats::filter()
## x recipes::fixed() masks stringr::fixed()
## x dplyr::lag() masks stats::lag()
## x yardstick::spec() masks readr::spec()
## x recipes::step() masks stats::step()
set.seed(123)
dino_folds <- datasaurus_dozen %>%
mutate(dataset = factor(dataset)) %>%
bootstraps()
dino_folds
## # Bootstrap sampling
## # A tibble: 25 x 2
## splits id
## <list> <chr>
## 1 <split [1.8K/672]> Bootstrap01
## 2 <split [1.8K/689]> Bootstrap02
## 3 <split [1.8K/680]> Bootstrap03
## 4 <split [1.8K/674]> Bootstrap04
## 5 <split [1.8K/692]> Bootstrap05
## 6 <split [1.8K/689]> Bootstrap06
## 7 <split [1.8K/689]> Bootstrap07
## 8 <split [1.8K/695]> Bootstrap08
## 9 <split [1.8K/664]> Bootstrap09
## 10 <split [1.8K/671]> Bootstrap10
## # ... with 15 more rows
Let’s create a random forest model and set up a model workflow with the model and a formula preprocessor. We are predicting the dataset class (dino vs circle vs bullseye, vs…) from x and y. A random forest model can often do a good job of learning complex interactions in predictors.
rf_spec <- rand_forest(trees = 1000) %>%
set_mode('classification') %>%
set_engine('ranger')
dino_wf <- workflow() %>%
add_formula(dataset ~ x+ y) %>%
add_model(rf_spec)
dino_wf
## == Workflow ==================
## Preprocessor: Formula
## Model: rand_forest()
##
## -- Preprocessor --------------
## dataset ~ x + y
##
## -- Model ---------------------
## Random Forest Model Specification (classification)
##
## Main Arguments:
## trees = 1000
##
## Computational engine: ranger
Let’s fit the random forest model to the bootstrap resamples
doParallel::registerDoParallel()
dino_rs <- fit_resamples(
dino_wf,
resamples =dino_folds,
control = control_resamples(save_pred = TRUE)
)
dino_rs
## # Resampling results
## # Bootstrap sampling
## # A tibble: 25 x 5
## splits id .metrics .notes .predictions
## <list> <chr> <list> <list> <list>
## 1 <split [1.8K/672~ Bootstrap01 <tibble [2 x ~ <tibble [0 x ~ <tibble [672 x 1~
## 2 <split [1.8K/689~ Bootstrap02 <tibble [2 x ~ <tibble [0 x ~ <tibble [689 x 1~
## 3 <split [1.8K/680~ Bootstrap03 <tibble [2 x ~ <tibble [0 x ~ <tibble [680 x 1~
## 4 <split [1.8K/674~ Bootstrap04 <tibble [2 x ~ <tibble [0 x ~ <tibble [674 x 1~
## 5 <split [1.8K/692~ Bootstrap05 <tibble [2 x ~ <tibble [0 x ~ <tibble [692 x 1~
## 6 <split [1.8K/689~ Bootstrap06 <tibble [2 x ~ <tibble [0 x ~ <tibble [689 x 1~
## 7 <split [1.8K/689~ Bootstrap07 <tibble [2 x ~ <tibble [0 x ~ <tibble [689 x 1~
## 8 <split [1.8K/695~ Bootstrap08 <tibble [2 x ~ <tibble [0 x ~ <tibble [695 x 1~
## 9 <split [1.8K/664~ Bootstrap09 <tibble [2 x ~ <tibble [0 x ~ <tibble [664 x 1~
## 10 <split [1.8K/671~ Bootstrap10 <tibble [2 x ~ <tibble [0 x ~ <tibble [671 x 1~
## # ... with 15 more rows
We did it!
How did these models do overall?
collect_metrics(dino_rs)
## # A tibble: 2 x 5
## .metric .estimator mean n std_err
## <chr> <chr> <dbl> <int> <dbl>
## 1 accuracy multiclass 0.449 25 0.00337
## 2 roc_auc hand_till 0.846 25 0.00128
The accuracy is not great, a multiclass problem like this, especially one with so many classes, is harder than a binary classification problem
Since we saved the predictions with save_pred = TRUE we can compute other performance metrics. Notice that by default the positive predictive value (like accuracy) is macro-weighted for multiclass problems
dino_rs %>%
collect_predictions() %>%
group_by(id) %>%
ppv(dataset, .pred_class)
## # A tibble: 25 x 4
## id .metric .estimator .estimate
## <chr> <chr> <chr> <dbl>
## 1 Bootstrap01 ppv macro 0.428
## 2 Bootstrap02 ppv macro 0.422
## 3 Bootstrap03 ppv macro 0.442
## 4 Bootstrap04 ppv macro 0.419
## 5 Bootstrap05 ppv macro 0.448
## 6 Bootstrap06 ppv macro 0.404
## 7 Bootstrap07 ppv macro 0.423
## 8 Bootstrap08 ppv macro 0.423
## 9 Bootstrap09 ppv macro 0.395
## 10 Bootstrap10 ppv macro 0.428
## # ... with 15 more rows
Next, let’s compute ROC curves for each class
dino_rs %>% collect_predictions() %>%
group_by(id) %>%
roc_curve(dataset, .pred_away:.pred_x_shape)%>%
ggplot(aes(1 - specificity, sensitivity, color = id)) +
geom_abline(lty = 2, color = "gray80", size = 1.5) +
geom_path(show.legend = FALSE, alpha = 0.6, size = 1.2) +
facet_wrap(~.level, ncol = 5) +
coord_equal()
dino_rs %>%
collect_predictions() %>%
conf_mat(dataset, .pred_class)
## Truth
## Prediction away bullseye circle dino dots h_lines high_lines slant_down
## away 223 84 45 57 9 57 76 127
## bullseye 123 470 15 95 6 40 100 76
## circle 96 14 850 99 5 30 153 50
## dino 55 67 18 145 5 39 81 149
## dots 23 19 23 33 1213 41 57 46
## h_lines 52 79 37 59 26 894 42 46
## high_lines 116 105 73 148 8 27 376 97
## slant_down 137 51 27 159 11 28 67 325
## slant_up 77 83 33 147 1 28 66 105
## star 60 51 39 80 18 29 60 73
## v_lines 34 64 32 68 8 8 45 77
## wide_lines 177 138 55 136 1 62 67 99
## x_shape 153 101 64 79 6 29 123 65
## Truth
## Prediction slant_up star v_lines wide_lines x_shape
## away 88 60 3 111 81
## bullseye 107 34 42 95 53
## circle 98 84 4 59 29
## dino 113 50 23 65 51
## dots 33 15 13 26 15
## h_lines 56 35 5 57 34
## high_lines 130 59 32 77 79
## slant_down 116 32 41 91 26
## slant_up 260 29 13 95 52
## star 39 754 0 34 84
## v_lines 55 19 1133 34 14
## wide_lines 199 54 21 394 147
## x_shape 44 89 1 135 653
dino_rs %>%
collect_predictions() %>%
conf_mat(dataset, .pred_class) %>%
autoplot(type = "heatmap")
dino_rs %>%
collect_predictions() %>%
filter(.pred_class != dataset) %>%
conf_mat(dataset, .pred_class) %>%
autoplot(type = "heatmap")