Goal: is to predict attrition, employees who are likely to leave the company.
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.1 ✔ tibble 3.2.1
## ✔ lubridate 1.9.3 ✔ tidyr 1.3.1
## ✔ purrr 1.0.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(correlationfunnel)
## ══ Using correlationfunnel? ════════════════════════════════════════════════════
## You might also be interested in applied data science training for business.
## </> Learn more at - www.business-science.io </>
data <- read_csv("../00_data/WA_Fn-UseC_-HR-Employee-Attrition.csv")
## Rows: 1470 Columns: 35
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (9): Attrition, BusinessTravel, Department, EducationField, Gender, Job...
## dbl (26): Age, DailyRate, DistanceFromHome, Education, EmployeeCount, Employ...
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skimr::skim(data)
Name | data |
Number of rows | 1470 |
Number of columns | 35 |
_______________________ | |
Column type frequency: | |
character | 9 |
numeric | 26 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
Attrition | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
BusinessTravel | 0 | 1 | 10 | 17 | 0 | 3 | 0 |
Department | 0 | 1 | 5 | 22 | 0 | 3 | 0 |
EducationField | 0 | 1 | 5 | 16 | 0 | 6 | 0 |
Gender | 0 | 1 | 4 | 6 | 0 | 2 | 0 |
JobRole | 0 | 1 | 7 | 25 | 0 | 9 | 0 |
MaritalStatus | 0 | 1 | 6 | 8 | 0 | 3 | 0 |
Over18 | 0 | 1 | 1 | 1 | 0 | 1 | 0 |
OverTime | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
Age | 0 | 1 | 36.92 | 9.14 | 18 | 30.00 | 36.0 | 43.00 | 60 | ▂▇▇▃▂ |
DailyRate | 0 | 1 | 802.49 | 403.51 | 102 | 465.00 | 802.0 | 1157.00 | 1499 | ▇▇▇▇▇ |
DistanceFromHome | 0 | 1 | 9.19 | 8.11 | 1 | 2.00 | 7.0 | 14.00 | 29 | ▇▅▂▂▂ |
Education | 0 | 1 | 2.91 | 1.02 | 1 | 2.00 | 3.0 | 4.00 | 5 | ▂▃▇▆▁ |
EmployeeCount | 0 | 1 | 1.00 | 0.00 | 1 | 1.00 | 1.0 | 1.00 | 1 | ▁▁▇▁▁ |
EmployeeNumber | 0 | 1 | 1024.87 | 602.02 | 1 | 491.25 | 1020.5 | 1555.75 | 2068 | ▇▇▇▇▇ |
EnvironmentSatisfaction | 0 | 1 | 2.72 | 1.09 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
HourlyRate | 0 | 1 | 65.89 | 20.33 | 30 | 48.00 | 66.0 | 83.75 | 100 | ▇▇▇▇▇ |
JobInvolvement | 0 | 1 | 2.73 | 0.71 | 1 | 2.00 | 3.0 | 3.00 | 4 | ▁▃▁▇▁ |
JobLevel | 0 | 1 | 2.06 | 1.11 | 1 | 1.00 | 2.0 | 3.00 | 5 | ▇▇▃▂▁ |
JobSatisfaction | 0 | 1 | 2.73 | 1.10 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
MonthlyIncome | 0 | 1 | 6502.93 | 4707.96 | 1009 | 2911.00 | 4919.0 | 8379.00 | 19999 | ▇▅▂▁▂ |
MonthlyRate | 0 | 1 | 14313.10 | 7117.79 | 2094 | 8047.00 | 14235.5 | 20461.50 | 26999 | ▇▇▇▇▇ |
NumCompaniesWorked | 0 | 1 | 2.69 | 2.50 | 0 | 1.00 | 2.0 | 4.00 | 9 | ▇▃▂▂▁ |
PercentSalaryHike | 0 | 1 | 15.21 | 3.66 | 11 | 12.00 | 14.0 | 18.00 | 25 | ▇▅▃▂▁ |
PerformanceRating | 0 | 1 | 3.15 | 0.36 | 3 | 3.00 | 3.0 | 3.00 | 4 | ▇▁▁▁▂ |
RelationshipSatisfaction | 0 | 1 | 2.71 | 1.08 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
StandardHours | 0 | 1 | 80.00 | 0.00 | 80 | 80.00 | 80.0 | 80.00 | 80 | ▁▁▇▁▁ |
StockOptionLevel | 0 | 1 | 0.79 | 0.85 | 0 | 0.00 | 1.0 | 1.00 | 3 | ▇▇▁▂▁ |
TotalWorkingYears | 0 | 1 | 11.28 | 7.78 | 0 | 6.00 | 10.0 | 15.00 | 40 | ▇▇▂▁▁ |
TrainingTimesLastYear | 0 | 1 | 2.80 | 1.29 | 0 | 2.00 | 3.0 | 3.00 | 6 | ▂▇▇▂▃ |
WorkLifeBalance | 0 | 1 | 2.76 | 0.71 | 1 | 2.00 | 3.0 | 3.00 | 4 | ▁▃▁▇▂ |
YearsAtCompany | 0 | 1 | 7.01 | 6.13 | 0 | 3.00 | 5.0 | 9.00 | 40 | ▇▂▁▁▁ |
YearsInCurrentRole | 0 | 1 | 4.23 | 3.62 | 0 | 2.00 | 3.0 | 7.00 | 18 | ▇▃▂▁▁ |
YearsSinceLastPromotion | 0 | 1 | 2.19 | 3.22 | 0 | 0.00 | 1.0 | 3.00 | 15 | ▇▁▁▁▁ |
YearsWithCurrManager | 0 | 1 | 4.12 | 3.57 | 0 | 2.00 | 3.0 | 7.00 | 17 | ▇▂▅▁▁ |
Issues with data * Missing Values * Factors or numeric variables * Education, EnivironmentSatisfaction, JobInvolvement, JobSatisfaction, PerformanceRating, RelationshipSatisfaction, WorkLifeBalance * Zero variance variables * Over18, EmployCount, StandardHours * Character variables: Convert them to numbers in the recipes steps * Unbalanced target variable: Attrition * ID variable: EmployeeNumber
factors_vec <- data %>% select(Education, EnvironmentSatisfaction, JobInvolvement, JobSatisfaction, PerformanceRating, RelationshipSatisfaction, WorkLifeBalance, JobLevel, StockOptionLevel) %>% names()
data_clean <- data %>%
# Address factors imported as numeric
mutate(across(all_of(factors_vec), as.factor)) %>%
# Drop zero-variance variables
select(-c(Over18, EmployeeCount, StandardHours)) %>%
# Recode Attrition
mutate(Attrition = if_else(Attrition == "Yes", "Left", Attrition))
data_clean %>% count(Attrition)
## # A tibble: 2 × 2
## Attrition n
## <chr> <int>
## 1 Left 237
## 2 No 1233
data_clean %>%
ggplot(aes(Attrition)) +
geom_bar()
attrition vs. monthly income
data_clean %>%
ggplot(aes(Attrition, MonthlyIncome)) +
geom_boxplot()
correlation plot
# Step 1: Binarize
data_binarized <- data_clean %>%
select(-EmployeeNumber) %>%
binarize()
data_binarized %>% glimpse()
## Rows: 1,470
## Columns: 120
## $ `Age__-Inf_30` <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Age__30_36 <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, …
## $ Age__36_43 <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ Age__43_Inf <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ Attrition__Left <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ Attrition__No <dbl> 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `BusinessTravel__Non-Travel` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ BusinessTravel__Travel_Frequently <dbl> 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ BusinessTravel__Travel_Rarely <dbl> 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, …
## $ `DailyRate__-Inf_465` <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ DailyRate__465_802 <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ DailyRate__802_1157 <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, …
## $ DailyRate__1157_Inf <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, …
## $ Department__Human_Resources <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `Department__Research_&_Development` <dbl> 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ Department__Sales <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `DistanceFromHome__-Inf_2` <dbl> 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__2_7 <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ DistanceFromHome__7_14 <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__14_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, …
## $ Education__1 <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Education__2 <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ Education__3 <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, …
## $ Education__4 <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Education__5 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Human_Resources <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Life_Sciences <dbl> 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ EducationField__Marketing <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Medical <dbl> 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, …
## $ EducationField__Other <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Technical_Degree <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__1 <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ EnvironmentSatisfaction__2 <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__3 <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, …
## $ EnvironmentSatisfaction__4 <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ Gender__Female <dbl> 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ Gender__Male <dbl> 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ `HourlyRate__-Inf_48` <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, …
## $ HourlyRate__48_66 <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ HourlyRate__66_83.75 <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ HourlyRate__83.75_Inf <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ JobInvolvement__1 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobInvolvement__2 <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobInvolvement__3 <dbl> 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, …
## $ JobInvolvement__4 <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, …
## $ JobLevel__1 <dbl> 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, …
## $ JobLevel__2 <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobLevel__3 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobLevel__4 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobLevel__5 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Healthcare_Representative <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobRole__Human_Resources <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Laboratory_Technician <dbl> 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, …
## $ JobRole__Manager <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Manufacturing_Director <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobRole__Research_Director <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Research_Scientist <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Executive <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Representative <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobSatisfaction__1 <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ JobSatisfaction__2 <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ JobSatisfaction__3 <dbl> 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, …
## $ JobSatisfaction__4 <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ MaritalStatus__Divorced <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ MaritalStatus__Married <dbl> 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, …
## $ MaritalStatus__Single <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, …
## $ `MonthlyIncome__-Inf_2911` <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, …
## $ MonthlyIncome__2911_4919 <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ MonthlyIncome__4919_8379 <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ MonthlyIncome__8379_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ `MonthlyRate__-Inf_8047` <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ MonthlyRate__8047_14235.5 <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, …
## $ MonthlyRate__14235.5_20461.5 <dbl> 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, …
## $ MonthlyRate__20461.5_Inf <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ `NumCompaniesWorked__-Inf_1` <dbl> 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, …
## $ NumCompaniesWorked__1_2 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ NumCompaniesWorked__2_4 <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ NumCompaniesWorked__4_Inf <dbl> 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, …
## $ OverTime__No <dbl> 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ OverTime__Yes <dbl> 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ `PercentSalaryHike__-Inf_12` <dbl> 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__12_14 <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, …
## $ PercentSalaryHike__14_18 <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__18_Inf <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ PerformanceRating__3 <dbl> 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, …
## $ PerformanceRating__4 <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ RelationshipSatisfaction__1 <dbl> 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ RelationshipSatisfaction__2 <dbl> 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, …
## $ RelationshipSatisfaction__3 <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, …
## $ RelationshipSatisfaction__4 <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ StockOptionLevel__0 <dbl> 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ StockOptionLevel__1 <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ StockOptionLevel__2 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ StockOptionLevel__3 <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ `TotalWorkingYears__-Inf_6` <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ TotalWorkingYears__6_10 <dbl> 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ TotalWorkingYears__10_15 <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ TotalWorkingYears__15_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `TrainingTimesLastYear__-Inf_2` <dbl> 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, …
## $ TrainingTimesLastYear__2_3 <dbl> 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, …
## $ TrainingTimesLastYear__3_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ WorkLifeBalance__1 <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ WorkLifeBalance__2 <dbl> 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, …
## $ WorkLifeBalance__3 <dbl> 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, …
## $ WorkLifeBalance__4 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsAtCompany__-Inf_3` <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsAtCompany__3_5 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsAtCompany__5_9 <dbl> 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, …
## $ YearsAtCompany__9_Inf <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsInCurrentRole__-Inf_2` <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsInCurrentRole__2_3 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ YearsInCurrentRole__3_7 <dbl> 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, …
## $ YearsInCurrentRole__7_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsSinceLastPromotion__-Inf_1` <dbl> 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, …
## $ YearsSinceLastPromotion__1_3 <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ YearsSinceLastPromotion__3_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `YearsWithCurrManager__-Inf_2` <dbl> 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsWithCurrManager__2_3 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsWithCurrManager__3_7 <dbl> 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, …
## $ YearsWithCurrManager__7_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
# Step 2: Correlation
data_correlation <- data_binarized %>%
correlate(Attrition__Left)
data_correlation
## # A tibble: 120 × 3
## feature bin correlation
## <fct> <chr> <dbl>
## 1 Attrition Left 1
## 2 Attrition No -1
## 3 OverTime Yes 0.246
## 4 OverTime No -0.246
## 5 JobLevel 1 0.213
## 6 MonthlyIncome -Inf_2911 0.207
## 7 StockOptionLevel 0 0.195
## 8 YearsAtCompany -Inf_3 0.183
## 9 MaritalStatus Single 0.175
## 10 TotalWorkingYears -Inf_6 0.169
## # ℹ 110 more rows
# Step 3: Plot
data_correlation %>%
correlationfunnel::plot_correlation_funnel()
## Warning: ggrepel: 72 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom 1.0.7 ✔ rsample 1.2.1
## ✔ dials 1.4.0 ✔ tune 1.2.1
## ✔ infer 1.0.7 ✔ workflows 1.1.4
## ✔ modeldata 1.4.0 ✔ workflowsets 1.1.0
## ✔ parsnip 1.3.0 ✔ yardstick 1.3.2
## ✔ recipes 1.1.0
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Use suppressPackageStartupMessages() to eliminate package startup messages
set.seed(1234)
# data_clean <- data_clean %>% sample_n(200)
data_split <- initial_split(data_clean, strata = Attrition)
data_train <- training(data_split)
data_test <- testing(data_split)
data_cv <- rsample::vfold_cv(data_train, strata = Attrition)
data_cv
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [990/111]> Fold01
## 2 <split [990/111]> Fold02
## 3 <split [990/111]> Fold03
## 4 <split [990/111]> Fold04
## 5 <split [991/110]> Fold05
## 6 <split [991/110]> Fold06
## 7 <split [991/110]> Fold07
## 8 <split [992/109]> Fold08
## 9 <split [992/109]> Fold09
## 10 <split [992/109]> Fold10
library(themis)
xgboost_rec <- recipes::recipe(Attrition ~ ., data = data_train) %>%
update_role(EmployeeNumber, new_role = "ID") %>%
step_dummy(all_nominal_predictors()) %>%
step_normalize(all_numeric_predictors()) %>%
step_pca(all_numeric_predictors(), threshold = .99) %>%
step_smote(Attrition)
xgboost_rec %>% prep() %>% juice() %>% glimpse()
## Rows: 1,848
## Columns: 55
## $ EmployeeNumber <dbl> 4, 27, 33, 45, 47, 55, 58, 64, 90, 133, 137, 142, 147, …
## $ Attrition <fct> Left, Left, Left, Left, Left, Left, Left, Left, Left, L…
## $ PC01 <dbl> -2.4866153, -1.2706453, -0.7324540, -2.3254766, -1.3338…
## $ PC02 <dbl> -1.5740328, 1.6147295, -1.2705724, -1.4395422, 3.001711…
## $ PC03 <dbl> -0.4071661, -0.2353137, -2.5874581, 1.8732588, 2.134789…
## $ PC04 <dbl> -2.1621159, -1.8105866, 1.3721663, 0.9392024, -1.396647…
## $ PC05 <dbl> 0.45861920, -2.77363603, -3.25156361, -0.07203097, 0.26…
## $ PC06 <dbl> 1.615563727, 2.253536625, 1.083269298, -2.292998738, 2.…
## $ PC07 <dbl> 1.4883213, 2.2164428, 0.3940496, 1.0582713, 0.3149397, …
## $ PC08 <dbl> -2.28935947, -0.65896251, 0.03078843, 2.04345583, -0.79…
## $ PC09 <dbl> -0.06772915, 0.29111306, 0.99191980, -0.32085252, -0.78…
## $ PC10 <dbl> 0.29335993, -1.11306746, -0.47440345, 0.53996616, 0.391…
## $ PC11 <dbl> 0.1659217504, 0.5061105005, 1.9420598055, 1.1025003461,…
## $ PC12 <dbl> -4.675975e-01, 4.111882e-01, 9.276777e-02, -6.238264e-0…
## $ PC13 <dbl> 2.85396775, -1.07113382, -0.93685907, 1.18721592, 0.568…
## $ PC14 <dbl> 1.38057389, 2.51085129, -0.30440657, -1.42427939, 0.993…
## $ PC15 <dbl> -0.81419960, 0.12003963, -0.69553953, 1.16236038, -0.60…
## $ PC16 <dbl> -0.9618830, 1.3746140, 1.9681632, 0.7935656, 0.3307197,…
## $ PC17 <dbl> -0.44532481, -1.23972263, -1.02729524, 0.67439574, 1.70…
## $ PC18 <dbl> 0.380636453, -0.449154973, -0.330690090, 0.018486927, 1…
## $ PC19 <dbl> -0.07432895, -0.02816621, 1.55639820, 1.18698194, 0.337…
## $ PC20 <dbl> 0.190055041, -0.943213996, -0.710395599, -0.068611621, …
## $ PC21 <dbl> 0.93576699, -0.87120310, 0.46111288, 1.09552009, -1.274…
## $ PC22 <dbl> 0.7723501, 1.7592785, 0.4209123, -0.1822564, -0.9687064…
## $ PC23 <dbl> 0.59058788, -0.88406532, 1.19129021, 2.03861336, -0.793…
## $ PC24 <dbl> 0.34406058, 0.35199172, 1.17268274, 0.97714776, 1.60921…
## $ PC25 <dbl> -2.840061341, 0.081000861, -0.558817300, 0.739376717, -…
## $ PC26 <dbl> 0.00670827, -1.13119414, 0.55055528, -0.45185739, -0.26…
## $ PC27 <dbl> -2.587450157, 0.562290099, 0.551239258, 0.061306846, -1…
## $ PC28 <dbl> -1.007635212, 0.495167051, -0.840768378, -0.178329865, …
## $ PC29 <dbl> 0.25019205, 0.28404406, 0.36937699, 0.08934139, 0.29380…
## $ PC30 <dbl> 0.29512988, 1.17011590, 0.04668851, -0.47334117, -0.732…
## $ PC31 <dbl> 0.4689843, -1.9467875, 1.2227697, -1.1700914, -2.281479…
## $ PC32 <dbl> 2.737601575, 0.654485660, 0.583235651, -0.383841872, 0.…
## $ PC33 <dbl> -0.18667937, -0.72183346, -0.84744782, 0.02174618, -1.6…
## $ PC34 <dbl> -1.98922408, -0.61994292, -0.54052521, -0.69205872, -1.…
## $ PC35 <dbl> -0.23337720, 0.30718959, 0.17335990, 0.01494822, 0.3464…
## $ PC36 <dbl> 0.679982603, 0.811624734, -1.000920714, 0.278472471, 0.…
## $ PC37 <dbl> 0.65511540, 1.05547283, -0.46401625, -0.39019346, 0.279…
## $ PC38 <dbl> -0.56693010, 1.16820836, -0.81323056, -1.05320897, 0.09…
## $ PC39 <dbl> -1.424101065, -2.485640251, -1.698241592, -0.106716279,…
## $ PC40 <dbl> 0.26031526, -0.32560543, 0.35430932, -0.33156991, 2.329…
## $ PC41 <dbl> -0.12539323, -0.22236295, -0.55833599, -0.40448086, 1.3…
## $ PC42 <dbl> 0.43903304, -0.20509255, 0.82174524, 0.44922386, 0.6396…
## $ PC43 <dbl> 0.13223427, -1.08633953, 0.74411180, 0.25789769, 0.1192…
## $ PC44 <dbl> 0.29528009, 0.27204883, 0.10531949, -0.03792183, -0.316…
## $ PC45 <dbl> -0.09218376, 0.08453865, 0.94289152, 0.31490531, 0.3020…
## $ PC46 <dbl> -0.216241476, 1.157640429, 1.101114969, -0.568406306, -…
## $ PC47 <dbl> 0.27488823, 0.21013194, 0.55499237, -0.21569969, -0.667…
## $ PC48 <dbl> -0.076056764, 0.159877279, 0.178370860, 0.957661221, -1…
## $ PC49 <dbl> 0.040037002, 0.020037544, 0.366118639, -0.274663728, -0…
## $ PC50 <dbl> -0.189295083, 0.281183292, 0.363495234, 0.272166154, -0…
## $ PC51 <dbl> -0.12195720, 0.30732072, 0.14659593, -0.40381741, 0.144…
## $ PC52 <dbl> 0.01363750, -0.11063079, 1.20800782, -0.41780709, 0.295…
## $ PC53 <dbl> -0.21278128, -0.20296817, 0.07684094, -0.10573568, -0.2…
xgboost_spec <-
boost_tree(trees = tune(), tree_depth = tune()) %>%
set_mode("classification") %>%
set_engine("xgboost")
xgboost_workflow <-
workflow() %>%
add_recipe(xgboost_rec) %>%
add_model(xgboost_spec)
tree_grid <- grid_regular(trees(),
tree_depth(),
levels = 5)
doParallel::registerDoParallel()
set.seed(65743)
xgboost_tune <-
tune_grid(xgboost_workflow,
resamples = data_cv,
grid = 5,
control = control_grid(save_pred = TRUE))
collect_metrics(xgboost_tune)
## # A tibble: 15 × 8
## trees tree_depth .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1741 3 accuracy binary 0.855 10 0.0123 Preprocessor1_Mo…
## 2 1741 3 brier_class binary 0.120 10 0.0107 Preprocessor1_Mo…
## 3 1741 3 roc_auc binary 0.819 10 0.0218 Preprocessor1_Mo…
## 4 885 5 accuracy binary 0.862 10 0.0111 Preprocessor1_Mo…
## 5 885 5 brier_class binary 0.120 10 0.0104 Preprocessor1_Mo…
## 6 885 5 roc_auc binary 0.820 10 0.0198 Preprocessor1_Mo…
## 7 325 7 accuracy binary 0.846 10 0.0123 Preprocessor1_Mo…
## 8 325 7 brier_class binary 0.122 10 0.0106 Preprocessor1_Mo…
## 9 325 7 roc_auc binary 0.806 10 0.0233 Preprocessor1_Mo…
## 10 1312 12 accuracy binary 0.846 10 0.00978 Preprocessor1_Mo…
## 11 1312 12 brier_class binary 0.117 10 0.00803 Preprocessor1_Mo…
## 12 1312 12 roc_auc binary 0.819 10 0.0186 Preprocessor1_Mo…
## 13 555 15 accuracy binary 0.846 10 0.0104 Preprocessor1_Mo…
## 14 555 15 brier_class binary 0.116 10 0.00945 Preprocessor1_Mo…
## 15 555 15 roc_auc binary 0.821 10 0.0198 Preprocessor1_Mo…
collect_predictions(xgboost_tune) %>%
group_by(id) %>%
roc_curve(Attrition, .pred_Left) %>%
autoplot()
xgboost_last <- xgboost_workflow %>%
finalize_workflow(select_best(xgboost_tune, metric = "accuracy")) %>%
last_fit(data_split)
collect_metrics(xgboost_last)
## # A tibble: 3 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.854 Preprocessor1_Model1
## 2 roc_auc binary 0.788 Preprocessor1_Model1
## 3 brier_class binary 0.121 Preprocessor1_Model1
collect_predictions(xgboost_last) %>%
yardstick::conf_mat(Attrition, .pred_class) %>%
autoplot()
library(vip)
##
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
##
## vi
xgboost_last %>%
workflows::extract_fit_engine() %>%
vip()
The previous model had accuracy of 0.851 and AUC of 0.753.