Goal is to build a classification model to predict whether the person died.
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.1 ✔ tibble 3.2.1
## ✔ lubridate 1.9.3 ✔ tidyr 1.3.1
## ✔ purrr 1.0.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(readr)
library(correlationfunnel)
## ══ Using correlationfunnel? ════════════════════════════════════════════════════
## You might also be interested in applied data science training for business.
## </> Learn more at - www.business-science.io </>
data1 <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/main/data/2020/2020-09-22/members.csv')
## Rows: 76519 Columns: 21
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (10): expedition_id, member_id, peak_id, peak_name, season, sex, citizen...
## dbl (5): year, age, highpoint_metres, death_height_metres, injury_height_me...
## lgl (6): hired, success, solo, oxygen_used, died, injured
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skimr::skim(data1)
Name | data1 |
Number of rows | 76519 |
Number of columns | 21 |
_______________________ | |
Column type frequency: | |
character | 10 |
logical | 6 |
numeric | 5 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
expedition_id | 0 | 1.00 | 9 | 9 | 0 | 10350 | 0 |
member_id | 0 | 1.00 | 12 | 12 | 0 | 76518 | 0 |
peak_id | 0 | 1.00 | 4 | 4 | 0 | 391 | 0 |
peak_name | 15 | 1.00 | 4 | 25 | 0 | 390 | 0 |
season | 0 | 1.00 | 6 | 7 | 0 | 5 | 0 |
sex | 2 | 1.00 | 1 | 1 | 0 | 2 | 0 |
citizenship | 10 | 1.00 | 2 | 23 | 0 | 212 | 0 |
expedition_role | 21 | 1.00 | 4 | 25 | 0 | 524 | 0 |
death_cause | 75413 | 0.01 | 3 | 27 | 0 | 12 | 0 |
injury_type | 74807 | 0.02 | 3 | 27 | 0 | 11 | 0 |
Variable type: logical
skim_variable | n_missing | complete_rate | mean | count |
---|---|---|---|---|
hired | 0 | 1 | 0.21 | FAL: 60788, TRU: 15731 |
success | 0 | 1 | 0.38 | FAL: 47320, TRU: 29199 |
solo | 0 | 1 | 0.00 | FAL: 76398, TRU: 121 |
oxygen_used | 0 | 1 | 0.24 | FAL: 58286, TRU: 18233 |
died | 0 | 1 | 0.01 | FAL: 75413, TRU: 1106 |
injured | 0 | 1 | 0.02 | FAL: 74806, TRU: 1713 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
year | 0 | 1.00 | 2000.36 | 14.78 | 1905 | 1991 | 2004 | 2012 | 2019 | ▁▁▁▃▇ |
age | 3497 | 0.95 | 37.33 | 10.40 | 7 | 29 | 36 | 44 | 85 | ▁▇▅▁▁ |
highpoint_metres | 21833 | 0.71 | 7470.68 | 1040.06 | 3800 | 6700 | 7400 | 8400 | 8850 | ▁▁▆▃▇ |
death_height_metres | 75451 | 0.01 | 6592.85 | 1308.19 | 400 | 5800 | 6600 | 7550 | 8830 | ▁▁▂▇▆ |
injury_height_metres | 75510 | 0.01 | 7049.91 | 1214.24 | 400 | 6200 | 7100 | 8000 | 8880 | ▁▁▂▇▇ |
Issues with data: - Missing Values: peak_name, sex, citizenship, expedition_role, death_cause, injury_type, age, highpoint_metres, death_height_metres, injury_height_metres - Factors or numeric variables: none - Zero variance values: none - Character variables: Convert them to numbers in recipe step (step_dummy) _ Unbalanced Target variable: died _ ID Variable: member_id
# drop columns with missing values
data_clean1 <- data1 %>% select(-c(death_cause, injury_type, highpoint_metres, death_height_metres, injury_height_metres)) %>%
# remove redundant variable
select(-peak_id) %>%
# remove rows with NA
na.omit() %>%
# remove duplicates in member_id
distinct(member_id, .keep_all = TRUE) %>%
# convert character data to factors
mutate(across(where(is.character), as.factor)) %>%
# convert logical data to factors
mutate(across(where(is.logical), as.factor)) %>%
# reorder levels in died
mutate(died = case_when(died == "TRUE" ~ "died", died == "FALSE" ~ "no"))
data_clean1 %>% count(died)
## # A tibble: 2 × 2
## died n
## <chr> <int>
## 1 died 929
## 2 no 72055
data_clean1 %>%
ggplot(aes(died)) +
geom_bar()
died vs age
data_clean1 %>%
ggplot(aes(died, age)) +
geom_boxplot()
died vs year
data_clean1 %>%
ggplot(aes(died, year)) +
geom_boxplot()
correlation plot
# step 1 binarize
data_binarized1 <- data_clean1 %>%
select(-member_id) %>%
binarize()
data_binarized1 %>% glimpse()
## Rows: 72,984
## Columns: 69
## $ expedition_id__EVER88101 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `expedition_id__-OTHER` <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ peak_name__Ama_Dablam <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ peak_name__Annapurna_I <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Annapurna_IV <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Baruntse <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Cho_Oyu <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Dhaulagiri_I <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Everest <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Himlung_Himal <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Kangchenjunga <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Lhotse <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Makalu <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Manaslu <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Pumori <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `peak_name__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `year__-Inf_1992` <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ year__1992_2004 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ year__2004_2012 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ year__2012_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ season__Autumn <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ season__Spring <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, …
## $ season__Winter <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `season__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ sex__F <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ sex__M <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `age__-Inf_29` <dbl> 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, …
## $ age__29_36 <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ age__36_44 <dbl> 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, …
## $ age__44_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Australia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Austria <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Canada <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__China <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__France <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ citizenship__Germany <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__India <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Italy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Japan <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Nepal <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Netherlands <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__New_Zealand <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Poland <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Russia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__S_Korea <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Spain <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Switzerland <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__UK <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__USA <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, …
## $ citizenship__W_Germany <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ `citizenship__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ expedition_role__Climber <dbl> 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, …
## $ expedition_role__Deputy_Leader <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ expedition_role__Exp_Doctor <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `expedition_role__H-A_Worker` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ expedition_role__Leader <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `expedition_role__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ hired__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ hired__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ success__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, …
## $ success__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, …
## $ solo__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `solo__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ oxygen_used__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ oxygen_used__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ died__died <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ died__no <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ injured__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ injured__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
# step 2 correlate
data_correlation1 <- data_binarized1 %>%
correlate(died__died)
## Warning: correlate(): [Data Imbalance Detected] Consider sampling to balance the classes more than 5%
## Column with imbalance: died__died
data_correlation1
## # A tibble: 69 × 3
## feature bin correlation
## <fct> <chr> <dbl>
## 1 died died 1
## 2 died no -1
## 3 year -Inf_1992 0.0519
## 4 peak_name Annapurna_I 0.0336
## 5 success FALSE 0.0332
## 6 success TRUE -0.0332
## 7 peak_name Dhaulagiri_I 0.0290
## 8 peak_name Ama_Dablam -0.0281
## 9 peak_name Cho_Oyu -0.0241
## 10 year 2004_2012 -0.0211
## # ℹ 59 more rows
#step 3 plot
data_correlation1 %>%
correlationfunnel::plot_correlation_funnel()
## Warning: ggrepel: 41 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom 1.0.6 ✔ rsample 1.2.1
## ✔ dials 1.3.0 ✔ tune 1.2.1
## ✔ infer 1.0.7 ✔ workflows 1.1.4
## ✔ modeldata 1.4.0 ✔ workflowsets 1.1.0
## ✔ parsnip 1.2.1 ✔ yardstick 1.3.2
## ✔ recipes 1.1.0
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Search for functions across packages at https://www.tidymodels.org/find/
set.seed(1236)
# data_clean1 <- data_clean1 %>%
# group_by(died) %>%
# sample_n(100)
data_split1 <- initial_split(data_clean1, strata = died)
data_train1 <- training(data_split1)
data_test1 <- testing(data_split1)
data_cv1 <- rsample::vfold_cv(data_train1, strata = died)
data_cv1
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [49264/5474]> Fold01
## 2 <split [49264/5474]> Fold02
## 3 <split [49264/5474]> Fold03
## 4 <split [49264/5474]> Fold04
## 5 <split [49264/5474]> Fold05
## 6 <split [49264/5474]> Fold06
## 7 <split [49264/5474]> Fold07
## 8 <split [49264/5474]> Fold08
## 9 <split [49265/5473]> Fold09
## 10 <split [49265/5473]> Fold10
library(themis)
xgboost_rec1 <- recipes::recipe(died ~ ., data = data_train1) %>%
update_role(member_id, new_role = "ID") %>%
step_other(expedition_id, peak_name, expedition_role, citizenship, threshold = .02) %>%
step_dummy(all_nominal_predictors()) %>%
step_smote(died)
xgboost_rec1 %>% prep() %>% juice() %>% glimpse()
## Rows: 108,094
## Columns: 40
## $ member_id <fct> AMAD78301-01, AMAD78301-02, AMAD78301-03, A…
## $ year <dbl> 1978, 1978, 1978, 1978, 1978, 1978, 1979, 1…
## $ age <dbl> 40, 41, 27, 40, 34, 29, 35, 37, 44, 28, 32,…
## $ died <fct> no, no, no, no, no, no, no, no, no, no, no,…
## $ expedition_id_other <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
## $ peak_name_Annapurna.I <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ peak_name_Baruntse <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ peak_name_Cho.Oyu <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ peak_name_Dhaulagiri.I <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ peak_name_Everest <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ peak_name_Lhotse <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ peak_name_Makalu <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ peak_name_Manaslu <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ peak_name_Pumori <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ peak_name_other <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ season_Spring <dbl> 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1…
## $ season_Summer <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ season_Winter <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ sex_M <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
## $ citizenship_China <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ citizenship_France <dbl> 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ citizenship_Germany <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ citizenship_India <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ citizenship_Italy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ citizenship_Japan <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ citizenship_Nepal <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ citizenship_S.Korea <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ citizenship_Spain <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ citizenship_Switzerland <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ citizenship_UK <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ citizenship_USA <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1…
## $ citizenship_other <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0…
## $ expedition_role_H.A.Worker <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ expedition_role_Leader <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ expedition_role_other <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1…
## $ hired_TRUE. <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ success_TRUE. <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1…
## $ solo_TRUE. <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ oxygen_used_TRUE. <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ injured_TRUE. <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
xgboost_spec <-
boost_tree(trees = tune()) %>%
set_mode("classification") %>%
set_engine("xgboost")
xgboost_workflow <-
workflow() %>%
add_recipe(xgboost_rec1) %>%
add_model(xgboost_spec)
doParallel::registerDoParallel()
set.seed(10891)
xgboost_tune <-
tune_grid(xgboost_workflow,
resamples = data_cv1,
grid = 5,
control = control_grid(save_pred = TRUE))
collect_metrics(xgboost_tune)
## # A tibble: 15 × 7
## trees .metric .estimator mean n std_err .config
## <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 253 accuracy binary 0.987 10 0.000477 Preprocessor1_Model1
## 2 253 brier_class binary 0.0126 10 0.000436 Preprocessor1_Model1
## 3 253 roc_auc binary 0.753 10 0.00692 Preprocessor1_Model1
## 4 604 accuracy binary 0.986 10 0.000574 Preprocessor1_Model2
## 5 604 brier_class binary 0.0133 10 0.000470 Preprocessor1_Model2
## 6 604 roc_auc binary 0.731 10 0.00805 Preprocessor1_Model2
## 7 1191 accuracy binary 0.985 10 0.000634 Preprocessor1_Model3
## 8 1191 brier_class binary 0.0140 10 0.000495 Preprocessor1_Model3
## 9 1191 roc_auc binary 0.721 10 0.00830 Preprocessor1_Model3
## 10 1522 accuracy binary 0.984 10 0.000629 Preprocessor1_Model4
## 11 1522 brier_class binary 0.0143 10 0.000500 Preprocessor1_Model4
## 12 1522 roc_auc binary 0.718 10 0.00840 Preprocessor1_Model4
## 13 1734 accuracy binary 0.984 10 0.000611 Preprocessor1_Model5
## 14 1734 brier_class binary 0.0144 10 0.000498 Preprocessor1_Model5
## 15 1734 roc_auc binary 0.718 10 0.00851 Preprocessor1_Model5
collect_predictions(xgboost_tune) %>%
group_by(id) %>%
roc_curve(died, .pred_died) %>%
autoplot()
xgboost_last <- xgboost_workflow %>%
finalize_workflow(select_best(xgboost_tune, metric = "accuracy")) %>%
last_fit(data_split1)
collect_metrics(xgboost_last)
## # A tibble: 3 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.986 Preprocessor1_Model1
## 2 roc_auc binary 0.747 Preprocessor1_Model1
## 3 brier_class binary 0.0129 Preprocessor1_Model1
collect_predictions(xgboost_last) %>%
yardstick::conf_mat(died, .pred_class) %>%
autoplot()
library(vip)
##
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
##
## vi
xgboost_last %>%
workflows::extract_fit_engine() %>%
vip()
The previous model had an accuracy of .986 and AUC of .747.
Feature transformation: normalized numeric data. It resulted in no change in accuracy and a very slight decrease in AUC (.743).
Feature transformation: YeoJohnson transformation. It resulted in no change in accuracy or AUC.
Feature Selection: PCA didn’t make an improvement
Added grid function to tune hyperparameters. It resulted in the same accuracy but a lower AUC of .741.