Goal is to predict attrition, employees who are likely to leave the company.

Import Data

data <- read_csv("../00_data/WA-HR-Employee-Attrition.csv")
## Rows: 1470 Columns: 35
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (9): Attrition, BusinessTravel, Department, EducationField, Gender, Job...
## dbl (26): Age, DailyRate, DistanceFromHome, Education, EmployeeCount, Employ...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Clean Data

skimr::skim(data)
Data summary
Name data
Number of rows 1470
Number of columns 35
_______________________
Column type frequency:
character 9
numeric 26
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
Attrition 0 1 2 3 0 2 0
BusinessTravel 0 1 10 17 0 3 0
Department 0 1 5 22 0 3 0
EducationField 0 1 5 16 0 6 0
Gender 0 1 4 6 0 2 0
JobRole 0 1 7 25 0 9 0
MaritalStatus 0 1 6 8 0 3 0
Over18 0 1 1 1 0 1 0
OverTime 0 1 2 3 0 2 0

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Age 0 1 36.92 9.14 18 30.00 36.0 43.00 60 ▂▇▇▃▂
DailyRate 0 1 802.49 403.51 102 465.00 802.0 1157.00 1499 ▇▇▇▇▇
DistanceFromHome 0 1 9.19 8.11 1 2.00 7.0 14.00 29 ▇▅▂▂▂
Education 0 1 2.91 1.02 1 2.00 3.0 4.00 5 ▂▃▇▆▁
EmployeeCount 0 1 1.00 0.00 1 1.00 1.0 1.00 1 ▁▁▇▁▁
EmployeeNumber 0 1 1024.87 602.02 1 491.25 1020.5 1555.75 2068 ▇▇▇▇▇
EnvironmentSatisfaction 0 1 2.72 1.09 1 2.00 3.0 4.00 4 ▅▅▁▇▇
HourlyRate 0 1 65.89 20.33 30 48.00 66.0 83.75 100 ▇▇▇▇▇
JobInvolvement 0 1 2.73 0.71 1 2.00 3.0 3.00 4 ▁▃▁▇▁
JobLevel 0 1 2.06 1.11 1 1.00 2.0 3.00 5 ▇▇▃▂▁
JobSatisfaction 0 1 2.73 1.10 1 2.00 3.0 4.00 4 ▅▅▁▇▇
MonthlyIncome 0 1 6502.93 4707.96 1009 2911.00 4919.0 8379.00 19999 ▇▅▂▁▂
MonthlyRate 0 1 14313.10 7117.79 2094 8047.00 14235.5 20461.50 26999 ▇▇▇▇▇
NumCompaniesWorked 0 1 2.69 2.50 0 1.00 2.0 4.00 9 ▇▃▂▂▁
PercentSalaryHike 0 1 15.21 3.66 11 12.00 14.0 18.00 25 ▇▅▃▂▁
PerformanceRating 0 1 3.15 0.36 3 3.00 3.0 3.00 4 ▇▁▁▁▂
RelationshipSatisfaction 0 1 2.71 1.08 1 2.00 3.0 4.00 4 ▅▅▁▇▇
StandardHours 0 1 80.00 0.00 80 80.00 80.0 80.00 80 ▁▁▇▁▁
StockOptionLevel 0 1 0.79 0.85 0 0.00 1.0 1.00 3 ▇▇▁▂▁
TotalWorkingYears 0 1 11.28 7.78 0 6.00 10.0 15.00 40 ▇▇▂▁▁
TrainingTimesLastYear 0 1 2.80 1.29 0 2.00 3.0 3.00 6 ▂▇▇▂▃
WorkLifeBalance 0 1 2.76 0.71 1 2.00 3.0 3.00 4 ▁▃▁▇▂
YearsAtCompany 0 1 7.01 6.13 0 3.00 5.0 9.00 40 ▇▂▁▁▁
YearsInCurrentRole 0 1 4.23 3.62 0 2.00 3.0 7.00 18 ▇▃▂▁▁
YearsSinceLastPromotion 0 1 2.19 3.22 0 0.00 1.0 3.00 15 ▇▁▁▁▁
YearsWithCurrManager 0 1 4.12 3.57 0 2.00 3.0 7.00 17 ▇▂▅▁▁

Issues with data: - missing values - factors or numeric variables - Education, EnvironmentSatisfaction, JobInvolvement, JobSatisfaction, PerformanceRating, RelationshipSatisfaction, WorkLifeBalance - zero variance variables - over18, EmployeeCount, StandardHours - Character variables: convert to numbers in the recipes steps - unbalanced target variable - ID variable: EmployeeNumber

factors_vec <- data %>% select(Education, EnvironmentSatisfaction, JobInvolvement, JobSatisfaction, PerformanceRating, RelationshipSatisfaction, WorkLifeBalance) %>% names()

data_clean <- data %>%
    
    # Address factors imported as numeric
    mutate(across(all_of(factors_vec), as.factor)) %>% 
    
    # Drop zero-variance variables
    select(-c(Over18, EmployeeCount, StandardHours)) %>% 

    # Recode Attrition
    mutate(Attrition = if_else(Attrition == "Yes", "Left", Attrition))

Explore Data

data_clean %>% count(Attrition)
## # A tibble: 2 × 2
##   Attrition     n
##   <chr>     <int>
## 1 Left        237
## 2 No         1233
data_clean %>% 
    ggplot(aes(Attrition)) + 
    geom_bar()

Attrition vs. Monthly Income

data_clean %>% 
    ggplot(aes(Attrition, MonthlyIncome)) + 
    geom_boxplot()

Correlation Plot

# Step 1: Binarize
data_binarized <- data_clean %>% 
    select(-EmployeeNumber) %>% 
    binarize()

data_binarized %>% glimpse()
## Rows: 1,470
## Columns: 120
## $ `Age__-Inf_30`                       <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Age__30_36                           <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, …
## $ Age__36_43                           <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ Age__43_Inf                          <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ Attrition__Left                      <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ Attrition__No                        <dbl> 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `BusinessTravel__Non-Travel`         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ BusinessTravel__Travel_Frequently    <dbl> 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ BusinessTravel__Travel_Rarely        <dbl> 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, …
## $ `DailyRate__-Inf_465`                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ DailyRate__465_802                   <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ DailyRate__802_1157                  <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, …
## $ DailyRate__1157_Inf                  <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, …
## $ Department__Human_Resources          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `Department__Research_&_Development` <dbl> 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ Department__Sales                    <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `DistanceFromHome__-Inf_2`           <dbl> 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__2_7                <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ DistanceFromHome__7_14               <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__14_Inf             <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, …
## $ Education__1                         <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Education__2                         <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ Education__3                         <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, …
## $ Education__4                         <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Education__5                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Human_Resources      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Life_Sciences        <dbl> 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ EducationField__Marketing            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Medical              <dbl> 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, …
## $ EducationField__Other                <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Technical_Degree     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__1           <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ EnvironmentSatisfaction__2           <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__3           <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, …
## $ EnvironmentSatisfaction__4           <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ Gender__Female                       <dbl> 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ Gender__Male                         <dbl> 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ `HourlyRate__-Inf_48`                <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, …
## $ HourlyRate__48_66                    <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ HourlyRate__66_83.75                 <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ HourlyRate__83.75_Inf                <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ JobInvolvement__1                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobInvolvement__2                    <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobInvolvement__3                    <dbl> 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, …
## $ JobInvolvement__4                    <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, …
## $ JobLevel__1                          <dbl> 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, …
## $ JobLevel__2                          <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobLevel__3                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobLevel__4                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobLevel__5                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Healthcare_Representative   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobRole__Human_Resources             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Laboratory_Technician       <dbl> 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, …
## $ JobRole__Manager                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Manufacturing_Director      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobRole__Research_Director           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Research_Scientist          <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Executive             <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Representative        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobSatisfaction__1                   <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ JobSatisfaction__2                   <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ JobSatisfaction__3                   <dbl> 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, …
## $ JobSatisfaction__4                   <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ MaritalStatus__Divorced              <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ MaritalStatus__Married               <dbl> 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, …
## $ MaritalStatus__Single                <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, …
## $ `MonthlyIncome__-Inf_2911`           <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, …
## $ MonthlyIncome__2911_4919             <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ MonthlyIncome__4919_8379             <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ MonthlyIncome__8379_Inf              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ `MonthlyRate__-Inf_8047`             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ MonthlyRate__8047_14235.5            <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, …
## $ MonthlyRate__14235.5_20461.5         <dbl> 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, …
## $ MonthlyRate__20461.5_Inf             <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ `NumCompaniesWorked__-Inf_1`         <dbl> 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, …
## $ NumCompaniesWorked__1_2              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ NumCompaniesWorked__2_4              <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ NumCompaniesWorked__4_Inf            <dbl> 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, …
## $ OverTime__No                         <dbl> 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ OverTime__Yes                        <dbl> 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ `PercentSalaryHike__-Inf_12`         <dbl> 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__12_14             <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, …
## $ PercentSalaryHike__14_18             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__18_Inf            <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ PerformanceRating__3                 <dbl> 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, …
## $ PerformanceRating__4                 <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ RelationshipSatisfaction__1          <dbl> 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ RelationshipSatisfaction__2          <dbl> 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, …
## $ RelationshipSatisfaction__3          <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, …
## $ RelationshipSatisfaction__4          <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ StockOptionLevel__0                  <dbl> 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ StockOptionLevel__1                  <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ StockOptionLevel__2                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ StockOptionLevel__3                  <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ `TotalWorkingYears__-Inf_6`          <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ TotalWorkingYears__6_10              <dbl> 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ TotalWorkingYears__10_15             <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ TotalWorkingYears__15_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `TrainingTimesLastYear__-Inf_2`      <dbl> 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, …
## $ TrainingTimesLastYear__2_3           <dbl> 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, …
## $ TrainingTimesLastYear__3_Inf         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ WorkLifeBalance__1                   <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ WorkLifeBalance__2                   <dbl> 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, …
## $ WorkLifeBalance__3                   <dbl> 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, …
## $ WorkLifeBalance__4                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsAtCompany__-Inf_3`             <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsAtCompany__3_5                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsAtCompany__5_9                  <dbl> 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, …
## $ YearsAtCompany__9_Inf                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsInCurrentRole__-Inf_2`         <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsInCurrentRole__2_3              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ YearsInCurrentRole__3_7              <dbl> 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, …
## $ YearsInCurrentRole__7_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsSinceLastPromotion__-Inf_1`    <dbl> 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, …
## $ YearsSinceLastPromotion__1_3         <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ YearsSinceLastPromotion__3_Inf       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `YearsWithCurrManager__-Inf_2`       <dbl> 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsWithCurrManager__2_3            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsWithCurrManager__3_7            <dbl> 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, …
## $ YearsWithCurrManager__7_Inf          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
# Step 2: Correlation
data_correlation <- data_binarized %>% 
    correlate(Attrition__Left)

data_correlation
## # A tibble: 120 × 3
##    feature           bin       correlation
##    <fct>             <chr>           <dbl>
##  1 Attrition         Left            1    
##  2 Attrition         No             -1    
##  3 OverTime          No             -0.246
##  4 OverTime          Yes             0.246
##  5 JobLevel          1               0.213
##  6 MonthlyIncome     -Inf_2911       0.207
##  7 StockOptionLevel  0               0.195
##  8 YearsAtCompany    -Inf_3          0.183
##  9 MaritalStatus     Single          0.175
## 10 TotalWorkingYears -Inf_6          0.169
## # ℹ 110 more rows
# Step 3: Plot
data_correlation %>% 
    correlationfunnel::plot_correlation_funnel()
## Warning: ggrepel: 73 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Model Building

Split Data

set.seed(1234)
# data <- data_clean %>% sample_n(100)

data_split <- initial_split(data_clean, strata = Attrition)
data_train <- training(data_split)
data_test <- testing(data_split)

data_cv <- rsample::vfold_cv(data_train, strata = Attrition)
data_cv
## #  10-fold cross-validation using stratification 
## # A tibble: 10 × 2
##    splits            id    
##    <list>            <chr> 
##  1 <split [990/111]> Fold01
##  2 <split [990/111]> Fold02
##  3 <split [990/111]> Fold03
##  4 <split [990/111]> Fold04
##  5 <split [991/110]> Fold05
##  6 <split [991/110]> Fold06
##  7 <split [991/110]> Fold07
##  8 <split [992/109]> Fold08
##  9 <split [992/109]> Fold09
## 10 <split [992/109]> Fold10

Preprocess Data

library(themis)
## Warning: package 'themis' was built under R version 4.2.3
xgboost_rec <- recipes::recipe(Attrition ~., data = data_train) %>% 
    update_role(EmployeeNumber, new_role = "ID") %>% 
    step_dummy(all_nominal_predictors()) %>% 
    step_normalize(all_numeric_predictors()) %>% 
    step_pca(all_numeric_predictors(), threshold = 0.99) %>% 
    step_smote(Attrition)
    
xgboost_rec %>% prep() %>% juice() %>% glimpse()  
## Rows: 1,848
## Columns: 51
## $ EmployeeNumber <dbl> 4, 27, 33, 45, 47, 55, 58, 64, 90, 133, 137, 142, 147, …
## $ Attrition      <fct> Left, Left, Left, Left, Left, Left, Left, Left, Left, L…
## $ PC01           <dbl> -2.7008796, -1.3851016, -0.9898314, -2.7123949, -1.4082…
## $ PC02           <dbl> -1.1582242, 2.0758272, -1.1769684, -1.0635118, 3.547456…
## $ PC03           <dbl> -1.4071746, -1.5566383, -2.1704958, 1.5138776, 1.230319…
## $ PC04           <dbl> -1.8426942, -0.2189249, 3.3214414, 0.1627416, -0.940169…
## $ PC05           <dbl> -0.25210132, 2.70012402, 2.40029506, 0.16116327, -0.034…
## $ PC06           <dbl> 2.59254625, 2.81147695, 0.70726636, -2.44417469, 2.2332…
## $ PC07           <dbl> -0.55593527, -1.24958628, 0.01531976, -2.16787791, 0.21…
## $ PC08           <dbl> -0.89240997, -0.07652722, 0.56130410, 0.58312834, -0.90…
## $ PC09           <dbl> 1.19431200, 1.00186769, 1.07540726, -0.81843662, -0.096…
## $ PC10           <dbl> -2.08259705, 0.56337423, -0.15568277, -0.81644957, -0.7…
## $ PC11           <dbl> -0.3934252, 0.5379347, 1.1038419, 0.3179483, -0.8389709…
## $ PC12           <dbl> 0.6217673, -2.5481178, -0.7844686, 1.3383124, -0.806984…
## $ PC13           <dbl> -1.40210098, -1.04469438, 2.16121634, 0.96844715, 0.238…
## $ PC14           <dbl> 1.786546612, 0.947038464, 0.728867736, -1.203721858, 0.…
## $ PC15           <dbl> 1.50900812, -0.82277088, -1.29183287, -0.65406491, 0.60…
## $ PC16           <dbl> -0.19437803, -0.91896213, -1.64225604, -0.35247632, 1.4…
## $ PC17           <dbl> -0.13999669, 0.57010089, 0.94631147, 0.60202385, -0.970…
## $ PC18           <dbl> -0.74594330, 0.87513463, 0.19960168, -1.13921071, -0.21…
## $ PC19           <dbl> -0.1848677, -0.9562717, 0.4086475, 1.2973671, 2.3679767…
## $ PC20           <dbl> -0.55339508, -1.35106209, -0.01057609, -0.69481221, -1.…
## $ PC21           <dbl> -0.13065253, 1.45537716, 0.22364100, -0.18896653, 0.134…
## $ PC22           <dbl> 0.79779388, -0.96024363, -0.91304176, -1.98595539, -0.1…
## $ PC23           <dbl> 3.22852202, 0.23213193, 0.76844435, -0.19681532, 1.3233…
## $ PC24           <dbl> 2.25403847, -0.73294485, -0.28739630, 0.70231139, -0.27…
## $ PC25           <dbl> 1.12730949, -1.30313515, 0.92367530, 0.64488095, 0.4799…
## $ PC26           <dbl> 0.77564433, -0.38089595, -0.65786318, 0.50247084, 0.599…
## $ PC27           <dbl> -1.221687488, 0.874353237, -1.016860596, -0.144609586, …
## $ PC28           <dbl> -2.7410678, -0.5790063, -0.9980385, -0.4497550, 0.20913…
## $ PC29           <dbl> 0.2509422, 2.2249652, -0.8511565, 0.5617327, 2.2427028,…
## $ PC30           <dbl> -0.5160910, -0.6846793, -0.3356748, -1.2007427, -2.5159…
## $ PC31           <dbl> -0.786023708, -0.440781389, -0.610681242, 0.353466990, …
## $ PC32           <dbl> 1.3231294983, 0.6065014272, 0.7773408307, -0.5604557323…
## $ PC33           <dbl> 0.5301238, 1.6180841, -1.2078405, -1.0254676, 0.4020246…
## $ PC34           <dbl> 0.81150211, 0.33071364, -0.64331436, 0.41004008, 0.1788…
## $ PC35           <dbl> -1.40474186, -2.42122089, -1.68214959, -0.01447111, -1.…
## $ PC36           <dbl> -0.078536567, -0.123303251, 0.035522194, 0.524214319, 1…
## $ PC37           <dbl> -0.16825126, 0.33602570, -0.39537080, 0.30034723, -2.44…
## $ PC38           <dbl> -0.22667694, -0.24944872, -0.80922279, -0.59753526, 0.5…
## $ PC39           <dbl> 0.18055939, -1.14697557, 0.86855130, 0.33490275, 0.4211…
## $ PC40           <dbl> 0.04497181, 0.06581771, -0.82818604, -0.37140836, -0.30…
## $ PC41           <dbl> 0.34111038, -1.10608772, -1.08420254, 0.45253468, 0.106…
## $ PC42           <dbl> 0.194363520, 0.141264439, 0.405442315, -0.143984092, -0…
## $ PC43           <dbl> 0.16081450, -0.01588223, 0.13145777, -0.97414364, 0.766…
## $ PC44           <dbl> -0.06758401, -0.19587807, -0.62324861, -0.01742433, 0.9…
## $ PC45           <dbl> -0.183378946, 0.232482848, 0.236647628, 0.396652061, -0…
## $ PC46           <dbl> 0.22878570, -0.15130989, -0.03590376, 0.35287800, -0.22…
## $ PC47           <dbl> 0.118950505, 0.388201362, 0.246572558, -0.202709146, -0…
## $ PC48           <dbl> -0.31263005, -0.21284450, -0.99732266, -0.06094232, 0.1…
## $ PC49           <dbl> -0.114339812, -0.200857771, 0.954179455, -0.510904679, …

Specify Model

library(usemodels)
## Warning: package 'usemodels' was built under R version 4.2.3
usemodels::use_xgboost(Attrition ~., data = data_train)
## xgboost_recipe <- 
##   recipe(formula = Attrition ~ ., data = data_train) %>% 
##   step_zv(all_predictors()) 
## 
## xgboost_spec <- 
##   boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
##     loss_reduction = tune(), sample_size = tune()) %>% 
##   set_mode("classification") %>% 
##   set_engine("xgboost") 
## 
## xgboost_workflow <- 
##   workflow() %>% 
##   add_recipe(xgboost_recipe) %>% 
##   add_model(xgboost_spec) 
## 
## set.seed(3682)
## xgboost_tune <-
##   tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
xgboost_spec <- 
  boost_tree(trees = tune(), tree_depth = tune()) %>% 
  set_mode("classification") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_rec) %>% 
  add_model(xgboost_spec)

Tune Hyperparameters

library(parsnip)
tree_grid <- grid_regular(trees(), 
                          tree_depth(), 
                          levels = 5)

doParallel::registerDoParallel()

set.seed(65743)
xgboost_tune <-
  tune_grid(xgboost_workflow, 
            resamples = data_cv, 
            grid = 5, 
            control = control_grid(save_pred = TRUE))

Model Evaluation

Identify Optimal Values for Hyperparameters

collect_metrics(xgboost_tune)
## # A tibble: 10 × 8
##    trees tree_depth .metric  .estimator  mean     n std_err .config             
##    <int>      <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
##  1  1741          3 accuracy binary     0.840    10 0.00896 Preprocessor1_Model1
##  2  1741          3 roc_auc  binary     0.796    10 0.0198  Preprocessor1_Model1
##  3   885          5 accuracy binary     0.837    10 0.0118  Preprocessor1_Model2
##  4   885          5 roc_auc  binary     0.804    10 0.0209  Preprocessor1_Model2
##  5   325          7 accuracy binary     0.846    10 0.0115  Preprocessor1_Model3
##  6   325          7 roc_auc  binary     0.795    10 0.0235  Preprocessor1_Model3
##  7  1312         12 accuracy binary     0.837    10 0.0121  Preprocessor1_Model4
##  8  1312         12 roc_auc  binary     0.797    10 0.0205  Preprocessor1_Model4
##  9   555         15 accuracy binary     0.842    10 0.0101  Preprocessor1_Model5
## 10   555         15 roc_auc  binary     0.794    10 0.0191  Preprocessor1_Model5
collect_predictions(xgboost_tune) %>% 
    group_by(id) %>% 
    roc_curve(Attrition, .pred_Left) %>% 
    autoplot()

Fit the Model for the Last Time

xgboost_last <- xgboost_workflow %>% 
    finalize_workflow(select_best(xgboost_tune, metric = "accuracy")) %>% 
    last_fit(data_split) 

collect_metrics(xgboost_last)
## # A tibble: 2 × 4
##   .metric  .estimator .estimate .config             
##   <chr>    <chr>          <dbl> <chr>               
## 1 accuracy binary         0.846 Preprocessor1_Model1
## 2 roc_auc  binary         0.742 Preprocessor1_Model1
collect_predictions(xgboost_last) %>% 
    yardstick::conf_mat(Attrition, .pred_class) %>% 
    autoplot()

Variable Importance

library(vip)
## Warning: package 'vip' was built under R version 4.2.3
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi
xgboost_last %>% 
    workflows::extract_fit_engine() %>% 
    vip()

Conclusion

The previous model had accuracy of 0.851 and AUC of 0.753.