IBM HR Analytics Employee Attrition & Performance: It is is a fictional data set created by IBM data scientists. Build a classification model to predict employee attrition (Attrition).
library(tidyverse)
library(tidyquant)
Attrition <- readr::read_csv("../00_data/WA_Fn-UseC_-HR-Employee-Attrition.csv")
Attrition %>% skimr::skim()
Name | Piped data |
Number of rows | 1470 |
Number of columns | 35 |
_______________________ | |
Column type frequency: | |
character | 9 |
numeric | 26 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
Attrition | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
BusinessTravel | 0 | 1 | 10 | 17 | 0 | 3 | 0 |
Department | 0 | 1 | 5 | 22 | 0 | 3 | 0 |
EducationField | 0 | 1 | 5 | 16 | 0 | 6 | 0 |
Gender | 0 | 1 | 4 | 6 | 0 | 2 | 0 |
JobRole | 0 | 1 | 7 | 25 | 0 | 9 | 0 |
MaritalStatus | 0 | 1 | 6 | 8 | 0 | 3 | 0 |
Over18 | 0 | 1 | 1 | 1 | 0 | 1 | 0 |
OverTime | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
Age | 0 | 1 | 36.92 | 9.14 | 18 | 30.00 | 36.0 | 43.00 | 60 | ▂▇▇▃▂ |
DailyRate | 0 | 1 | 802.49 | 403.51 | 102 | 465.00 | 802.0 | 1157.00 | 1499 | ▇▇▇▇▇ |
DistanceFromHome | 0 | 1 | 9.19 | 8.11 | 1 | 2.00 | 7.0 | 14.00 | 29 | ▇▅▂▂▂ |
Education | 0 | 1 | 2.91 | 1.02 | 1 | 2.00 | 3.0 | 4.00 | 5 | ▂▃▇▆▁ |
EmployeeCount | 0 | 1 | 1.00 | 0.00 | 1 | 1.00 | 1.0 | 1.00 | 1 | ▁▁▇▁▁ |
EmployeeNumber | 0 | 1 | 1024.87 | 602.02 | 1 | 491.25 | 1020.5 | 1555.75 | 2068 | ▇▇▇▇▇ |
EnvironmentSatisfaction | 0 | 1 | 2.72 | 1.09 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
HourlyRate | 0 | 1 | 65.89 | 20.33 | 30 | 48.00 | 66.0 | 83.75 | 100 | ▇▇▇▇▇ |
JobInvolvement | 0 | 1 | 2.73 | 0.71 | 1 | 2.00 | 3.0 | 3.00 | 4 | ▁▃▁▇▁ |
JobLevel | 0 | 1 | 2.06 | 1.11 | 1 | 1.00 | 2.0 | 3.00 | 5 | ▇▇▃▂▁ |
JobSatisfaction | 0 | 1 | 2.73 | 1.10 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
MonthlyIncome | 0 | 1 | 6502.93 | 4707.96 | 1009 | 2911.00 | 4919.0 | 8379.00 | 19999 | ▇▅▂▁▂ |
MonthlyRate | 0 | 1 | 14313.10 | 7117.79 | 2094 | 8047.00 | 14235.5 | 20461.50 | 26999 | ▇▇▇▇▇ |
NumCompaniesWorked | 0 | 1 | 2.69 | 2.50 | 0 | 1.00 | 2.0 | 4.00 | 9 | ▇▃▂▂▁ |
PercentSalaryHike | 0 | 1 | 15.21 | 3.66 | 11 | 12.00 | 14.0 | 18.00 | 25 | ▇▅▃▂▁ |
PerformanceRating | 0 | 1 | 3.15 | 0.36 | 3 | 3.00 | 3.0 | 3.00 | 4 | ▇▁▁▁▂ |
RelationshipSatisfaction | 0 | 1 | 2.71 | 1.08 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
StandardHours | 0 | 1 | 80.00 | 0.00 | 80 | 80.00 | 80.0 | 80.00 | 80 | ▁▁▇▁▁ |
StockOptionLevel | 0 | 1 | 0.79 | 0.85 | 0 | 0.00 | 1.0 | 1.00 | 3 | ▇▇▁▂▁ |
TotalWorkingYears | 0 | 1 | 11.28 | 7.78 | 0 | 6.00 | 10.0 | 15.00 | 40 | ▇▇▂▁▁ |
TrainingTimesLastYear | 0 | 1 | 2.80 | 1.29 | 0 | 2.00 | 3.0 | 3.00 | 6 | ▂▇▇▂▃ |
WorkLifeBalance | 0 | 1 | 2.76 | 0.71 | 1 | 2.00 | 3.0 | 3.00 | 4 | ▁▃▁▇▂ |
YearsAtCompany | 0 | 1 | 7.01 | 6.13 | 0 | 3.00 | 5.0 | 9.00 | 40 | ▇▂▁▁▁ |
YearsInCurrentRole | 0 | 1 | 4.23 | 3.62 | 0 | 2.00 | 3.0 | 7.00 | 18 | ▇▃▂▁▁ |
YearsSinceLastPromotion | 0 | 1 | 2.19 | 3.22 | 0 | 0.00 | 1.0 | 3.00 | 15 | ▇▁▁▁▁ |
YearsWithCurrManager | 0 | 1 | 4.12 | 3.57 | 0 | 2.00 | 3.0 | 7.00 | 17 | ▇▂▅▁▁ |
Notes about the data
data <- Attrition %>%
# Recode the target variable
mutate(Attrition = if_else(Attrition == "Yes", "Left", "No")) %>%
# Remove zero variance variables
select(-Over18, -EmployeeCount, -StandardHours) %>%
# Convert character to factor
mutate(across(where(is.character), factor)) %>%
# Convert factors imported as numeric
mutate(across(c(Education, JobLevel, StockOptionLevel, WorkLifeBalance), factor))
skimr::skim(data)
Name | data |
Number of rows | 1470 |
Number of columns | 32 |
_______________________ | |
Column type frequency: | |
factor | 12 |
numeric | 20 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
Attrition | 0 | 1 | FALSE | 2 | No: 1233, Lef: 237 |
BusinessTravel | 0 | 1 | FALSE | 3 | Tra: 1043, Tra: 277, Non: 150 |
Department | 0 | 1 | FALSE | 3 | Res: 961, Sal: 446, Hum: 63 |
Education | 0 | 1 | FALSE | 5 | 3: 572, 4: 398, 2: 282, 1: 170 |
EducationField | 0 | 1 | FALSE | 6 | Lif: 606, Med: 464, Mar: 159, Tec: 132 |
Gender | 0 | 1 | FALSE | 2 | Mal: 882, Fem: 588 |
JobLevel | 0 | 1 | FALSE | 5 | 1: 543, 2: 534, 3: 218, 4: 106 |
JobRole | 0 | 1 | FALSE | 9 | Sal: 326, Res: 292, Lab: 259, Man: 145 |
MaritalStatus | 0 | 1 | FALSE | 3 | Mar: 673, Sin: 470, Div: 327 |
OverTime | 0 | 1 | FALSE | 2 | No: 1054, Yes: 416 |
StockOptionLevel | 0 | 1 | FALSE | 4 | 0: 631, 1: 596, 2: 158, 3: 85 |
WorkLifeBalance | 0 | 1 | FALSE | 4 | 3: 893, 2: 344, 4: 153, 1: 80 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
Age | 0 | 1 | 36.92 | 9.14 | 18 | 30.00 | 36.0 | 43.00 | 60 | ▂▇▇▃▂ |
DailyRate | 0 | 1 | 802.49 | 403.51 | 102 | 465.00 | 802.0 | 1157.00 | 1499 | ▇▇▇▇▇ |
DistanceFromHome | 0 | 1 | 9.19 | 8.11 | 1 | 2.00 | 7.0 | 14.00 | 29 | ▇▅▂▂▂ |
EmployeeNumber | 0 | 1 | 1024.87 | 602.02 | 1 | 491.25 | 1020.5 | 1555.75 | 2068 | ▇▇▇▇▇ |
EnvironmentSatisfaction | 0 | 1 | 2.72 | 1.09 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
HourlyRate | 0 | 1 | 65.89 | 20.33 | 30 | 48.00 | 66.0 | 83.75 | 100 | ▇▇▇▇▇ |
JobInvolvement | 0 | 1 | 2.73 | 0.71 | 1 | 2.00 | 3.0 | 3.00 | 4 | ▁▃▁▇▁ |
JobSatisfaction | 0 | 1 | 2.73 | 1.10 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
MonthlyIncome | 0 | 1 | 6502.93 | 4707.96 | 1009 | 2911.00 | 4919.0 | 8379.00 | 19999 | ▇▅▂▁▂ |
MonthlyRate | 0 | 1 | 14313.10 | 7117.79 | 2094 | 8047.00 | 14235.5 | 20461.50 | 26999 | ▇▇▇▇▇ |
NumCompaniesWorked | 0 | 1 | 2.69 | 2.50 | 0 | 1.00 | 2.0 | 4.00 | 9 | ▇▃▂▂▁ |
PercentSalaryHike | 0 | 1 | 15.21 | 3.66 | 11 | 12.00 | 14.0 | 18.00 | 25 | ▇▅▃▂▁ |
PerformanceRating | 0 | 1 | 3.15 | 0.36 | 3 | 3.00 | 3.0 | 3.00 | 4 | ▇▁▁▁▂ |
RelationshipSatisfaction | 0 | 1 | 2.71 | 1.08 | 1 | 2.00 | 3.0 | 4.00 | 4 | ▅▅▁▇▇ |
TotalWorkingYears | 0 | 1 | 11.28 | 7.78 | 0 | 6.00 | 10.0 | 15.00 | 40 | ▇▇▂▁▁ |
TrainingTimesLastYear | 0 | 1 | 2.80 | 1.29 | 0 | 2.00 | 3.0 | 3.00 | 6 | ▂▇▇▂▃ |
YearsAtCompany | 0 | 1 | 7.01 | 6.13 | 0 | 3.00 | 5.0 | 9.00 | 40 | ▇▂▁▁▁ |
YearsInCurrentRole | 0 | 1 | 4.23 | 3.62 | 0 | 2.00 | 3.0 | 7.00 | 18 | ▇▃▂▁▁ |
YearsSinceLastPromotion | 0 | 1 | 2.19 | 3.22 | 0 | 0.00 | 1.0 | 3.00 | 15 | ▇▁▁▁▁ |
YearsWithCurrManager | 0 | 1 | 4.12 | 3.57 | 0 | 2.00 | 3.0 | 7.00 | 17 | ▇▂▅▁▁ |
*** Identify variables with correlation with the target variable.***
data %>% count(Attrition)
## # A tibble: 2 × 2
## Attrition n
## <fct> <int>
## 1 Left 237
## 2 No 1233
The more satisfied an employee is, the less likely to leave the job. Not surprising!
data %>%
ggplot(aes(Attrition, JobSatisfaction)) +
geom_boxplot()
Monthly income looks like the best predictor among income variables.
data %>%
ggplot(aes(Attrition, MonthlyIncome)) +
geom_boxplot()
data %>%
ggplot(aes(Attrition, MonthlyRate)) +
geom_boxplot()
data %>%
ggplot(aes(Attrition, DailyRate)) +
geom_boxplot()
data %>%
ggplot(aes(Attrition, HourlyRate)) +
geom_boxplot()
Yes
data %>%
ggplot(aes(Attrition, DistanceFromHome)) +
geom_boxplot()
It is very telling. No more overtime if you want to keep employees!
data %>%
count(Attrition, OverTime) %>%
pivot_wider(names_from = Attrition, values_from = n, values_fill = 0) %>%
mutate(pct_left = (Left / (No + Left)) * 100) %>%
ggplot(aes(pct_left, fct_reorder(OverTime, pct_left))) +
geom_col() +
labs(y = "Overtime", x = "Percent Left")
It seems relevant. Category 1 stands out with more than 30% leaving.
data %>%
count(Attrition, WorkLifeBalance) %>%
pivot_wider(names_from = Attrition, values_from = n, values_fill = 0) %>%
mutate(pct_left = (Left / (No + Left)) * 100) %>%
ggplot(aes(pct_left, fct_reorder(WorkLifeBalance, pct_left))) +
geom_col() +
labs(y = "Work Life Balance", x = "Percent Left")
Does 0 mean no stock option? It makes sense. No stock option. Less incentive to stay.
data %>%
count(Attrition, StockOptionLevel) %>%
pivot_wider(names_from = Attrition, values_from = n, values_fill = 0) %>%
mutate(pct_left = (Left / (No + Left)) * 100) %>%
ggplot(aes(pct_left, fct_reorder(StockOptionLevel, pct_left))) +
geom_col() +
labs(y = "Stock Option Levels", x = "Percent Left")
There are a lot more variables in the dataset. I decided to leave them all in the model.
# data <- data %>% group_by(Attrition) %>% sample_n(50)
library(tidymodels)
set.seed(123)
attrition_split <- initial_split(data, strata = Attrition)
attrition_train <- training(attrition_split)
attrition_test <- testing(attrition_split)
set.seed(234)
attrition_folds <- bootstraps(attrition_train, strata = Attrition)
attrition_folds
## # Bootstrap sampling using stratification
## # A tibble: 25 × 2
## splits id
## <list> <chr>
## 1 <split [1101/404]> Bootstrap01
## 2 <split [1101/401]> Bootstrap02
## 3 <split [1101/388]> Bootstrap03
## 4 <split [1101/406]> Bootstrap04
## 5 <split [1101/421]> Bootstrap05
## 6 <split [1101/407]> Bootstrap06
## 7 <split [1101/407]> Bootstrap07
## 8 <split [1101/396]> Bootstrap08
## 9 <split [1101/407]> Bootstrap09
## 10 <split [1101/407]> Bootstrap10
## # ℹ 15 more rows
library(themis)
attritions_rec <-
recipe(Attrition ~ ., data = attrition_train) %>%
update_role(EmployeeNumber, new_role = "id") %>%
step_dummy(all_nominal_predictors(), one_hot = TRUE) %>%
step_normalize(all_numeric_predictors()) %>%
step_zv(all_numeric_predictors()) %>%
step_smote(Attrition)
attritions_rec %>% prep() %>% juice() %>% glimpse()
## Rows: 1,848
## Columns: 67
## $ Age <dbl> 0.445029949, -0.107743036, -0.328852…
## $ DailyRate <dbl> 0.75926794, 1.04787368, -0.24338822,…
## $ DistanceFromHome <dbl> -0.99648786, -0.02331289, -0.3882535…
## $ EmployeeNumber <dbl> 1, 27, 31, 33, 42, 45, 47, 58, 65, 9…
## $ EnvironmentSatisfaction <dbl> -0.6544585, 0.2622830, -0.6544585, -…
## $ HourlyRate <dbl> 1.35598232, 0.76221199, 0.81169285, …
## $ JobInvolvement <dbl> 0.364117, -1.018272, 0.364117, -2.40…
## $ JobSatisfaction <dbl> 1.1416683, -1.5712313, -1.5712313, -…
## $ MonthlyIncome <dbl> -0.1190701, -0.6660728, -0.7606243, …
## $ MonthlyRate <dbl> 0.7355336, -1.0430438, 0.3971298, -1…
## $ NumCompaniesWorked <dbl> 2.1148603, 1.7142304, -0.2889193, -0…
## $ PercentSalaryHike <dbl> -1.16117737, 2.14019457, -1.16117737…
## $ PerformanceRating <dbl> -0.4256357, 2.3472929, -0.4256357, 2…
## $ RelationshipSatisfaction <dbl> -1.5621678, -0.6433928, 0.2753821, -…
## $ TotalWorkingYears <dbl> -0.42529085, -0.17074058, -0.4252908…
## $ TrainingTimesLastYear <dbl> -2.1655389, 0.9491030, -0.6082180, 1…
## $ YearsAtCompany <dbl> -0.166711069, -0.329000094, -0.49128…
## $ YearsInCurrentRole <dbl> -0.05982354, -0.33773798, -0.6156524…
## $ YearsSinceLastPromotion <dbl> -0.67855380, -0.67855380, -0.3663566…
## $ YearsWithCurrManager <dbl> 0.25954528, -0.30743756, -0.30743756…
## $ Attrition <fct> Left, Left, Left, Left, Left, Left, …
## $ BusinessTravel_Non.Travel <dbl> -0.3463115, -0.3463115, -0.3463115, …
## $ BusinessTravel_Travel_Frequently <dbl> -0.4781076, -0.4781076, -0.4781076, …
## $ BusinessTravel_Travel_Rarely <dbl> 0.6440419, 0.6440419, 0.6440419, -1.…
## $ Department_Human.Resources <dbl> -0.1915458, -0.1915458, -0.1915458, …
## $ Department_Research...Development <dbl> -1.3879626, -1.3879626, 0.7198261, 0…
## $ Department_Sales <dbl> 1.5049928, 1.5049928, -0.6638515, -0…
## $ Education_X1 <dbl> -0.3544752, -0.3544752, 2.8185098, 2…
## $ Education_X2 <dbl> 2.0232402, -0.4938078, -0.4938078, -…
## $ Education_X3 <dbl> -0.7759335, -0.7759335, -0.7759335, …
## $ Education_X4 <dbl> -0.6201258, 1.6111113, -0.6201258, -…
## $ Education_X5 <dbl> -0.1990579, -0.1990579, -0.1990579, …
## $ EducationField_Human.Resources <dbl> -0.135958, -0.135958, -0.135958, -0.…
## $ EducationField_Life.Sciences <dbl> 1.1999719, 1.1999719, -0.8325959, 1.…
## $ EducationField_Marketing <dbl> -0.3593235, -0.3593235, -0.3593235, …
## $ EducationField_Medical <dbl> -0.6666917, -0.6666917, 1.4985813, -…
## $ EducationField_Other <dbl> -0.2357081, -0.2357081, -0.2357081, …
## $ EducationField_Technical.Degree <dbl> -0.3279453, -0.3279453, -0.3279453, …
## $ Gender_Female <dbl> 1.2415007, -0.8047452, -0.8047452, 1…
## $ Gender_Male <dbl> -1.2415007, 0.8047452, 0.8047452, -1…
## $ JobLevel_X1 <dbl> -0.7639647, 1.3077721, 1.3077721, 1.…
## $ JobLevel_X2 <dbl> 1.3310481, -0.7506053, -0.7506053, -…
## $ JobLevel_X3 <dbl> -0.4136678, -0.4136678, -0.4136678, …
## $ JobLevel_X4 <dbl> -0.2891115, -0.2891115, -0.2891115, …
## $ JobLevel_X5 <dbl> -0.2225444, -0.2225444, -0.2225444, …
## $ JobRole_Healthcare.Representative <dbl> -0.312439, -0.312439, -0.312439, -0.…
## $ JobRole_Human.Resources <dbl> -0.1672895, -0.1672895, -0.1672895, …
## $ JobRole_Laboratory.Technician <dbl> -0.4608275, -0.4608275, -0.4608275, …
## $ JobRole_Manager <dbl> -0.2816729, -0.2816729, -0.2816729, …
## $ JobRole_Manufacturing.Director <dbl> -0.3430154, -0.3430154, -0.3430154, …
## $ JobRole_Research.Director <dbl> -0.2378454, -0.2378454, -0.2378454, …
## $ JobRole_Research.Scientist <dbl> -0.4966502, -0.4966502, 2.0116610, 2…
## $ JobRole_Sales.Executive <dbl> 1.8982673, -0.5263177, -0.5263177, -…
## $ JobRole_Sales.Representative <dbl> -0.2544367, 3.9266804, -0.2544367, -…
## $ MaritalStatus_Divorced <dbl> -0.5263177, -0.5263177, -0.5263177, …
## $ MaritalStatus_Married <dbl> -0.9387492, -0.9387492, -0.9387492, …
## $ MaritalStatus_Single <dbl> 1.4765158, 1.4765158, 1.4765158, 1.4…
## $ OverTime_No <dbl> -1.6332712, 0.6117121, 0.6117121, -1…
## $ OverTime_Yes <dbl> 1.6332712, -0.6117121, -0.6117121, 1…
## $ StockOptionLevel_X0 <dbl> 1.1689738, 1.1689738, 1.1689738, 1.1…
## $ StockOptionLevel_X1 <dbl> -0.8279138, -0.8279138, -0.8279138, …
## $ StockOptionLevel_X2 <dbl> -0.3593235, -0.3593235, -0.3593235, …
## $ StockOptionLevel_X3 <dbl> -0.2441691, -0.2441691, -0.2441691, …
## $ WorkLifeBalance_X1 <dbl> 3.9904868, -0.2503684, -0.2503684, -…
## $ WorkLifeBalance_X2 <dbl> -0.5725453, -0.5725453, -0.5725453, …
## $ WorkLifeBalance_X3 <dbl> -1.2113117, 0.8248015, 0.8248015, 0.…
## $ WorkLifeBalance_X4 <dbl> -0.3313295, -0.3313295, -0.3313295, …
xgb_spec <-
boost_tree(
trees = tune(),
min_n = tune(),
mtry = tune(),
learn_rate = 0.01
) %>%
set_engine("xgboost") %>%
set_mode("classification")
xgb_wf <- workflow(attritions_rec, xgb_spec)
doParallel::registerDoParallel()
set.seed(345)
xgb_rs <- tune_grid(
xgb_wf,
resamples = attrition_folds,
grid = 5,
control = control_grid(verbose = TRUE, save_pred = TRUE)
)
xgb_rs
## # Tuning results
## # Bootstrap sampling using stratification
## # A tibble: 25 × 5
## splits id .metrics .notes .predictions
## <list> <chr> <list> <list> <list>
## 1 <split [1101/404]> Bootstrap01 <tibble [10 × 7]> <tibble> <tibble>
## 2 <split [1101/401]> Bootstrap02 <tibble [10 × 7]> <tibble> <tibble>
## 3 <split [1101/388]> Bootstrap03 <tibble [10 × 7]> <tibble> <tibble>
## 4 <split [1101/406]> Bootstrap04 <tibble [10 × 7]> <tibble> <tibble>
## 5 <split [1101/421]> Bootstrap05 <tibble [10 × 7]> <tibble> <tibble>
## 6 <split [1101/407]> Bootstrap06 <tibble [10 × 7]> <tibble> <tibble>
## 7 <split [1101/407]> Bootstrap07 <tibble [10 × 7]> <tibble> <tibble>
## 8 <split [1101/396]> Bootstrap08 <tibble [10 × 7]> <tibble> <tibble>
## 9 <split [1101/407]> Bootstrap09 <tibble [10 × 7]> <tibble> <tibble>
## 10 <split [1101/407]> Bootstrap10 <tibble [10 × 7]> <tibble> <tibble>
## # ℹ 15 more rows
collect_metrics(xgb_rs)
## # A tibble: 10 × 9
## mtry trees min_n .metric .estimator mean n std_err .config
## <int> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 14 606 28 accuracy binary 0.852 25 0.00379 Preprocessor1_Mode…
## 2 14 606 28 roc_auc binary 0.807 25 0.00521 Preprocessor1_Mode…
## 3 15 1196 24 accuracy binary 0.861 25 0.00303 Preprocessor1_Mode…
## 4 15 1196 24 roc_auc binary 0.812 25 0.00484 Preprocessor1_Mode…
## 5 27 1331 37 accuracy binary 0.855 25 0.00332 Preprocessor1_Mode…
## 6 27 1331 37 roc_auc binary 0.811 25 0.00455 Preprocessor1_Mode…
## 7 43 1641 7 accuracy binary 0.861 25 0.00301 Preprocessor1_Mode…
## 8 43 1641 7 roc_auc binary 0.800 25 0.00456 Preprocessor1_Mode…
## 9 60 194 13 accuracy binary 0.834 25 0.00437 Preprocessor1_Mode…
## 10 60 194 13 roc_auc binary 0.775 25 0.00670 Preprocessor1_Mode…
collect_predictions(xgb_rs) %>%
group_by(id) %>%
roc_curve(Attrition, .pred_Left) %>%
autoplot()
# conf_mat_resampled(xgb_rs, tidy = FALSE, parameters = tibble(mtry = 12, trees = 666, min_n = 28)) %>%
# autoplot()
xgb_last <- xgb_wf %>%
finalize_workflow(select_best(xgb_rs, "accuracy")) %>%
last_fit(attrition_split)
xgb_last
## # Resampling results
## # Manual resampling
## # A tibble: 1 × 6
## splits id .metrics .notes .predictions .workflow
## <list> <chr> <list> <list> <list> <list>
## 1 <split [1101/369]> train/test split <tibble> <tibble> <tibble> <workflow>
collect_metrics(xgb_last)
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.851 Preprocessor1_Model1
## 2 roc_auc binary 0.803 Preprocessor1_Model1
collect_predictions(xgb_last) %>%
conf_mat(Attrition, .pred_class) %>%
autoplot()
library(vip)
xgb_last %>%
extract_fit_engine() %>%
vip()