Import Data

library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.3     ✔ readr     2.1.4
## ✔ forcats   1.0.0     ✔ stringr   1.5.0
## ✔ ggplot2   3.4.3     ✔ tibble    3.2.1
## ✔ lubridate 1.9.3     ✔ tidyr     1.3.0
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
data <- read_csv("../00_data/EmployeeDataSet.csv")
## Rows: 1470 Columns: 35
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (9): Attrition, BusinessTravel, Department, EducationField, Gender, Job...
## dbl (26): Age, DailyRate, DistanceFromHome, Education, EmployeeCount, Employ...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Goal: Predict attrition on employees who are likely to leave the company

Issues with data:

# Cleaning Data

 # Adressing all numeric variables
factors_vec <- data %>% select(Education, EnvironmentSatisfaction, JobInvolvement, JobSatisfaction, PerformanceRating, RelationshipSatisfaction, WorkLifeBalance, JobLevel, StockOptionLevel) %>% names()

 # Adressing factors imported to R as numeric variable
data_clean <- data %>%
    # Adressing one variable: mutate(Education = Education %>% as.factor()) %>%
     mutate(across(all_of(factors_vec), as.factor)) %>%
    
    # Drop zero-variance variables
    select(-c(Over18,EmployeeCount,StandardHours)) %>%
    
    # Recode Attrition
    mutate(Attrition = if_else(Attrition == "Yes", "Left", Attrition))

Explore Data

# Adressing unnbalanced target variable
data_clean %>% count(Attrition)
## # A tibble: 2 × 2
##   Attrition     n
##   <chr>     <int>
## 1 Left        237
## 2 No         1233
data_clean %>%
    ggplot(aes(Attrition)) +
    geom_bar()

Attrition vs. MonthlyIncome

# Relationship between pay and attrition
data_clean %>%
    ggplot(aes(Attrition, MonthlyIncome)) +
    geom_boxplot()

Relationship in all variables with Correlation plot

library(correlationfunnel)
## ══ Using correlationfunnel? ════════════════════════════════════════════════════
## You might also be interested in applied data science training for business.
## </> Learn more at - www.business-science.io </>
 # Step 1: Binarize
data_binarized <- data_clean %>%
    select(-EmployeeNumber) %>% # ID variable
    binarize()

data_binarized %>% glimpse()
## Rows: 1,470
## Columns: 120
## $ `Age__-Inf_30`                       <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Age__30_36                           <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, …
## $ Age__36_43                           <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ Age__43_Inf                          <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ Attrition__Left                      <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ Attrition__No                        <dbl> 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `BusinessTravel__Non-Travel`         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ BusinessTravel__Travel_Frequently    <dbl> 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ BusinessTravel__Travel_Rarely        <dbl> 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, …
## $ `DailyRate__-Inf_465`                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ DailyRate__465_802                   <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ DailyRate__802_1157                  <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, …
## $ DailyRate__1157_Inf                  <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, …
## $ Department__Human_Resources          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `Department__Research_&_Development` <dbl> 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ Department__Sales                    <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `DistanceFromHome__-Inf_2`           <dbl> 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__2_7                <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ DistanceFromHome__7_14               <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__14_Inf             <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, …
## $ Education__1                         <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Education__2                         <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ Education__3                         <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, …
## $ Education__4                         <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Education__5                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Human_Resources      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Life_Sciences        <dbl> 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ EducationField__Marketing            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Medical              <dbl> 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, …
## $ EducationField__Other                <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Technical_Degree     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__1           <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ EnvironmentSatisfaction__2           <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__3           <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, …
## $ EnvironmentSatisfaction__4           <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ Gender__Female                       <dbl> 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ Gender__Male                         <dbl> 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ `HourlyRate__-Inf_48`                <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, …
## $ HourlyRate__48_66                    <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ HourlyRate__66_83.75                 <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ HourlyRate__83.75_Inf                <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ JobInvolvement__1                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobInvolvement__2                    <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobInvolvement__3                    <dbl> 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, …
## $ JobInvolvement__4                    <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, …
## $ JobLevel__1                          <dbl> 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, …
## $ JobLevel__2                          <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobLevel__3                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobLevel__4                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobLevel__5                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Healthcare_Representative   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobRole__Human_Resources             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Laboratory_Technician       <dbl> 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, …
## $ JobRole__Manager                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Manufacturing_Director      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobRole__Research_Director           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Research_Scientist          <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Executive             <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Representative        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobSatisfaction__1                   <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ JobSatisfaction__2                   <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ JobSatisfaction__3                   <dbl> 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, …
## $ JobSatisfaction__4                   <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ MaritalStatus__Divorced              <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ MaritalStatus__Married               <dbl> 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, …
## $ MaritalStatus__Single                <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, …
## $ `MonthlyIncome__-Inf_2911`           <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, …
## $ MonthlyIncome__2911_4919             <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ MonthlyIncome__4919_8379             <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ MonthlyIncome__8379_Inf              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ `MonthlyRate__-Inf_8047`             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ MonthlyRate__8047_14235.5            <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, …
## $ MonthlyRate__14235.5_20461.5         <dbl> 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, …
## $ MonthlyRate__20461.5_Inf             <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ `NumCompaniesWorked__-Inf_1`         <dbl> 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, …
## $ NumCompaniesWorked__1_2              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ NumCompaniesWorked__2_4              <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ NumCompaniesWorked__4_Inf            <dbl> 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, …
## $ OverTime__No                         <dbl> 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ OverTime__Yes                        <dbl> 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ `PercentSalaryHike__-Inf_12`         <dbl> 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__12_14             <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, …
## $ PercentSalaryHike__14_18             <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__18_Inf            <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ PerformanceRating__3                 <dbl> 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, …
## $ PerformanceRating__4                 <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ RelationshipSatisfaction__1          <dbl> 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ RelationshipSatisfaction__2          <dbl> 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, …
## $ RelationshipSatisfaction__3          <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, …
## $ RelationshipSatisfaction__4          <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ StockOptionLevel__0                  <dbl> 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ StockOptionLevel__1                  <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ StockOptionLevel__2                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ StockOptionLevel__3                  <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ `TotalWorkingYears__-Inf_6`          <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ TotalWorkingYears__6_10              <dbl> 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ TotalWorkingYears__10_15             <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ TotalWorkingYears__15_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `TrainingTimesLastYear__-Inf_2`      <dbl> 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, …
## $ TrainingTimesLastYear__2_3           <dbl> 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, …
## $ TrainingTimesLastYear__3_Inf         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ WorkLifeBalance__1                   <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ WorkLifeBalance__2                   <dbl> 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, …
## $ WorkLifeBalance__3                   <dbl> 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, …
## $ WorkLifeBalance__4                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsAtCompany__-Inf_3`             <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsAtCompany__3_5                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsAtCompany__5_9                  <dbl> 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, …
## $ YearsAtCompany__9_Inf                <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsInCurrentRole__-Inf_2`         <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsInCurrentRole__2_3              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ YearsInCurrentRole__3_7              <dbl> 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, …
## $ YearsInCurrentRole__7_Inf            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsSinceLastPromotion__-Inf_1`    <dbl> 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, …
## $ YearsSinceLastPromotion__1_3         <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ YearsSinceLastPromotion__3_Inf       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `YearsWithCurrManager__-Inf_2`       <dbl> 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsWithCurrManager__2_3            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsWithCurrManager__3_7            <dbl> 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, …
## $ YearsWithCurrManager__7_Inf          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
# Step 2: Correlation
data_correlation <- data_binarized %>%
    correlate(Attrition__Left) # Target variable

data_correlation
## # A tibble: 120 × 3
##    feature           bin       correlation
##    <fct>             <chr>           <dbl>
##  1 Attrition         Left            1    
##  2 Attrition         No             -1    
##  3 OverTime          Yes             0.246
##  4 OverTime          No             -0.246
##  5 JobLevel          1               0.213
##  6 MonthlyIncome     -Inf_2911       0.207
##  7 StockOptionLevel  0               0.195
##  8 YearsAtCompany    -Inf_3          0.183
##  9 MaritalStatus     Single          0.175
## 10 TotalWorkingYears -Inf_6          0.169
## # ℹ 110 more rows
# Step 3: Plot
data_correlation %>%
    correlationfunnel::plot_correlation_funnel()
## Warning: ggrepel: 72 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Model Building

Split Data

library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
## ✔ broom        1.0.5     ✔ rsample      1.2.0
## ✔ dials        1.2.0     ✔ tune         1.1.2
## ✔ infer        1.0.6     ✔ workflows    1.1.3
## ✔ modeldata    1.3.0     ✔ workflowsets 1.0.1
## ✔ parsnip      1.1.1     ✔ yardstick    1.3.0
## ✔ recipes      1.0.8
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
## • Learn how to get started at https://www.tidymodels.org/start/
set.seed(1234)
data_clean <- data_clean %>% sample_n(500) 

data_split <- initial_split(data_clean, strata = Attrition)
data_training <- training(data_split)
data_test <- testing(data_split)

 # Cross Validation
data_cv <- rsample::vfold_cv(data_training, strata = Attrition)
data_cv
## #  10-fold cross-validation using stratification 
## # A tibble: 10 × 2
##    splits           id    
##    <list>           <chr> 
##  1 <split [336/38]> Fold01
##  2 <split [336/38]> Fold02
##  3 <split [336/38]> Fold03
##  4 <split [336/38]> Fold04
##  5 <split [336/38]> Fold05
##  6 <split [337/37]> Fold06
##  7 <split [337/37]> Fold07
##  8 <split [337/37]> Fold08
##  9 <split [337/37]> Fold09
## 10 <split [338/36]> Fold10

Preprocess Data

  # Solving the unbalanced target variable
library(themis)

xgboost_rec <- recipes::recipe(Attrition ~ ., data = data_training) %>%
    update_role(EmployeeNumber, new_role = "ID") %>%
    step_dummy(all_nominal_predictors()) %>%
    # Include if model gets improved by this: 
    # step_YeoJohnson(DistanceFromHome, MonthlyIncome, NumCompaniesWorked, PercentSalaryHike,     TotalWorkingYears, starts_with("Years")) %>%
    step_normalize(all_numeric_predictors()) %>%
   # step_pca(all_nominal_predictors(), threshold = .99) %>%
    step_smote(Attrition)
    
xgboost_rec %>% prep() %>% juice() %>% glimpse()
## Rows: 638
## Columns: 64
## $ Age                               <dbl> -0.224477519, -1.210889173, -0.55328…
## $ DailyRate                         <dbl> -0.4781680, 0.8517798, 1.1385815, -0…
## $ DistanceFromHome                  <dbl> 0.69063582, -0.08658925, -0.86381432…
## $ EmployeeNumber                    <dbl> 1010, 796, 1692, 161, 840, 787, 1372…
## $ HourlyRate                        <dbl> -1.30338779, -1.35243367, 1.44318167…
## $ MonthlyIncome                     <dbl> -0.57365753, -0.24406941, -1.0629387…
## $ MonthlyRate                       <dbl> -0.5691381, -1.5384076, 1.4742092, -…
## $ NumCompaniesWorked                <dbl> -0.7012057, 1.3444335, -0.7012057, 2…
## $ PercentSalaryHike                 <dbl> 2.45237158, 0.50446236, -0.88690137,…
## $ TotalWorkingYears                 <dbl> -0.8187896, -0.6863383, -1.3485946, …
## $ TrainingTimesLastYear             <dbl> -0.6390327, -0.6390327, -0.6390327, …
## $ YearsAtCompany                    <dbl> -0.5334685, -0.5334685, -1.0683673, …
## $ YearsInCurrentRole                <dbl> -0.63688360, -0.34886127, -1.2129282…
## $ YearsSinceLastPromotion           <dbl> -0.69141891, -0.36200404, -0.6914189…
## $ YearsWithCurrManager              <dbl> -0.61497409, -0.61497409, -1.1814773…
## $ Attrition                         <fct> Left, Left, Left, Left, Left, Left, …
## $ BusinessTravel_Travel_Frequently  <dbl> -0.5126489, -0.5126489, -0.5126489, …
## $ BusinessTravel_Travel_Rarely      <dbl> 0.6696344, 0.6696344, 0.6696344, 0.6…
## $ Department_Research...Development <dbl> 0.7118329, -1.4010679, 0.7118329, 0.…
## $ Department_Sales                  <dbl> -0.6487871, 1.5372164, -0.6487871, -…
## $ Education_X2                      <dbl> -0.4876199, -0.4876199, -0.4876199, …
## $ Education_X3                      <dbl> -0.7462174, 1.3365088, -0.7462174, -…
## $ Education_X4                      <dbl> 1.5571423, -0.6404849, 1.5571423, 1.…
## $ Education_X5                      <dbl> -0.1738448, -0.1738448, -0.1738448, …
## $ EducationField_Life.Sciences      <dbl> -0.8172223, -0.8172223, 1.2203854, 1…
## $ EducationField_Marketing          <dbl> -0.355201, -0.355201, -0.355201, -0.…
## $ EducationField_Medical            <dbl> -0.7033305, -0.7033305, -0.7033305, …
## $ EducationField_Other              <dbl> 3.9946488, -0.2496656, -0.2496656, -…
## $ EducationField_Technical.Degree   <dbl> -0.2949171, 3.3817166, -0.2949171, -…
## $ EnvironmentSatisfaction_X2        <dbl> -0.4834221, -0.4834221, -0.4834221, …
## $ EnvironmentSatisfaction_X3        <dbl> 1.5372164, -0.6487871, -0.6487871, -…
## $ EnvironmentSatisfaction_X4        <dbl> -0.6738217, 1.4801040, 1.4801040, -0…
## $ Gender_Male                       <dbl> 0.8172223, 0.8172223, 0.8172223, -1.…
## $ JobInvolvement_X2                 <dbl> 1.8555727, 1.8555727, -0.5374762, -0…
## $ JobInvolvement_X3                 <dbl> -1.2985095, -1.2985095, 0.7680546, 0…
## $ JobInvolvement_X4                 <dbl> -0.3358465, -0.3358465, -0.3358465, …
## $ JobLevel_X2                       <dbl> -0.7549177, 1.3211059, -0.7549177, -…
## $ JobLevel_X3                       <dbl> -0.4322046, -0.4322046, -0.4322046, …
## $ JobLevel_X4                       <dbl> -0.2496656, -0.2496656, -0.2496656, …
## $ JobLevel_X5                       <dbl> -0.2435796, -0.2435796, -0.2435796, …
## $ JobRole_Human.Resources           <dbl> -0.1969388, -0.1969388, -0.1969388, …
## $ JobRole_Laboratory.Technician     <dbl> 2.1774148, -0.4580322, 2.1774148, -0…
## $ JobRole_Manager                   <dbl> -0.2729705, -0.2729705, -0.2729705, …
## $ JobRole_Manufacturing.Director    <dbl> -0.3407442, -0.3407442, -0.3407442, …
## $ JobRole_Research.Director         <dbl> -0.2496656, -0.2496656, -0.2496656, …
## $ JobRole_Research.Scientist        <dbl> -0.4876199, -0.4876199, -0.4876199, …
## $ JobRole_Sales.Executive           <dbl> -0.5209427, 1.9144644, -0.5209427, -…
## $ JobRole_Sales.Representative      <dbl> -0.2310368, -0.2310368, -0.2310368, …
## $ JobSatisfaction_X2                <dbl> 2.0278715, -0.4918094, 2.0278715, 2.…
## $ JobSatisfaction_X3                <dbl> -0.6822162, -0.6822162, -0.6822162, …
## $ JobSatisfaction_X4                <dbl> -0.6822162, -0.6822162, -0.6822162, …
## $ MaritalStatus_Married             <dbl> -0.9365072, -0.9365072, -0.9365072, …
## $ MaritalStatus_Single              <dbl> -0.6280644, 1.5879363, 1.5879363, -0…
## $ OverTime_Yes                      <dbl> 1.5471208, -0.6446337, -0.6446337, 1…
## $ PerformanceRating_X4              <dbl> 2.4050971, -0.4146719, -0.4146719, -…
## $ RelationshipSatisfaction_X2       <dbl> -0.4959911, -0.4959911, -0.4959911, …
## $ RelationshipSatisfaction_X3       <dbl> -0.7075774, 1.4094942, -0.7075774, -…
## $ RelationshipSatisfaction_X4       <dbl> 1.6418718, -0.6074324, -0.6074324, -…
## $ StockOptionLevel_X1               <dbl> 1.1612862, -0.8588117, -0.8588117, -…
## $ StockOptionLevel_X2               <dbl> -0.355201, -0.355201, -0.355201, -0.…
## $ StockOptionLevel_X3               <dbl> -0.2785709, -0.2785709, -0.2785709, …
## $ WorkLifeBalance_X2                <dbl> -0.5868564, 1.6994383, -0.5868564, -…
## $ WorkLifeBalance_X3                <dbl> -1.1612862, -1.1612862, 0.8588117, 0…
## $ WorkLifeBalance_X4                <dbl> -0.3646599, -0.3646599, -0.3646599, …

Specify Model

xgboost_spec <- 
  boost_tree(trees = tune(), tree_depth = tune()) %>% 
  set_mode("classification") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_rec) %>% 
  add_model(xgboost_spec) 

Tune Hyperparameters

tree_grid <- grid_regular(trees(),
                          tree_depth(),
                          levels = 5)

doParallel::registerDoParallel()

set.seed(45034)
xgboost_tune <-
  tune_grid(xgboost_workflow, 
            resamples = data_cv, 
            grid = 5,
            control = control_grid(save_pred = TRUE))

Model Evaluation

Identifying optimal values for hyperparameters

collect_metrics(xgboost_tune)
## # A tibble: 10 × 8
##    trees tree_depth .metric  .estimator  mean     n std_err .config             
##    <int>      <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
##  1  1311          2 accuracy binary     0.859    10  0.0129 Preprocessor1_Model1
##  2  1311          2 roc_auc  binary     0.696    10  0.0269 Preprocessor1_Model1
##  3   641          5 accuracy binary     0.848    10  0.0113 Preprocessor1_Model2
##  4   641          5 roc_auc  binary     0.704    10  0.0279 Preprocessor1_Model2
##  5   805          8 accuracy binary     0.856    10  0.0141 Preprocessor1_Model3
##  6   805          8 roc_auc  binary     0.689    10  0.0318 Preprocessor1_Model3
##  7  1883         10 accuracy binary     0.856    10  0.0129 Preprocessor1_Model4
##  8  1883         10 roc_auc  binary     0.697    10  0.0264 Preprocessor1_Model4
##  9   215         13 accuracy binary     0.848    10  0.0123 Preprocessor1_Model5
## 10   215         13 roc_auc  binary     0.707    10  0.0304 Preprocessor1_Model5
collect_predictions(xgboost_tune) %>%
    group_by(id) %>%
    roc_curve(Attrition, .pred_Left) %>%
    autoplot()

Fit the model for the last time

xgboost_last <- xgboost_workflow %>%
    finalize_workflow(select_best(xgboost_tune, metric = "accuracy")) %>%
    last_fit(data_split)

collect_metrics(xgboost_last)
## # A tibble: 2 × 4
##   .metric  .estimator .estimate .config             
##   <chr>    <chr>          <dbl> <chr>               
## 1 accuracy binary         0.825 Preprocessor1_Model1
## 2 roc_auc  binary         0.736 Preprocessor1_Model1
collect_predictions(xgboost_last) %>%
    yardstick::conf_mat(Attrition, .pred_class) %>%
    autoplot()

library(vip)
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi
xgboost_last %>%
    workflows::extract_fit_engine() %>%
    vip()

Conlusion

The previous model had accuracy of 0.731 and AUC of 0.633