library(tidyverse)
library(correlationfunnel)
library(skimr)
library(tidymodels)
library(themis)
library(xgboost)
library(vip)
records <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/main/data/2021/2021-05-25/records.csv')
skimr::skim(records)
| Name | records |
| Number of rows | 2334 |
| Number of columns | 9 |
| _______________________ | |
| Column type frequency: | |
| character | 6 |
| Date | 1 |
| numeric | 2 |
| ________________________ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| track | 0 | 1 | 12 | 21 | 0 | 16 | 0 |
| type | 0 | 1 | 9 | 10 | 0 | 2 | 0 |
| shortcut | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
| player | 0 | 1 | 2 | 10 | 0 | 65 | 0 |
| system_played | 0 | 1 | 3 | 4 | 0 | 2 | 0 |
| time_period | 0 | 1 | 3 | 9 | 0 | 1577 | 0 |
Variable type: Date
| skim_variable | n_missing | complete_rate | min | max | median | n_unique |
|---|---|---|---|---|---|---|
| date | 0 | 1 | 1997-02-15 | 2021-02-25 | 2004-06-18 | 1096 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| time | 0 | 1 | 90.62 | 66.67 | 14.59 | 39.03 | 86.19 | 120.16 | 375.83 | ▇▆▁▁▁ |
| record_duration | 0 | 1 | 220.75 | 429.08 | 0.00 | 6.00 | 51.00 | 198.75 | 3659.00 | ▇▁▁▁▁ |
data_clean <- records %>%
mutate(
shortcut = factor(shortcut, levels = c("Yes", "No")),
type = as.factor(type),
system_played = as.factor(system_played),
track = as.factor(track),
player = as.factor(player),
year = lubridate::year(date),
month = lubridate::month(date)
) %>%
select(-time_period)
data_clean %>% count(shortcut)
## # A tibble: 2 × 2
## shortcut n
## <fct> <int>
## 1 Yes 887
## 2 No 1447
Shortcut vs time
data_clean %>%
ggplot(aes(shortcut, time)) +
geom_boxplot() +
labs(title = "Record Time by Shortcut Use",
x = "Shortcut", y = "Time (seconds)")
Record duration by shortcut
data_clean %>%
ggplot(aes(shortcut, record_duration)) +
geom_boxplot() +
labs(title = "Record Duration by Shortcut Use",
x = "Shortcut", y = "Duration (days)")
data_binarized <- data_clean %>%
select(-player, -date) %>% # too many levels
binarize()
target_col <- names(data_binarized) %>%
str_subset("shortcut") %>%
tail(1)
data_correlation <- data_binarized %>%
correlate(target = !!sym(target_col))
data_correlation %>%
plot_correlation_funnel() +
labs(title = "Correlation Funnel: Shortcut vs Non-Shortcut Records")
set.seed(1234)
data_split <- initial_split(data_clean, strata = shortcut)
data_train <- training(data_split)
data_test <- testing(data_split)
data_cv <- rsample::vfold_cv(data_train, v = 5, strata = shortcut)
data_cv
## # 5-fold cross-validation using stratification
## # A tibble: 5 × 2
## splits id
## <list> <chr>
## 1 <split [1400/350]> Fold1
## 2 <split [1400/350]> Fold2
## 3 <split [1400/350]> Fold3
## 4 <split [1400/350]> Fold4
## 5 <split [1400/350]> Fold5
xgboost_rec <- recipe(shortcut ~ ., data = data_train) %>%
step_date(date, features = c("year", "month", "dow"), keep_original_cols = FALSE) %>% #new
step_other(player, track, threshold = 0.02) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_YeoJohnson(all_numeric_predictors()) %>%
step_normalize(all_numeric_predictors()) %>%
step_pca(all_numeric_predictors(), threshold = 0.99) %>%
step_smote(shortcut)
xgboost_rec %>% prep() %>% juice() %>% glimpse()
## Rows: 2,170
## Columns: 47
## $ PC01 <dbl> -2.1876834, -2.2632955, -1.7254933, -2.4023462, -3.1400292, -…
## $ PC02 <dbl> -1.1228353, -0.8316566, -0.4627191, -0.1444717, 0.8537208, 0.…
## $ PC03 <dbl> 0.99070379, 1.50915458, 1.93730500, 1.23434359, 0.81721464, 0…
## $ PC04 <dbl> -0.85890483, -0.17524868, -0.12625116, -3.32966799, -7.313321…
## $ PC05 <dbl> -0.9582896, -1.5674133, -0.3093152, -0.1254826, -0.6817842, -…
## $ PC06 <dbl> 0.40670063, -0.05305086, 2.43873039, 1.57982381, -0.16291975,…
## $ PC07 <dbl> 0.04461873, -0.42909521, 0.26899858, 0.45666508, 0.04405820, …
## $ PC08 <dbl> -1.31164702, -1.13083206, -0.25865428, 1.49237857, -0.2140163…
## $ PC09 <dbl> -2.22063106, -2.24101416, -1.48067856, -1.10749287, -0.537362…
## $ PC10 <dbl> 0.74173212, 0.83119916, 1.78863514, 1.86710250, 1.22879117, 1…
## $ PC11 <dbl> -0.05408829, -0.65264261, 1.71480938, 1.74075925, -0.45536791…
## $ PC12 <dbl> 0.9362634, 1.0227750, 0.3441453, -0.1630744, -0.1022205, -0.2…
## $ PC13 <dbl> -2.18227093, -1.55127724, -1.27081629, 0.11387525, 0.66772981…
## $ PC14 <dbl> 1.10745431, 0.83516525, 0.21361742, -1.16512655, -0.94573513,…
## $ PC15 <dbl> 1.7392548, 1.5396184, 1.6394793, 1.1635840, 0.3097178, 0.3831…
## $ PC16 <dbl> -0.53502820, -0.58919581, 0.58116252, 0.23124119, 0.14635628,…
## $ PC17 <dbl> 0.13985762, 0.22464177, -0.32780242, -1.02288695, -0.51111810…
## $ PC18 <dbl> -0.75035981, -0.05902456, 0.59917932, 0.80777852, 0.61419764,…
## $ PC19 <dbl> -0.308926022, 0.054154753, -0.199837598, 0.001941767, 0.03711…
## $ PC20 <dbl> -0.357770562, -0.791436129, -0.292344520, 1.037868351, 0.5789…
## $ PC21 <dbl> 0.28657931, 0.52561099, -0.28787006, -0.71368816, 0.02296620,…
## $ PC22 <dbl> -0.38360126, -0.47325487, -0.86872933, -0.70097201, -0.866665…
## $ PC23 <dbl> -0.150292282, 0.162285698, 0.999151910, 0.660060137, 0.036811…
## $ PC24 <dbl> -2.6688950, -3.1269659, -2.5405070, -1.6981455, -1.2039885, -…
## $ PC25 <dbl> 0.12337295, 0.25980946, -0.47355352, -0.74930346, 0.16690881,…
## $ PC26 <dbl> 0.65915441, 0.73384086, -0.35464337, -0.73070971, -1.01481807…
## $ PC27 <dbl> -1.21169070, -1.11611755, 0.62668612, 0.13188801, -0.61470754…
## $ PC28 <dbl> 0.78833464, 1.07412913, 0.01453208, -0.19953186, -0.33355949,…
## $ PC29 <dbl> -0.13199778, 0.35005005, -0.90130412, -0.33798561, -0.4345172…
## $ PC30 <dbl> -0.03344156, -0.13793595, -0.21861782, -0.78335085, -0.508913…
## $ PC31 <dbl> 0.425832831, 0.714616792, -0.262568295, 0.711130667, 0.577108…
## $ PC32 <dbl> 0.019385678, -0.365893347, 0.456966532, 0.704027121, 0.003800…
## $ PC33 <dbl> -0.32431889, -1.03930151, -0.55551642, -0.14636430, 0.0808248…
## $ PC34 <dbl> 1.1330323, 1.8126224, 1.8361464, 2.6566832, 2.2696875, 2.0933…
## $ PC35 <dbl> -0.66776630, -0.55350719, 1.70570911, 1.56335007, 0.31922601,…
## $ PC36 <dbl> 0.08872137, 0.16330051, -0.05006357, 0.32931915, -0.13038917,…
## $ PC37 <dbl> -0.2188337, 0.3219134, -2.1732515, -1.6767641, -0.6626660, -0…
## $ PC38 <dbl> 0.42895341, 0.77459780, 0.38465808, 1.83091067, 1.10773444, 1…
## $ PC39 <dbl> 0.67887718, 0.18281037, 1.51702982, 1.63376954, 0.47640436, 0…
## $ PC40 <dbl> 0.31017534, 0.63385631, 0.27044521, 1.00438731, 0.43375564, 0…
## $ PC41 <dbl> 0.16652852, 0.32076690, -0.84838705, -0.41907785, 0.68460201,…
## $ PC42 <dbl> 0.6194183396, 0.7425598274, 0.1724111750, 0.4924469961, -1.20…
## $ PC43 <dbl> 0.22715094, 0.08053056, -0.31173955, 0.88147316, 0.84586574, …
## $ PC44 <dbl> -0.23344982, -0.20963089, 0.10379835, 4.02650095, 0.84915192,…
## $ PC45 <dbl> 0.11290837, 0.17186255, 0.27639909, -0.31403043, -0.01806689,…
## $ PC46 <dbl> -0.047702204, 0.579215235, 0.437543028, -0.226492145, -0.0062…
## $ shortcut <fct> No, No, No, No, No, No, No, No, No, No, No, No, No, No, No, N…
xgboost_spec <-
boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(),
learn_rate = tune(), loss_reduction = tune(), sample_size = tune()) %>%
set_mode("classification") %>%
set_engine("xgboost")
xgboost_workflow <-
workflow() %>%
add_recipe(xgboost_rec) %>%
add_model(xgboost_spec)
doParallel::registerDoParallel()
set.seed(65743)
# cant use tune grid because it throws errors
xgboost_tune <-
tune_grid(xgboost_workflow,
resamples = data_cv,
grid = 5,
control = control_grid(save_pred = TRUE))
collect_metrics(xgboost_tune)
## # A tibble: 15 × 12
## trees min_n tree_depth learn_rate loss_reduction sample_size .metric
## <int> <int> <int> <dbl> <dbl> <dbl> <chr>
## 1 1 30 15 0.0750 0.0422 0.625 accuracy
## 2 1 30 15 0.0750 0.0422 0.625 brier_class
## 3 1 30 15 0.0750 0.0422 0.625 roc_auc
## 4 500 21 1 0.316 0.0000562 1 accuracy
## 5 500 21 1 0.316 0.0000562 1 brier_class
## 6 500 21 1 0.316 0.0000562 1 roc_auc
## 7 1000 2 8 0.0178 0.0000000001 0.5 accuracy
## 8 1000 2 8 0.0178 0.0000000001 0.5 brier_class
## 9 1000 2 8 0.0178 0.0000000001 0.5 roc_auc
## 10 1500 11 4 0.001 31.6 0.75 accuracy
## 11 1500 11 4 0.001 31.6 0.75 brier_class
## 12 1500 11 4 0.001 31.6 0.75 roc_auc
## 13 2000 40 11 0.00422 0.0000000750 0.875 accuracy
## 14 2000 40 11 0.00422 0.0000000750 0.875 brier_class
## 15 2000 40 11 0.00422 0.0000000750 0.875 roc_auc
## # ℹ 5 more variables: .estimator <chr>, mean <dbl>, n <int>, std_err <dbl>,
## # .config <chr>
collect_predictions(xgboost_tune) %>%
group_by(id) %>%
roc_curve(shortcut, .pred_Yes) %>%
autoplot() +
labs(title = "ROC Curve: Shortcut Classification")
xgboost_last <- xgboost_workflow %>%
finalize_workflow(select_best(xgboost_tune, metric = "roc_auc")) %>%
last_fit(data_split)
collect_metrics(xgboost_last)
## # A tibble: 3 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.659 pre0_mod0_post0
## 2 roc_auc binary 0.753 pre0_mod0_post0
## 3 brier_class binary 0.201 pre0_mod0_post0
collect_predictions(xgboost_last) %>%
yardstick::conf_mat(shortcut, .pred_class) %>%
autoplot(type = "heatmap") +
labs(title = "Confusion Matrix: Test Set Predictions")
xgboost_last %>%
workflows::extract_fit_engine() %>%
vip(num_features = 15) +
labs(title = "Top 15 Most Important Features")
# Extract the fitted workflow
final_model <- xgboost_last %>% extract_workflow()
# Predict class and probabilities on test data
test_predictions <- final_model %>%
predict(data_test) %>%
bind_cols(
final_model %>% predict(data_test, type = "prob"),
data_test
)
# View predictions
test_predictions %>%
select(shortcut, .pred_class, .pred_Yes, .pred_No, track, type, time) %>%
head(20)
## # A tibble: 20 × 7
## shortcut .pred_class .pred_Yes .pred_No track type time
## <fct> <fct> <dbl> <dbl> <fct> <fct> <dbl>
## 1 No No 0.369 0.631 Luigi Raceway Three Lap 133.
## 2 No No 0.433 0.567 Luigi Raceway Three Lap 130.
## 3 No No 0.344 0.656 Luigi Raceway Three Lap 125.
## 4 No No 0.214 0.786 Luigi Raceway Three Lap 123.
## 5 No No 0.324 0.676 Luigi Raceway Three Lap 121.
## 6 No No 0.389 0.611 Luigi Raceway Three Lap 120.
## 7 No Yes 0.503 0.497 Luigi Raceway Three Lap 120.
## 8 No No 0.297 0.703 Luigi Raceway Three Lap 120.
## 9 No No 0.473 0.527 Luigi Raceway Three Lap 120.
## 10 No No 0.337 0.663 Luigi Raceway Three Lap 120.
## 11 No No 0.198 0.802 Luigi Raceway Three Lap 120.
## 12 No No 0.347 0.653 Luigi Raceway Three Lap 119.
## 13 No No 0.260 0.740 Luigi Raceway Three Lap 119.
## 14 No No 0.0704 0.930 Luigi Raceway Three Lap 119.
## 15 No No 0.0846 0.915 Luigi Raceway Three Lap 119.
## 16 No No 0.279 0.721 Luigi Raceway Three Lap 118.
## 17 No No 0.244 0.756 Luigi Raceway Three Lap 118.
## 18 No No 0.288 0.712 Luigi Raceway Three Lap 118.
## 19 No No 0.303 0.697 Luigi Raceway Three Lap 118.
## 20 No No 0.0476 0.952 Luigi Raceway Three Lap 118.
# Summary of predictions vs actuals
test_predictions %>%
count(shortcut, .pred_class) %>%
mutate(correct = shortcut == .pred_class)
## # A tibble: 4 × 4
## shortcut .pred_class n correct
## <fct> <fct> <int> <lgl>
## 1 Yes Yes 141 TRUE
## 2 Yes No 81 FALSE
## 3 No Yes 118 FALSE
## 4 No No 244 TRUE
The previous Apply 7 model had an accuracy of 0.731 and an AUC of 0.83.
Feature transformation: Added Yeo-Johnson transformation, normalization, and PCA to the numeric predictors. Algorithm tuning: Implemented a tuning grid evaluating tree depth and number of trees.
These alterations resulted in 0.659 & 0.753 decreased