library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.1     ✔ tibble    3.2.1
## ✔ lubridate 1.9.3     ✔ tidyr     1.3.1
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(correlationfunnel)
## ══ correlationfunnel Tip #1 ════════════════════════════════════════════════════
## Make sure your data is not overly imbalanced prior to using `correlate()`.
## If less than 5% imbalance, consider sampling. :)

Import data

data <- read_csv("../00_data/WA_Fn-UseC_-HR-Employee-Attrition.csv")
## Rows: 1470 Columns: 35
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (9): Attrition, BusinessTravel, Department, EducationField, Gender, Job...
## dbl (26): Age, DailyRate, DistanceFromHome, Education, EmployeeCount, Employ...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Explore data

skimr::skim(data)
Data summary
Name data
Number of rows 1470
Number of columns 35
_______________________
Column type frequency:
character 9
numeric 26
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
Attrition 0 1 2 3 0 2 0
BusinessTravel 0 1 10 17 0 3 0
Department 0 1 5 22 0 3 0
EducationField 0 1 5 16 0 6 0
Gender 0 1 4 6 0 2 0
JobRole 0 1 7 25 0 9 0
MaritalStatus 0 1 6 8 0 3 0
Over18 0 1 1 1 0 1 0
OverTime 0 1 2 3 0 2 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Age 0 1 36.92 9.14 18 30.00 36.0 43.00 60 ▂▇▇▃▂
DailyRate 0 1 802.49 403.51 102 465.00 802.0 1157.00 1499 ▇▇▇▇▇
DistanceFromHome 0 1 9.19 8.11 1 2.00 7.0 14.00 29 ▇▅▂▂▂
Education 0 1 2.91 1.02 1 2.00 3.0 4.00 5 ▂▃▇▆▁
EmployeeCount 0 1 1.00 0.00 1 1.00 1.0 1.00 1 ▁▁▇▁▁
EmployeeNumber 0 1 1024.87 602.02 1 491.25 1020.5 1555.75 2068 ▇▇▇▇▇
EnvironmentSatisfaction 0 1 2.72 1.09 1 2.00 3.0 4.00 4 ▅▅▁▇▇
HourlyRate 0 1 65.89 20.33 30 48.00 66.0 83.75 100 ▇▇▇▇▇
JobInvolvement 0 1 2.73 0.71 1 2.00 3.0 3.00 4 ▁▃▁▇▁
JobLevel 0 1 2.06 1.11 1 1.00 2.0 3.00 5 ▇▇▃▂▁
JobSatisfaction 0 1 2.73 1.10 1 2.00 3.0 4.00 4 ▅▅▁▇▇
MonthlyIncome 0 1 6502.93 4707.96 1009 2911.00 4919.0 8379.00 19999 ▇▅▂▁▂
MonthlyRate 0 1 14313.10 7117.79 2094 8047.00 14235.5 20461.50 26999 ▇▇▇▇▇
NumCompaniesWorked 0 1 2.69 2.50 0 1.00 2.0 4.00 9 ▇▃▂▂▁
PercentSalaryHike 0 1 15.21 3.66 11 12.00 14.0 18.00 25 ▇▅▃▂▁
PerformanceRating 0 1 3.15 0.36 3 3.00 3.0 3.00 4 ▇▁▁▁▂
RelationshipSatisfaction 0 1 2.71 1.08 1 2.00 3.0 4.00 4 ▅▅▁▇▇
StandardHours 0 1 80.00 0.00 80 80.00 80.0 80.00 80 ▁▁▇▁▁
StockOptionLevel 0 1 0.79 0.85 0 0.00 1.0 1.00 3 ▇▇▁▂▁
TotalWorkingYears 0 1 11.28 7.78 0 6.00 10.0 15.00 40 ▇▇▂▁▁
TrainingTimesLastYear 0 1 2.80 1.29 0 2.00 3.0 3.00 6 ▂▇▇▂▃
WorkLifeBalance 0 1 2.76 0.71 1 2.00 3.0 3.00 4 ▁▃▁▇▂
YearsAtCompany 0 1 7.01 6.13 0 3.00 5.0 9.00 40 ▇▂▁▁▁
YearsInCurrentRole 0 1 4.23 3.62 0 2.00 3.0 7.00 18 ▇▃▂▁▁
YearsSinceLastPromotion 0 1 2.19 3.22 0 0.00 1.0 3.00 15 ▇▁▁▁▁
YearsWithCurrManager 0 1 4.12 3.57 0 2.00 3.0 7.00 17 ▇▂▅▁▁

issues with data missing values factors or numeric variables Education, EnvironmentSatisfaction, JobInvolvement, JobSatisfaction, PerformanceRating, RelationshipSatisfaction, WorklifeBalance Zero Variance variables Over18, EmployCount, StandardHours Character variables Convert to numbers in recipe step Unbalanced target variables attrition id variable EmployeeNumber

factors_vec <- data %>% select( Education, EnvironmentSatisfaction, JobInvolvement, JobSatisfaction, PerformanceRating, RelationshipSatisfaction, WorkLifeBalance) %>% names()

data_clean <- data %>%
    # mutate(Education = Education %>% as.factor(), EnvironmentSatisfaction = EnvironmentSatisfaction %>% as.factor, JobInvolvement = JobInvolvement %>% as.factor(), JobSatisfaction = JobSatisfaction %>% as.factor(), PerformanceRating = PerformanceRating %>% as.factor(), RelationshipSatisfaction = RelationshipSatisfaction %>% as.factor(), WorkLifeBalance = WorkLifeBalance %>% as.factor()) %>%
   # address factors imported as numeric
      mutate(across(all_of(factors_vec), as.factor)) %>%
    # drop the zero variance variables
    select(-c(Over18, EmployeeCount, StandardHours)) %>%
  
  # recode attrition
  mutate(Attrition = if_else(Attrition == "Yes", "Left", Attrition))

Explore data

data_clean %>% count(Attrition)
## # A tibble: 2 × 2
##   Attrition     n
##   <chr>     <int>
## 1 Left        237
## 2 No         1233
data_clean %>%
    ggplot(aes(Attrition)) +
    geom_bar()

attrition vs monthly income

data_clean %>%
    ggplot(aes(Attrition, MonthlyIncome)) +
    geom_boxplot()

correlation plot

# step 1: binarize
data_binarized <- data_clean %>%
    select(-EmployeeNumber) %>%
    binarize()

data_binarized %>% glimpse()
## Rows: 1,470
## Columns: 120
## $ `Age__-Inf_30`                       <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Age__30_36                           <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, …
## $ Age__36_43                           <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ Age__43_Inf                          <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ Attrition__Left                      <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ Attrition__No                        <dbl> 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `BusinessTravel__Non-Travel`         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ BusinessTravel__Travel_Frequently    <dbl> 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ BusinessTravel__Travel_Rarely        <dbl> 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, …
## $ `DailyRate__-Inf_465`                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ DailyRate__465_802                   <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ DailyRate__802_1157                  <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, …
## $ DailyRate__1157_Inf                  <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, …
## $ Department__Human_Resources          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `Department__Research_&_Development` <dbl> 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ Department__Sales                    <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `DistanceFromHome__-Inf_2`           <dbl> 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__2_7                <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ DistanceFromHome__7_14               <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__14_Inf             <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, …
## $ Education__1                         <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Education__2                         <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ Education__3                         <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, …
## $ Education__4                         <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Education__5                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Human_Resources      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Life_Sciences        <dbl> 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ EducationField__Marketing            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Medical              <dbl> 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, …
## $ EducationField__Other                <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Technical_Degree     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__1           <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ EnvironmentSatisfaction__2           <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__3           <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, …
## $ EnvironmentSatisfaction__4           <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ Gender__Female                       <dbl> 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ Gender__Male                         <dbl> 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ `HourlyRate__-Inf_48`                <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, …
## $ HourlyRate__48_66                    <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ HourlyRate__66_83.75                 <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ HourlyRate__83.75_Inf                <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ JobInvolvement__1                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobInvolvement__2                    <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobInvolvement__3                    <dbl> 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, …
## $ JobInvolvement__4                    <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, …
## $ JobLevel__1                          <dbl> 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, …
## $ JobLevel__2                          <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobLevel__3                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobLevel__4                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobLevel__5                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Healthcare_Representative   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobRole__Human_Resources             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Laboratory_Technician       <dbl> 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, …
## $ JobRole__Manager                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Manufacturing_Director      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobRole__Research_Director           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Research_Scientist          <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Executive             <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Representative        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobSatisfaction__1                   <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ JobSatisfaction__2                   <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ JobSatisfaction__3                   <dbl> 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, …
## $ JobSatisfaction__4                   <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ MaritalStatus__Divorced              <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ MaritalStatus__Married               <dbl> 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, …
## $ MaritalStatus__Single                <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, …
## $ `MonthlyIncome__-Inf_2911`           <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, …
## $ MonthlyIncome__2911_4919             <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ MonthlyIncome__4919_8379             <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ MonthlyIncome__8379_Inf              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ `MonthlyRate__-Inf_8047`             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ MonthlyRate__8047_14235.5            <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, …
## $ MonthlyRate__14235.5_20461.5         <dbl> 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, …
## $ MonthlyRate__20461.5_Inf             <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ `NumCompaniesWorked__-Inf_1`         <dbl> 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, …
## $ NumCompaniesWorked__1_2              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ NumCompaniesWorked__2_4              <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ NumCompaniesWorked__4_Inf            <dbl> 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, …
## $ OverTime__No                         <dbl> 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ OverTime__Yes                        <dbl> 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ `PercentSalaryHike__-Inf_12`         <dbl> 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__12_14             <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, …
## $ PercentSalaryHike__14_18             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__18_Inf            <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ PerformanceRating__3                 <dbl> 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, …
## $ PerformanceRating__4                 <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ RelationshipSatisfaction__1          <dbl> 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ RelationshipSatisfaction__2          <dbl> 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, …
## $ RelationshipSatisfaction__3          <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, …
## $ RelationshipSatisfaction__4          <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ StockOptionLevel__0                  <dbl> 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ StockOptionLevel__1                  <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ StockOptionLevel__2                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ StockOptionLevel__3                  <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ `TotalWorkingYears__-Inf_6`          <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ TotalWorkingYears__6_10              <dbl> 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ TotalWorkingYears__10_15             <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ TotalWorkingYears__15_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `TrainingTimesLastYear__-Inf_2`      <dbl> 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, …
## $ TrainingTimesLastYear__2_3           <dbl> 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, …
## $ TrainingTimesLastYear__3_Inf         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ WorkLifeBalance__1                   <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ WorkLifeBalance__2                   <dbl> 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, …
## $ WorkLifeBalance__3                   <dbl> 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, …
## $ WorkLifeBalance__4                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsAtCompany__-Inf_3`             <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsAtCompany__3_5                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsAtCompany__5_9                  <dbl> 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, …
## $ YearsAtCompany__9_Inf                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsInCurrentRole__-Inf_2`         <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsInCurrentRole__2_3              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ YearsInCurrentRole__3_7              <dbl> 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, …
## $ YearsInCurrentRole__7_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsSinceLastPromotion__-Inf_1`    <dbl> 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, …
## $ YearsSinceLastPromotion__1_3         <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ YearsSinceLastPromotion__3_Inf       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `YearsWithCurrManager__-Inf_2`       <dbl> 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsWithCurrManager__2_3            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsWithCurrManager__3_7            <dbl> 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, …
## $ YearsWithCurrManager__7_Inf          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
# step 2: correlation
data_correlation <- data_binarized %>%
    correlate(Attrition__Left)

data_correlation
## # A tibble: 120 × 3
##    feature           bin       correlation
##    <fct>             <chr>           <dbl>
##  1 Attrition         Left            1    
##  2 Attrition         No             -1    
##  3 OverTime          Yes             0.246
##  4 OverTime          No             -0.246
##  5 JobLevel          1               0.213
##  6 MonthlyIncome     -Inf_2911       0.207
##  7 StockOptionLevel  0               0.195
##  8 YearsAtCompany    -Inf_3          0.183
##  9 MaritalStatus     Single          0.175
## 10 TotalWorkingYears -Inf_6          0.169
## # ℹ 110 more rows
data_correlation %>%
    correlationfunnel::plot_correlation_funnel()
## Warning: ggrepel: 72 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Model Building

Split data

library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom        1.0.5     ✔ rsample      1.2.1
## ✔ dials        1.2.1     ✔ tune         1.2.1
## ✔ infer        1.0.7     ✔ workflows    1.1.4
## ✔ modeldata    1.4.0     ✔ workflowsets 1.1.0
## ✔ parsnip      1.2.1     ✔ yardstick    1.3.1
## ✔ recipes      1.1.0
## Warning: package 'modeldata' was built under R version 4.3.3
## Warning: package 'recipes' was built under R version 4.3.3
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
## • Dig deeper into tidy modeling with R at https://www.tmwr.org
set.seed(1234)
# data_clean <- data_clean %>% sample_n(100)

data_split <- initial_split(data_clean, strata = Attrition)
data_train <- training(data_split)
data_test <- testing(data_split)


data_cv <- rsample::vfold_cv(data_train, strata = Attrition)

data_cv
## #  10-fold cross-validation using stratification 
## # A tibble: 10 × 2
##    splits            id    
##    <list>            <chr> 
##  1 <split [990/111]> Fold01
##  2 <split [990/111]> Fold02
##  3 <split [990/111]> Fold03
##  4 <split [990/111]> Fold04
##  5 <split [991/110]> Fold05
##  6 <split [991/110]> Fold06
##  7 <split [991/110]> Fold07
##  8 <split [992/109]> Fold08
##  9 <split [992/109]> Fold09
## 10 <split [992/109]> Fold10

Preprocess data

library(themis)

xgboost_rec <- recipes::recipe(Attrition ~ ., data = data_train) %>%
    update_role(EmployeeNumber, new_role = "ID") %>%
    step_dummy(all_nominal_predictors()) %>%
    step_smote(Attrition)

xgboost_rec %>% prep() %>% juice() %>% glimpse()
## Rows: 1,848
## Columns: 59
## $ Age                               <dbl> 37, 36, 32, 24, 50, 26, 41, 48, 36, …
## $ DailyRate                         <dbl> 1373, 1218, 1125, 813, 869, 1357, 13…
## $ DistanceFromHome                  <dbl> 2, 9, 16, 1, 3, 25, 12, 1, 9, 6, 6, …
## $ EmployeeNumber                    <dbl> 4, 27, 33, 45, 47, 55, 58, 64, 90, 1…
## $ HourlyRate                        <dbl> 92, 82, 72, 61, 86, 48, 49, 98, 79, …
## $ JobLevel                          <dbl> 1, 1, 1, 1, 1, 1, 5, 3, 1, 1, 1, 2, …
## $ MonthlyIncome                     <dbl> 2090, 3407, 3919, 2293, 2683, 2293, …
## $ MonthlyRate                       <dbl> 2396, 6986, 4681, 3020, 3810, 10558,…
## $ NumCompaniesWorked                <dbl> 6, 7, 1, 2, 1, 1, 1, 9, 0, 4, 1, 1, …
## $ PercentSalaryHike                 <dbl> 15, 23, 22, 16, 14, 12, 12, 13, 17, …
## $ StockOptionLevel                  <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ TotalWorkingYears                 <dbl> 7, 10, 10, 6, 3, 1, 23, 23, 2, 7, 1,…
## $ TrainingTimesLastYear             <dbl> 3, 4, 5, 2, 2, 2, 0, 2, 0, 3, 5, 1, …
## $ YearsAtCompany                    <dbl> 0, 5, 10, 2, 3, 1, 22, 1, 1, 3, 1, 6…
## $ YearsInCurrentRole                <dbl> 0, 3, 2, 0, 2, 0, 15, 0, 0, 2, 0, 4,…
## $ YearsSinceLastPromotion           <dbl> 0, 0, 6, 2, 0, 0, 15, 0, 0, 0, 1, 0,…
## $ YearsWithCurrManager              <dbl> 0, 3, 7, 0, 2, 1, 8, 0, 0, 2, 0, 3, …
## $ Attrition                         <fct> Left, Left, Left, Left, Left, Left, …
## $ BusinessTravel_Travel_Frequently  <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ BusinessTravel_Travel_Rarely      <dbl> 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, …
## $ Department_Research...Development <dbl> 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, …
## $ Department_Sales                  <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, …
## $ Education_X2                      <dbl> 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ Education_X3                      <dbl> 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, …
## $ Education_X4                      <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ Education_X5                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField_Life.Sciences      <dbl> 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, …
## $ EducationField_Marketing          <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, …
## $ EducationField_Medical            <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ EducationField_Other              <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField_Technical.Degree   <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction_X2        <dbl> 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction_X3        <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, …
## $ EnvironmentSatisfaction_X4        <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, …
## $ Gender_Male                       <dbl> 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, …
## $ JobInvolvement_X2                 <dbl> 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, …
## $ JobInvolvement_X3                 <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, …
## $ JobInvolvement_X4                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole_Human.Resources           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobRole_Laboratory.Technician     <dbl> 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, …
## $ JobRole_Manager                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole_Manufacturing.Director    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole_Research.Director         <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ JobRole_Research.Scientist        <dbl> 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ JobRole_Sales.Executive           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ JobRole_Sales.Representative      <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ JobSatisfaction_X2                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobSatisfaction_X3                <dbl> 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, …
## $ JobSatisfaction_X4                <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ MaritalStatus_Married             <dbl> 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, …
## $ MaritalStatus_Single              <dbl> 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, …
## $ OverTime_Yes                      <dbl> 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, …
## $ PerformanceRating_X4              <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ RelationshipSatisfaction_X2       <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ RelationshipSatisfaction_X3       <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ RelationshipSatisfaction_X4       <dbl> 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, …
## $ WorkLifeBalance_X2                <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ WorkLifeBalance_X3                <dbl> 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, …
## $ WorkLifeBalance_X4                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …

Specify model

xgboost_spec <- 
  boost_tree(trees = tune()) %>% 
  set_mode("classification") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_rec) %>% 
  add_model(xgboost_spec) 

Tune hyperparameters

doParallel::registerDoParallel
## function (cl, cores = NULL, ...) 
## {
##     opts <- list(...)
##     optnames <- names(opts)
##     if (is.null(optnames)) 
##         optnames <- rep("", length(opts))
##     unnamed <- !nzchar(optnames)
##     if (any(unnamed)) {
##         warning("ignoring doParallel package option(s) specified with unnamed argument")
##         opts <- opts[!unnamed]
##         optnames <- optnames[!unnamed]
##     }
##     recog <- optnames %in% c("nocompile")
##     if (any(!recog)) {
##         warning(sprintf("ignoring unrecognized doParallel package option(s): %s", 
##             paste(optnames[!recog], collapse = ", ")), call. = FALSE)
##         opts <- opts[recog]
##         optnames <- optnames[recog]
##     }
##     old.optnames <- ls(.options, all.names = TRUE)
##     rm(list = old.optnames, pos = .options)
##     for (i in seq_along(opts)) {
##         assign(optnames[i], opts[[i]], pos = .options)
##     }
##     if (missing(cl) || is.numeric(cl)) {
##         if (.Platform$OS.type == "windows") {
##             if (!missing(cl) && is.numeric(cl)) {
##                 cl <- makeCluster(cl)
##             }
##             else {
##                 if (!missing(cores) && is.numeric(cores)) {
##                   cl <- makeCluster(cores)
##                 }
##                 else {
##                   cl <- makeCluster(3)
##                 }
##             }
##             assign(".revoDoParCluster", cl, pos = .options)
##             reg.finalizer(.options, function(e) {
##                 stopImplicitCluster()
##             }, onexit = TRUE)
##             setDoPar(doParallelSNOW, cl, snowinfo)
##         }
##         else {
##             if (!missing(cl) && is.numeric(cl)) {
##                 cores <- cl
##             }
##             setDoPar(doParallelMC, cores, mcinfo)
##         }
##     }
##     else {
##         setDoPar(doParallelSNOW, cl, snowinfo)
##     }
## }
## <bytecode: 0x13c538ef8>
## <environment: namespace:doParallel>
set.seed(13500)
xgboost_tune <-
  tune_grid(xgboost_workflow,
            resamples = data_cv,
            grid = 5,
            control = control_grid(save_pred = TRUE))
## Warning: package 'xgboost' was built under R version 4.3.3

Model Evaluation

Identify optimal values for hyperparameters

collect_metrics(xgboost_tune)
## # A tibble: 15 × 7
##    trees .metric     .estimator  mean     n std_err .config             
##    <int> <chr>       <chr>      <dbl> <int>   <dbl> <chr>               
##  1   167 accuracy    binary     0.872    10 0.00917 Preprocessor1_Model1
##  2   167 brier_class binary     0.106    10 0.00724 Preprocessor1_Model1
##  3   167 roc_auc     binary     0.802    10 0.0221  Preprocessor1_Model1
##  4   515 accuracy    binary     0.867    10 0.0102  Preprocessor1_Model2
##  5   515 brier_class binary     0.108    10 0.00752 Preprocessor1_Model2
##  6   515 roc_auc     binary     0.806    10 0.0210  Preprocessor1_Model2
##  7   886 accuracy    binary     0.867    10 0.0109  Preprocessor1_Model3
##  8   886 brier_class binary     0.109    10 0.00753 Preprocessor1_Model3
##  9   886 roc_auc     binary     0.806    10 0.0210  Preprocessor1_Model3
## 10  1253 accuracy    binary     0.868    10 0.00955 Preprocessor1_Model4
## 11  1253 brier_class binary     0.109    10 0.00738 Preprocessor1_Model4
## 12  1253 roc_auc     binary     0.805    10 0.0212  Preprocessor1_Model4
## 13  1863 accuracy    binary     0.868    10 0.00955 Preprocessor1_Model5
## 14  1863 brier_class binary     0.109    10 0.00728 Preprocessor1_Model5
## 15  1863 roc_auc     binary     0.804    10 0.0211  Preprocessor1_Model5
collect_predictions(xgboost_tune) %>%
  group_by(id) %>%
  roc_curve(Attrition, .pred_Left) %>%
  autoplot()

Fit the model for the last time

xgboost_last <- xgboost_workflow %>%
  finalize_workflow(select_best(xgboost_tune, metric = "accuracy")) %>%
  last_fit(data_split)

collect_metrics(xgboost_last)
## # A tibble: 3 × 4
##   .metric     .estimator .estimate .config             
##   <chr>       <chr>          <dbl> <chr>               
## 1 accuracy    binary         0.854 Preprocessor1_Model1
## 2 roc_auc     binary         0.752 Preprocessor1_Model1
## 3 brier_class binary         0.123 Preprocessor1_Model1
collect_predictions(xgboost_last) %>%
  yardstick::conf_mat(Attrition, .pred_class) %>%
  autoplot()

Variable importance

library(vip)
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi
xgboost_last %>%
  workflows::extract_fit_engine() %>%
  vip()