knitr::opts_chunk$set(echo = TRUE)
library(tidyverse)
## Warning: package 'ggplot2' was built under R version 4.3.2
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.3 ✔ readr 2.1.4
## ✔ forcats 1.0.0 ✔ stringr 1.5.0
## ✔ ggplot2 3.5.0 ✔ tibble 3.2.1
## ✔ lubridate 1.9.2 ✔ tidyr 1.3.0
## ✔ purrr 1.0.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidytext)
## Warning: package 'tidytext' was built under R version 4.3.2
library(correlationfunnel)
## Warning: package 'correlationfunnel' was built under R version 4.3.2
## ══ Using correlationfunnel? ════════════════════════════════════════════════════
## You might also be interested in applied data science training for business.
## </> Learn more at - www.business-science.io </>
library(textrecipes)
## Warning: package 'textrecipes' was built under R version 4.3.2
## Loading required package: recipes
##
## Attaching package: 'recipes'
##
## The following object is masked from 'package:stringr':
##
## fixed
##
## The following object is masked from 'package:stats':
##
## step
library(tidymodels)
## Warning: package 'tidymodels' was built under R version 4.3.2
## ── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
## ✔ broom 1.0.5 ✔ rsample 1.2.0
## ✔ dials 1.2.0 ✔ tune 1.1.2
## ✔ infer 1.0.6 ✔ workflows 1.1.3
## ✔ modeldata 1.3.0 ✔ workflowsets 1.0.1
## ✔ parsnip 1.1.1 ✔ yardstick 1.3.0
## Warning: package 'dials' was built under R version 4.3.2
## Warning: package 'scales' was built under R version 4.3.2
## Warning: package 'infer' was built under R version 4.3.2
## Warning: package 'modeldata' was built under R version 4.3.2
## Warning: package 'parsnip' was built under R version 4.3.2
## Warning: package 'tune' was built under R version 4.3.2
## Warning: package 'workflows' was built under R version 4.3.2
## Warning: package 'workflowsets' was built under R version 4.3.2
## Warning: package 'yardstick' was built under R version 4.3.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Search for functions across packages at https://www.tidymodels.org/find/
library(xgboost)
## Warning: package 'xgboost' was built under R version 4.3.2
##
## Attaching package: 'xgboost'
##
## The following object is masked from 'package:dplyr':
##
## slice
library(ggplot2)
library(themis)
## Warning: package 'themis' was built under R version 4.3.3
library(usemodels)
## Warning: package 'usemodels' was built under R version 4.3.2
The research question is about whether a person died or not; a classification model can be build for the data set.
#Import data
members <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-09-22/members.csv')
## Rows: 76519 Columns: 21
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (10): expedition_id, member_id, peak_id, peak_name, season, sex, citizen...
## dbl (5): year, age, highpoint_metres, death_height_metres, injury_height_me...
## lgl (6): hired, success, solo, oxygen_used, died, injured
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
members
## # A tibble: 76,519 × 21
## expedition_id member_id peak_id peak_name year season sex age
## <chr> <chr> <chr> <chr> <dbl> <chr> <chr> <dbl>
## 1 AMAD78301 AMAD78301-01 AMAD Ama Dablam 1978 Autumn M 40
## 2 AMAD78301 AMAD78301-02 AMAD Ama Dablam 1978 Autumn M 41
## 3 AMAD78301 AMAD78301-03 AMAD Ama Dablam 1978 Autumn M 27
## 4 AMAD78301 AMAD78301-04 AMAD Ama Dablam 1978 Autumn M 40
## 5 AMAD78301 AMAD78301-05 AMAD Ama Dablam 1978 Autumn M 34
## 6 AMAD78301 AMAD78301-06 AMAD Ama Dablam 1978 Autumn M 25
## 7 AMAD78301 AMAD78301-07 AMAD Ama Dablam 1978 Autumn M 41
## 8 AMAD78301 AMAD78301-08 AMAD Ama Dablam 1978 Autumn M 29
## 9 AMAD79101 AMAD79101-03 AMAD Ama Dablam 1979 Spring M 35
## 10 AMAD79101 AMAD79101-04 AMAD Ama Dablam 1979 Spring M 37
## # ℹ 76,509 more rows
## # ℹ 13 more variables: citizenship <chr>, expedition_role <chr>, hired <lgl>,
## # highpoint_metres <dbl>, success <lgl>, solo <lgl>, oxygen_used <lgl>,
## # died <lgl>, death_cause <chr>, death_height_metres <dbl>, injured <lgl>,
## # injury_type <chr>, injury_height_metres <dbl>
skimr::skim(members)
| Name | members |
| Number of rows | 76519 |
| Number of columns | 21 |
| _______________________ | |
| Column type frequency: | |
| character | 10 |
| logical | 6 |
| numeric | 5 |
| ________________________ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| expedition_id | 0 | 1.00 | 9 | 9 | 0 | 10350 | 0 |
| member_id | 0 | 1.00 | 12 | 12 | 0 | 76518 | 0 |
| peak_id | 0 | 1.00 | 4 | 4 | 0 | 391 | 0 |
| peak_name | 15 | 1.00 | 4 | 25 | 0 | 390 | 0 |
| season | 0 | 1.00 | 6 | 7 | 0 | 5 | 0 |
| sex | 2 | 1.00 | 1 | 1 | 0 | 2 | 0 |
| citizenship | 10 | 1.00 | 2 | 23 | 0 | 212 | 0 |
| expedition_role | 21 | 1.00 | 4 | 25 | 0 | 524 | 0 |
| death_cause | 75413 | 0.01 | 3 | 27 | 0 | 12 | 0 |
| injury_type | 74807 | 0.02 | 3 | 27 | 0 | 11 | 0 |
Variable type: logical
| skim_variable | n_missing | complete_rate | mean | count |
|---|---|---|---|---|
| hired | 0 | 1 | 0.21 | FAL: 60788, TRU: 15731 |
| success | 0 | 1 | 0.38 | FAL: 47320, TRU: 29199 |
| solo | 0 | 1 | 0.00 | FAL: 76398, TRU: 121 |
| oxygen_used | 0 | 1 | 0.24 | FAL: 58286, TRU: 18233 |
| died | 0 | 1 | 0.01 | FAL: 75413, TRU: 1106 |
| injured | 0 | 1 | 0.02 | FAL: 74806, TRU: 1713 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| year | 0 | 1.00 | 2000.36 | 14.78 | 1905 | 1991 | 2004 | 2012 | 2019 | ▁▁▁▃▇ |
| age | 3497 | 0.95 | 37.33 | 10.40 | 7 | 29 | 36 | 44 | 85 | ▁▇▅▁▁ |
| highpoint_metres | 21833 | 0.71 | 7470.68 | 1040.06 | 3800 | 6700 | 7400 | 8400 | 8850 | ▁▁▆▃▇ |
| death_height_metres | 75451 | 0.01 | 6592.85 | 1308.19 | 400 | 5800 | 6600 | 7550 | 8830 | ▁▁▂▇▆ |
| injury_height_metres | 75510 | 0.01 | 7049.91 | 1214.24 | 400 | 6200 | 7100 | 8000 | 8880 | ▁▁▂▇▇ |
Issues with data * Missing values * Factors or numeric variables sex, highpoint_metres, citizenship, expedition_role, season, death_height_metres, injury_height_metres Zero variance variables * hired, success, solo, oxygen_used, injured Character variables: Convert them to numbers in the recipes steps Unbalanced target variables: died * ID variable: member_id
factors_vec <- members %>% select(-death_cause, -injury_type, -highpoint_metres, -death_height_metres, -injury_height_metres, -peak_id, -expedition_id) %>%
drop_na() %>%
distinct(member_id, .keep_all = TRUE) %>%
mutate(across(where(is.character), factor)) %>%
mutate(across(where(is.logical), factor)) %>%
mutate(member_id = as.character(member_id))
skimr::skim(factors_vec)
| Name | factors_vec |
| Number of rows | 72984 |
| Number of columns | 14 |
| _______________________ | |
| Column type frequency: | |
| character | 1 |
| factor | 11 |
| numeric | 2 |
| ________________________ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| member_id | 0 | 1 | 12 | 12 | 0 | 72984 | 0 |
Variable type: factor
| skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
|---|---|---|---|---|---|
| peak_name | 0 | 1 | FALSE | 390 | Eve: 20994, Cho: 8608, Ama: 8235, Man: 4510 |
| season | 0 | 1 | FALSE | 4 | Spr: 36150, Aut: 34186, Win: 2011, Sum: 637 |
| sex | 0 | 1 | FALSE | 2 | M: 66150, F: 6834 |
| citizenship | 0 | 1 | FALSE | 207 | Nep: 14367, USA: 6318, Jap: 6188, UK: 5071 |
| expedition_role | 0 | 1 | FALSE | 483 | Cli: 43315, H-A: 13033, Lea: 9884, Exp: 1411 |
| hired | 0 | 1 | FALSE | 2 | FAL: 59006, TRU: 13978 |
| success | 0 | 1 | FALSE | 2 | FAL: 44913, TRU: 28071 |
| solo | 0 | 1 | FALSE | 2 | FAL: 72868, TRU: 116 |
| oxygen_used | 0 | 1 | FALSE | 2 | FAL: 55215, TRU: 17769 |
| died | 0 | 1 | FALSE | 2 | FAL: 72055, TRU: 929 |
| injured | 0 | 1 | FALSE | 2 | FAL: 71333, TRU: 1651 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| year | 0 | 1 | 2001.00 | 14.12 | 1905 | 1992 | 2004 | 2012 | 2019 | ▁▁▁▃▇ |
| age | 0 | 1 | 37.34 | 10.39 | 7 | 29 | 36 | 44 | 85 | ▁▇▅▁▁ |
factors_vec %>% count(died)
## # A tibble: 2 × 2
## died n
## <fct> <int>
## 1 FALSE 72055
## 2 TRUE 929
factors_vec %>%
ggplot(aes(died)) +
geom_bar()
factors_vec %>%
ggplot(aes(died, year)) +
geom_boxplot()
Correlation plot
# Step 1: Binarize
data_binarized <- factors_vec %>%
binarize()
data_binarized %>% glimpse()
## Rows: 72,984
## Columns: 69
## $ `member_id__ACHN15301-01` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `member_id__-OTHER` <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ peak_name__Ama_Dablam <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ peak_name__Annapurna_I <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Annapurna_IV <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Baruntse <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Cho_Oyu <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Dhaulagiri_I <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Everest <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Himlung_Himal <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Kangchenjunga <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Lhotse <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Makalu <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Manaslu <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ peak_name__Pumori <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `peak_name__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `year__-Inf_1992` <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ year__1992_2004 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ year__2004_2012 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ year__2012_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ season__Autumn <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ season__Spring <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, …
## $ season__Winter <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `season__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ sex__F <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ sex__M <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `age__-Inf_29` <dbl> 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, …
## $ age__29_36 <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ age__36_44 <dbl> 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, …
## $ age__44_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Australia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Austria <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Canada <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__China <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__France <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ citizenship__Germany <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__India <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Italy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Japan <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Nepal <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Netherlands <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__New_Zealand <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Poland <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Russia <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__S_Korea <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Spain <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__Switzerland <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__UK <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ citizenship__USA <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, …
## $ citizenship__W_Germany <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ `citizenship__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ expedition_role__Climber <dbl> 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, …
## $ expedition_role__Deputy_Leader <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ expedition_role__Exp_Doctor <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `expedition_role__H-A_Worker` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ expedition_role__Leader <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `expedition_role__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ hired__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ hired__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ success__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, …
## $ success__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, …
## $ solo__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `solo__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ oxygen_used__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ oxygen_used__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ died__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ died__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ injured__FALSE <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ injured__TRUE <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
# Step 2: Correlation
data_correlation <- data_binarized %>%
correlate(died__TRUE)
## Warning: correlate(): [Data Imbalance Detected] Consider sampling to balance the classes more than 5%
## Column with imbalance: died__TRUE
data_correlation
## # A tibble: 69 × 3
## feature bin correlation
## <fct> <chr> <dbl>
## 1 died FALSE -1
## 2 died TRUE 1
## 3 year -Inf_1992 0.0519
## 4 peak_name Annapurna_I 0.0336
## 5 success FALSE 0.0332
## 6 success TRUE -0.0332
## 7 peak_name Dhaulagiri_I 0.0290
## 8 peak_name Ama_Dablam -0.0281
## 9 peak_name Cho_Oyu -0.0241
## 10 year 2004_2012 -0.0211
## # ℹ 59 more rows
# Step 3: Plot
data_correlation %>%
correlationfunnel::plot_correlation_funnel()
## Warning: ggrepel: 41 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps
library(tidymodels)
factors_v <- factors_vec %>% sample_n(5000)
set.seed(879)
data_split <- initial_split(factors_v, strata = died)
data_train <- training(data_split)
data_test <- testing(data_split)
set.seed(891)
data_folds <- vfold_cv(data_train, strata = died)
data_folds
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [3375/375]> Fold01
## 2 <split [3375/375]> Fold02
## 3 <split [3375/375]> Fold03
## 4 <split [3375/375]> Fold04
## 5 <split [3375/375]> Fold05
## 6 <split [3375/375]> Fold06
## 7 <split [3375/375]> Fold07
## 8 <split [3375/375]> Fold08
## 9 <split [3375/375]> Fold09
## 10 <split [3375/375]> Fold10
library(themis)
library(embed)
## Warning: package 'embed' was built under R version 4.3.2
xgboost_recipe <-
recipe(formula = died ~ ., data = data_train) %>%
update_role(member_id, new_role = "id") %>%
step_other(peak_name, citizenship, expedition_role) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_numeric_predictors()) %>%
step_smote(died)
xgboost_recipe %>% prep() %>% juice() %>% glimpse()
## Rows: 7,390
## Columns: 25
## $ member_id <fct> AMAD91105-02, EVER18103-12, CHOY92104-02, A…
## $ year <dbl> -0.70330474, 1.22018746, -0.63206428, 0.579…
## $ age <dbl> -0.97119022, 0.46009908, -0.49409379, -0.58…
## $ died <fct> FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, F…
## $ peak_name_Cho.Oyu <dbl> -0.370157, -0.370157, 2.700836, -0.370157, …
## $ peak_name_Everest <dbl> -0.6313679, 1.5834401, -0.6313679, -0.63136…
## $ peak_name_Manaslu <dbl> -0.2672256, -0.2672256, -0.2672256, -0.2672…
## $ peak_name_other <dbl> -0.8346537, -0.8346537, -0.8346537, -0.8346…
## $ season_Spring <dbl> 0.9972039, 0.9972039, 0.9972039, -1.0025365…
## $ season_Summer <dbl> -0.08979068, -0.08979068, -0.08979068, -0.0…
## $ season_Winter <dbl> -0.1721805, -0.1721805, -0.1721805, -0.1721…
## $ sex_M <dbl> -3.131182, 0.319283, 0.319283, 0.319283, 0.…
## $ citizenship_Japan <dbl> -0.2985677, -0.2985677, -0.2985677, -0.2985…
## $ citizenship_Nepal <dbl> -0.4978494, -0.4978494, -0.4978494, -0.4978…
## $ citizenship_UK <dbl> -0.2812749, -0.2812749, -0.2812749, -0.2812…
## $ citizenship_USA <dbl> -0.3006811, -0.3006811, -0.3006811, -0.3006…
## $ citizenship_other <dbl> -1.0138290, 0.9860967, 0.9860967, 0.9860967…
## $ expedition_role_H.A.Worker <dbl> -0.469305, -0.469305, -0.469305, -0.469305,…
## $ expedition_role_Leader <dbl> -0.3871873, -0.3871873, -0.3871873, 2.58204…
## $ expedition_role_other <dbl> -0.328827, -0.328827, 3.040302, -0.328827, …
## $ hired_TRUE. <dbl> -0.490332, -0.490332, -0.490332, -0.490332,…
## $ success_TRUE. <dbl> -0.7863251, 1.2713994, -0.7863251, 1.271399…
## $ solo_TRUE. <dbl> -0.02309709, -0.02309709, -0.02309709, -0.0…
## $ oxygen_used_TRUE. <dbl> -0.5590011, 1.7884281, -0.5590011, -0.55900…
## $ injured_TRUE. <dbl> -0.1358797, -0.1358797, -0.1358797, -0.1358…
xgboost_spec <-
boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune()) %>%
set_mode("classification") %>%
set_engine("xgboost")
xgboost_workflow <- workflow() %>%
add_recipe(xgboost_recipe) %>%
add_model(xgboost_spec)
doParallel::registerDoParallel()
set.seed(87654)
xgboost_tune <-
tune_grid(xgboost_workflow,
resamples = data_folds,
grid = 5,
control = control_resamples(save_pred = TRUE))
tune::show_best(xgboost_tune, metric = "roc_auc")
## # A tibble: 5 × 10
## trees min_n tree_depth learn_rate .metric .estimator mean n std_err
## <int> <int> <int> <dbl> <chr> <chr> <dbl> <int> <dbl>
## 1 1261 19 10 0.00338 roc_auc binary 0.770 10 0.0277
## 2 474 30 14 0.0215 roc_auc binary 0.766 10 0.0364
## 3 1961 13 3 0.0763 roc_auc binary 0.740 10 0.0277
## 4 110 35 8 0.00300 roc_auc binary 0.735 10 0.0406
## 5 1153 3 5 0.149 roc_auc binary 0.709 10 0.0262
## # ℹ 1 more variable: .config <chr>
xgboost_fw <- tune::finalize_workflow(xgboost_workflow, tune::select_best(xgboost_tune, metric = "roc_auc"))
data_fit <- tune::last_fit(xgboost_fw, data_split)
tune::collect_metrics(data_fit)
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.928 Preprocessor1_Model1
## 2 roc_auc binary 0.600 Preprocessor1_Model1
tune::collect_predictions(data_fit) %>%
ggplot(aes(died, .pred_TRUE)) +
geom_point(alpha = 0.3, fill = "midnightblue") +
geom_abline(lty = 2, color = "gray50") +
coord_fixed()