data <- read_csv("../PSU_DAT3000_IntroToDA/00_data/WA_Fn-UseC_-HR-Employee-Attrition.csv")
## Rows: 1470 Columns: 35
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (9): Attrition, BusinessTravel, Department, EducationField, Gender, Job...
## dbl (26): Age, DailyRate, DistanceFromHome, Education, EmployeeCount, Employ...
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
##Explore data
skimr::skim(data)
| Name | data |
| Number of rows | 1470 |
| Number of columns | 35 |
| _______________________ | |
| Column type frequency: | |
| character | 9 |
| numeric | 26 |
| ________________________ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| Attrition | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
| BusinessTravel | 0 | 1 | 10 | 17 | 0 | 3 | 0 |
| Department | 0 | 1 | 5 | 22 | 0 | 3 | 0 |
| EducationField | 0 | 1 | 5 | 16 | 0 | 6 | 0 |
| Gender | 0 | 1 | 4 | 6 | 0 | 2 | 0 |
| JobRole | 0 | 1 | 7 | 25 | 0 | 9 | 0 |
| MaritalStatus | 0 | 1 | 6 | 8 | 0 | 3 | 0 |
| Over18 | 0 | 1 | 1 | 1 | 0 | 1 | 0 |
| OverTime | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| Age | 0 | 1 | 36.92 | 9.14 | 18 | 30.00 | 36.0 | 43.00 | 60 | ▂▇▇▃▂ |
| DailyRate | 0 | 1 | 802.49 | 403.51 | 102 | 465.00 | 802.0 | 1157.00 | 1499 | ▇▇▇▇▇ |
| DistanceFromHome | 0 | 1 | 9.19 | 8.11 | 1 | 2.00 | 7.0 | 14.00 | 29 | ▇▅▂▂▂ |
| Education | 0 | 1 | 2.91 | 1.02 | 1 | 2.00 | 3.0 | 4.00 | 5 | ▂▃▇▆▁ |
| EmployeeCount | 0 | 1 | 1.00 | 0.00 | 1 | 1.00 | 1.0 | 1.00 | 1 | ▁▁▇▁▁ |
| EmployeeNumber | 0 | 1 | 1024.87 | 602.02 | 1 | 491.25 | 1020.5 | 1555.75 | 2068 | ▇▇▇▇▇ |
| EnvironmentSatisfaction | 0 | 1 | 2.72 | 1.09 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
| HourlyRate | 0 | 1 | 65.89 | 20.33 | 30 | 48.00 | 66.0 | 83.75 | 100 | ▇▇▇▇▇ |
| JobInvolvement | 0 | 1 | 2.73 | 0.71 | 1 | 2.00 | 3.0 | 3.00 | 4 | ▁▃▁▇▁ |
| JobLevel | 0 | 1 | 2.06 | 1.11 | 1 | 1.00 | 2.0 | 3.00 | 5 | ▇▇▃▂▁ |
| JobSatisfaction | 0 | 1 | 2.73 | 1.10 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
| MonthlyIncome | 0 | 1 | 6502.93 | 4707.96 | 1009 | 2911.00 | 4919.0 | 8379.00 | 19999 | ▇▅▂▁▂ |
| MonthlyRate | 0 | 1 | 14313.10 | 7117.79 | 2094 | 8047.00 | 14235.5 | 20461.50 | 26999 | ▇▇▇▇▇ |
| NumCompaniesWorked | 0 | 1 | 2.69 | 2.50 | 0 | 1.00 | 2.0 | 4.00 | 9 | ▇▃▂▂▁ |
| PercentSalaryHike | 0 | 1 | 15.21 | 3.66 | 11 | 12.00 | 14.0 | 18.00 | 25 | ▇▅▃▂▁ |
| PerformanceRating | 0 | 1 | 3.15 | 0.36 | 3 | 3.00 | 3.0 | 3.00 | 4 | ▇▁▁▁▂ |
| RelationshipSatisfaction | 0 | 1 | 2.71 | 1.08 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
| StandardHours | 0 | 1 | 80.00 | 0.00 | 80 | 80.00 | 80.0 | 80.00 | 80 | ▁▁▇▁▁ |
| StockOptionLevel | 0 | 1 | 0.79 | 0.85 | 0 | 0.00 | 1.0 | 1.00 | 3 | ▇▇▁▂▁ |
| TotalWorkingYears | 0 | 1 | 11.28 | 7.78 | 0 | 6.00 | 10.0 | 15.00 | 40 | ▇▇▂▁▁ |
| TrainingTimesLastYear | 0 | 1 | 2.80 | 1.29 | 0 | 2.00 | 3.0 | 3.00 | 6 | ▂▇▇▂▃ |
| WorkLifeBalance | 0 | 1 | 2.76 | 0.71 | 1 | 2.00 | 3.0 | 3.00 | 4 | ▁▃▁▇▂ |
| YearsAtCompany | 0 | 1 | 7.01 | 6.13 | 0 | 3.00 | 5.0 | 9.00 | 40 | ▇▂▁▁▁ |
| YearsInCurrentRole | 0 | 1 | 4.23 | 3.62 | 0 | 2.00 | 3.0 | 7.00 | 18 | ▇▃▂▁▁ |
| YearsSinceLastPromotion | 0 | 1 | 2.19 | 3.22 | 0 | 0.00 | 1.0 | 3.00 | 15 | ▇▁▁▁▁ |
| YearsWithCurrManager | 0 | 1 | 4.12 | 3.57 | 0 | 2.00 | 3.0 | 7.00 | 17 | ▇▂▅▁▁ |
# Create a vector of column names to convert to factors
factors_vec <- data %>%
select(Education, EnvironmentSatisfaction, JobInvolvement, JobSatisfaction,
PerformanceRating, RelationshipSatisfaction, WorkLifeBalance) %>%
names()
# Clean the data
data_clean <- data %>%
# Address factors imported as numeric
mutate(across(all_of(factors_vec), as.factor)) %>%
# Drop only the specific zero-variance variables
# Use a simple vector with a single minus sign
select(-Over18, -EmployeeCount, -StandardHours) %>%
mutate(Attrition = if_else(Attrition == "Yes", "Left", Attrition))
#explore data
# Explore data
data_clean %>% count(Attrition)
## # A tibble: 2 × 2
## Attrition n
## <chr> <int>
## 1 Left 237
## 2 No 1233
data_clean %>%
ggplot(aes(Attrition)) +
geom_bar()
attrition vs monthly income
# Visualize attrition across income levels
data_clean %>%
ggplot(aes(Attrition, MonthlyIncome)) +
geom_boxplot()
correlation plot
# Step 1: Binarize the data (dropping the ID variable first)
data_binarized <- data_clean %>%
select(-EmployeeNumber) %>%
binarize()
data_binarized %>% glimpse()
## Rows: 1,470
## Columns: 120
## $ `Age__-Inf_30` <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Age__30_36 <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, …
## $ Age__36_43 <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ Age__43_Inf <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ Attrition__Left <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ Attrition__No <dbl> 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ `BusinessTravel__Non-Travel` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ BusinessTravel__Travel_Frequently <dbl> 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ BusinessTravel__Travel_Rarely <dbl> 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, …
## $ `DailyRate__-Inf_465` <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ DailyRate__465_802 <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ DailyRate__802_1157 <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, …
## $ DailyRate__1157_Inf <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, …
## $ Department__Human_Resources <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `Department__Research_&_Development` <dbl> 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ Department__Sales <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `DistanceFromHome__-Inf_2` <dbl> 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__2_7 <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ DistanceFromHome__7_14 <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ DistanceFromHome__14_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, …
## $ Education__1 <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ Education__2 <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ Education__3 <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, …
## $ Education__4 <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Education__5 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Human_Resources <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Life_Sciences <dbl> 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ EducationField__Marketing <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Medical <dbl> 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, …
## $ EducationField__Other <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EducationField__Technical_Degree <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__1 <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ EnvironmentSatisfaction__2 <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ EnvironmentSatisfaction__3 <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, …
## $ EnvironmentSatisfaction__4 <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, …
## $ Gender__Female <dbl> 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ Gender__Male <dbl> 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ `HourlyRate__-Inf_48` <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, …
## $ HourlyRate__48_66 <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ HourlyRate__66_83.75 <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, …
## $ HourlyRate__83.75_Inf <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ JobInvolvement__1 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobInvolvement__2 <dbl> 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobInvolvement__3 <dbl> 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, …
## $ JobInvolvement__4 <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, …
## $ JobLevel__1 <dbl> 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, …
## $ JobLevel__2 <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobLevel__3 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobLevel__4 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobLevel__5 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Healthcare_Representative <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ JobRole__Human_Resources <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Laboratory_Technician <dbl> 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, …
## $ JobRole__Manager <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Manufacturing_Director <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ JobRole__Research_Director <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Research_Scientist <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Executive <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobRole__Sales_Representative <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ JobSatisfaction__1 <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ JobSatisfaction__2 <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, …
## $ JobSatisfaction__3 <dbl> 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, …
## $ JobSatisfaction__4 <dbl> 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, …
## $ MaritalStatus__Divorced <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, …
## $ MaritalStatus__Married <dbl> 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, …
## $ MaritalStatus__Single <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, …
## $ `MonthlyIncome__-Inf_2911` <dbl> 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, …
## $ MonthlyIncome__2911_4919 <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, …
## $ MonthlyIncome__4919_8379 <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ MonthlyIncome__8379_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ `MonthlyRate__-Inf_8047` <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ MonthlyRate__8047_14235.5 <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, …
## $ MonthlyRate__14235.5_20461.5 <dbl> 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, …
## $ MonthlyRate__20461.5_Inf <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ `NumCompaniesWorked__-Inf_1` <dbl> 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, …
## $ NumCompaniesWorked__1_2 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ NumCompaniesWorked__2_4 <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ NumCompaniesWorked__4_Inf <dbl> 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, …
## $ OverTime__No <dbl> 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, …
## $ OverTime__Yes <dbl> 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, …
## $ `PercentSalaryHike__-Inf_12` <dbl> 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__12_14 <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, …
## $ PercentSalaryHike__14_18 <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ PercentSalaryHike__18_Inf <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ PerformanceRating__3 <dbl> 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, …
## $ PerformanceRating__4 <dbl> 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, …
## $ RelationshipSatisfaction__1 <dbl> 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ RelationshipSatisfaction__2 <dbl> 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, …
## $ RelationshipSatisfaction__3 <dbl> 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, …
## $ RelationshipSatisfaction__4 <dbl> 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ StockOptionLevel__0 <dbl> 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ StockOptionLevel__1 <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ StockOptionLevel__2 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ StockOptionLevel__3 <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ `TotalWorkingYears__-Inf_6` <dbl> 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, …
## $ TotalWorkingYears__6_10 <dbl> 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, …
## $ TotalWorkingYears__10_15 <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ TotalWorkingYears__15_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `TrainingTimesLastYear__-Inf_2` <dbl> 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, …
## $ TrainingTimesLastYear__2_3 <dbl> 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, …
## $ TrainingTimesLastYear__3_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ WorkLifeBalance__1 <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ WorkLifeBalance__2 <dbl> 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, …
## $ WorkLifeBalance__3 <dbl> 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, …
## $ WorkLifeBalance__4 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsAtCompany__-Inf_3` <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsAtCompany__3_5 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsAtCompany__5_9 <dbl> 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, …
## $ YearsAtCompany__9_Inf <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsInCurrentRole__-Inf_2` <dbl> 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsInCurrentRole__2_3 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ YearsInCurrentRole__3_7 <dbl> 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, …
## $ YearsInCurrentRole__7_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ `YearsSinceLastPromotion__-Inf_1` <dbl> 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, …
## $ YearsSinceLastPromotion__1_3 <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, …
## $ YearsSinceLastPromotion__3_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ `YearsWithCurrManager__-Inf_2` <dbl> 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, …
## $ YearsWithCurrManager__2_3 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, …
## $ YearsWithCurrManager__3_7 <dbl> 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, …
## $ YearsWithCurrManager__7_Inf <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
# Step 2: Correlation analysis against the target variable
data_correlation <- data_binarized %>%
correlate(target = Attrition__Left)
data_correlation
## # A tibble: 120 × 3
## feature bin correlation
## <fct> <chr> <dbl>
## 1 Attrition Left 1
## 2 Attrition No -1
## 3 OverTime No -0.246
## 4 OverTime Yes 0.246
## 5 JobLevel 1 0.213
## 6 MonthlyIncome -Inf_2911 0.207
## 7 StockOptionLevel 0 0.195
## 8 YearsAtCompany -Inf_3 0.183
## 9 MaritalStatus Single 0.175
## 10 TotalWorkingYears -Inf_6 0.169
## # ℹ 110 more rows
#step 3:plot
data_correlation %>%
correlationfunnel::plot_correlation_funnel()
## Warning: The `size` argument of `element_line()` is deprecated as of ggplot2 3.4.0.
## ℹ Please use the `linewidth` argument instead.
## ℹ The deprecated feature was likely used in the correlationfunnel package.
## Please report the issue at
## <https://github.com/business-science/correlationfunnel/issues>.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Warning: The `size` argument of `element_rect()` is deprecated as of ggplot2 3.4.0.
## ℹ Please use the `linewidth` argument instead.
## ℹ The deprecated feature was likely used in the correlationfunnel package.
## Please report the issue at
## <https://github.com/business-science/correlationfunnel/issues>.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Warning: ggrepel: 73 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps
#split data
set.seed(1234)
## data_clean <- data_clean %>% sample_n(500)
data_split <- initial_split(data_clean, strata = Attrition)
data_train <- training(data_split)
data_test <- testing(data_split)
data_cv <- rsample::vfold_cv(data_train, strata = Attrition)
data_cv
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [990/111]> Fold01
## 2 <split [990/111]> Fold02
## 3 <split [990/111]> Fold03
## 4 <split [990/111]> Fold04
## 5 <split [991/110]> Fold05
## 6 <split [991/110]> Fold06
## 7 <split [991/110]> Fold07
## 8 <split [992/109]> Fold08
## 9 <split [992/109]> Fold09
## 10 <split [992/109]> Fold10
Preprocess data
xgboost_rec <- recipes::recipe(Attrition ~ ., data = data_train) %>%
update_role(EmployeeNumber, new_role = "ID") %>%
step_dummy(all_nominal_predictors()) %>%
step_normalize(all_numeric_predictors()) %>% # normalization step
step_pca(all_numeric_predictors(), threshold = .99) %>%
step_smote(Attrition)
xgboost_rec %>% prep() %>% juice() %>% glimpse()
## Rows: 1,848
## Columns: 51
## $ EmployeeNumber <dbl> 4, 27, 33, 45, 47, 55, 58, 64, 90, 133, 137, 142, 147, …
## $ Attrition <fct> Left, Left, Left, Left, Left, Left, Left, Left, Left, L…
## $ PC01 <dbl> -2.7008796, -1.3851016, -0.9898314, -2.7123949, -1.4082…
## $ PC02 <dbl> -1.1582242, 2.0758272, -1.1769684, -1.0635118, 3.547456…
## $ PC03 <dbl> -1.4071746, -1.5566383, -2.1704958, 1.5138776, 1.230319…
## $ PC04 <dbl> -1.8426942, -0.2189249, 3.3214414, 0.1627416, -0.940169…
## $ PC05 <dbl> -0.25210132, 2.70012402, 2.40029506, 0.16116327, -0.034…
## $ PC06 <dbl> 2.59254625, 2.81147695, 0.70726636, -2.44417469, 2.2332…
## $ PC07 <dbl> -0.55593527, -1.24958628, 0.01531976, -2.16787791, 0.21…
## $ PC08 <dbl> -0.89240997, -0.07652722, 0.56130410, 0.58312834, -0.90…
## $ PC09 <dbl> 1.19431200, 1.00186769, 1.07540726, -0.81843662, -0.096…
## $ PC10 <dbl> -2.08259705, 0.56337423, -0.15568277, -0.81644957, -0.7…
## $ PC11 <dbl> -0.3934252, 0.5379347, 1.1038419, 0.3179483, -0.8389709…
## $ PC12 <dbl> 0.6217673, -2.5481178, -0.7844686, 1.3383124, -0.806984…
## $ PC13 <dbl> -1.40210098, -1.04469438, 2.16121634, 0.96844715, 0.238…
## $ PC14 <dbl> 1.786546612, 0.947038464, 0.728867736, -1.203721858, 0.…
## $ PC15 <dbl> 1.50900812, -0.82277088, -1.29183287, -0.65406491, 0.60…
## $ PC16 <dbl> -0.19437803, -0.91896213, -1.64225604, -0.35247632, 1.4…
## $ PC17 <dbl> -0.13999669, 0.57010089, 0.94631147, 0.60202385, -0.970…
## $ PC18 <dbl> -0.74594330, 0.87513463, 0.19960168, -1.13921071, -0.21…
## $ PC19 <dbl> -0.1848677, -0.9562717, 0.4086475, 1.2973671, 2.3679767…
## $ PC20 <dbl> -0.55339508, -1.35106209, -0.01057609, -0.69481221, -1.…
## $ PC21 <dbl> -0.13065253, 1.45537716, 0.22364100, -0.18896653, 0.134…
## $ PC22 <dbl> 0.79779388, -0.96024363, -0.91304176, -1.98595539, -0.1…
## $ PC23 <dbl> 3.22852202, 0.23213193, 0.76844435, -0.19681532, 1.3233…
## $ PC24 <dbl> 2.25403847, -0.73294485, -0.28739630, 0.70231139, -0.27…
## $ PC25 <dbl> 1.12730949, -1.30313515, 0.92367530, 0.64488095, 0.4799…
## $ PC26 <dbl> 0.77564433, -0.38089595, -0.65786318, 0.50247084, 0.599…
## $ PC27 <dbl> -1.221687488, 0.874353237, -1.016860596, -0.144609586, …
## $ PC28 <dbl> -2.7410678, -0.5790063, -0.9980385, -0.4497550, 0.20913…
## $ PC29 <dbl> 0.2509422, 2.2249652, -0.8511565, 0.5617327, 2.2427028,…
## $ PC30 <dbl> -0.5160910, -0.6846793, -0.3356748, -1.2007427, -2.5159…
## $ PC31 <dbl> -0.786023708, -0.440781389, -0.610681242, 0.353466990, …
## $ PC32 <dbl> 1.3231294983, 0.6065014272, 0.7773408307, -0.5604557323…
## $ PC33 <dbl> 0.5301238, 1.6180841, -1.2078405, -1.0254676, 0.4020246…
## $ PC34 <dbl> 0.81150211, 0.33071364, -0.64331436, 0.41004008, 0.1788…
## $ PC35 <dbl> -1.40474186, -2.42122089, -1.68214959, -0.01447111, -1.…
## $ PC36 <dbl> -0.078536567, -0.123303251, 0.035522194, 0.524214319, 1…
## $ PC37 <dbl> -0.16825126, 0.33602570, -0.39537080, 0.30034723, -2.44…
## $ PC38 <dbl> -0.22667694, -0.24944872, -0.80922279, -0.59753526, 0.5…
## $ PC39 <dbl> 0.18055939, -1.14697557, 0.86855130, 0.33490275, 0.4211…
## $ PC40 <dbl> 0.04497181, 0.06581771, -0.82818604, -0.37140836, -0.30…
## $ PC41 <dbl> 0.34111038, -1.10608772, -1.08420254, 0.45253468, 0.106…
## $ PC42 <dbl> 0.194363520, 0.141264439, 0.405442315, -0.143984092, -0…
## $ PC43 <dbl> 0.16081450, -0.01588223, 0.13145777, -0.97414364, 0.766…
## $ PC44 <dbl> -0.06758401, -0.19587807, -0.62324861, -0.01742433, 0.9…
## $ PC45 <dbl> -0.183378946, 0.232482848, 0.236647628, 0.396652061, -0…
## $ PC46 <dbl> 0.22878570, -0.15130989, -0.03590376, 0.35287800, -0.22…
## $ PC47 <dbl> 0.118950505, 0.388201362, 0.246572558, -0.202709146, -0…
## $ PC48 <dbl> -0.31263005, -0.21284450, -0.99732266, -0.06094232, 0.1…
## $ PC49 <dbl> -0.114339812, -0.200857771, 0.954179455, -0.510904679, …
#Specify model
xgboost_spec <-
boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(),
loss_reduction = tune(), sample_size = tune()) %>%
set_mode("classification") %>%
set_engine("xgboost")
xgboost_workflow <-
workflow() %>%
add_recipe(xgboost_rec) %>%
add_model(xgboost_spec)
##tune hyperparameters
tree_grid <- grid_regular(trees(),
tree_depth(),
levels = 5)
doParallel::registerDoParallel()
set.seed(65743)
xgboost_tune <-
tune_grid(xgboost_workflow,
resamples = data_cv,
grid = 5,
control = control_grid(save_pred = TRUE))
#Model evaluation
collect_metrics(xgboost_tune)
## # A tibble: 15 × 12
## trees min_n tree_depth learn_rate loss_reduction sample_size .metric
## <int> <int> <int> <dbl> <dbl> <dbl> <chr>
## 1 1 30 15 0.0750 0.0422 0.625 accuracy
## 2 1 30 15 0.0750 0.0422 0.625 brier_class
## 3 1 30 15 0.0750 0.0422 0.625 roc_auc
## 4 500 21 1 0.316 0.0000562 1 accuracy
## 5 500 21 1 0.316 0.0000562 1 brier_class
## 6 500 21 1 0.316 0.0000562 1 roc_auc
## 7 1000 2 8 0.0178 0.0000000001 0.5 accuracy
## 8 1000 2 8 0.0178 0.0000000001 0.5 brier_class
## 9 1000 2 8 0.0178 0.0000000001 0.5 roc_auc
## 10 1500 11 4 0.001 31.6 0.75 accuracy
## 11 1500 11 4 0.001 31.6 0.75 brier_class
## 12 1500 11 4 0.001 31.6 0.75 roc_auc
## 13 2000 40 11 0.00422 0.0000000750 0.875 accuracy
## 14 2000 40 11 0.00422 0.0000000750 0.875 brier_class
## 15 2000 40 11 0.00422 0.0000000750 0.875 roc_auc
## # ℹ 5 more variables: .estimator <chr>, mean <dbl>, n <int>, std_err <dbl>,
## # .config <chr>
collect_predictions(xgboost_tune) %>%
group_by(id) %>%
roc_curve(Attrition, .pred_Left) %>%
autoplot()
#Fit model for last time
xgboost_last <- xgboost_workflow %>%
finalize_workflow(select_best(xgboost_tune, metric = "accuracy")) %>%
last_fit(data_split)
collect_metrics(xgboost_last)
## # A tibble: 3 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.848 pre0_mod0_post0
## 2 roc_auc binary 0.761 pre0_mod0_post0
## 3 brier_class binary 0.120 pre0_mod0_post0
collect_predictions(xgboost_last) %>%
yardstick::conf_mat(Attrition, .pred_class) %>%
autoplot()
#variable importance
xgboost_last %>%
workflows::extract_fit_engine() %>%
vip()
The XGBoost model performed well in predicting employee attrition, showing solid accuracy and a good overall fit. Key factors such as income, job satisfaction, and work environment were important drivers of employee turnover. Overall, the model provides useful insights that can help organizations better understand and reduce attrition.