knitr::opts_chunk$set(echo = TRUE, message = FALSE)

Goal is to predict attrition, employees who are likely to leave the company.

Import data

library(tidyverse)

data <- read.csv("C:/Users/tch30/Desktop/PSU_DAT3100/00_data/WA_Fn-UseC_-HR-Employee-Attrition.csv")

Explore data

skimr::skim(data)
Data summary
Name data
Number of rows 1470
Number of columns 35
_______________________
Column type frequency:
character 9
numeric 26
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
Attrition 0 1 2 3 0 2 0
BusinessTravel 0 1 10 17 0 3 0
Department 0 1 5 22 0 3 0
EducationField 0 1 5 16 0 6 0
Gender 0 1 4 6 0 2 0
JobRole 0 1 7 25 0 9 0
MaritalStatus 0 1 6 8 0 3 0
Over18 0 1 1 1 0 1 0
OverTime 0 1 2 3 0 2 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Age 0 1 36.92 9.14 18 30.00 36.0 43.00 60 ▂▇▇▃▂
DailyRate 0 1 802.49 403.51 102 465.00 802.0 1157.00 1499 ▇▇▇▇▇
DistanceFromHome 0 1 9.19 8.11 1 2.00 7.0 14.00 29 ▇▅▂▂▂
Education 0 1 2.91 1.02 1 2.00 3.0 4.00 5 ▂▃▇▆▁
EmployeeCount 0 1 1.00 0.00 1 1.00 1.0 1.00 1 ▁▁▇▁▁
EmployeeNumber 0 1 1024.87 602.02 1 491.25 1020.5 1555.75 2068 ▇▇▇▇▇
EnvironmentSatisfaction 0 1 2.72 1.09 1 2.00 3.0 4.00 4 ▅▅▁▇▇
HourlyRate 0 1 65.89 20.33 30 48.00 66.0 83.75 100 ▇▇▇▇▇
JobInvolvement 0 1 2.73 0.71 1 2.00 3.0 3.00 4 ▁▃▁▇▁
JobLevel 0 1 2.06 1.11 1 1.00 2.0 3.00 5 ▇▇▃▂▁
JobSatisfaction 0 1 2.73 1.10 1 2.00 3.0 4.00 4 ▅▅▁▇▇
MonthlyIncome 0 1 6502.93 4707.96 1009 2911.00 4919.0 8379.00 19999 ▇▅▂▁▂
MonthlyRate 0 1 14313.10 7117.79 2094 8047.00 14235.5 20461.50 26999 ▇▇▇▇▇
NumCompaniesWorked 0 1 2.69 2.50 0 1.00 2.0 4.00 9 ▇▃▂▂▁
PercentSalaryHike 0 1 15.21 3.66 11 12.00 14.0 18.00 25 ▇▅▃▂▁
PerformanceRating 0 1 3.15 0.36 3 3.00 3.0 3.00 4 ▇▁▁▁▂
RelationshipSatisfaction 0 1 2.71 1.08 1 2.00 3.0 4.00 4 ▅▅▁▇▇
StandardHours 0 1 80.00 0.00 80 80.00 80.0 80.00 80 ▁▁▇▁▁
StockOptionLevel 0 1 0.79 0.85 0 0.00 1.0 1.00 3 ▇▇▁▂▁
TotalWorkingYears 0 1 11.28 7.78 0 6.00 10.0 15.00 40 ▇▇▂▁▁
TrainingTimesLastYear 0 1 2.80 1.29 0 2.00 3.0 3.00 6 ▂▇▇▂▃
WorkLifeBalance 0 1 2.76 0.71 1 2.00 3.0 3.00 4 ▁▃▁▇▂
YearsAtCompany 0 1 7.01 6.13 0 3.00 5.0 9.00 40 ▇▂▁▁▁
YearsInCurrentRole 0 1 4.23 3.62 0 2.00 3.0 7.00 18 ▇▃▂▁▁
YearsSinceLastPromotion 0 1 2.19 3.22 0 0.00 1.0 3.00 15 ▇▁▁▁▁
YearsWithCurrManager 0 1 4.12 3.57 0 2.00 3.0 7.00 17 ▇▂▅▁▁
factors_vec <- data %>% select(Education,EnvironmentSatisfaction, JobInvolvement, PerformanceRating, RelationshipSatisfaction, WorkLifeBalance, JobSatisfaction) %>% names()

data_clean <- data %>%
    
    # Address factors imported as numeric
    mutate(across(all_of(factors_vec), as.factor)) %>%
    
    #Drop zero-variance variables
    select(-c(Over18,EmployeeCount,StandardHours)) %>%
    
    # Recode Attrition
    mutate(Attrition = if_else(Attrition == "Yes", "Left", Attrition))

Explore data

data_clean %>% count(Attrition)
##   Attrition    n
## 1      Left  237
## 2        No 1233
data_clean %>%
    ggplot(aes(Attrition)) +
    geom_bar()

attrition vs. monthly income

data_clean %>%
    ggplot(aes(Attrition, MonthlyIncome)) +
    geom_boxplot()

Correlation Funnel

library(correlationfunnel)
library(dplyr)

data_binarized <- data_clean %>%
    select(-EmployeeNumber) %>%
    binarize()

data_binarized %>% glimpse()
## Rows: 1,470
## Columns: 120
## $ `Age__-Inf_30`                       <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Age__30_36                           <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, …
## $ Age__36_43                           <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ Age__43_Inf                          <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ Attrition__Left                      <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ Attrition__No                        <dbl> 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `BusinessTravel__Non-Travel`         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ BusinessTravel__Travel_Frequently    <dbl> 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ BusinessTravel__Travel_Rarely        <dbl> 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, …
## $ `DailyRate__-Inf_465`                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ DailyRate__465_802                   <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ DailyRate__802_1157                  <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, …
## $ DailyRate__1157_Inf                  <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, …
## $ Department__Human_Resources          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `Department__Research_&_Development` <dbl> 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ Department__Sales                    <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `DistanceFromHome__-Inf_2`           <dbl> 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__2_7                <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ DistanceFromHome__7_14               <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__14_Inf             <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, …
## $ Education__1                         <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Education__2                         <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ Education__3                         <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, …
## $ Education__4                         <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Education__5                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Human_Resources      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Life_Sciences        <dbl> 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ EducationField__Marketing            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Medical              <dbl> 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, …
## $ EducationField__Other                <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Technical_Degree     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__1           <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ EnvironmentSatisfaction__2           <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__3           <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, …
## $ EnvironmentSatisfaction__4           <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ Gender__Female                       <dbl> 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ Gender__Male                         <dbl> 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ `HourlyRate__-Inf_48`                <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, …
## $ HourlyRate__48_66                    <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ HourlyRate__66_83.75                 <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ HourlyRate__83.75_Inf                <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ JobInvolvement__1                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobInvolvement__2                    <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobInvolvement__3                    <dbl> 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, …
## $ JobInvolvement__4                    <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, …
## $ JobLevel__1                          <dbl> 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, …
## $ JobLevel__2                          <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobLevel__3                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobLevel__4                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobLevel__5                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Healthcare_Representative   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobRole__Human_Resources             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Laboratory_Technician       <dbl> 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, …
## $ JobRole__Manager                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Manufacturing_Director      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobRole__Research_Director           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Research_Scientist          <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Executive             <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Representative        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobSatisfaction__1                   <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ JobSatisfaction__2                   <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ JobSatisfaction__3                   <dbl> 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, …
## $ JobSatisfaction__4                   <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ MaritalStatus__Divorced              <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ MaritalStatus__Married               <dbl> 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, …
## $ MaritalStatus__Single                <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, …
## $ `MonthlyIncome__-Inf_2911`           <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, …
## $ MonthlyIncome__2911_4919             <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ MonthlyIncome__4919_8379             <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ MonthlyIncome__8379_Inf              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ `MonthlyRate__-Inf_8047`             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ MonthlyRate__8047_14235.5            <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, …
## $ MonthlyRate__14235.5_20461.5         <dbl> 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, …
## $ MonthlyRate__20461.5_Inf             <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ `NumCompaniesWorked__-Inf_1`         <dbl> 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, …
## $ NumCompaniesWorked__1_2              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ NumCompaniesWorked__2_4              <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ NumCompaniesWorked__4_Inf            <dbl> 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, …
## $ OverTime__No                         <dbl> 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ OverTime__Yes                        <dbl> 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ `PercentSalaryHike__-Inf_12`         <dbl> 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__12_14             <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, …
## $ PercentSalaryHike__14_18             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__18_Inf            <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ PerformanceRating__3                 <dbl> 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, …
## $ PerformanceRating__4                 <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ RelationshipSatisfaction__1          <dbl> 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ RelationshipSatisfaction__2          <dbl> 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, …
## $ RelationshipSatisfaction__3          <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, …
## $ RelationshipSatisfaction__4          <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ StockOptionLevel__0                  <dbl> 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ StockOptionLevel__1                  <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ StockOptionLevel__2                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ StockOptionLevel__3                  <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ `TotalWorkingYears__-Inf_6`          <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ TotalWorkingYears__6_10              <dbl> 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ TotalWorkingYears__10_15             <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ TotalWorkingYears__15_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `TrainingTimesLastYear__-Inf_2`      <dbl> 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, …
## $ TrainingTimesLastYear__2_3           <dbl> 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, …
## $ TrainingTimesLastYear__3_Inf         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ WorkLifeBalance__1                   <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ WorkLifeBalance__2                   <dbl> 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, …
## $ WorkLifeBalance__3                   <dbl> 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, …
## $ WorkLifeBalance__4                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsAtCompany__-Inf_3`             <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsAtCompany__3_5                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsAtCompany__5_9                  <dbl> 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, …
## $ YearsAtCompany__9_Inf                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsInCurrentRole__-Inf_2`         <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsInCurrentRole__2_3              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ YearsInCurrentRole__3_7              <dbl> 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, …
## $ YearsInCurrentRole__7_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsSinceLastPromotion__-Inf_1`    <dbl> 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, …
## $ YearsSinceLastPromotion__1_3         <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ YearsSinceLastPromotion__3_Inf       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `YearsWithCurrManager__-Inf_2`       <dbl> 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsWithCurrManager__2_3            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsWithCurrManager__3_7            <dbl> 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, …
## $ YearsWithCurrManager__7_Inf          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
#step 2: correlation
data_correlation <- data_binarized %>%
    correlate(Attrition__Left)

data_correlation
## # A tibble: 120 × 3
##    feature           bin       correlation
##    <fct>             <chr>           <dbl>
##  1 Attrition         Left            1    
##  2 Attrition         No             -1    
##  3 OverTime          No             -0.246
##  4 OverTime          Yes             0.246
##  5 JobLevel          1               0.213
##  6 MonthlyIncome     -Inf_2911       0.207
##  7 StockOptionLevel  0               0.195
##  8 YearsAtCompany    -Inf_3          0.183
##  9 MaritalStatus     Single          0.175
## 10 TotalWorkingYears -Inf_6          0.169
## # ℹ 110 more rows
# step 3: plot
data_correlation %>%
    correlationfunnel:: plot_correlation_funnel()
## Warning: ggrepel: 73 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Split data

library(tidymodels)
## Warning: package 'tidymodels' was built under R version 4.4.3
## Warning: package 'broom' was built under R version 4.4.3
## Warning: package 'parsnip' was built under R version 4.4.3
## Warning: package 'recipes' was built under R version 4.4.3
## Warning: package 'rsample' was built under R version 4.4.3
## Warning: package 'yardstick' was built under R version 4.4.3
set.seed(1235)
data <- data %>% sample_n(100)

data_split <- initial_split(data, strata = Attrition)
data_train <- training(data_split)
data_test <- testing(data_split)

data_cv <- rsample::vfold_cv(data_train, strata = Attrition)
data_cv
## #  10-fold cross-validation using stratification 
## # A tibble: 10 × 2
##    splits         id    
##    <list>         <chr> 
##  1 <split [66/9]> Fold01
##  2 <split [66/9]> Fold02
##  3 <split [67/8]> Fold03
##  4 <split [68/7]> Fold04
##  5 <split [68/7]> Fold05
##  6 <split [68/7]> Fold06
##  7 <split [68/7]> Fold07
##  8 <split [68/7]> Fold08
##  9 <split [68/7]> Fold09
## 10 <split [68/7]> Fold10

Preprocess data

library(themis)
## Warning: package 'themis' was built under R version 4.4.3
data_rec <- recipes::recipe(Attrition ~ ., data = data_train) %>%
    update_role(EmployeeNumber, new_role = "ID") %>%
    step_rm(Over18) %>%  # Remove column with only one level
    step_dummy(all_nominal_predictors()) %>%
    step_YeoJohnson(DistanceFromHome, MonthlyIncome, NumCompaniesWorked, PercentSalaryHike, TotalWorkingYears, starts_with("Years")) %>%
    step_normalize(all_numeric_predictors()) %>%
    step_pca(all_numeric_predictors(), threshold = 0.99) %>%
    step_smote(Attrition)

data_rec %>% prep() %>% juice() %>% glimpse()
## Warning: !  The following columns have zero variance so scaling cannot be used:
##   EmployeeCount and StandardHours.
## ℹ Consider using ?step_zv (`?recipes::step_zv()`) to remove those columns
##   before normalizing.
## Rows: 126
## Columns: 37
## $ EmployeeNumber <int> 495, 1601, 1641, 1055, 1753, 369, 11, 273, 1597, 1598, …
## $ Attrition      <fct> No, No, No, No, No, No, No, No, No, No, No, No, No, No,…
## $ PC01           <dbl> -3.034029407, -4.130699378, 0.492587693, -3.723703038, …
## $ PC02           <dbl> -1.1912994, -0.5664410, -2.2880007, 2.3561060, -0.16946…
## $ PC03           <dbl> -1.351268944, 0.608494938, -0.986935320, 2.807425438, -…
## $ PC04           <dbl> 0.5735383, -0.8322231, 0.5145759, -0.3871477, 1.0583880…
## $ PC05           <dbl> 0.73028382, -1.57129530, 2.40372912, 2.89641232, 2.8478…
## $ PC06           <dbl> -1.57104934, 0.35728435, -1.41383257, -1.17299261, 0.32…
## $ PC07           <dbl> -0.178665988, -0.004827665, -1.670661664, -1.184322815,…
## $ PC08           <dbl> -1.02463989, -0.86406683, -0.64515896, 0.23310583, -1.4…
## $ PC09           <dbl> -0.21735735, 0.21538928, 2.02087644, -0.03812573, -0.15…
## $ PC10           <dbl> 1.433461687, -0.743611489, -0.599106482, -1.902127312, …
## $ PC11           <dbl> 0.12801709, -0.57692964, 0.08555616, 1.92050619, 0.4871…
## $ PC12           <dbl> -0.1104960, -0.6756634, 0.2063425, 1.0534276, -1.115190…
## $ PC13           <dbl> 0.17484519, -0.69566304, -0.29937448, -1.93618017, -0.7…
## $ PC14           <dbl> 1.63430850, 0.90806925, -0.83671619, -0.02438546, -1.13…
## $ PC15           <dbl> -1.272775916, -0.683255747, -0.376926049, 0.259538093, …
## $ PC16           <dbl> 0.75609056, 1.17169373, 0.65047360, -2.00827769, -0.442…
## $ PC17           <dbl> -1.110855628, 1.577984989, -0.635898678, -0.032008839, …
## $ PC18           <dbl> 0.06164218, -0.42807403, 0.25600731, -1.29462238, -0.50…
## $ PC19           <dbl> -0.26973514, 0.95640555, 0.03718603, 1.15419936, 1.0313…
## $ PC20           <dbl> 0.46254535, -0.28111417, -1.20285224, 0.79045956, 0.477…
## $ PC21           <dbl> 0.78931833, -0.01122378, -0.28765802, -0.41828970, 0.42…
## $ PC22           <dbl> -0.32335543, 1.24695882, 0.70471929, 0.83311108, 0.2464…
## $ PC23           <dbl> -0.06194641, -0.32724733, -1.60418143, -0.25572240, 0.3…
## $ PC24           <dbl> 1.24600906, -1.33562468, 1.09597912, 0.58903333, 0.0960…
## $ PC25           <dbl> -0.64852077, -1.05983939, 0.45608021, -0.27719163, -0.3…
## $ PC26           <dbl> -0.17053145, -0.66421551, -0.21847890, -0.89666253, 0.7…
## $ PC27           <dbl> -0.06536499, 0.62183152, -0.69101245, 0.48588291, -0.89…
## $ PC28           <dbl> 1.28111700, 0.80648166, -1.08116001, 0.28330247, -0.566…
## $ PC29           <dbl> -0.17022877, -0.05725846, -0.99373957, 0.42876544, -1.1…
## $ PC30           <dbl> 0.05490964, 0.44001780, 0.26199435, 0.34958966, 0.78315…
## $ PC31           <dbl> 0.560831196, -0.269602967, -0.465919408, -0.097438715, …
## $ PC32           <dbl> -0.41828900, -0.22905176, 0.36434590, -0.34262916, -0.2…
## $ PC33           <dbl> 0.37767841, 0.48326824, 0.30311051, 0.73306492, 0.26772…
## $ PC34           <dbl> -0.74753989, -0.12977548, 0.12101925, 0.23467689, -0.55…
## $ PC35           <dbl> 0.01300861, 0.18138686, -0.42779136, 0.17861471, -0.280…

Specify Model

library(usemodels)
## Warning: package 'usemodels' was built under R version 4.4.3
usemodels::use_xgboost(Attrition ~ ., data = data_train)
## xgboost_recipe <- 
##   recipe(formula = Attrition ~ ., data = data_train) %>% 
##   step_zv(all_predictors()) 
## 
## xgboost_spec <- 
##   boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
##     loss_reduction = tune(), sample_size = tune()) %>% 
##   set_mode("classification") %>% 
##   set_engine("xgboost") 
## 
## xgboost_workflow <- 
##   workflow() %>% 
##   add_recipe(xgboost_recipe) %>% 
##   add_model(xgboost_spec) 
## 
## set.seed(78430)
## xgboost_tune <-
##   tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
xgboost_rec <- recipe(Attrition ~ ., data = data_train) %>%
    update_role(EmployeeNumber, new_role = "ID") %>%
    step_rm(Over18) %>%  # Remove columns with only one level
    step_dummy(all_nominal_predictors()) %>%
    step_YeoJohnson(DistanceFromHome, MonthlyIncome, NumCompaniesWorked, PercentSalaryHike, TotalWorkingYears, starts_with("Years")) %>%
    step_normalize(all_numeric_predictors())%>%
    step_smote(Attrition)
xgboost_spec <-
    boost_tree(trees = tune(), tree_depth = tune()) %>%
    set_mode("classification") %>%
    set_engine("xgboost")

xgboost_workflow <-
    workflow() %>%
    add_recipe(xgboost_rec) %>%  # Use the newly defined recipe
    add_model(xgboost_spec)

Tune Hyperparameters

library(doParallel)
## Warning: package 'doParallel' was built under R version 4.4.3
## Warning: package 'foreach' was built under R version 4.4.3
tree_grid <- grid_regular(trees(),
                          tree_depth(),
                          levels = 5)

doParallel::registerDoParallel()

set.seed(65743)
xgboost_tune <-
    tune_grid(xgboost_workflow,
              resamples = data_cv,
              grid = 5,
              control = control_grid(save_pred = TRUE))
## Warning: ! tune detected a parallel backend registered with foreach but no backend
##   registered with future.
## ℹ Support for parallel processing with foreach was soft-deprecated in tune
##   1.2.1.
## ℹ See ?parallelism (`?tune::parallelism()`) to learn more.

Model evaluation

Identify optimal values for hyperparameters

collect_metrics(xgboost_tune)
## # A tibble: 15 × 8
##    trees tree_depth .metric     .estimator  mean     n std_err .config          
##    <int>      <int> <chr>       <chr>      <dbl> <int>   <dbl> <chr>            
##  1  1000          1 accuracy    binary     0.761    10  0.0385 Preprocessor1_Mo…
##  2  1000          1 brier_class binary     0.195    10  0.0273 Preprocessor1_Mo…
##  3  1000          1 roc_auc     binary     0.545    10  0.0951 Preprocessor1_Mo…
##  4     1          4 accuracy    binary     0.743    10  0.0406 Preprocessor1_Mo…
##  5     1          4 brier_class binary     0.215    10  0.0120 Preprocessor1_Mo…
##  6     1          4 roc_auc     binary     0.540    10  0.103  Preprocessor1_Mo…
##  7  2000          8 accuracy    binary     0.718    10  0.0386 Preprocessor1_Mo…
##  8  2000          8 brier_class binary     0.185    10  0.0232 Preprocessor1_Mo…
##  9  2000          8 roc_auc     binary     0.538    10  0.102  Preprocessor1_Mo…
## 10   500         11 accuracy    binary     0.718    10  0.0386 Preprocessor1_Mo…
## 11   500         11 brier_class binary     0.185    10  0.0232 Preprocessor1_Mo…
## 12   500         11 roc_auc     binary     0.538    10  0.102  Preprocessor1_Mo…
## 13  1500         15 accuracy    binary     0.718    10  0.0386 Preprocessor1_Mo…
## 14  1500         15 brier_class binary     0.185    10  0.0232 Preprocessor1_Mo…
## 15  1500         15 roc_auc     binary     0.538    10  0.102  Preprocessor1_Mo…
collect_predictions(xgboost_tune) %>%
    group_by(id) %>%
    roc_curve(Attrition, .pred_Yes, event_level = "second") %>%
    autoplot()

##Fit the model for the last time

library(yardstick)
xgboost_last <- xgboost_workflow %>%
    finalize_workflow(select_best(xgboost_tune)) %>%
    last_fit(data_split)
## Warning in select_best(xgboost_tune): No value of `metric` was given; "roc_auc"
## will be used.
## Warning: package 'xgboost' was built under R version 4.4.3
collect_metrics(xgboost_last)
## # A tibble: 3 × 4
##   .metric     .estimator .estimate .config             
##   <chr>       <chr>          <dbl> <chr>               
## 1 accuracy    binary         0.8   Preprocessor1_Model1
## 2 roc_auc     binary         0.643 Preprocessor1_Model1
## 3 brier_class binary         0.169 Preprocessor1_Model1
collect_predictions(xgboost_last) %>%
    yardstick::conf_mat(Attrition, .pred_class) %>%
    autoplot()

Variable importance

library(vip)
## Warning: package 'vip' was built under R version 4.4.3
xgboost_last %>%
    workflows::extract_fit_engine() %>%
    vip()

Conclusion

The previous model had an accuracy of 0.851 and AUC of 0.753.