Predicting class membership for the TidyTuesday Datasaurus Dozen

Explore the data

The Datasaurus Dozen dataset is a collection of 13 sets of x/y data that have very similar summary statistics but look very different when plotted. Our modeling goal is to predict which member of the ‘dozen’ each point belongs to

Let’s start by reading in the data from the datasauRus package

library(tidyverse)
## -- Attaching packages --------
## v ggplot2 3.3.2     v purrr   0.3.4
## v tibble  3.0.3     v dplyr   1.0.0
## v tidyr   1.1.0     v stringr 1.4.0
## v readr   1.3.1     v forcats 0.5.0
## -- Conflicts -----------------
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(datasauRus)
## Warning: package 'datasauRus' was built under R version 4.0.3
datasaurus_dozen
## # A tibble: 1,846 x 3
##    dataset     x     y
##    <chr>   <dbl> <dbl>
##  1 dino     55.4  97.2
##  2 dino     51.5  96.0
##  3 dino     46.2  94.5
##  4 dino     42.8  91.4
##  5 dino     40.8  88.3
##  6 dino     38.7  84.9
##  7 dino     35.6  79.9
##  8 dino     33.1  77.6
##  9 dino     29.0  74.5
## 10 dino     26.2  71.4
## # ... with 1,836 more rows
theme_set(theme_light())

These datasets are very different from each other!

datasaurus_dozen %>%
               ggplot(aes(x, y , color = dataset)) + 
               geom_point(show.legend = F) +
               facet_wrap(~dataset, ncol = 5)

But their summary statistics are so similar:

datasaurus_dozen %>%
               group_by(dataset) %>%
               summarise(across(c(x, y), list(mean = mean, sd = sd)),
                         x_y_cor = cor(x, y))
## `summarise()` ungrouping output (override with `.groups` argument)
## # A tibble: 13 x 6
##    dataset    x_mean  x_sd y_mean  y_sd x_y_cor
##    <chr>       <dbl> <dbl>  <dbl> <dbl>   <dbl>
##  1 away         54.3  16.8   47.8  26.9 -0.0641
##  2 bullseye     54.3  16.8   47.8  26.9 -0.0686
##  3 circle       54.3  16.8   47.8  26.9 -0.0683
##  4 dino         54.3  16.8   47.8  26.9 -0.0645
##  5 dots         54.3  16.8   47.8  26.9 -0.0603
##  6 h_lines      54.3  16.8   47.8  26.9 -0.0617
##  7 high_lines   54.3  16.8   47.8  26.9 -0.0685
##  8 slant_down   54.3  16.8   47.8  26.9 -0.0690
##  9 slant_up     54.3  16.8   47.8  26.9 -0.0686
## 10 star         54.3  16.8   47.8  26.9 -0.0630
## 11 v_lines      54.3  16.8   47.8  26.9 -0.0694
## 12 wide_lines   54.3  16.8   47.8  26.9 -0.0666
## 13 x_shape      54.3  16.8   47.8  26.9 -0.0656

Let’s explore whether we can use modeling to predict which dataset a point belongs to. This is not a large dataset compared to the number of classes(13), so this isn’t a tutorial that shows best practices for a predictive modeling workflow overall, but it does demonstrate how to evaluate a multiclass model, as well as a bit about how random forest models work

datasaurus_dozen%>%
               count(dataset)
## # A tibble: 13 x 2
##    dataset        n
##    <chr>      <int>
##  1 away         142
##  2 bullseye     142
##  3 circle       142
##  4 dino         142
##  5 dots         142
##  6 h_lines      142
##  7 high_lines   142
##  8 slant_down   142
##  9 slant_up     142
## 10 star         142
## 11 v_lines      142
## 12 wide_lines   142
## 13 x_shape      142

Build a model

Let’s start out by creating resamples of the Datasaurus Dozen. Notice that we aren’t splitting into testing and training sets, so we won’t have an unbiased estimate of performance on new data. Instead, we will use these resamples to understand the dataset and multiclass model better.

library(tidymodels)
## -- Attaching packages --------
## v broom     0.7.0      v recipes   0.1.13
## v dials     0.0.8      v rsample   0.0.7 
## v infer     0.5.3      v tune      0.1.1 
## v modeldata 0.0.2      v workflows 0.1.2 
## v parsnip   0.1.2      v yardstick 0.0.7
## -- Conflicts -----------------
## x scales::discard() masks purrr::discard()
## x dplyr::filter()   masks stats::filter()
## x recipes::fixed()  masks stringr::fixed()
## x dplyr::lag()      masks stats::lag()
## x yardstick::spec() masks readr::spec()
## x recipes::step()   masks stats::step()
set.seed(123)

dino_folds <- datasaurus_dozen %>%
               mutate(dataset = factor(dataset)) %>%
               bootstraps()

dino_folds
## # Bootstrap sampling 
## # A tibble: 25 x 2
##    splits             id         
##    <list>             <chr>      
##  1 <split [1.8K/672]> Bootstrap01
##  2 <split [1.8K/689]> Bootstrap02
##  3 <split [1.8K/680]> Bootstrap03
##  4 <split [1.8K/674]> Bootstrap04
##  5 <split [1.8K/692]> Bootstrap05
##  6 <split [1.8K/689]> Bootstrap06
##  7 <split [1.8K/689]> Bootstrap07
##  8 <split [1.8K/695]> Bootstrap08
##  9 <split [1.8K/664]> Bootstrap09
## 10 <split [1.8K/671]> Bootstrap10
## # ... with 15 more rows

Let’s create a random forest model and set up a model workflow with the model and a formula preprocessor. We are predicting the dataset class (dino vs circle vs bullseye, vs…) from x and y. A random forest model can often do a good job of learning complex interactions in predictors.

rf_spec <- rand_forest(trees = 1000) %>%
               set_mode('classification') %>%
               set_engine('ranger')

dino_wf <- workflow() %>%
               add_formula(dataset ~ x+ y) %>%
               add_model(rf_spec)


dino_wf
## == Workflow ==================
## Preprocessor: Formula
## Model: rand_forest()
## 
## -- Preprocessor --------------
## dataset ~ x + y
## 
## -- Model ---------------------
## Random Forest Model Specification (classification)
## 
## Main Arguments:
##   trees = 1000
## 
## Computational engine: ranger

Let’s fit the random forest model to the bootstrap resamples

doParallel::registerDoParallel()

dino_rs <- fit_resamples(
               dino_wf,
               resamples =dino_folds,
               control = control_resamples(save_pred = TRUE)
)

dino_rs
## # Resampling results
## # Bootstrap sampling 
## # A tibble: 25 x 5
##    splits            id          .metrics       .notes         .predictions     
##    <list>            <chr>       <list>         <list>         <list>           
##  1 <split [1.8K/672~ Bootstrap01 <tibble [2 x ~ <tibble [0 x ~ <tibble [672 x 1~
##  2 <split [1.8K/689~ Bootstrap02 <tibble [2 x ~ <tibble [0 x ~ <tibble [689 x 1~
##  3 <split [1.8K/680~ Bootstrap03 <tibble [2 x ~ <tibble [0 x ~ <tibble [680 x 1~
##  4 <split [1.8K/674~ Bootstrap04 <tibble [2 x ~ <tibble [0 x ~ <tibble [674 x 1~
##  5 <split [1.8K/692~ Bootstrap05 <tibble [2 x ~ <tibble [0 x ~ <tibble [692 x 1~
##  6 <split [1.8K/689~ Bootstrap06 <tibble [2 x ~ <tibble [0 x ~ <tibble [689 x 1~
##  7 <split [1.8K/689~ Bootstrap07 <tibble [2 x ~ <tibble [0 x ~ <tibble [689 x 1~
##  8 <split [1.8K/695~ Bootstrap08 <tibble [2 x ~ <tibble [0 x ~ <tibble [695 x 1~
##  9 <split [1.8K/664~ Bootstrap09 <tibble [2 x ~ <tibble [0 x ~ <tibble [664 x 1~
## 10 <split [1.8K/671~ Bootstrap10 <tibble [2 x ~ <tibble [0 x ~ <tibble [671 x 1~
## # ... with 15 more rows

We did it!

Evaluate model

How did these models do overall?

collect_metrics(dino_rs)
## # A tibble: 2 x 5
##   .metric  .estimator  mean     n std_err
##   <chr>    <chr>      <dbl> <int>   <dbl>
## 1 accuracy multiclass 0.449    25 0.00337
## 2 roc_auc  hand_till  0.846    25 0.00128

The accuracy is not great, a multiclass problem like this, especially one with so many classes, is harder than a binary classification problem

Since we saved the predictions with save_pred = TRUE we can compute other performance metrics. Notice that by default the positive predictive value (like accuracy) is macro-weighted for multiclass problems

dino_rs %>%
               collect_predictions() %>%
               group_by(id) %>%
               ppv(dataset, .pred_class)
## # A tibble: 25 x 4
##    id          .metric .estimator .estimate
##    <chr>       <chr>   <chr>          <dbl>
##  1 Bootstrap01 ppv     macro          0.428
##  2 Bootstrap02 ppv     macro          0.422
##  3 Bootstrap03 ppv     macro          0.442
##  4 Bootstrap04 ppv     macro          0.419
##  5 Bootstrap05 ppv     macro          0.448
##  6 Bootstrap06 ppv     macro          0.404
##  7 Bootstrap07 ppv     macro          0.423
##  8 Bootstrap08 ppv     macro          0.423
##  9 Bootstrap09 ppv     macro          0.395
## 10 Bootstrap10 ppv     macro          0.428
## # ... with 15 more rows

Next, let’s compute ROC curves for each class

dino_rs %>% collect_predictions() %>%
               group_by(id) %>%
               roc_curve(dataset, .pred_away:.pred_x_shape)%>%
  ggplot(aes(1 - specificity, sensitivity, color = id)) +
  geom_abline(lty = 2, color = "gray80", size = 1.5) +
  geom_path(show.legend = FALSE, alpha = 0.6, size = 1.2) +
  facet_wrap(~.level, ncol = 5) +
  coord_equal()

dino_rs %>%
  collect_predictions() %>%
  conf_mat(dataset, .pred_class)
##             Truth
## Prediction   away bullseye circle dino dots h_lines high_lines slant_down
##   away        223       84     45   57    9      57         76        127
##   bullseye    123      470     15   95    6      40        100         76
##   circle       96       14    850   99    5      30        153         50
##   dino         55       67     18  145    5      39         81        149
##   dots         23       19     23   33 1213      41         57         46
##   h_lines      52       79     37   59   26     894         42         46
##   high_lines  116      105     73  148    8      27        376         97
##   slant_down  137       51     27  159   11      28         67        325
##   slant_up     77       83     33  147    1      28         66        105
##   star         60       51     39   80   18      29         60         73
##   v_lines      34       64     32   68    8       8         45         77
##   wide_lines  177      138     55  136    1      62         67         99
##   x_shape     153      101     64   79    6      29        123         65
##             Truth
## Prediction   slant_up star v_lines wide_lines x_shape
##   away             88   60       3        111      81
##   bullseye        107   34      42         95      53
##   circle           98   84       4         59      29
##   dino            113   50      23         65      51
##   dots             33   15      13         26      15
##   h_lines          56   35       5         57      34
##   high_lines      130   59      32         77      79
##   slant_down      116   32      41         91      26
##   slant_up        260   29      13         95      52
##   star             39  754       0         34      84
##   v_lines          55   19    1133         34      14
##   wide_lines      199   54      21        394     147
##   x_shape          44   89       1        135     653
dino_rs %>%
  collect_predictions() %>%
  conf_mat(dataset, .pred_class) %>%
  autoplot(type = "heatmap")

dino_rs %>%
  collect_predictions() %>%
  filter(.pred_class != dataset) %>%
  conf_mat(dataset, .pred_class) %>%
  autoplot(type = "heatmap")