library(tidyverse)
library(tidyquant)
attrition <- readr::read_csv('../data/WA_Fn-UseC_-HR-Employee-Attrition.csv')
attrition %>% skimr::skim()
Name | Piped data |
Number of rows | 1470 |
Number of columns | 35 |
_______________________ | |
Column type frequency: | |
character | 9 |
numeric | 26 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
Attrition | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
BusinessTravel | 0 | 1 | 10 | 17 | 0 | 3 | 0 |
Department | 0 | 1 | 5 | 22 | 0 | 3 | 0 |
EducationField | 0 | 1 | 5 | 16 | 0 | 6 | 0 |
Gender | 0 | 1 | 4 | 6 | 0 | 2 | 0 |
JobRole | 0 | 1 | 7 | 25 | 0 | 9 | 0 |
MaritalStatus | 0 | 1 | 6 | 8 | 0 | 3 | 0 |
Over18 | 0 | 1 | 1 | 1 | 0 | 1 | 0 |
OverTime | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
Age | 0 | 1 | 36.92 | 9.14 | 18 | 30.00 | 36.0 | 43.00 | 60 | ▂▇▇▃▂ |
DailyRate | 0 | 1 | 802.49 | 403.51 | 102 | 465.00 | 802.0 | 1157.00 | 1499 | ▇▇▇▇▇ |
DistanceFromHome | 0 | 1 | 9.19 | 8.11 | 1 | 2.00 | 7.0 | 14.00 | 29 | ▇▅▂▂▂ |
Education | 0 | 1 | 2.91 | 1.02 | 1 | 2.00 | 3.0 | 4.00 | 5 | ▂▃▇▆▁ |
EmployeeCount | 0 | 1 | 1.00 | 0.00 | 1 | 1.00 | 1.0 | 1.00 | 1 | ▁▁▇▁▁ |
EmployeeNumber | 0 | 1 | 1024.87 | 602.02 | 1 | 491.25 | 1020.5 | 1555.75 | 2068 | ▇▇▇▇▇ |
EnvironmentSatisfaction | 0 | 1 | 2.72 | 1.09 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
HourlyRate | 0 | 1 | 65.89 | 20.33 | 30 | 48.00 | 66.0 | 83.75 | 100 | ▇▇▇▇▇ |
JobInvolvement | 0 | 1 | 2.73 | 0.71 | 1 | 2.00 | 3.0 | 3.00 | 4 | ▁▃▁▇▁ |
JobLevel | 0 | 1 | 2.06 | 1.11 | 1 | 1.00 | 2.0 | 3.00 | 5 | ▇▇▃▂▁ |
JobSatisfaction | 0 | 1 | 2.73 | 1.10 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
MonthlyIncome | 0 | 1 | 6502.93 | 4707.96 | 1009 | 2911.00 | 4919.0 | 8379.00 | 19999 | ▇▅▂▁▂ |
MonthlyRate | 0 | 1 | 14313.10 | 7117.79 | 2094 | 8047.00 | 14235.5 | 20461.50 | 26999 | ▇▇▇▇▇ |
NumCompaniesWorked | 0 | 1 | 2.69 | 2.50 | 0 | 1.00 | 2.0 | 4.00 | 9 | ▇▃▂▂▁ |
PercentSalaryHike | 0 | 1 | 15.21 | 3.66 | 11 | 12.00 | 14.0 | 18.00 | 25 | ▇▅▃▂▁ |
PerformanceRating | 0 | 1 | 3.15 | 0.36 | 3 | 3.00 | 3.0 | 3.00 | 4 | ▇▁▁▁▂ |
RelationshipSatisfaction | 0 | 1 | 2.71 | 1.08 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
StandardHours | 0 | 1 | 80.00 | 0.00 | 80 | 80.00 | 80.0 | 80.00 | 80 | ▁▁▇▁▁ |
StockOptionLevel | 0 | 1 | 0.79 | 0.85 | 0 | 0.00 | 1.0 | 1.00 | 3 | ▇▇▁▂▁ |
TotalWorkingYears | 0 | 1 | 11.28 | 7.78 | 0 | 6.00 | 10.0 | 15.00 | 40 | ▇▇▂▁▁ |
TrainingTimesLastYear | 0 | 1 | 2.80 | 1.29 | 0 | 2.00 | 3.0 | 3.00 | 6 | ▂▇▇▂▃ |
WorkLifeBalance | 0 | 1 | 2.76 | 0.71 | 1 | 2.00 | 3.0 | 3.00 | 4 | ▁▃▁▇▂ |
YearsAtCompany | 0 | 1 | 7.01 | 6.13 | 0 | 3.00 | 5.0 | 9.00 | 40 | ▇▂▁▁▁ |
YearsInCurrentRole | 0 | 1 | 4.23 | 3.62 | 0 | 2.00 | 3.0 | 7.00 | 18 | ▇▃▂▁▁ |
YearsSinceLastPromotion | 0 | 1 | 2.19 | 3.22 | 0 | 0.00 | 1.0 | 3.00 | 15 | ▇▁▁▁▁ |
YearsWithCurrManager | 0 | 1 | 4.12 | 3.57 | 0 | 2.00 | 3.0 | 7.00 | 17 | ▇▂▅▁▁ |
data <- attrition %>%
# Remove zero variance variables
select(-Over18, -EmployeeCount, -StandardHours) %>%
# Convert character to factor
mutate(across(where(is.character), factor)) %>%
# Convert factors imported as numeric
mutate(across(c(Education, JobLevel, StockOptionLevel), factor))
*** Identify variables with correlation with the target variable.***
data %>% count(Attrition)
## # A tibble: 2 × 2
## Attrition n
## <fct> <int>
## 1 No 1233
## 2 Yes 237
Monthly income looks like the best predictor among income variables
data %>%
ggplot(aes(Attrition, MonthlyIncome)) +
geom_boxplot()
data %>%
ggplot(aes(Attrition, MonthlyRate)) +
geom_boxplot()
data %>% ggplot(aes(Attrition, DailyRate)) +
geom_boxplot()
data %>% ggplot(aes(Attrition, HourlyRate)) +
geom_boxplot()
data %>%
ggplot(aes(Attrition, DistanceFromHome)) +
geom_boxplot()
data %>%
# Transform
count(Attrition, WorkLifeBalance) %>%
# Plot
ggplot(aes(Attrition, WorkLifeBalance, fill = n)) +
geom_tile()
data %>%
count(Department, Attrition) %>%
ggplot(aes(Attrition, n, fill = Attrition)) +
geom_col(show.legend = FALSE) +
facet_wrap(vars(Department), scales = "free")
data %>%
count(JobLevel, Attrition) %>%
ggplot(aes(Attrition, n, fill = Attrition)) +
geom_col(show.legend = FALSE) +
facet_wrap(vars(JobLevel), scales = "free")
data %>%
count(JobRole, Attrition) %>%
ggplot(aes(Attrition, n, fill = Attrition)) +
geom_col(show.legend = FALSE) +
facet_wrap(vars(JobRole), scales = "free")
data_tidy <- data %>%
select(Age, Attrition, Department, DistanceFromHome, Education, EnvironmentSatisfaction, Gender:MonthlyRate)
library(tidymodels)
set.seed(123)
data_split <- initial_split(data_tidy, strata = Attrition)
data_train <- training(data_split)
data_test <- testing(data_split)
set.seed(124)
data_folds <- vfold_cv(data_train, strata = Attrition)
data_folds
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [990/111]> Fold01
## 2 <split [990/111]> Fold02
## 3 <split [990/111]> Fold03
## 4 <split [990/111]> Fold04
## 5 <split [991/110]> Fold05
## 6 <split [991/110]> Fold06
## 7 <split [991/110]> Fold07
## 8 <split [992/109]> Fold08
## 9 <split [992/109]> Fold09
## 10 <split [992/109]> Fold10
library(embed)
data_rec <- recipe(Attrition ~ ., data = data_train) %>%
step_lencode_glm(JobRole, outcome = vars(Attrition)) %>%
step_dummy(all_nominal_predictors())
data_rec
prep(data_rec) %>%
tidy(number = 1) %>%
filter(level == "..new")
## # A tibble: 1 × 4
## level value terms id
## <chr> <dbl> <chr> <chr>
## 1 ..new 2.02 JobRole lencode_glm_qxh67
xgb_spec <-
boost_tree(
trees = tune(),
min_n = tune(),
mtry = tune(),
learn_rate = .01
) %>%
set_engine("xgboost") %>%
set_mode("classification")
xgb_wf <- workflow(data_rec, xgb_spec)
library(finetune)
doParallel::registerDoParallel()
set.seed(125)
xgb_rs <-
tune_race_anova(
xgb_wf,
resamples = data_folds,
grid = 25,
control = control_race(verbose_elim = TRUE)
)
# I deciced to up the grid number
xgb_rs
## # Tuning results
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 5
## splits id .order .metrics .notes
## <list> <chr> <int> <list> <list>
## 1 <split [990/111]> Fold02 1 <tibble [50 × 7]> <tibble [0 × 3]>
## 2 <split [990/111]> Fold03 2 <tibble [50 × 7]> <tibble [0 × 3]>
## 3 <split [992/109]> Fold08 3 <tibble [50 × 7]> <tibble [0 × 3]>
## 4 <split [990/111]> Fold04 4 <tibble [46 × 7]> <tibble [0 × 3]>
## 5 <split [991/110]> Fold07 5 <tibble [36 × 7]> <tibble [0 × 3]>
## 6 <split [992/109]> Fold10 6 <tibble [32 × 7]> <tibble [0 × 3]>
## 7 <split [992/109]> Fold09 7 <tibble [28 × 7]> <tibble [0 × 3]>
## 8 <split [990/111]> Fold01 8 <tibble [26 × 7]> <tibble [0 × 3]>
## 9 <split [991/110]> Fold05 9 <tibble [24 × 7]> <tibble [0 × 3]>
## 10 <split [991/110]> Fold06 10 <tibble [16 × 7]> <tibble [0 × 3]>
plot_race(xgb_rs)
collect_metrics(xgb_rs)
## # A tibble: 16 × 9
## mtry trees min_n .metric .estimator mean n std_err .config
## <int> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 940 38 accuracy binary 0.839 10 0.00106 Preprocessor1_Mode…
## 2 1 940 38 roc_auc binary 0.762 10 0.0312 Preprocessor1_Mode…
## 3 2 860 11 accuracy binary 0.846 10 0.00689 Preprocessor1_Mode…
## 4 2 860 11 roc_auc binary 0.765 10 0.0277 Preprocessor1_Mode…
## 5 4 392 8 accuracy binary 0.847 10 0.00555 Preprocessor1_Mode…
## 6 4 392 8 roc_auc binary 0.761 10 0.0265 Preprocessor1_Mode…
## 7 4 1549 25 accuracy binary 0.843 10 0.00950 Preprocessor1_Mode…
## 8 4 1549 25 roc_auc binary 0.768 10 0.0268 Preprocessor1_Mode…
## 9 10 526 20 accuracy binary 0.838 10 0.00514 Preprocessor1_Mode…
## 10 10 526 20 roc_auc binary 0.756 10 0.0286 Preprocessor1_Mode…
## 11 11 1010 32 accuracy binary 0.840 10 0.00702 Preprocessor1_Mode…
## 12 11 1010 32 roc_auc binary 0.757 10 0.0298 Preprocessor1_Mode…
## 13 13 667 28 accuracy binary 0.840 10 0.00542 Preprocessor1_Mode…
## 14 13 667 28 roc_auc binary 0.758 10 0.0284 Preprocessor1_Mode…
## 15 15 1061 27 accuracy binary 0.845 10 0.00893 Preprocessor1_Mode…
## 16 15 1061 27 roc_auc binary 0.762 10 0.0272 Preprocessor1_Mode…
xgb_last <- xgb_wf %>%
finalize_workflow(select_best(xgb_rs, "accuracy")) %>%
last_fit(data_split)
collect_metrics(xgb_last)
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.840 Preprocessor1_Model1
## 2 roc_auc binary 0.694 Preprocessor1_Model1
collect_predictions(xgb_last) %>%
conf_mat(Attrition, .pred_class)
## Truth
## Prediction No Yes
## No 303 53
## Yes 6 7
library(vip)
xgb_last %>%
extract_fit_engine() %>%
vip::vip()