knitr::opts_chunk$set(echo = TRUE)
library(tidyverse)
## Warning: package 'ggplot2' was built under R version 4.3.2
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.3 ✔ readr 2.1.4
## ✔ forcats 1.0.0 ✔ stringr 1.5.0
## ✔ ggplot2 3.5.0 ✔ tibble 3.2.1
## ✔ lubridate 1.9.2 ✔ tidyr 1.3.0
## ✔ purrr 1.0.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidytext)
## Warning: package 'tidytext' was built under R version 4.3.2
library(correlationfunnel)
## Warning: package 'correlationfunnel' was built under R version 4.3.2
## ══ correlationfunnel Tip #1 ════════════════════════════════════════════════════
## Make sure your data is not overly imbalanced prior to using `correlate()`.
## If less than 5% imbalance, consider sampling. :)
library(textrecipes)
## Warning: package 'textrecipes' was built under R version 4.3.2
## Loading required package: recipes
##
## Attaching package: 'recipes'
##
## The following object is masked from 'package:stringr':
##
## fixed
##
## The following object is masked from 'package:stats':
##
## step
library(tidymodels)
## Warning: package 'tidymodels' was built under R version 4.3.2
## ── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
## ✔ broom 1.0.5 ✔ rsample 1.2.0
## ✔ dials 1.2.0 ✔ tune 1.1.2
## ✔ infer 1.0.6 ✔ workflows 1.1.3
## ✔ modeldata 1.3.0 ✔ workflowsets 1.0.1
## ✔ parsnip 1.1.1 ✔ yardstick 1.3.0
## Warning: package 'dials' was built under R version 4.3.2
## Warning: package 'scales' was built under R version 4.3.2
## Warning: package 'infer' was built under R version 4.3.2
## Warning: package 'modeldata' was built under R version 4.3.2
## Warning: package 'parsnip' was built under R version 4.3.2
## Warning: package 'tune' was built under R version 4.3.2
## Warning: package 'workflows' was built under R version 4.3.2
## Warning: package 'workflowsets' was built under R version 4.3.2
## Warning: package 'yardstick' was built under R version 4.3.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Use tidymodels_prefer() to resolve common conflicts.
library(xgboost)
## Warning: package 'xgboost' was built under R version 4.3.2
##
## Attaching package: 'xgboost'
##
## The following object is masked from 'package:dplyr':
##
## slice
library(ggplot2)
Goal is to predict attrition, employees who are likely to leave the company.
#Import data
members <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-09-22/members.csv')
## Rows: 76519 Columns: 21
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (10): expedition_id, member_id, peak_id, peak_name, season, sex, citizen...
## dbl (5): year, age, highpoint_metres, death_height_metres, injury_height_me...
## lgl (6): hired, success, solo, oxygen_used, died, injured
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skimr::skim(members)
| Name | members |
| Number of rows | 76519 |
| Number of columns | 21 |
| _______________________ | |
| Column type frequency: | |
| character | 10 |
| logical | 6 |
| numeric | 5 |
| ________________________ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| expedition_id | 0 | 1.00 | 9 | 9 | 0 | 10350 | 0 |
| member_id | 0 | 1.00 | 12 | 12 | 0 | 76518 | 0 |
| peak_id | 0 | 1.00 | 4 | 4 | 0 | 391 | 0 |
| peak_name | 15 | 1.00 | 4 | 25 | 0 | 390 | 0 |
| season | 0 | 1.00 | 6 | 7 | 0 | 5 | 0 |
| sex | 2 | 1.00 | 1 | 1 | 0 | 2 | 0 |
| citizenship | 10 | 1.00 | 2 | 23 | 0 | 212 | 0 |
| expedition_role | 21 | 1.00 | 4 | 25 | 0 | 524 | 0 |
| death_cause | 75413 | 0.01 | 3 | 27 | 0 | 12 | 0 |
| injury_type | 74807 | 0.02 | 3 | 27 | 0 | 11 | 0 |
Variable type: logical
| skim_variable | n_missing | complete_rate | mean | count |
|---|---|---|---|---|
| hired | 0 | 1 | 0.21 | FAL: 60788, TRU: 15731 |
| success | 0 | 1 | 0.38 | FAL: 47320, TRU: 29199 |
| solo | 0 | 1 | 0.00 | FAL: 76398, TRU: 121 |
| oxygen_used | 0 | 1 | 0.24 | FAL: 58286, TRU: 18233 |
| died | 0 | 1 | 0.01 | FAL: 75413, TRU: 1106 |
| injured | 0 | 1 | 0.02 | FAL: 74806, TRU: 1713 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| year | 0 | 1.00 | 2000.36 | 14.78 | 1905 | 1991 | 2004 | 2012 | 2019 | ▁▁▁▃▇ |
| age | 3497 | 0.95 | 37.33 | 10.40 | 7 | 29 | 36 | 44 | 85 | ▁▇▅▁▁ |
| highpoint_metres | 21833 | 0.71 | 7470.68 | 1040.06 | 3800 | 6700 | 7400 | 8400 | 8850 | ▁▁▆▃▇ |
| death_height_metres | 75451 | 0.01 | 6592.85 | 1308.19 | 400 | 5800 | 6600 | 7550 | 8830 | ▁▁▂▇▆ |
| injury_height_metres | 75510 | 0.01 | 7049.91 | 1214.24 | 400 | 6200 | 7100 | 8000 | 8880 | ▁▁▂▇▇ |
data_clean <- members %>%
# Treat missing values
select(-death_cause, -injury_type, -highpoint_metres, -death_height_metres, -injury_height_metres, -peak_id) %>%
na.omit() %>%
# Log Transform Variables with pos-skewed Distribution
mutate(across(where(is.logical), as.factor))
# Step 1: Prepare data
data_binarized_tbl <- data_clean %>%
select(-peak_name, -expedition_id) %>%
binarize()
data_binarized_tbl %>% glimpse()
## Rows: 72,985
## Columns: 55
## $ `member_id__KANG10101-01` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `member_id__-OTHER` <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `year__-Inf_1992` <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ year__1992_2004 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ year__2004_2012 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ year__2012_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ season__Autumn <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ season__Spring <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, …
## $ season__Winter <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `season__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ sex__F <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ sex__M <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `age__-Inf_29` <dbl> 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, …
## $ age__29_36 <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ age__36_44 <dbl> 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, …
## $ age__44_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Australia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Austria <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Canada <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__China <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__France <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ citizenship__Germany <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__India <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Italy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Japan <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Nepal <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Netherlands <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__New_Zealand <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Poland <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Russia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__S_Korea <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Spain <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Switzerland <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__UK <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__USA <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, …
## $ citizenship__W_Germany <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ `citizenship__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ expedition_role__Climber <dbl> 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, …
## $ expedition_role__Deputy_Leader <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ expedition_role__Exp_Doctor <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `expedition_role__H-A_Worker` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ expedition_role__Leader <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `expedition_role__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ hired__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ hired__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ success__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, …
## $ success__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, …
## $ solo__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `solo__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ oxygen_used__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ oxygen_used__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ died__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ died__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ injured__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ injured__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
# Step 2: Correlate
data_corr_tbl <- data_binarized_tbl %>%
correlate(died__TRUE)
## Warning: correlate(): [Data Imbalance Detected] Consider sampling to balance the classes more than 5%
## Column with imbalance: died__TRUE
data_corr_tbl
## # A tibble: 55 × 3
## feature bin correlation
## <fct> <chr> <dbl>
## 1 died FALSE -1
## 2 died TRUE 1
## 3 year -Inf_1992 0.0519
## 4 success FALSE 0.0332
## 5 success TRUE -0.0332
## 6 year 2004_2012 -0.0211
## 7 year 2012_Inf -0.0208
## 8 sex F -0.0168
## 9 sex M 0.0168
## 10 citizenship USA -0.0154
## # ℹ 45 more rows
# Step 3: Plot
data_corr_tbl %>%
plot_correlation_funnel()
## Warning: ggrepel: 27 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps
library(tidymodels)
set.seed(7186)
data_clean <- data_clean %>% group_by(died) %>% sample_n(50) %>% ungroup()
data_split <- initial_split(data_clean, strata = died)
data_train <- training(data_split)
data_test <- testing(data_split)
data_cv <- rsample::vfold_cv(data_train, strata = died)
data_cv
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [66/8]> Fold01
## 2 <split [66/8]> Fold02
## 3 <split [66/8]> Fold03
## 4 <split [66/8]> Fold04
## 5 <split [66/8]> Fold05
## 6 <split [66/8]> Fold06
## 7 <split [66/8]> Fold07
## 8 <split [68/6]> Fold08
## 9 <split [68/6]> Fold09
## 10 <split [68/6]> Fold10
#Preprocess Data
library(themis)
## Warning: package 'themis' was built under R version 4.3.3
xgboost_rec <- recipes::recipe(died ~ ., data = data_train) %>%
update_role(member_id, new_role = "ID") %>%
step_other(expedition_id, threshold = 0.1) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_numeric_predictors()) %>%
step_pca(all_numeric_predictors(), threshold = .75)
xgboost_rec %>% prep() %>% juice() %>% glimpse()
## Rows: 74
## Columns: 29
## $ member_id <fct> EVER03132-03, EVER19104-19, LHOT19118-01, AMAD79101-13, EVER…
## $ died <fct> FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALS…
## $ PC01 <dbl> -0.4844176, 3.8545913, 1.2898016, -0.9539217, 3.7413495, 0.1…
## $ PC02 <dbl> -1.06008342, -0.43784435, -1.16830405, -0.93907035, -0.46279…
## $ PC03 <dbl> -1.068950544, -0.218785713, 2.656601957, -0.131892610, -0.23…
## $ PC04 <dbl> -0.108661629, 0.217849261, 1.514956068, -0.683087093, 0.2187…
## $ PC05 <dbl> 0.7822710, 0.2234093, -0.2134190, -0.9883744, 0.1595970, -1.…
## $ PC06 <dbl> -1.74201534, -0.22088791, 0.14234296, -0.11199690, -0.157202…
## $ PC07 <dbl> 0.44227671, 0.29370377, 0.08602802, 0.12958151, 0.25078254, …
## $ PC08 <dbl> 0.57417745, 0.30518878, -0.19441605, 0.17364818, 0.29325385,…
## $ PC09 <dbl> -0.497897170, -0.177363446, 0.658362772, 0.121338919, -0.184…
## $ PC10 <dbl> -0.831899216, 0.226221226, 1.438653841, -0.269494973, 0.2817…
## $ PC11 <dbl> 0.431205175, 0.064675671, -1.443053546, -0.472443169, -0.018…
## $ PC12 <dbl> 0.71470782, -0.06794343, -0.46696363, 0.40595462, -0.1060530…
## $ PC13 <dbl> -0.18670097, -0.14648759, 1.30412919, 1.19669572, -0.2788706…
## $ PC14 <dbl> 0.002036603, 0.265949912, -0.371784569, 0.380997718, 0.10496…
## $ PC15 <dbl> 2.00302267, -0.28060510, 0.63035748, -0.27431195, -0.3058309…
## $ PC16 <dbl> 2.0184087, -0.3963151, -0.4961841, 0.4663700, -0.3626071, -0…
## $ PC17 <dbl> -0.40626968, 0.16374806, 0.37858105, 1.00347943, 0.14436802,…
## $ PC18 <dbl> 0.388037545, -0.111247651, -0.316026724, 0.255353425, -0.202…
## $ PC19 <dbl> 0.30802974, 0.04250331, -1.17038843, 0.43781767, 0.11368396,…
## $ PC20 <dbl> -0.82102883, 0.33720337, 0.72109824, -0.10940887, 0.19412460…
## $ PC21 <dbl> 0.83005214, -0.49187448, 0.40383070, -0.59797616, -0.5132374…
## $ PC22 <dbl> 1.251336619, -0.023967708, -0.159738938, 0.229268022, 0.0107…
## $ PC23 <dbl> -0.59632260, -0.07286510, -0.03904975, 0.11119527, -0.022888…
## $ PC24 <dbl> -1.94294456, -0.07985174, -0.33629278, -0.35346526, -0.05820…
## $ PC25 <dbl> -2.443647024, -0.064748761, -0.140813740, -0.033287090, -0.0…
## $ PC26 <dbl> -1.315648371, -0.012674957, 0.041019127, 0.152804624, 0.0014…
## $ PC27 <dbl> -1.4217756858, -0.0842914188, -0.0576309457, 0.0156340184, -…
#Specify Model
library(usemodels)
## Warning: package 'usemodels' was built under R version 4.3.2
usemodels::use_xgboost(died ~ ., data = data_train)
## xgboost_recipe <-
## recipe(formula = died ~ ., data = data_train) %>%
## step_zv(all_predictors())
##
## xgboost_spec <-
## boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(),
## loss_reduction = tune(), sample_size = tune()) %>%
## set_mode("classification") %>%
## set_engine("xgboost")
##
## xgboost_workflow <-
## workflow() %>%
## add_recipe(xgboost_recipe) %>%
## add_model(xgboost_spec)
##
## set.seed(86616)
## xgboost_tune <-
## tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
xgboost_spec <-
boost_tree(trees = tune()) %>% #, tree_depth = tune()) %>%
set_mode("classification") %>%
set_engine("xgboost")
xgboost_workflow <-
workflow() %>%
add_recipe(xgboost_rec) %>%
add_model(xgboost_spec)
#Tune Hyperparameters
doParallel::registerDoParallel()
set.seed(24817)
xgboost_tune <-
tune_grid(xgboost_workflow,
resamples = data_cv,
grid = 5,
control = control_grid(save_pred = TRUE))
#Evaluate
tune::show_best(xgboost_tune, metric = "roc_auc")
## # A tibble: 5 × 7
## trees .metric .estimator mean n std_err .config
## <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 232 roc_auc binary 0.776 10 0.0720 Preprocessor1_Model1
## 2 784 roc_auc binary 0.776 10 0.0720 Preprocessor1_Model2
## 3 1191 roc_auc binary 0.776 10 0.0720 Preprocessor1_Model3
## 4 1209 roc_auc binary 0.776 10 0.0720 Preprocessor1_Model4
## 5 1975 roc_auc binary 0.776 10 0.0720 Preprocessor1_Model5
# Update the model by selecting the best hyper-parameters
xgboost_fw <- tune::finalize_workflow(xgboost_workflow,
tune::select_best(xgboost_tune, metric = "roc_auc"))
# Fit the model on the entire training data and test it on the test data
data_fit <- tune::last_fit(xgboost_fw, data_split)
## → A | warning: There are new levels in a factor: Gurkarpo Ri, Nemjung, Tawoche, Annapurna South, Tashi Kang, Lhotse Shar, There are new levels in a factor: China, S Korea, Austria, There are new levels in a factor: BC Manager, Leader (Tibetan Staff)
##
There were issues with some computations A: x1
There were issues with some computations A: x1
tune::collect_metrics(data_fit)
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.5 Preprocessor1_Model1
## 2 roc_auc binary 0.503 Preprocessor1_Model1
tune::collect_predictions(data_fit) %>%
ggplot(aes(died, .pred_TRUE)) +
geom_point(alpha = 0.3, fill = "midnightblue") +
geom_abline(lty = 2, color = "gray50") +
coord_fixed()