library(tidyverse)
library(correlationfunnel)
library(skimr)
library(tidymodels)
library(themis)
library(xgboost)
library(vip)

Import Data

records <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/main/data/2021/2021-05-25/records.csv')

Explore Data

skimr::skim(records)
Data summary
Name records
Number of rows 2334
Number of columns 9
_______________________
Column type frequency:
character 6
Date 1
numeric 2
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
track 0 1 12 21 0 16 0
type 0 1 9 10 0 2 0
shortcut 0 1 2 3 0 2 0
player 0 1 2 10 0 65 0
system_played 0 1 3 4 0 2 0
time_period 0 1 3 9 0 1577 0

Variable type: Date

skim_variable n_missing complete_rate min max median n_unique
date 0 1 1997-02-15 2021-02-25 2004-06-18 1096

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
time 0 1 90.62 66.67 14.59 39.03 86.19 120.16 375.83 ▇▆▁▁▁
record_duration 0 1 220.75 429.08 0.00 6.00 51.00 198.75 3659.00 ▇▁▁▁▁

Clean Data

data_clean <- records %>%
  mutate(
    shortcut   = factor(shortcut,       levels = c("No", "Yes")),
    type       = as.factor(type),
    system_played = as.factor(system_played),
    track      = as.factor(track),
    player     = as.factor(player),
    year       = lubridate::year(date),
    month      = lubridate::month(date)
  ) %>%
  select(-time_period, -date)

data_clean %>% count(shortcut)
## # A tibble: 2 × 2
##   shortcut     n
##   <fct>    <int>
## 1 No        1447
## 2 Yes        887

Explore Data

Shortcut vs time

data_clean %>%
  ggplot(aes(shortcut, time)) +
  geom_boxplot() +
  labs(title = "Record Time by Shortcut Use",
       x = "Shortcut", y = "Time (seconds)")

Record duration by shortcut

data_clean %>%
  ggplot(aes(shortcut, record_duration)) +
  geom_boxplot() +
  labs(title = "Record Duration by Shortcut Use",
       x = "Shortcut", y = "Duration (days)")

Correlation Funnel

data_binarized <- data_clean %>%
  select(-player) %>%      # too many levels
  binarize()

target_col <- names(data_binarized) %>%
  str_subset("shortcut") %>%
  tail(1)

data_correlation <- data_binarized %>%
  correlate(target = !!sym(target_col))

data_correlation %>%
  plot_correlation_funnel() +
  labs(title = "Correlation Funnel: Shortcut vs Non-Shortcut Records")

Model Building

Split Data

set.seed(1234)
data_split <- initial_split(data_clean, strata = shortcut)
data_train <- training(data_split)
data_test  <- testing(data_split)
data_cv    <- rsample::vfold_cv(data_train, v = 5, strata = shortcut)
data_cv
## #  5-fold cross-validation using stratification 
## # A tibble: 5 × 2
##   splits             id   
##   <list>             <chr>
## 1 <split [1400/350]> Fold1
## 2 <split [1400/350]> Fold2
## 3 <split [1400/350]> Fold3
## 4 <split [1400/350]> Fold4
## 5 <split [1400/350]> Fold5

Preprocess Data

xgboost_rec <- recipe(shortcut ~ ., data = data_train) %>%
  step_other(player, track, threshold = 0.02) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_zv(all_predictors()) %>%
  step_normalize(all_numeric_predictors()) %>%
  step_smote(shortcut)

xgboost_rec %>% prep() %>% juice() %>% glimpse()
## Rows: 2,170
## Columns: 36
## $ time                        <dbl> 0.5572584, 0.5278139, 0.4674526, 0.4658331…
## $ record_duration             <dbl> -0.48587649, -0.49784466, -0.51460011, -0.…
## $ year                        <dbl> -1.1513448, -1.1513448, -1.1513448, -1.151…
## $ month                       <dbl> -1.3040588, -1.3040588, -0.7263113, -0.726…
## $ track_Bowser.s.Castle       <dbl> -0.1766741, -0.1766741, -0.1766741, -0.176…
## $ track_Choco.Mountain        <dbl> -0.268817, -0.268817, -0.268817, -0.268817…
## $ track_D.K..s.Jungle.Parkway <dbl> -0.2901974, -0.2901974, -0.2901974, -0.290…
## $ track_Frappe.Snowland       <dbl> -0.2901974, -0.2901974, -0.2901974, -0.290…
## $ track_Kalimari.Desert       <dbl> -0.2820196, -0.2820196, -0.2820196, -0.282…
## $ track_Koopa.Troopa.Beach    <dbl> -0.1947767, -0.1947767, -0.1947767, -0.194…
## $ track_Luigi.Raceway         <dbl> 3.97717, 3.97717, 3.97717, 3.97717, 3.9771…
## $ track_Mario.Raceway         <dbl> -0.2626581, -0.2626581, -0.2626581, -0.262…
## $ track_Moo.Moo.Farm          <dbl> -0.1850929, -0.1850929, -0.1850929, -0.185…
## $ track_Rainbow.Road          <dbl> -0.2970825, -0.2970825, -0.2970825, -0.297…
## $ track_Royal.Raceway         <dbl> -0.2589109, -0.2589109, -0.2589109, -0.258…
## $ track_Sherbet.Land          <dbl> -0.2448048, -0.2448048, -0.2448048, -0.244…
## $ track_Toad.s.Turnpike       <dbl> -0.3038614, -0.3038614, -0.3038614, -0.303…
## $ track_Wario.Stadium         <dbl> -0.316038, -0.316038, -0.316038, -0.316038…
## $ track_Yoshi.Valley          <dbl> -0.2736709, -0.2736709, -0.2736709, -0.273…
## $ type_Three.Lap              <dbl> 0.9726595, 0.9726595, 0.9726595, 0.9726595…
## $ player_Booth                <dbl> -0.2395183, -0.2395183, -0.2395183, -0.239…
## $ player_Dan                  <dbl> -0.3072134, -0.3072134, -0.3072134, -0.307…
## $ player_Jonathan             <dbl> -0.1469258, -0.1469258, -0.1469258, -0.146…
## $ player_Karlo                <dbl> -0.1448844, -0.1448844, -0.1448844, -0.144…
## $ player_Lacey                <dbl> -0.1867369, -0.1867369, -0.1867369, -0.186…
## $ player_Launspach            <dbl> -0.1660801, -0.1660801, -0.1660801, 6.0177…
## $ player_MJ                   <dbl> -0.3105416, -0.3105416, -0.3105416, -0.310…
## $ player_MR                   <dbl> -0.4241905, -0.4241905, -0.4241905, -0.424…
## $ player_Penev                <dbl> -0.4279337, -0.4279337, -0.4279337, -0.427…
## $ player_Peter.E              <dbl> -0.1963508, -0.1963508, -0.1963508, -0.196…
## $ player_Sami                 <dbl> -0.1605509, -0.1605509, -0.1605509, -0.160…
## $ player_VAJ                  <dbl> -0.1915956, -0.1915956, -0.1915956, -0.191…
## $ player_Zwartjes             <dbl> -0.1899879, -0.1899879, -0.1899879, -0.189…
## $ player_other                <dbl> 2.180335, 2.180335, 2.180335, -0.458383, -…
## $ system_played_PAL           <dbl> -1.5333448, -1.5333448, -1.5333448, -1.533…
## $ shortcut                    <fct> No, No, No, No, No, No, No, No, No, No, No…

Specify Model

xgboost_spec <-
  boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(),
             learn_rate = tune(), loss_reduction = tune(), sample_size = tune()) %>%
  set_mode("classification") %>%
  set_engine("xgboost")

xgboost_workflow <-
  workflow() %>%
  add_recipe(xgboost_rec) %>%
  add_model(xgboost_spec)

Tune Hyperparameters

doParallel::registerDoParallel()
set.seed(65743)

xgboost_tune <-
  tune_grid(xgboost_workflow,
            resamples = data_cv,
            grid = 5,
            control = control_grid(save_pred = TRUE))

Model Evaluation

Classification Metrics (accuracy & roc_auc)

collect_metrics(xgboost_tune)
## # A tibble: 15 × 12
##    trees min_n tree_depth learn_rate loss_reduction sample_size .metric    
##    <int> <int>      <int>      <dbl>          <dbl>       <dbl> <chr>      
##  1     1    30         15    0.0750    0.0422             0.625 accuracy   
##  2     1    30         15    0.0750    0.0422             0.625 brier_class
##  3     1    30         15    0.0750    0.0422             0.625 roc_auc    
##  4   500    21          1    0.316     0.0000562          1     accuracy   
##  5   500    21          1    0.316     0.0000562          1     brier_class
##  6   500    21          1    0.316     0.0000562          1     roc_auc    
##  7  1000     2          8    0.0178    0.0000000001       0.5   accuracy   
##  8  1000     2          8    0.0178    0.0000000001       0.5   brier_class
##  9  1000     2          8    0.0178    0.0000000001       0.5   roc_auc    
## 10  1500    11          4    0.001    31.6                0.75  accuracy   
## 11  1500    11          4    0.001    31.6                0.75  brier_class
## 12  1500    11          4    0.001    31.6                0.75  roc_auc    
## 13  2000    40         11    0.00422   0.0000000750       0.875 accuracy   
## 14  2000    40         11    0.00422   0.0000000750       0.875 brier_class
## 15  2000    40         11    0.00422   0.0000000750       0.875 roc_auc    
## # ℹ 5 more variables: .estimator <chr>, mean <dbl>, n <int>, std_err <dbl>,
## #   .config <chr>

ROC Curve

collect_predictions(xgboost_tune) %>%
  group_by(id) %>%
  roc_curve(shortcut, .pred_Yes) %>%
  autoplot() +
  labs(title = "ROC Curve: Shortcut Classification")

Fit Final Model

xgboost_last <- xgboost_workflow %>%
  finalize_workflow(select_best(xgboost_tune, metric = "roc_auc")) %>%
  last_fit(data_split)

collect_metrics(xgboost_last)
## # A tibble: 3 × 4
##   .metric     .estimator .estimate .config        
##   <chr>       <chr>          <dbl> <chr>          
## 1 accuracy    binary         0.731 pre0_mod0_post0
## 2 roc_auc     binary         0.834 pre0_mod0_post0
## 3 brier_class binary         0.157 pre0_mod0_post0

Confusion Matrix

collect_predictions(xgboost_last) %>%
  yardstick::conf_mat(shortcut, .pred_class) %>%
  autoplot(type = "heatmap") +
  labs(title = "Confusion Matrix: Test Set Predictions")

Variable Importance

xgboost_last %>%
  workflows::extract_fit_engine() %>%
  vip(num_features = 15) +
  labs(title = "Top 15 Most Important Features")

Predict on Test Dataset

# Extract the fitted workflow
final_model <- xgboost_last %>% extract_workflow()

# Predict class and probabilities on test data
test_predictions <- final_model %>%
  predict(data_test) %>%
  bind_cols(
    final_model %>% predict(data_test, type = "prob"),
    data_test
  )

# View predictions
test_predictions %>%
  select(shortcut, .pred_class, .pred_Yes, .pred_No, track, type, time) %>%
  head(20)
## # A tibble: 20 × 7
##    shortcut .pred_class .pred_Yes .pred_No track         type       time
##    <fct>    <fct>           <dbl>    <dbl> <fct>         <fct>     <dbl>
##  1 No       No             0.0427    0.957 Luigi Raceway Three Lap  133.
##  2 No       No             0.101     0.899 Luigi Raceway Three Lap  130.
##  3 No       No             0.169     0.831 Luigi Raceway Three Lap  125.
##  4 No       No             0.157     0.843 Luigi Raceway Three Lap  123.
##  5 No       No             0.172     0.828 Luigi Raceway Three Lap  121.
##  6 No       No             0.158     0.842 Luigi Raceway Three Lap  120.
##  7 No       No             0.101     0.899 Luigi Raceway Three Lap  120.
##  8 No       No             0.104     0.896 Luigi Raceway Three Lap  120.
##  9 No       No             0.135     0.865 Luigi Raceway Three Lap  120.
## 10 No       No             0.118     0.882 Luigi Raceway Three Lap  120.
## 11 No       No             0.103     0.897 Luigi Raceway Three Lap  120.
## 12 No       No             0.116     0.884 Luigi Raceway Three Lap  119.
## 13 No       No             0.122     0.878 Luigi Raceway Three Lap  119.
## 14 No       No             0.114     0.886 Luigi Raceway Three Lap  119.
## 15 No       No             0.0899    0.910 Luigi Raceway Three Lap  119.
## 16 No       No             0.108     0.892 Luigi Raceway Three Lap  118.
## 17 No       No             0.0989    0.901 Luigi Raceway Three Lap  118.
## 18 No       No             0.0782    0.922 Luigi Raceway Three Lap  118.
## 19 No       No             0.0780    0.922 Luigi Raceway Three Lap  118.
## 20 No       No             0.103     0.897 Luigi Raceway Three Lap  118.
# Summary of predictions vs actuals
test_predictions %>%
  count(shortcut, .pred_class) %>%
  mutate(correct = shortcut == .pred_class)
## # A tibble: 4 × 4
##   shortcut .pred_class     n correct
##   <fct>    <fct>       <int> <lgl>  
## 1 No       No            270 TRUE   
## 2 No       Yes            92 FALSE  
## 3 Yes      No             65 FALSE  
## 4 Yes      Yes           157 TRUE