Goal is to predict attrition, employees who are likely to leave the company.

Import data

library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.6
## ✔ forcats   1.0.1     ✔ stringr   1.6.0
## ✔ ggplot2   4.0.2     ✔ tibble    3.3.1
## ✔ lubridate 1.9.4     ✔ tidyr     1.3.2
## ✔ purrr     1.2.1     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(correlationfunnel)
## ══ Using correlationfunnel? ════════════════════════════════════════════════════
## You might also be interested in applied data science training for business.
## </> Learn more at - www.business-science.io </>
library(recipes)
## 
## Attaching package: 'recipes'
## 
## The following object is masked from 'package:stringr':
## 
##     fixed
## 
## The following object is masked from 'package:stats':
## 
##     step
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.4.1 ──
## ✔ broom        1.0.12     ✔ tailor       0.1.0 
## ✔ dials        1.4.2      ✔ tune         2.0.1 
## ✔ infer        1.1.0      ✔ workflows    1.3.0 
## ✔ modeldata    1.5.1      ✔ workflowsets 1.1.1 
## ✔ parsnip      1.4.1      ✔ yardstick    1.3.2 
## ✔ rsample      1.3.2      
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
library(themis)
library(doParallel)
## Loading required package: foreach
## 
## Attaching package: 'foreach'
## 
## The following objects are masked from 'package:purrr':
## 
##     accumulate, when
## 
## Loading required package: iterators
## Loading required package: parallel
data <- read_csv("WA_Fn-UseC_-HR-Employee-Attrition.csv")
## Rows: 1470 Columns: 35
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (9): Attrition, BusinessTravel, Department, EducationField, Gender, Job...
## dbl (26): Age, DailyRate, DistanceFromHome, Education, EmployeeCount, Employ...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Issues with data

factors_vec <- data %>% select(Education,EnvironmentSatisfaction, JobInvolvement,
JobSatisfaction, PerformanceRating, RelationshipSatisfaction, WorkLifeBalance) %>% names()

data_clean <- data %>%
  
  # Address factors imported as numeric
  mutate(across(all_of(factors_vec), as.factor)) %>%
  
  # Drop zero-variance variables
  select(-c(Over18,EmployeeCount,StandardHours)) %>%
  
  # Recode Attrition
  mutate(Attrition = if_else(Attrition == "Yes", "Left", Attrition))

Explore data

data_clean %>% count(Attrition)
## # A tibble: 2 × 2
##   Attrition     n
##   <chr>     <int>
## 1 Left        237
## 2 No         1233
data_clean %>%
  ggplot(aes(Attrition)) +
  geom_bar()

attrition vs. monthly income

data_clean %>%
  ggplot(aes(Attrition, MonthlyIncome)) +
  geom_boxplot()

correlation plot

# step 1: binarize
data_binarized <- data_clean %>%
  select(-EmployeeNumber) %>%
  binarize()

data_binarized %>% glimpse()
## Rows: 1,470
## Columns: 120
## $ `Age__-Inf_30`                       <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Age__30_36                           <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, …
## $ Age__36_43                           <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ Age__43_Inf                          <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ Attrition__Left                      <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ Attrition__No                        <dbl> 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `BusinessTravel__Non-Travel`         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ BusinessTravel__Travel_Frequently    <dbl> 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ BusinessTravel__Travel_Rarely        <dbl> 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, …
## $ `DailyRate__-Inf_465`                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ DailyRate__465_802                   <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ DailyRate__802_1157                  <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, …
## $ DailyRate__1157_Inf                  <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, …
## $ Department__Human_Resources          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `Department__Research_&_Development` <dbl> 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ Department__Sales                    <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `DistanceFromHome__-Inf_2`           <dbl> 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__2_7                <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ DistanceFromHome__7_14               <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__14_Inf             <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, …
## $ Education__1                         <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Education__2                         <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ Education__3                         <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, …
## $ Education__4                         <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Education__5                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Human_Resources      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Life_Sciences        <dbl> 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ EducationField__Marketing            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Medical              <dbl> 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, …
## $ EducationField__Other                <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Technical_Degree     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__1           <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ EnvironmentSatisfaction__2           <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__3           <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, …
## $ EnvironmentSatisfaction__4           <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ Gender__Female                       <dbl> 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ Gender__Male                         <dbl> 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ `HourlyRate__-Inf_48`                <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, …
## $ HourlyRate__48_66                    <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ HourlyRate__66_83.75                 <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ HourlyRate__83.75_Inf                <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ JobInvolvement__1                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobInvolvement__2                    <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobInvolvement__3                    <dbl> 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, …
## $ JobInvolvement__4                    <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, …
## $ JobLevel__1                          <dbl> 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, …
## $ JobLevel__2                          <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobLevel__3                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobLevel__4                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobLevel__5                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Healthcare_Representative   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobRole__Human_Resources             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Laboratory_Technician       <dbl> 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, …
## $ JobRole__Manager                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Manufacturing_Director      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobRole__Research_Director           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Research_Scientist          <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Executive             <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Representative        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobSatisfaction__1                   <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ JobSatisfaction__2                   <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ JobSatisfaction__3                   <dbl> 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, …
## $ JobSatisfaction__4                   <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ MaritalStatus__Divorced              <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ MaritalStatus__Married               <dbl> 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, …
## $ MaritalStatus__Single                <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, …
## $ `MonthlyIncome__-Inf_2911`           <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, …
## $ MonthlyIncome__2911_4919             <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ MonthlyIncome__4919_8379             <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ MonthlyIncome__8379_Inf              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ `MonthlyRate__-Inf_8047`             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ MonthlyRate__8047_14235.5            <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, …
## $ MonthlyRate__14235.5_20461.5         <dbl> 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, …
## $ MonthlyRate__20461.5_Inf             <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ `NumCompaniesWorked__-Inf_1`         <dbl> 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, …
## $ NumCompaniesWorked__1_2              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ NumCompaniesWorked__2_4              <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ NumCompaniesWorked__4_Inf            <dbl> 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, …
## $ OverTime__No                         <dbl> 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ OverTime__Yes                        <dbl> 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ `PercentSalaryHike__-Inf_12`         <dbl> 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__12_14             <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, …
## $ PercentSalaryHike__14_18             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__18_Inf            <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ PerformanceRating__3                 <dbl> 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, …
## $ PerformanceRating__4                 <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ RelationshipSatisfaction__1          <dbl> 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ RelationshipSatisfaction__2          <dbl> 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, …
## $ RelationshipSatisfaction__3          <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, …
## $ RelationshipSatisfaction__4          <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ StockOptionLevel__0                  <dbl> 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ StockOptionLevel__1                  <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ StockOptionLevel__2                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ StockOptionLevel__3                  <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ `TotalWorkingYears__-Inf_6`          <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ TotalWorkingYears__6_10              <dbl> 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ TotalWorkingYears__10_15             <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ TotalWorkingYears__15_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `TrainingTimesLastYear__-Inf_2`      <dbl> 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, …
## $ TrainingTimesLastYear__2_3           <dbl> 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, …
## $ TrainingTimesLastYear__3_Inf         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ WorkLifeBalance__1                   <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ WorkLifeBalance__2                   <dbl> 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, …
## $ WorkLifeBalance__3                   <dbl> 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, …
## $ WorkLifeBalance__4                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsAtCompany__-Inf_3`             <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsAtCompany__3_5                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsAtCompany__5_9                  <dbl> 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, …
## $ YearsAtCompany__9_Inf                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsInCurrentRole__-Inf_2`         <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsInCurrentRole__2_3              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ YearsInCurrentRole__3_7              <dbl> 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, …
## $ YearsInCurrentRole__7_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsSinceLastPromotion__-Inf_1`    <dbl> 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, …
## $ YearsSinceLastPromotion__1_3         <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ YearsSinceLastPromotion__3_Inf       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `YearsWithCurrManager__-Inf_2`       <dbl> 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsWithCurrManager__2_3            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsWithCurrManager__3_7            <dbl> 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, …
## $ YearsWithCurrManager__7_Inf          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
# step 2: correlation
data_correlation <- data_binarized %>%
  correlate(Attrition__Left)

data_correlation
## # A tibble: 120 × 3
##    feature           bin       correlation
##    <fct>             <chr>           <dbl>
##  1 Attrition         Left            1    
##  2 Attrition         No             -1    
##  3 OverTime          No             -0.246
##  4 OverTime          Yes             0.246
##  5 JobLevel          1               0.213
##  6 MonthlyIncome     -Inf_2911       0.207
##  7 StockOptionLevel  0               0.195
##  8 YearsAtCompany    -Inf_3          0.183
##  9 MaritalStatus     Single          0.175
## 10 TotalWorkingYears -Inf_6          0.169
## # ℹ 110 more rows
# step 3: plot
data_correlation %>%
  correlationfunnel::plot_correlation_funnel()
## Warning: The `size` argument of `element_line()` is deprecated as of ggplot2 3.4.0.
## ℹ Please use the `linewidth` argument instead.
## ℹ The deprecated feature was likely used in the correlationfunnel package.
##   Please report the issue at
##   <https://github.com/business-science/correlationfunnel/issues>.
## This warning is displayed once per session.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Warning: The `size` argument of `element_rect()` is deprecated as of ggplot2 3.4.0.
## ℹ Please use the `linewidth` argument instead.
## ℹ The deprecated feature was likely used in the correlationfunnel package.
##   Please report the issue at
##   <https://github.com/business-science/correlationfunnel/issues>.
## This warning is displayed once per session.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Warning: ggrepel: 73 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Model building

Split data

set.seed(1234)

data_split <- initial_split(data_clean, strata = Attrition)
data_train <- training(data_split)
data_test <- testing(data_split)

data_cv <- vfold_cv(data_train, strata = Attrition)

Preprocess data

xgboost_recipes <- recipes::recipe(Attrition ~ ., data = data_train) %>%
  update_role(EmployeeNumber, new_role = "ID") %>%
  step_dummy(all_nominal_predictors()) %>%
  step_smote(Attrition)

Specify model

library(usemodels)
usemodels::use_xgboost(Attrition ~ ., data = data_train)
## xgboost_recipe <- 
##   recipe(formula = Attrition ~ ., data = data_train) %>% 
##   step_zv(all_predictors()) 
## 
## xgboost_spec <- 
##   boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
##     loss_reduction = tune(), sample_size = tune()) %>% 
##   set_mode("classification") %>% 
##   set_engine("xgboost") 
## 
## xgboost_workflow <- 
##   workflow() %>% 
##   add_recipe(xgboost_recipe) %>% 
##   add_model(xgboost_spec) 
## 
## set.seed(59891)
## xgboost_tune <-
##   tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
xgboost_spec <- 
  boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(),
             loss_reduction = tune(), sample_size = tune()) %>%
  set_mode("classification") %>%
  set_engine("xgboost")

xgboost_workflow <- 
  workflow() %>%
  add_recipe(xgboost_recipes) %>%
  add_model(xgboost_spec)

Tune hyperparameters

doParallel::registerDoParallel()

set.seed(65743)
xgboost_tune <- 
  tune_grid(xgboost_workflow,
            resamples = data_cv,
            grid = 5,
            control = control_grid(save_pred = TRUE))

Model evaluation

Identify optimal values for hyperparameters

collect_metrics(xgboost_tune)
## # A tibble: 15 × 12
##    trees min_n tree_depth learn_rate loss_reduction sample_size .metric    
##    <int> <int>      <int>      <dbl>          <dbl>       <dbl> <chr>      
##  1     1    30         15    0.0750    0.0422             0.625 accuracy   
##  2     1    30         15    0.0750    0.0422             0.625 brier_class
##  3     1    30         15    0.0750    0.0422             0.625 roc_auc    
##  4   500    21          1    0.316     0.0000562          1     accuracy   
##  5   500    21          1    0.316     0.0000562          1     brier_class
##  6   500    21          1    0.316     0.0000562          1     roc_auc    
##  7  1000     2          8    0.0178    0.0000000001       0.5   accuracy   
##  8  1000     2          8    0.0178    0.0000000001       0.5   brier_class
##  9  1000     2          8    0.0178    0.0000000001       0.5   roc_auc    
## 10  1500    11          4    0.001    31.6                0.75  accuracy   
## 11  1500    11          4    0.001    31.6                0.75  brier_class
## 12  1500    11          4    0.001    31.6                0.75  roc_auc    
## 13  2000    40         11    0.00422   0.0000000750       0.875 accuracy   
## 14  2000    40         11    0.00422   0.0000000750       0.875 brier_class
## 15  2000    40         11    0.00422   0.0000000750       0.875 roc_auc    
## # ℹ 5 more variables: .estimator <chr>, mean <dbl>, n <int>, std_err <dbl>,
## #   .config <chr>
collect_predictions(xgboost_tune) %>%
  group_by(id) %>%
  roc_curve(Attrition, .pred_Left) %>%
  autoplot()

Fit the model for the last time

xgboost_last <- xgboost_workflow %>%
  finalize_workflow(select_best(xgboost_tune, metric = "accuracy")) %>%
  last_fit(data_split)

collect_metrics(xgboost_last)
## # A tibble: 3 × 4
##   .metric     .estimator .estimate .config        
##   <chr>       <chr>          <dbl> <chr>          
## 1 accuracy    binary         0.848 pre0_mod0_post0
## 2 roc_auc     binary         0.812 pre0_mod0_post0
## 3 brier_class binary         0.110 pre0_mod0_post0
collect_predictions(xgboost_last) %>%
  yardstick::conf_mat(Attrition, .pred_class) %>%
  autoplot()

Variable importance

library(vip)
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi
xgboost_last %>%
  workflows::extract_fit_engine() %>%
  vip()