library(tidyverse)
library(correlationfunnel)
library(skimr)
library(tidymodels)
library(themis)
library(xgboost)
library(vip)
records <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/main/data/2021/2021-05-25/records.csv')
skimr::skim(records)
| Name | records |
| Number of rows | 2334 |
| Number of columns | 9 |
| _______________________ | |
| Column type frequency: | |
| character | 6 |
| Date | 1 |
| numeric | 2 |
| ________________________ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| track | 0 | 1 | 12 | 21 | 0 | 16 | 0 |
| type | 0 | 1 | 9 | 10 | 0 | 2 | 0 |
| shortcut | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
| player | 0 | 1 | 2 | 10 | 0 | 65 | 0 |
| system_played | 0 | 1 | 3 | 4 | 0 | 2 | 0 |
| time_period | 0 | 1 | 3 | 9 | 0 | 1577 | 0 |
Variable type: Date
| skim_variable | n_missing | complete_rate | min | max | median | n_unique |
|---|---|---|---|---|---|---|
| date | 0 | 1 | 1997-02-15 | 2021-02-25 | 2004-06-18 | 1096 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| time | 0 | 1 | 90.62 | 66.67 | 14.59 | 39.03 | 86.19 | 120.16 | 375.83 | ▇▆▁▁▁ |
| record_duration | 0 | 1 | 220.75 | 429.08 | 0.00 | 6.00 | 51.00 | 198.75 | 3659.00 | ▇▁▁▁▁ |
data_clean <- records %>%
mutate(
shortcut = factor(shortcut, levels = c("No", "Yes")),
type = as.factor(type),
system_played = as.factor(system_played),
track = as.factor(track),
player = as.factor(player),
year = lubridate::year(date),
month = lubridate::month(date)
) %>%
select(-time_period, -date)
data_clean %>% count(shortcut)
## # A tibble: 2 × 2
## shortcut n
## <fct> <int>
## 1 No 1447
## 2 Yes 887
Shortcut vs time
data_clean %>%
ggplot(aes(shortcut, time)) +
geom_boxplot() +
labs(title = "Record Time by Shortcut Use",
x = "Shortcut", y = "Time (seconds)")
Record duration by shortcut
data_clean %>%
ggplot(aes(shortcut, record_duration)) +
geom_boxplot() +
labs(title = "Record Duration by Shortcut Use",
x = "Shortcut", y = "Duration (days)")
data_binarized <- data_clean %>%
select(-player) %>% # too many levels
binarize()
target_col <- names(data_binarized) %>%
str_subset("shortcut") %>%
tail(1)
data_correlation <- data_binarized %>%
correlate(target = !!sym(target_col))
data_correlation %>%
plot_correlation_funnel() +
labs(title = "Correlation Funnel: Shortcut vs Non-Shortcut Records")
set.seed(1234)
data_split <- initial_split(data_clean, strata = shortcut)
data_train <- training(data_split)
data_test <- testing(data_split)
data_cv <- rsample::vfold_cv(data_train, v = 5, strata = shortcut)
data_cv
## # 5-fold cross-validation using stratification
## # A tibble: 5 × 2
## splits id
## <list> <chr>
## 1 <split [1400/350]> Fold1
## 2 <split [1400/350]> Fold2
## 3 <split [1400/350]> Fold3
## 4 <split [1400/350]> Fold4
## 5 <split [1400/350]> Fold5
xgboost_rec <- recipe(shortcut ~ ., data = data_train) %>%
step_other(player, track, threshold = 0.02) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_numeric_predictors()) %>%
step_smote(shortcut)
xgboost_rec %>% prep() %>% juice() %>% glimpse()
## Rows: 2,170
## Columns: 36
## $ time <dbl> 0.5572584, 0.5278139, 0.4674526, 0.4658331…
## $ record_duration <dbl> -0.48587649, -0.49784466, -0.51460011, -0.…
## $ year <dbl> -1.1513448, -1.1513448, -1.1513448, -1.151…
## $ month <dbl> -1.3040588, -1.3040588, -0.7263113, -0.726…
## $ track_Bowser.s.Castle <dbl> -0.1766741, -0.1766741, -0.1766741, -0.176…
## $ track_Choco.Mountain <dbl> -0.268817, -0.268817, -0.268817, -0.268817…
## $ track_D.K..s.Jungle.Parkway <dbl> -0.2901974, -0.2901974, -0.2901974, -0.290…
## $ track_Frappe.Snowland <dbl> -0.2901974, -0.2901974, -0.2901974, -0.290…
## $ track_Kalimari.Desert <dbl> -0.2820196, -0.2820196, -0.2820196, -0.282…
## $ track_Koopa.Troopa.Beach <dbl> -0.1947767, -0.1947767, -0.1947767, -0.194…
## $ track_Luigi.Raceway <dbl> 3.97717, 3.97717, 3.97717, 3.97717, 3.9771…
## $ track_Mario.Raceway <dbl> -0.2626581, -0.2626581, -0.2626581, -0.262…
## $ track_Moo.Moo.Farm <dbl> -0.1850929, -0.1850929, -0.1850929, -0.185…
## $ track_Rainbow.Road <dbl> -0.2970825, -0.2970825, -0.2970825, -0.297…
## $ track_Royal.Raceway <dbl> -0.2589109, -0.2589109, -0.2589109, -0.258…
## $ track_Sherbet.Land <dbl> -0.2448048, -0.2448048, -0.2448048, -0.244…
## $ track_Toad.s.Turnpike <dbl> -0.3038614, -0.3038614, -0.3038614, -0.303…
## $ track_Wario.Stadium <dbl> -0.316038, -0.316038, -0.316038, -0.316038…
## $ track_Yoshi.Valley <dbl> -0.2736709, -0.2736709, -0.2736709, -0.273…
## $ type_Three.Lap <dbl> 0.9726595, 0.9726595, 0.9726595, 0.9726595…
## $ player_Booth <dbl> -0.2395183, -0.2395183, -0.2395183, -0.239…
## $ player_Dan <dbl> -0.3072134, -0.3072134, -0.3072134, -0.307…
## $ player_Jonathan <dbl> -0.1469258, -0.1469258, -0.1469258, -0.146…
## $ player_Karlo <dbl> -0.1448844, -0.1448844, -0.1448844, -0.144…
## $ player_Lacey <dbl> -0.1867369, -0.1867369, -0.1867369, -0.186…
## $ player_Launspach <dbl> -0.1660801, -0.1660801, -0.1660801, 6.0177…
## $ player_MJ <dbl> -0.3105416, -0.3105416, -0.3105416, -0.310…
## $ player_MR <dbl> -0.4241905, -0.4241905, -0.4241905, -0.424…
## $ player_Penev <dbl> -0.4279337, -0.4279337, -0.4279337, -0.427…
## $ player_Peter.E <dbl> -0.1963508, -0.1963508, -0.1963508, -0.196…
## $ player_Sami <dbl> -0.1605509, -0.1605509, -0.1605509, -0.160…
## $ player_VAJ <dbl> -0.1915956, -0.1915956, -0.1915956, -0.191…
## $ player_Zwartjes <dbl> -0.1899879, -0.1899879, -0.1899879, -0.189…
## $ player_other <dbl> 2.180335, 2.180335, 2.180335, -0.458383, -…
## $ system_played_PAL <dbl> -1.5333448, -1.5333448, -1.5333448, -1.533…
## $ shortcut <fct> No, No, No, No, No, No, No, No, No, No, No…
xgboost_spec <-
boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(),
learn_rate = tune(), loss_reduction = tune(), sample_size = tune()) %>%
set_mode("classification") %>%
set_engine("xgboost")
xgboost_workflow <-
workflow() %>%
add_recipe(xgboost_rec) %>%
add_model(xgboost_spec)
doParallel::registerDoParallel()
set.seed(65743)
xgboost_tune <-
tune_grid(xgboost_workflow,
resamples = data_cv,
grid = 5,
control = control_grid(save_pred = TRUE))
collect_metrics(xgboost_tune)
## # A tibble: 15 × 12
## trees min_n tree_depth learn_rate loss_reduction sample_size .metric
## <int> <int> <int> <dbl> <dbl> <dbl> <chr>
## 1 1 30 15 0.0750 0.0422 0.625 accuracy
## 2 1 30 15 0.0750 0.0422 0.625 brier_class
## 3 1 30 15 0.0750 0.0422 0.625 roc_auc
## 4 500 21 1 0.316 0.0000562 1 accuracy
## 5 500 21 1 0.316 0.0000562 1 brier_class
## 6 500 21 1 0.316 0.0000562 1 roc_auc
## 7 1000 2 8 0.0178 0.0000000001 0.5 accuracy
## 8 1000 2 8 0.0178 0.0000000001 0.5 brier_class
## 9 1000 2 8 0.0178 0.0000000001 0.5 roc_auc
## 10 1500 11 4 0.001 31.6 0.75 accuracy
## 11 1500 11 4 0.001 31.6 0.75 brier_class
## 12 1500 11 4 0.001 31.6 0.75 roc_auc
## 13 2000 40 11 0.00422 0.0000000750 0.875 accuracy
## 14 2000 40 11 0.00422 0.0000000750 0.875 brier_class
## 15 2000 40 11 0.00422 0.0000000750 0.875 roc_auc
## # ℹ 5 more variables: .estimator <chr>, mean <dbl>, n <int>, std_err <dbl>,
## # .config <chr>
collect_predictions(xgboost_tune) %>%
group_by(id) %>%
roc_curve(shortcut, .pred_Yes) %>%
autoplot() +
labs(title = "ROC Curve: Shortcut Classification")
xgboost_last <- xgboost_workflow %>%
finalize_workflow(select_best(xgboost_tune, metric = "roc_auc")) %>%
last_fit(data_split)
collect_metrics(xgboost_last)
## # A tibble: 3 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.731 pre0_mod0_post0
## 2 roc_auc binary 0.834 pre0_mod0_post0
## 3 brier_class binary 0.157 pre0_mod0_post0
collect_predictions(xgboost_last) %>%
yardstick::conf_mat(shortcut, .pred_class) %>%
autoplot(type = "heatmap") +
labs(title = "Confusion Matrix: Test Set Predictions")
xgboost_last %>%
workflows::extract_fit_engine() %>%
vip(num_features = 15) +
labs(title = "Top 15 Most Important Features")
# Extract the fitted workflow
final_model <- xgboost_last %>% extract_workflow()
# Predict class and probabilities on test data
test_predictions <- final_model %>%
predict(data_test) %>%
bind_cols(
final_model %>% predict(data_test, type = "prob"),
data_test
)
# View predictions
test_predictions %>%
select(shortcut, .pred_class, .pred_Yes, .pred_No, track, type, time) %>%
head(20)
## # A tibble: 20 × 7
## shortcut .pred_class .pred_Yes .pred_No track type time
## <fct> <fct> <dbl> <dbl> <fct> <fct> <dbl>
## 1 No No 0.0427 0.957 Luigi Raceway Three Lap 133.
## 2 No No 0.101 0.899 Luigi Raceway Three Lap 130.
## 3 No No 0.169 0.831 Luigi Raceway Three Lap 125.
## 4 No No 0.157 0.843 Luigi Raceway Three Lap 123.
## 5 No No 0.172 0.828 Luigi Raceway Three Lap 121.
## 6 No No 0.158 0.842 Luigi Raceway Three Lap 120.
## 7 No No 0.101 0.899 Luigi Raceway Three Lap 120.
## 8 No No 0.104 0.896 Luigi Raceway Three Lap 120.
## 9 No No 0.135 0.865 Luigi Raceway Three Lap 120.
## 10 No No 0.118 0.882 Luigi Raceway Three Lap 120.
## 11 No No 0.103 0.897 Luigi Raceway Three Lap 120.
## 12 No No 0.116 0.884 Luigi Raceway Three Lap 119.
## 13 No No 0.122 0.878 Luigi Raceway Three Lap 119.
## 14 No No 0.114 0.886 Luigi Raceway Three Lap 119.
## 15 No No 0.0899 0.910 Luigi Raceway Three Lap 119.
## 16 No No 0.108 0.892 Luigi Raceway Three Lap 118.
## 17 No No 0.0989 0.901 Luigi Raceway Three Lap 118.
## 18 No No 0.0782 0.922 Luigi Raceway Three Lap 118.
## 19 No No 0.0780 0.922 Luigi Raceway Three Lap 118.
## 20 No No 0.103 0.897 Luigi Raceway Three Lap 118.
# Summary of predictions vs actuals
test_predictions %>%
count(shortcut, .pred_class) %>%
mutate(correct = shortcut == .pred_class)
## # A tibble: 4 × 4
## shortcut .pred_class n correct
## <fct> <fct> <int> <lgl>
## 1 No No 270 TRUE
## 2 No Yes 92 FALSE
## 3 Yes No 65 FALSE
## 4 Yes Yes 157 TRUE