library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.1     ✔ tibble    3.2.1
## ✔ lubridate 1.9.3     ✔ tidyr     1.3.1
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(correlationfunnel)
## ══ Using correlationfunnel? ════════════════════════════════════════════════════
## You might also be interested in applied data science training for business.
## </> Learn more at - www.business-science.io </>

Import data

data <- read_csv("../00_data/WA_Fn-UseC_-HR-Employee-Attrition.csv")
## Rows: 1470 Columns: 35
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (9): Attrition, BusinessTravel, Department, EducationField, Gender, Job...
## dbl (26): Age, DailyRate, DistanceFromHome, Education, EmployeeCount, Employ...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Explore data

skimr::skim(data)
Data summary
Name data
Number of rows 1470
Number of columns 35
_______________________
Column type frequency:
character 9
numeric 26
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
Attrition 0 1 2 3 0 2 0
BusinessTravel 0 1 10 17 0 3 0
Department 0 1 5 22 0 3 0
EducationField 0 1 5 16 0 6 0
Gender 0 1 4 6 0 2 0
JobRole 0 1 7 25 0 9 0
MaritalStatus 0 1 6 8 0 3 0
Over18 0 1 1 1 0 1 0
OverTime 0 1 2 3 0 2 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Age 0 1 36.92 9.14 18 30.00 36.0 43.00 60 ▂▇▇▃▂
DailyRate 0 1 802.49 403.51 102 465.00 802.0 1157.00 1499 ▇▇▇▇▇
DistanceFromHome 0 1 9.19 8.11 1 2.00 7.0 14.00 29 ▇▅▂▂▂
Education 0 1 2.91 1.02 1 2.00 3.0 4.00 5 ▂▃▇▆▁
EmployeeCount 0 1 1.00 0.00 1 1.00 1.0 1.00 1 ▁▁▇▁▁
EmployeeNumber 0 1 1024.87 602.02 1 491.25 1020.5 1555.75 2068 ▇▇▇▇▇
EnvironmentSatisfaction 0 1 2.72 1.09 1 2.00 3.0 4.00 4 ▅▅▁▇▇
HourlyRate 0 1 65.89 20.33 30 48.00 66.0 83.75 100 ▇▇▇▇▇
JobInvolvement 0 1 2.73 0.71 1 2.00 3.0 3.00 4 ▁▃▁▇▁
JobLevel 0 1 2.06 1.11 1 1.00 2.0 3.00 5 ▇▇▃▂▁
JobSatisfaction 0 1 2.73 1.10 1 2.00 3.0 4.00 4 ▅▅▁▇▇
MonthlyIncome 0 1 6502.93 4707.96 1009 2911.00 4919.0 8379.00 19999 ▇▅▂▁▂
MonthlyRate 0 1 14313.10 7117.79 2094 8047.00 14235.5 20461.50 26999 ▇▇▇▇▇
NumCompaniesWorked 0 1 2.69 2.50 0 1.00 2.0 4.00 9 ▇▃▂▂▁
PercentSalaryHike 0 1 15.21 3.66 11 12.00 14.0 18.00 25 ▇▅▃▂▁
PerformanceRating 0 1 3.15 0.36 3 3.00 3.0 3.00 4 ▇▁▁▁▂
RelationshipSatisfaction 0 1 2.71 1.08 1 2.00 3.0 4.00 4 ▅▅▁▇▇
StandardHours 0 1 80.00 0.00 80 80.00 80.0 80.00 80 ▁▁▇▁▁
StockOptionLevel 0 1 0.79 0.85 0 0.00 1.0 1.00 3 ▇▇▁▂▁
TotalWorkingYears 0 1 11.28 7.78 0 6.00 10.0 15.00 40 ▇▇▂▁▁
TrainingTimesLastYear 0 1 2.80 1.29 0 2.00 3.0 3.00 6 ▂▇▇▂▃
WorkLifeBalance 0 1 2.76 0.71 1 2.00 3.0 3.00 4 ▁▃▁▇▂
YearsAtCompany 0 1 7.01 6.13 0 3.00 5.0 9.00 40 ▇▂▁▁▁
YearsInCurrentRole 0 1 4.23 3.62 0 2.00 3.0 7.00 18 ▇▃▂▁▁
YearsSinceLastPromotion 0 1 2.19 3.22 0 0.00 1.0 3.00 15 ▇▁▁▁▁
YearsWithCurrManager 0 1 4.12 3.57 0 2.00 3.0 7.00 17 ▇▂▅▁▁

issues with data missing values factors or numeric variables Education, EnvironmentSatisfaction, JobInvolvement, JobSatisfaction, PerformanceRating, RelationshipSatisfaction, WorklifeBalance Zero Variance variables Over18, EmployCount, StandardHours Character variables Convert to numbers in recipe step Unbalanced target variables attrition id variable EmployeeNumber

factors_vec <- data %>% select( Education, EnvironmentSatisfaction, JobInvolvement, JobSatisfaction, PerformanceRating, RelationshipSatisfaction, WorkLifeBalance, JobLevel, StockOptionLevel) %>% names()

data_clean <- data %>%
    # mutate(Education = Education %>% as.factor(), EnvironmentSatisfaction = EnvironmentSatisfaction %>% as.factor, JobInvolvement = JobInvolvement %>% as.factor(), JobSatisfaction = JobSatisfaction %>% as.factor(), PerformanceRating = PerformanceRating %>% as.factor(), RelationshipSatisfaction = RelationshipSatisfaction %>% as.factor(), WorkLifeBalance = WorkLifeBalance %>% as.factor()) %>%
   # address factors imported as numeric
      mutate(across(all_of(factors_vec), as.factor)) %>%
    # drop the zero variance variables
    select(-c(Over18, EmployeeCount, StandardHours)) %>%
  
  # recode attrition
  mutate(Attrition = if_else(Attrition == "Yes", "Left", Attrition))

Explore data

data_clean %>% count(Attrition)
## # A tibble: 2 × 2
##   Attrition     n
##   <chr>     <int>
## 1 Left        237
## 2 No         1233
data_clean %>%
    ggplot(aes(Attrition)) +
    geom_bar()

attrition vs monthly income

data_clean %>%
    ggplot(aes(Attrition, MonthlyIncome)) +
    geom_boxplot()

correlation plot

# step 1: binarize
data_binarized <- data_clean %>%
    select(-EmployeeNumber) %>%
    binarize()

data_binarized %>% glimpse()
## Rows: 1,470
## Columns: 120
## $ `Age__-Inf_30`                       <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Age__30_36                           <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, …
## $ Age__36_43                           <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ Age__43_Inf                          <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ Attrition__Left                      <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ Attrition__No                        <dbl> 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `BusinessTravel__Non-Travel`         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ BusinessTravel__Travel_Frequently    <dbl> 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ BusinessTravel__Travel_Rarely        <dbl> 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, …
## $ `DailyRate__-Inf_465`                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ DailyRate__465_802                   <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ DailyRate__802_1157                  <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, …
## $ DailyRate__1157_Inf                  <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, …
## $ Department__Human_Resources          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `Department__Research_&_Development` <dbl> 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ Department__Sales                    <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `DistanceFromHome__-Inf_2`           <dbl> 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__2_7                <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ DistanceFromHome__7_14               <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__14_Inf             <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, …
## $ Education__1                         <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Education__2                         <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ Education__3                         <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, …
## $ Education__4                         <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Education__5                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Human_Resources      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Life_Sciences        <dbl> 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ EducationField__Marketing            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Medical              <dbl> 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, …
## $ EducationField__Other                <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Technical_Degree     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__1           <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ EnvironmentSatisfaction__2           <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__3           <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, …
## $ EnvironmentSatisfaction__4           <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ Gender__Female                       <dbl> 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ Gender__Male                         <dbl> 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ `HourlyRate__-Inf_48`                <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, …
## $ HourlyRate__48_66                    <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ HourlyRate__66_83.75                 <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ HourlyRate__83.75_Inf                <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ JobInvolvement__1                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobInvolvement__2                    <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobInvolvement__3                    <dbl> 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, …
## $ JobInvolvement__4                    <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, …
## $ JobLevel__1                          <dbl> 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, …
## $ JobLevel__2                          <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobLevel__3                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobLevel__4                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobLevel__5                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Healthcare_Representative   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobRole__Human_Resources             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Laboratory_Technician       <dbl> 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, …
## $ JobRole__Manager                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Manufacturing_Director      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobRole__Research_Director           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Research_Scientist          <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Executive             <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Representative        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobSatisfaction__1                   <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ JobSatisfaction__2                   <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ JobSatisfaction__3                   <dbl> 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, …
## $ JobSatisfaction__4                   <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ MaritalStatus__Divorced              <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ MaritalStatus__Married               <dbl> 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, …
## $ MaritalStatus__Single                <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, …
## $ `MonthlyIncome__-Inf_2911`           <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, …
## $ MonthlyIncome__2911_4919             <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ MonthlyIncome__4919_8379             <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ MonthlyIncome__8379_Inf              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ `MonthlyRate__-Inf_8047`             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ MonthlyRate__8047_14235.5            <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, …
## $ MonthlyRate__14235.5_20461.5         <dbl> 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, …
## $ MonthlyRate__20461.5_Inf             <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ `NumCompaniesWorked__-Inf_1`         <dbl> 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, …
## $ NumCompaniesWorked__1_2              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ NumCompaniesWorked__2_4              <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ NumCompaniesWorked__4_Inf            <dbl> 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, …
## $ OverTime__No                         <dbl> 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ OverTime__Yes                        <dbl> 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ `PercentSalaryHike__-Inf_12`         <dbl> 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__12_14             <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, …
## $ PercentSalaryHike__14_18             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__18_Inf            <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ PerformanceRating__3                 <dbl> 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, …
## $ PerformanceRating__4                 <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ RelationshipSatisfaction__1          <dbl> 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ RelationshipSatisfaction__2          <dbl> 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, …
## $ RelationshipSatisfaction__3          <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, …
## $ RelationshipSatisfaction__4          <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ StockOptionLevel__0                  <dbl> 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ StockOptionLevel__1                  <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ StockOptionLevel__2                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ StockOptionLevel__3                  <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ `TotalWorkingYears__-Inf_6`          <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ TotalWorkingYears__6_10              <dbl> 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ TotalWorkingYears__10_15             <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ TotalWorkingYears__15_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `TrainingTimesLastYear__-Inf_2`      <dbl> 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, …
## $ TrainingTimesLastYear__2_3           <dbl> 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, …
## $ TrainingTimesLastYear__3_Inf         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ WorkLifeBalance__1                   <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ WorkLifeBalance__2                   <dbl> 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, …
## $ WorkLifeBalance__3                   <dbl> 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, …
## $ WorkLifeBalance__4                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsAtCompany__-Inf_3`             <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsAtCompany__3_5                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsAtCompany__5_9                  <dbl> 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, …
## $ YearsAtCompany__9_Inf                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsInCurrentRole__-Inf_2`         <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsInCurrentRole__2_3              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ YearsInCurrentRole__3_7              <dbl> 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, …
## $ YearsInCurrentRole__7_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsSinceLastPromotion__-Inf_1`    <dbl> 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, …
## $ YearsSinceLastPromotion__1_3         <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ YearsSinceLastPromotion__3_Inf       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `YearsWithCurrManager__-Inf_2`       <dbl> 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsWithCurrManager__2_3            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsWithCurrManager__3_7            <dbl> 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, …
## $ YearsWithCurrManager__7_Inf          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
# step 2: correlation
data_correlation <- data_binarized %>%
    correlate(Attrition__Left)

data_correlation
## # A tibble: 120 × 3
##    feature           bin       correlation
##    <fct>             <chr>           <dbl>
##  1 Attrition         Left            1    
##  2 Attrition         No             -1    
##  3 OverTime          Yes             0.246
##  4 OverTime          No             -0.246
##  5 JobLevel          1               0.213
##  6 MonthlyIncome     -Inf_2911       0.207
##  7 StockOptionLevel  0               0.195
##  8 YearsAtCompany    -Inf_3          0.183
##  9 MaritalStatus     Single          0.175
## 10 TotalWorkingYears -Inf_6          0.169
## # ℹ 110 more rows
data_correlation %>%
    correlationfunnel::plot_correlation_funnel()
## Warning: ggrepel: 72 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Model Building

Split data

library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom        1.0.5     ✔ rsample      1.2.1
## ✔ dials        1.2.1     ✔ tune         1.2.1
## ✔ infer        1.0.7     ✔ workflows    1.1.4
## ✔ modeldata    1.4.0     ✔ workflowsets 1.1.0
## ✔ parsnip      1.2.1     ✔ yardstick    1.3.1
## ✔ recipes      1.1.0
## Warning: package 'modeldata' was built under R version 4.3.3
## Warning: package 'recipes' was built under R version 4.3.3
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
## • Use suppressPackageStartupMessages() to eliminate package startup messages
# set.seed(1234)
# data_clean <- data_clean %>% sample_n(100)

data_split <- initial_split(data_clean, strata = Attrition)
data_train <- training(data_split)
data_test <- testing(data_split)


data_cv <- rsample::vfold_cv(data_train, strata = Attrition)

data_cv
## #  10-fold cross-validation using stratification 
## # A tibble: 10 × 2
##    splits            id    
##    <list>            <chr> 
##  1 <split [990/111]> Fold01
##  2 <split [990/111]> Fold02
##  3 <split [990/111]> Fold03
##  4 <split [990/111]> Fold04
##  5 <split [991/110]> Fold05
##  6 <split [991/110]> Fold06
##  7 <split [991/110]> Fold07
##  8 <split [992/109]> Fold08
##  9 <split [992/109]> Fold09
## 10 <split [992/109]> Fold10

Preprocess data

library(themis)

xgboost_rec <- recipes::recipe(Attrition ~ ., data = data_train) %>%
    update_role(EmployeeNumber, new_role = "ID") %>%
    step_dummy(all_nominal_predictors()) %>%
  step_YeoJohnson(DistanceFromHome, MonthlyIncome, NumCompaniesWorked, PercentSalaryHike, TotalWorkingYears, starts_with("Years")) %>%
  step_normalize(all_numeric_predictors()) %>%
  step_smote(Attrition)

xgboost_rec %>% prep() %>% juice() %>% glimpse()
## Rows: 1,848
## Columns: 64
## $ Age                               <dbl> -0.96288315, -0.08301234, -0.3029800…
## $ DailyRate                         <dbl> -1.73653458, 1.04853941, -0.24783135…
## $ DistanceFromHome                  <dbl> 1.46553946, 0.36168312, -0.05930279,…
## $ EmployeeNumber                    <dbl> 19, 27, 31, 33, 42, 45, 47, 55, 58, …
## $ HourlyRate                        <dbl> -0.77377444, 0.78849611, 0.83731707,…
## $ MonthlyIncome                     <dbl> -1.48144316, -0.58066180, -0.8158669…
## $ MonthlyRate                       <dbl> -0.18890361, -1.02073577, 0.39090894…
## $ NumCompaniesWorked                <dbl> 1.08105974, 1.48217820, 0.07118277, …
## $ PercentSalaryHike                 <dbl> -0.0902833, 1.6669571, -1.4764400, 1…
## $ TotalWorkingYears                 <dbl> -0.58179808, 0.06140606, -0.23531869…
## $ TrainingTimesLastYear             <dbl> 0.9165158, 0.9165158, -0.6407124, 1.…
## $ YearsAtCompany                    <dbl> -0.30625938, -0.06705598, -0.3062593…
## $ YearsInCurrentRole                <dbl> -0.45417380, -0.09863849, -0.4541738…
## $ YearsSinceLastPromotion           <dbl> -1.09508987, -1.09508987, 0.07678726…
## $ YearsWithCurrManager              <dbl> -0.0593181, -0.0593181, -0.0593181, …
## $ Attrition                         <fct> Left, Left, Left, Left, Left, Left, …
## $ BusinessTravel_Travel_Frequently  <dbl> -0.4637201, -0.4637201, -0.4637201, …
## $ BusinessTravel_Travel_Rarely      <dbl> 0.617320, 0.617320, 0.617320, -1.618…
## $ Department_Research...Development <dbl> 0.7256494, -1.3768243, 0.7256494, 0.…
## $ Department_Sales                  <dbl> -0.655347, 1.524523, -0.655347, -0.6…
## $ Education_X2                      <dbl> -0.4881123, -0.4881123, -0.4881123, …
## $ Education_X3                      <dbl> 1.2606124, -0.7925448, -0.7925448, -…
## $ Education_X4                      <dbl> -0.6075092, 1.6445705, -0.6075092, -…
## $ Education_X5                      <dbl> -0.1889852, -0.1889852, -0.1889852, …
## $ EducationField_Life.Sciences      <dbl> 1.1887762, 1.1887762, -0.8404372, 1.…
## $ EducationField_Marketing          <dbl> -0.3446657, -0.3446657, -0.3446657, …
## $ EducationField_Medical            <dbl> -0.6652713, -0.6652713, 1.5017810, -…
## $ EducationField_Other              <dbl> -0.2503684, -0.2503684, -0.2503684, …
## $ EducationField_Technical.Degree   <dbl> -0.3245396, -0.3245396, -0.3245396, …
## $ EnvironmentSatisfaction_X2        <dbl> -0.4994892, -0.4994892, 2.0002271, 2…
## $ EnvironmentSatisfaction_X3        <dbl> 1.4985813, 1.4985813, -0.6666917, -0…
## $ EnvironmentSatisfaction_X4        <dbl> -0.6652713, -0.6652713, -0.6652713, …
## $ Gender_Male                       <dbl> 0.8139654, 0.8139654, 0.8139654, -1.…
## $ JobInvolvement_X2                 <dbl> 1.7239569, 1.7239569, -0.5795341, -0…
## $ JobInvolvement_X3                 <dbl> -1.2297657, -1.2297657, 0.8124244, -…
## $ JobInvolvement_X4                 <dbl> -0.3228284, -0.3228284, -0.3228284, …
## $ JobLevel_X2                       <dbl> -0.7639647, -0.7639647, -0.7639647, …
## $ JobLevel_X3                       <dbl> -0.4091432, -0.4091432, -0.4091432, …
## $ JobLevel_X4                       <dbl> -0.2721746, -0.2721746, -0.2721746, …
## $ JobLevel_X5                       <dbl> -0.2202892, -0.2202892, -0.2202892, …
## $ JobRole_Human.Resources           <dbl> -0.1990579, -0.1990579, -0.1990579, …
## $ JobRole_Laboratory.Technician     <dbl> 2.1748735, -0.4593792, -0.4593792, -…
## $ JobRole_Manager                   <dbl> -0.2740931, -0.2740931, -0.2740931, …
## $ JobRole_Manufacturing.Director    <dbl> -0.3363671, -0.3363671, -0.3363671, …
## $ JobRole_Research.Director         <dbl> -0.218015, -0.218015, -0.218015, -0.…
## $ JobRole_Research.Scientist        <dbl> -0.5023249, -0.5023249, 1.9889352, 1…
## $ JobRole_Sales.Executive           <dbl> -0.5277239, -0.5277239, -0.5277239, …
## $ JobRole_Sales.Representative      <dbl> -0.2483152, 4.0234821, -0.2483152, -…
## $ JobSatisfaction_X2                <dbl> -0.4881123, -0.4881123, -0.4881123, …
## $ JobSatisfaction_X3                <dbl> 1.5478921, -0.6454531, -0.6454531, -…
## $ JobSatisfaction_X4                <dbl> -0.6723805, -0.6723805, -0.6723805, …
## $ MaritalStatus_Married             <dbl> -0.8967647, -0.8967647, -0.8967647, …
## $ MaritalStatus_Single              <dbl> 1.4340376, 1.4340376, 1.4340376, 1.4…
## $ OverTime_Yes                      <dbl> 1.568448, -0.636994, -0.636994, 1.56…
## $ PerformanceRating_X4              <dbl> -0.4241471, 2.3555312, -0.4241471, 2…
## $ RelationshipSatisfaction_X2       <dbl> 1.9136157, 1.9136157, -0.5220963, 1.…
## $ RelationshipSatisfaction_X3       <dbl> -0.6766549, -0.6766549, 1.4765158, -…
## $ RelationshipSatisfaction_X4       <dbl> -0.6341783, -0.6341783, -0.6341783, …
## $ StockOptionLevel_X1               <dbl> -0.8263567, -0.8263567, -0.8263567, …
## $ StockOptionLevel_X2               <dbl> -0.3397008, -0.3397008, -0.3397008, …
## $ StockOptionLevel_X3               <dbl> -0.2503684, -0.2503684, -0.2503684, …
## $ WorkLifeBalance_X2                <dbl> -0.5613606, -0.5613606, -0.5613606, …
## $ WorkLifeBalance_X3                <dbl> 0.8124244, 0.8124244, 0.8124244, 0.8…
## $ WorkLifeBalance_X4                <dbl> -0.334693, -0.334693, -0.334693, -0.…

Specify model

xgboost_spec <- 
  boost_tree(trees = tune(), tree_depth = tune(), mtry = tune(), learn_rate = tune()) %>% 
  set_mode("classification") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_rec) %>% 
  add_model(xgboost_spec) 

Tune hyperparameters

tree_grid <- grid_regular(trees(),
                          tree_depth(),
                          levels = 5)

doParallel::registerDoParallel()

set.seed(13500)
xgboost_tune <-
  tune_grid(xgboost_workflow,
            resamples = data_cv,
            grid = 5,
            control = control_grid(save_pred = TRUE))
## i Creating pre-processing data to finalize unknown parameter: mtry
## Warning: package 'xgboost' was built under R version 4.3.3

Model Evaluation

Identify optimal values for hyperparameters

collect_metrics(xgboost_tune)
## # A tibble: 15 × 10
##     mtry trees tree_depth learn_rate .metric     .estimator  mean     n std_err
##    <int> <int>      <int>      <dbl> <chr>       <chr>      <dbl> <int>   <dbl>
##  1     3  1777         11    0.00371 accuracy    binary     0.863    10 0.00552
##  2     3  1777         11    0.00371 brier_class binary     0.100    10 0.00318
##  3     3  1777         11    0.00371 roc_auc     binary     0.813    10 0.0147 
##  4    18   892          4    0.128   accuracy    binary     0.865    10 0.00418
##  5    18   892          4    0.128   brier_class binary     0.114    10 0.00429
##  6    18   892          4    0.128   roc_auc     binary     0.792    10 0.0144 
##  7    33  1557          3    0.0230  accuracy    binary     0.860    10 0.00649
##  8    33  1557          3    0.0230  brier_class binary     0.107    10 0.00476
##  9    33  1557          3    0.0230  roc_auc     binary     0.797    10 0.0186 
## 10    40   244          8    0.0727  accuracy    binary     0.865    10 0.00458
## 11    40   244          8    0.0727  brier_class binary     0.110    10 0.00369
## 12    40   244          8    0.0727  roc_auc     binary     0.792    10 0.00848
## 13    53   633         14    0.00265 accuracy    binary     0.847    10 0.00834
## 14    53   633         14    0.00265 brier_class binary     0.122    10 0.00396
## 15    53   633         14    0.00265 roc_auc     binary     0.744    10 0.0192 
## # ℹ 1 more variable: .config <chr>
collect_predictions(xgboost_tune) %>%
  group_by(id) %>%
  roc_curve(Attrition, .pred_Left) %>%
  autoplot()

Fit the model for the last time

xgboost_last <- xgboost_workflow %>%
  finalize_workflow(select_best(xgboost_tune, metric = "accuracy")) %>%
  last_fit(data_split)

collect_metrics(xgboost_last)
## # A tibble: 3 × 4
##   .metric     .estimator .estimate .config             
##   <chr>       <chr>          <dbl> <chr>               
## 1 accuracy    binary        0.892  Preprocessor1_Model1
## 2 roc_auc     binary        0.798  Preprocessor1_Model1
## 3 brier_class binary        0.0977 Preprocessor1_Model1
collect_predictions(xgboost_last) %>%
  yardstick::conf_mat(Attrition, .pred_class) %>%
  autoplot()

Variable importance

library(vip)
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi
xgboost_last %>%
  workflows::extract_fit_engine() %>%
  vip()

Conclusion

previous model had accuaracy: 0.815 and auc: 0.890