Goal is to predict attrition, employees who are likely to leave the company

Import Data

library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.1     ✔ tibble    3.2.1
## ✔ lubridate 1.9.4     ✔ tidyr     1.3.1
## ✔ purrr     1.0.4     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(correlationfunnel)
## ══ Using correlationfunnel? ════════════════════════════════════════════════════
## You might also be interested in applied data science training for business.
## </> Learn more at - www.business-science.io </>
data <- read_csv("../00_data/WA_Fn-UseC_-HR-Employee-Attrition.csv")
## Rows: 1470 Columns: 35
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (9): Attrition, BusinessTravel, Department, EducationField, Gender, Job...
## dbl (26): Age, DailyRate, DistanceFromHome, Education, EmployeeCount, Employ...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Explore Data

skimr::skim(data)
Data summary
Name data
Number of rows 1470
Number of columns 35
_______________________
Column type frequency:
character 9
numeric 26
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
Attrition 0 1 2 3 0 2 0
BusinessTravel 0 1 10 17 0 3 0
Department 0 1 5 22 0 3 0
EducationField 0 1 5 16 0 6 0
Gender 0 1 4 6 0 2 0
JobRole 0 1 7 25 0 9 0
MaritalStatus 0 1 6 8 0 3 0
Over18 0 1 1 1 0 1 0
OverTime 0 1 2 3 0 2 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Age 0 1 36.92 9.14 18 30.00 36.0 43.00 60 ▂▇▇▃▂
DailyRate 0 1 802.49 403.51 102 465.00 802.0 1157.00 1499 ▇▇▇▇▇
DistanceFromHome 0 1 9.19 8.11 1 2.00 7.0 14.00 29 ▇▅▂▂▂
Education 0 1 2.91 1.02 1 2.00 3.0 4.00 5 ▂▃▇▆▁
EmployeeCount 0 1 1.00 0.00 1 1.00 1.0 1.00 1 ▁▁▇▁▁
EmployeeNumber 0 1 1024.87 602.02 1 491.25 1020.5 1555.75 2068 ▇▇▇▇▇
EnvironmentSatisfaction 0 1 2.72 1.09 1 2.00 3.0 4.00 4 ▅▅▁▇▇
HourlyRate 0 1 65.89 20.33 30 48.00 66.0 83.75 100 ▇▇▇▇▇
JobInvolvement 0 1 2.73 0.71 1 2.00 3.0 3.00 4 ▁▃▁▇▁
JobLevel 0 1 2.06 1.11 1 1.00 2.0 3.00 5 ▇▇▃▂▁
JobSatisfaction 0 1 2.73 1.10 1 2.00 3.0 4.00 4 ▅▅▁▇▇
MonthlyIncome 0 1 6502.93 4707.96 1009 2911.00 4919.0 8379.00 19999 ▇▅▂▁▂
MonthlyRate 0 1 14313.10 7117.79 2094 8047.00 14235.5 20461.50 26999 ▇▇▇▇▇
NumCompaniesWorked 0 1 2.69 2.50 0 1.00 2.0 4.00 9 ▇▃▂▂▁
PercentSalaryHike 0 1 15.21 3.66 11 12.00 14.0 18.00 25 ▇▅▃▂▁
PerformanceRating 0 1 3.15 0.36 3 3.00 3.0 3.00 4 ▇▁▁▁▂
RelationshipSatisfaction 0 1 2.71 1.08 1 2.00 3.0 4.00 4 ▅▅▁▇▇
StandardHours 0 1 80.00 0.00 80 80.00 80.0 80.00 80 ▁▁▇▁▁
StockOptionLevel 0 1 0.79 0.85 0 0.00 1.0 1.00 3 ▇▇▁▂▁
TotalWorkingYears 0 1 11.28 7.78 0 6.00 10.0 15.00 40 ▇▇▂▁▁
TrainingTimesLastYear 0 1 2.80 1.29 0 2.00 3.0 3.00 6 ▂▇▇▂▃
WorkLifeBalance 0 1 2.76 0.71 1 2.00 3.0 3.00 4 ▁▃▁▇▂
YearsAtCompany 0 1 7.01 6.13 0 3.00 5.0 9.00 40 ▇▂▁▁▁
YearsInCurrentRole 0 1 4.23 3.62 0 2.00 3.0 7.00 18 ▇▃▂▁▁
YearsSinceLastPromotion 0 1 2.19 3.22 0 0.00 1.0 3.00 15 ▇▁▁▁▁
YearsWithCurrManager 0 1 4.12 3.57 0 2.00 3.0 7.00 17 ▇▂▅▁▁
factors_vec <- data %>% select(Education, EnvironmentSatisfaction, JobInvolvement, JobSatisfaction, PerformanceRating, RelationshipSatisfaction, WorkLifeBalance, JobLevel, StockOptionLevel) %>% names()

data_clean <- data %>%
  
  # Address factors imported as numeric
  mutate(across(all_of(factors_vec), as.factor)) %>%
    
  # Drop zero-variance variables
  select(-c(Over18, EmployeeCount, StandardHours)) %>%
    
  # Recode Attrition
    mutate(Attrition = if_else(Attrition == "Yes", "Left", Attrition))

Explore Data

data_clean %>% count(Attrition)
## # A tibble: 2 × 2
##   Attrition     n
##   <chr>     <int>
## 1 Left        237
## 2 No         1233
data_clean %>%
  ggplot(aes(Attrition)) +
  geom_bar()

Attrition vs Monthly Income

data_clean %>%
  ggplot(aes(Attrition, MonthlyIncome)) +
  geom_boxplot()

Correlation Plot

# Step 1: binarize
data_binarized <- data_clean %>%
  select(-EmployeeNumber) %>%
  binarize()

data_binarized %>% glimpse()
## Rows: 1,470
## Columns: 120
## $ `Age__-Inf_30`                       <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Age__30_36                           <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, …
## $ Age__36_43                           <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ Age__43_Inf                          <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ Attrition__Left                      <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ Attrition__No                        <dbl> 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `BusinessTravel__Non-Travel`         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ BusinessTravel__Travel_Frequently    <dbl> 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ BusinessTravel__Travel_Rarely        <dbl> 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, …
## $ `DailyRate__-Inf_465`                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ DailyRate__465_802                   <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ DailyRate__802_1157                  <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, …
## $ DailyRate__1157_Inf                  <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, …
## $ Department__Human_Resources          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `Department__Research_&_Development` <dbl> 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ Department__Sales                    <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `DistanceFromHome__-Inf_2`           <dbl> 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__2_7                <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ DistanceFromHome__7_14               <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__14_Inf             <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, …
## $ Education__1                         <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Education__2                         <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ Education__3                         <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, …
## $ Education__4                         <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Education__5                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Human_Resources      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Life_Sciences        <dbl> 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ EducationField__Marketing            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Medical              <dbl> 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, …
## $ EducationField__Other                <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Technical_Degree     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__1           <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ EnvironmentSatisfaction__2           <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__3           <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, …
## $ EnvironmentSatisfaction__4           <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ Gender__Female                       <dbl> 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ Gender__Male                         <dbl> 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ `HourlyRate__-Inf_48`                <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, …
## $ HourlyRate__48_66                    <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ HourlyRate__66_83.75                 <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ HourlyRate__83.75_Inf                <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ JobInvolvement__1                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobInvolvement__2                    <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobInvolvement__3                    <dbl> 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, …
## $ JobInvolvement__4                    <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, …
## $ JobLevel__1                          <dbl> 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, …
## $ JobLevel__2                          <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobLevel__3                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobLevel__4                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobLevel__5                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Healthcare_Representative   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobRole__Human_Resources             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Laboratory_Technician       <dbl> 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, …
## $ JobRole__Manager                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Manufacturing_Director      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobRole__Research_Director           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Research_Scientist          <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Executive             <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Representative        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobSatisfaction__1                   <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ JobSatisfaction__2                   <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ JobSatisfaction__3                   <dbl> 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, …
## $ JobSatisfaction__4                   <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ MaritalStatus__Divorced              <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ MaritalStatus__Married               <dbl> 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, …
## $ MaritalStatus__Single                <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, …
## $ `MonthlyIncome__-Inf_2911`           <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, …
## $ MonthlyIncome__2911_4919             <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ MonthlyIncome__4919_8379             <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ MonthlyIncome__8379_Inf              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ `MonthlyRate__-Inf_8047`             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ MonthlyRate__8047_14235.5            <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, …
## $ MonthlyRate__14235.5_20461.5         <dbl> 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, …
## $ MonthlyRate__20461.5_Inf             <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ `NumCompaniesWorked__-Inf_1`         <dbl> 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, …
## $ NumCompaniesWorked__1_2              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ NumCompaniesWorked__2_4              <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ NumCompaniesWorked__4_Inf            <dbl> 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, …
## $ OverTime__No                         <dbl> 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ OverTime__Yes                        <dbl> 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ `PercentSalaryHike__-Inf_12`         <dbl> 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__12_14             <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, …
## $ PercentSalaryHike__14_18             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__18_Inf            <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ PerformanceRating__3                 <dbl> 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, …
## $ PerformanceRating__4                 <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ RelationshipSatisfaction__1          <dbl> 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ RelationshipSatisfaction__2          <dbl> 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, …
## $ RelationshipSatisfaction__3          <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, …
## $ RelationshipSatisfaction__4          <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ StockOptionLevel__0                  <dbl> 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ StockOptionLevel__1                  <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ StockOptionLevel__2                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ StockOptionLevel__3                  <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ `TotalWorkingYears__-Inf_6`          <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ TotalWorkingYears__6_10              <dbl> 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ TotalWorkingYears__10_15             <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ TotalWorkingYears__15_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `TrainingTimesLastYear__-Inf_2`      <dbl> 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, …
## $ TrainingTimesLastYear__2_3           <dbl> 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, …
## $ TrainingTimesLastYear__3_Inf         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ WorkLifeBalance__1                   <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ WorkLifeBalance__2                   <dbl> 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, …
## $ WorkLifeBalance__3                   <dbl> 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, …
## $ WorkLifeBalance__4                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsAtCompany__-Inf_3`             <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsAtCompany__3_5                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsAtCompany__5_9                  <dbl> 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, …
## $ YearsAtCompany__9_Inf                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsInCurrentRole__-Inf_2`         <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsInCurrentRole__2_3              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ YearsInCurrentRole__3_7              <dbl> 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, …
## $ YearsInCurrentRole__7_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsSinceLastPromotion__-Inf_1`    <dbl> 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, …
## $ YearsSinceLastPromotion__1_3         <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ YearsSinceLastPromotion__3_Inf       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `YearsWithCurrManager__-Inf_2`       <dbl> 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsWithCurrManager__2_3            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsWithCurrManager__3_7            <dbl> 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, …
## $ YearsWithCurrManager__7_Inf          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
# Step 2: correlation
data_correlation <- data_binarized %>%
  correlate(Attrition__Left)

data_correlation
## # A tibble: 120 × 3
##    feature           bin       correlation
##    <fct>             <chr>           <dbl>
##  1 Attrition         Left            1    
##  2 Attrition         No             -1    
##  3 OverTime          No             -0.246
##  4 OverTime          Yes             0.246
##  5 JobLevel          1               0.213
##  6 MonthlyIncome     -Inf_2911       0.207
##  7 StockOptionLevel  0               0.195
##  8 YearsAtCompany    -Inf_3          0.183
##  9 MaritalStatus     Single          0.175
## 10 TotalWorkingYears -Inf_6          0.169
## # ℹ 110 more rows
# Step 3: Plot
data_correlation %>%
  correlationfunnel::plot_correlation_funnel()
## Warning: ggrepel: 73 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

## Model Building

Split Data

library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom        1.0.7     ✔ rsample      1.2.1
## ✔ dials        1.4.0     ✔ tune         1.2.1
## ✔ infer        1.0.7     ✔ workflows    1.1.4
## ✔ modeldata    1.4.0     ✔ workflowsets 1.1.0
## ✔ parsnip      1.3.0     ✔ yardstick    1.3.2
## ✔ recipes      1.1.1
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
## • Learn how to get started at https://www.tidymodels.org/start/
set.seed(1234)
data <- data_clean %>% sample_n(100)

data_split <- initial_split(data_clean, strata = Attrition)
data_train <- training(data_split)
data_test <- testing(data_split)

data_cv <- rsample::vfold_cv(data_train, strata = Attrition)
data_cv
## #  10-fold cross-validation using stratification 
## # A tibble: 10 × 2
##    splits            id    
##    <list>            <chr> 
##  1 <split [990/111]> Fold01
##  2 <split [990/111]> Fold02
##  3 <split [990/111]> Fold03
##  4 <split [990/111]> Fold04
##  5 <split [991/110]> Fold05
##  6 <split [991/110]> Fold06
##  7 <split [991/110]> Fold07
##  8 <split [992/109]> Fold08
##  9 <split [992/109]> Fold09
## 10 <split [992/109]> Fold10
library(themis)
## Warning: package 'themis' was built under R version 4.4.3
xgboost_recipe <- recipes::recipe(Attrition ~ ., data = data_train) %>%
    update_role(EmployeeNumber, new_role = "ID") %>%
    step_dummy(all_nominal_predictors()) %>%
    step_YeoJohnson(DistanceFromHome, MonthlyIncome, NumCompaniesWorked,
                    PercentSalaryHike, TotalWorkingYears, starts_with("Years")) %>%
    step_normalize(all_numeric_predictors()) %>%
    step_pca(all_numeric_predictors(), threshold = .99)%>%
    step_smote(Attrition) 
        
    
    xgboost_recipe %>% prep() %>% juice() %>% glimpse()
## Rows: 1,848
## Columns: 55
## $ EmployeeNumber <dbl> 19, 27, 31, 33, 42, 47, 55, 58, 64, 65, 90, 118, 137, 1…
## $ Attrition      <fct> Left, Left, Left, Left, Left, Left, Left, Left, Left, L…
## $ PC01           <dbl> -2.5408074, -1.1373048, -1.6165450, -0.7146510, -1.9523…
## $ PC02           <dbl> -1.0750269, 1.8861792, -1.0954849, -1.0907458, 2.180957…
## $ PC03           <dbl> -0.43971157, 1.32357749, -0.95149664, -2.41103759, 2.24…
## $ PC04           <dbl> -1.34414052, -0.96050860, -1.56475884, 0.24168915, 0.18…
## $ PC05           <dbl> -2.0002222, -2.5297942, 0.7492029, -2.1753056, 2.687550…
## $ PC06           <dbl> -0.01695946, 0.31653822, -0.92773875, 1.40466526, 1.396…
## $ PC07           <dbl> -0.36271712, 0.16217554, -0.47505146, 0.75583638, -0.04…
## $ PC08           <dbl> 2.23583978, -0.87866085, 0.26925166, -2.39633041, 0.473…
## $ PC09           <dbl> 0.64670653, -0.04543999, -1.66535729, 0.20726907, -0.14…
## $ PC10           <dbl> 0.78925511, 1.34308972, -0.30122077, 0.07338142, 0.8990…
## $ PC11           <dbl> -0.876412194, -3.247754656, -0.006523447, -2.395912600,…
## $ PC12           <dbl> 0.74295506, -0.02192118, -0.91219547, 1.06797897, -0.84…
## $ PC13           <dbl> 1.55811726, 2.17286307, 0.80868236, 1.46329352, -0.7973…
## $ PC14           <dbl> 0.14255426, -1.15507431, 1.45972495, 1.36343587, -0.887…
## $ PC15           <dbl> -0.51038097, -2.17807250, 0.40425699, -1.41373786, -2.0…
## $ PC16           <dbl> -1.48860724, -0.51209227, 0.34268316, 0.53585198, 1.335…
## $ PC17           <dbl> -1.00279800, 1.21081549, -0.38836101, -0.10545648, 1.28…
## $ PC18           <dbl> -0.79053703, -0.07181857, 0.34748319, 1.05026066, -0.25…
## $ PC19           <dbl> -0.64810826, 0.16461037, 0.40342647, 0.30939573, -1.709…
## $ PC20           <dbl> -1.68992645, 0.62045310, 0.58469091, 0.67257274, 1.3170…
## $ PC21           <dbl> -0.84775520, -1.06666492, 1.86717611, 0.80000089, -0.45…
## $ PC22           <dbl> 0.56939618, -0.61420691, 0.43486881, 0.69950105, -3.159…
## $ PC23           <dbl> 0.47817791, 0.86316036, -0.42580172, -0.53319176, -0.49…
## $ PC24           <dbl> 0.49808821, -0.17973295, -0.84614627, -1.00624090, -0.0…
## $ PC25           <dbl> 0.78620071, 0.20493504, 0.09117069, -0.18327167, -2.509…
## $ PC26           <dbl> -0.127908095, 1.299112284, -0.095477728, 0.724729929, -…
## $ PC27           <dbl> 0.62601740, -0.98396158, -0.60288799, 0.52831676, -2.11…
## $ PC28           <dbl> -0.6806088, 0.4130443, -1.5616478, -0.5402558, 0.597098…
## $ PC29           <dbl> 0.24233079, -0.07014478, 0.92811689, -0.41890605, 1.514…
## $ PC30           <dbl> -0.43537978, 0.46354129, 0.94194097, -1.14419845, -0.82…
## $ PC31           <dbl> 1.2513049, -0.7568263, 0.1233751, 1.8231411, 0.4716849,…
## $ PC32           <dbl> -0.83558512, 0.59227146, -0.24156444, 1.16704391, 0.353…
## $ PC33           <dbl> -0.4477988, -0.7770123, -0.4835247, -0.7985294, -0.2408…
## $ PC34           <dbl> -0.1030229, -1.5638596, 0.4217827, 0.5491512, -1.430093…
## $ PC35           <dbl> -0.87421858, 0.37285651, 0.54985090, 0.80555081, -0.193…
## $ PC36           <dbl> 2.13676028, 1.75303488, -1.30700219, 1.19592865, 1.6970…
## $ PC37           <dbl> -0.59557223, 1.26862302, 0.92549208, -0.74904142, -0.70…
## $ PC38           <dbl> 0.76338383, -0.60478939, -0.53455128, 0.71756491, -1.16…
## $ PC39           <dbl> -1.19548022, -2.00890111, -0.47083477, -0.56457379, -2.…
## $ PC40           <dbl> 0.04868816, -0.65311107, 0.17374232, 0.30621479, -0.435…
## $ PC41           <dbl> 0.96241115, 1.19578777, -0.36893444, -1.18588089, 0.678…
## $ PC42           <dbl> 0.95907329, 0.49094719, -0.07078001, 0.19491976, -1.185…
## $ PC43           <dbl> -0.89127353, -0.80372168, -0.15338440, 0.08430098, 0.07…
## $ PC44           <dbl> -0.30105375, -0.31402824, -0.54386701, -0.38452616, 1.1…
## $ PC45           <dbl> -0.5067550766, 0.6858554310, 0.6060915011, 0.3147418547…
## $ PC46           <dbl> -0.174188501, 1.132482844, 0.914117550, 1.340247553, 0.…
## $ PC47           <dbl> -0.14292336, -0.37286358, -0.33491796, -0.13965105, -0.…
## $ PC48           <dbl> 0.50964670, 0.07342574, 0.09677421, 0.47488762, 0.05715…
## $ PC49           <dbl> 1.047428e-02, -1.880170e-01, 9.256232e-02, 4.470833e-01…
## $ PC50           <dbl> 0.15706359, -0.16317621, -0.18612135, 0.03997567, 0.490…
## $ PC51           <dbl> 0.31598906, -0.04902101, -0.21210018, -0.48368012, -0.9…
## $ PC52           <dbl> -0.23026881, -0.26587024, 0.87425856, 0.92635176, -0.51…
## $ PC53           <dbl> 0.071992418, -0.152065915, 0.296425363, 0.321754122, 0.…

Specify Model

xgboost_spec <- 
  boost_tree(trees = tune(), tree_depth = tune()) %>% 
  set_mode("classification") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_recipe) %>% 
  add_model(xgboost_spec) 

Tune Hyperparameters

tree_grid <- grid_regular(cost_complexity(),
                           tree_depth(),
                           levels = 5)


doParallel::registerDoParallel()

set.seed(65447)
xgboost_tune <-
  tune_grid(xgboost_workflow, 
            resamples = data_cv,
            grid = 5,
            control = control_grid(save_pred = TRUE))

Model Evaluation

Identify Optimal Values for Hyperparameters

collect_metrics(xgboost_tune)
## # A tibble: 15 × 8
##    trees tree_depth .metric     .estimator  mean     n std_err .config          
##    <int>      <int> <chr>       <chr>      <dbl> <int>   <dbl> <chr>            
##  1  1759          3 accuracy    binary     0.859    10 0.00850 Preprocessor1_Mo…
##  2  1759          3 brier_class binary     0.119    10 0.00815 Preprocessor1_Mo…
##  3  1759          3 roc_auc     binary     0.779    10 0.0213  Preprocessor1_Mo…
##  4   798          6 accuracy    binary     0.847    10 0.00994 Preprocessor1_Mo…
##  5   798          6 brier_class binary     0.125    10 0.00798 Preprocessor1_Mo…
##  6   798          6 roc_auc     binary     0.773    10 0.0216  Preprocessor1_Mo…
##  7   219          8 accuracy    binary     0.856    10 0.00732 Preprocessor1_Mo…
##  8   219          8 brier_class binary     0.120    10 0.00559 Preprocessor1_Mo…
##  9   219          8 roc_auc     binary     0.765    10 0.0212  Preprocessor1_Mo…
## 10   890         11 accuracy    binary     0.852    10 0.00811 Preprocessor1_Mo…
## 11   890         11 brier_class binary     0.122    10 0.00673 Preprocessor1_Mo…
## 12   890         11 roc_auc     binary     0.765    10 0.0216  Preprocessor1_Mo…
## 13  1246         13 accuracy    binary     0.844    10 0.00888 Preprocessor1_Mo…
## 14  1246         13 brier_class binary     0.124    10 0.00659 Preprocessor1_Mo…
## 15  1246         13 roc_auc     binary     0.762    10 0.0278  Preprocessor1_Mo…
collect_predictions(xgboost_tune) %>%
    group_by(id) %>%
    roc_curve(Attrition, .pred_Left) %>%
    autoplot()

Fit the model for the last time

xgboost_last <- xgboost_workflow %>%
    finalize_workflow(select_best(xgboost_tune, metric = "accuracy")) %>%
    last_fit(data_split)

collect_metrics(xgboost_last) 
## # A tibble: 3 × 4
##   .metric     .estimator .estimate .config             
##   <chr>       <chr>          <dbl> <chr>               
## 1 accuracy    binary         0.835 Preprocessor1_Model1
## 2 roc_auc     binary         0.775 Preprocessor1_Model1
## 3 brier_class binary         0.138 Preprocessor1_Model1
collect_predictions(xgboost_last) %>%
    yardstick::conf_mat(Attrition, .pred_class) %>%
    autoplot()

Variable Importance

library(vip)
## Warning: package 'vip' was built under R version 4.4.3
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi
xgboost_last %>%
    workflows::extract_fit_engine() %>%
    vip()

Conclusion

The previous model had accuracy 0f 0.851 and AUC of 0.753