This is a fictional data set created by IBM data scientists. The target why they create this fictial data set, is to uncover the factors that lead to employee attrition.
In this case i would use specific question is, How to classify using neural network method and how is result of our model and predict classify and compare with our data test. Beside it we would try to uncover which feature is important to understand employee attrition
Dataset we get from this kaggle, this is an public dataset which mean everyone can access dataset.
In this step, we would like to preparing and wrangling data to help do this process and to help our modelling we need to import some package. We will use some wrangling package, some data processing and modelling package, and we need some visualize package.
library(tidyverse)
library(ggplot2)
library(GGally)
library(reshape)
library(keras)
library(rsample)
library(recipes)
library(yardstick)
library(caret)
library(plotly)
library(lime)
library(scales)
library(corrr)
library(tidyquant)
options(scipen = 100)
Import the dataset from CSV
employee <- read.csv("data/WA_Fn-UseC_-HR-Employee-Attrition.csv")
glimpse(employee)
## Observations: 1,470
## Variables: 35
## $ Age <int> 41, 49, 37, 33, 27, 32, 59, 30, 38, 36, 35, …
## $ Attrition <fct> Yes, No, Yes, No, No, No, No, No, No, No, No…
## $ BusinessTravel <fct> Travel_Rarely, Travel_Frequently, Travel_Rar…
## $ DailyRate <int> 1102, 279, 1373, 1392, 591, 1005, 1324, 1358…
## $ Department <fct> Sales, Research & Development, Research & De…
## $ DistanceFromHome <int> 1, 8, 2, 3, 2, 2, 3, 24, 23, 27, 16, 15, 26,…
## $ Education <int> 2, 1, 2, 4, 1, 2, 3, 1, 3, 3, 3, 2, 1, 2, 3,…
## $ EducationField <fct> Life Sciences, Life Sciences, Other, Life Sc…
## $ EmployeeCount <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,…
## $ EmployeeNumber <int> 1, 2, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16…
## $ EnvironmentSatisfaction <int> 2, 3, 4, 4, 1, 4, 3, 4, 4, 3, 1, 4, 1, 2, 3,…
## $ Gender <fct> Female, Male, Male, Female, Male, Male, Fema…
## $ HourlyRate <int> 94, 61, 92, 56, 40, 79, 81, 67, 44, 94, 84, …
## $ JobInvolvement <int> 3, 2, 2, 3, 3, 3, 4, 3, 2, 3, 4, 2, 3, 3, 2,…
## $ JobLevel <int> 2, 2, 1, 1, 1, 1, 1, 1, 3, 2, 1, 2, 1, 1, 1,…
## $ JobRole <fct> Sales Executive, Research Scientist, Laborat…
## $ JobSatisfaction <int> 4, 2, 3, 3, 2, 4, 1, 3, 3, 3, 2, 3, 3, 4, 3,…
## $ MaritalStatus <fct> Single, Married, Single, Married, Married, S…
## $ MonthlyIncome <int> 5993, 5130, 2090, 2909, 3468, 3068, 2670, 26…
## $ MonthlyRate <int> 19479, 24907, 2396, 23159, 16632, 11864, 996…
## $ NumCompaniesWorked <int> 8, 1, 6, 1, 9, 0, 4, 1, 0, 6, 0, 0, 1, 0, 5,…
## $ Over18 <fct> Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y,…
## $ OverTime <fct> Yes, No, Yes, Yes, No, No, Yes, No, No, No, …
## $ PercentSalaryHike <int> 11, 23, 15, 11, 12, 13, 20, 22, 21, 13, 13, …
## $ PerformanceRating <int> 3, 4, 3, 3, 3, 3, 4, 4, 4, 3, 3, 3, 3, 3, 3,…
## $ RelationshipSatisfaction <int> 1, 4, 2, 3, 4, 3, 1, 2, 2, 2, 3, 4, 4, 3, 2,…
## $ StandardHours <int> 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, …
## $ StockOptionLevel <int> 0, 1, 0, 0, 1, 0, 3, 1, 0, 2, 1, 0, 1, 1, 0,…
## $ TotalWorkingYears <int> 8, 10, 7, 8, 6, 8, 12, 1, 10, 17, 6, 10, 5, …
## $ TrainingTimesLastYear <int> 0, 3, 3, 3, 3, 2, 3, 2, 2, 3, 5, 3, 1, 2, 4,…
## $ WorkLifeBalance <int> 1, 3, 3, 3, 3, 2, 2, 3, 3, 2, 3, 3, 2, 3, 3,…
## $ YearsAtCompany <int> 6, 10, 0, 8, 2, 7, 1, 1, 9, 7, 5, 9, 5, 2, 4…
## $ YearsInCurrentRole <int> 4, 7, 0, 7, 2, 7, 0, 0, 7, 7, 4, 5, 2, 2, 2,…
## $ YearsSinceLastPromotion <int> 0, 1, 0, 3, 2, 3, 0, 0, 1, 7, 0, 0, 4, 1, 0,…
## $ YearsWithCurrManager <int> 5, 7, 0, 0, 2, 6, 0, 0, 8, 7, 3, 8, 3, 2, 3,…
After import data we try to preview the data we have now. I use summary function to get summary from each feature
summary(employee)
## Age Attrition BusinessTravel DailyRate
## Min. :18.00 No :1233 Non-Travel : 150 Min. : 102.0
## 1st Qu.:30.00 Yes: 237 Travel_Frequently: 277 1st Qu.: 465.0
## Median :36.00 Travel_Rarely :1043 Median : 802.0
## Mean :36.92 Mean : 802.5
## 3rd Qu.:43.00 3rd Qu.:1157.0
## Max. :60.00 Max. :1499.0
##
## Department DistanceFromHome Education
## Human Resources : 63 Min. : 1.000 Min. :1.000
## Research & Development:961 1st Qu.: 2.000 1st Qu.:2.000
## Sales :446 Median : 7.000 Median :3.000
## Mean : 9.193 Mean :2.913
## 3rd Qu.:14.000 3rd Qu.:4.000
## Max. :29.000 Max. :5.000
##
## EducationField EmployeeCount EmployeeNumber EnvironmentSatisfaction
## Human Resources : 27 Min. :1 Min. : 1.0 Min. :1.000
## Life Sciences :606 1st Qu.:1 1st Qu.: 491.2 1st Qu.:2.000
## Marketing :159 Median :1 Median :1020.5 Median :3.000
## Medical :464 Mean :1 Mean :1024.9 Mean :2.722
## Other : 82 3rd Qu.:1 3rd Qu.:1555.8 3rd Qu.:4.000
## Technical Degree:132 Max. :1 Max. :2068.0 Max. :4.000
##
## Gender HourlyRate JobInvolvement JobLevel
## Female:588 Min. : 30.00 Min. :1.00 Min. :1.000
## Male :882 1st Qu.: 48.00 1st Qu.:2.00 1st Qu.:1.000
## Median : 66.00 Median :3.00 Median :2.000
## Mean : 65.89 Mean :2.73 Mean :2.064
## 3rd Qu.: 83.75 3rd Qu.:3.00 3rd Qu.:3.000
## Max. :100.00 Max. :4.00 Max. :5.000
##
## JobRole JobSatisfaction MaritalStatus MonthlyIncome
## Sales Executive :326 Min. :1.000 Divorced:327 Min. : 1009
## Research Scientist :292 1st Qu.:2.000 Married :673 1st Qu.: 2911
## Laboratory Technician :259 Median :3.000 Single :470 Median : 4919
## Manufacturing Director :145 Mean :2.729 Mean : 6503
## Healthcare Representative:131 3rd Qu.:4.000 3rd Qu.: 8379
## Manager :102 Max. :4.000 Max. :19999
## (Other) :215
## MonthlyRate NumCompaniesWorked Over18 OverTime PercentSalaryHike
## Min. : 2094 Min. :0.000 Y:1470 No :1054 Min. :11.00
## 1st Qu.: 8047 1st Qu.:1.000 Yes: 416 1st Qu.:12.00
## Median :14236 Median :2.000 Median :14.00
## Mean :14313 Mean :2.693 Mean :15.21
## 3rd Qu.:20462 3rd Qu.:4.000 3rd Qu.:18.00
## Max. :26999 Max. :9.000 Max. :25.00
##
## PerformanceRating RelationshipSatisfaction StandardHours StockOptionLevel
## Min. :3.000 Min. :1.000 Min. :80 Min. :0.0000
## 1st Qu.:3.000 1st Qu.:2.000 1st Qu.:80 1st Qu.:0.0000
## Median :3.000 Median :3.000 Median :80 Median :1.0000
## Mean :3.154 Mean :2.712 Mean :80 Mean :0.7939
## 3rd Qu.:3.000 3rd Qu.:4.000 3rd Qu.:80 3rd Qu.:1.0000
## Max. :4.000 Max. :4.000 Max. :80 Max. :3.0000
##
## TotalWorkingYears TrainingTimesLastYear WorkLifeBalance YearsAtCompany
## Min. : 0.00 Min. :0.000 Min. :1.000 Min. : 0.000
## 1st Qu.: 6.00 1st Qu.:2.000 1st Qu.:2.000 1st Qu.: 3.000
## Median :10.00 Median :3.000 Median :3.000 Median : 5.000
## Mean :11.28 Mean :2.799 Mean :2.761 Mean : 7.008
## 3rd Qu.:15.00 3rd Qu.:3.000 3rd Qu.:3.000 3rd Qu.: 9.000
## Max. :40.00 Max. :6.000 Max. :4.000 Max. :40.000
##
## YearsInCurrentRole YearsSinceLastPromotion YearsWithCurrManager
## Min. : 0.000 Min. : 0.000 Min. : 0.000
## 1st Qu.: 2.000 1st Qu.: 0.000 1st Qu.: 2.000
## Median : 3.000 Median : 1.000 Median : 3.000
## Mean : 4.229 Mean : 2.188 Mean : 4.123
## 3rd Qu.: 7.000 3rd Qu.: 3.000 3rd Qu.: 7.000
## Max. :18.000 Max. :15.000 Max. :17.000
##
We have refrences from sourc that some feature containt leveling classification:
Education
1. ‘Below College’
2. ‘College’
3. ‘Bachelor’
4. ‘Master’
5. ‘Doctor’
EnvironmentSatisfaction 1. ‘Low’
2. ‘Medium’
3. ‘High’
4. ‘Very High’
JobInvolvement 1. ‘Low’
2. ‘Medium’
3. ‘High’
4. ‘Very High’
JobSatisfaction
1. ‘Low’
2. ‘Medium’
3. ‘High’
4. ‘Very High’
PerformanceRating
1. ‘Low’
2. ‘Good’
3. ‘Excellent’
4. ‘Outstanding’
RelationshipSatisfaction
1. ‘Low’
2. ‘Medium’
3. ‘High’
4. ‘Very High’
WorkLifeBalance
1. ‘Bad’
2. ‘Good’
3. ‘Better’
4. ‘Best’
Based on refrence abouve some of our feature have missmatch data type, so we will tro to convert some feature to be correct data type.
employee <- employee %>%
mutate(
Education = as.factor(Education),
EnvironmentSatisfaction = as.factor(EnvironmentSatisfaction),
JobInvolvement = as.factor(JobInvolvement),
JobLevel = as.factor(JobLevel),
JobSatisfaction = as.factor(JobSatisfaction),
PerformanceRating = as.factor(PerformanceRating),
RelationshipSatisfaction = as.factor(RelationshipSatisfaction),
StockOptionLevel = as.factor(StockOptionLevel),
WorkLifeBalance = as.factor(WorkLifeBalance)
)
summary(employee)
## Age Attrition BusinessTravel DailyRate
## Min. :18.00 No :1233 Non-Travel : 150 Min. : 102.0
## 1st Qu.:30.00 Yes: 237 Travel_Frequently: 277 1st Qu.: 465.0
## Median :36.00 Travel_Rarely :1043 Median : 802.0
## Mean :36.92 Mean : 802.5
## 3rd Qu.:43.00 3rd Qu.:1157.0
## Max. :60.00 Max. :1499.0
##
## Department DistanceFromHome Education
## Human Resources : 63 Min. : 1.000 1:170
## Research & Development:961 1st Qu.: 2.000 2:282
## Sales :446 Median : 7.000 3:572
## Mean : 9.193 4:398
## 3rd Qu.:14.000 5: 48
## Max. :29.000
##
## EducationField EmployeeCount EmployeeNumber EnvironmentSatisfaction
## Human Resources : 27 Min. :1 Min. : 1.0 1:284
## Life Sciences :606 1st Qu.:1 1st Qu.: 491.2 2:287
## Marketing :159 Median :1 Median :1020.5 3:453
## Medical :464 Mean :1 Mean :1024.9 4:446
## Other : 82 3rd Qu.:1 3rd Qu.:1555.8
## Technical Degree:132 Max. :1 Max. :2068.0
##
## Gender HourlyRate JobInvolvement JobLevel
## Female:588 Min. : 30.00 1: 83 1:543
## Male :882 1st Qu.: 48.00 2:375 2:534
## Median : 66.00 3:868 3:218
## Mean : 65.89 4:144 4:106
## 3rd Qu.: 83.75 5: 69
## Max. :100.00
##
## JobRole JobSatisfaction MaritalStatus MonthlyIncome
## Sales Executive :326 1:289 Divorced:327 Min. : 1009
## Research Scientist :292 2:280 Married :673 1st Qu.: 2911
## Laboratory Technician :259 3:442 Single :470 Median : 4919
## Manufacturing Director :145 4:459 Mean : 6503
## Healthcare Representative:131 3rd Qu.: 8379
## Manager :102 Max. :19999
## (Other) :215
## MonthlyRate NumCompaniesWorked Over18 OverTime PercentSalaryHike
## Min. : 2094 Min. :0.000 Y:1470 No :1054 Min. :11.00
## 1st Qu.: 8047 1st Qu.:1.000 Yes: 416 1st Qu.:12.00
## Median :14236 Median :2.000 Median :14.00
## Mean :14313 Mean :2.693 Mean :15.21
## 3rd Qu.:20462 3rd Qu.:4.000 3rd Qu.:18.00
## Max. :26999 Max. :9.000 Max. :25.00
##
## PerformanceRating RelationshipSatisfaction StandardHours StockOptionLevel
## 3:1244 1:276 Min. :80 0:631
## 4: 226 2:303 1st Qu.:80 1:596
## 3:459 Median :80 2:158
## 4:432 Mean :80 3: 85
## 3rd Qu.:80
## Max. :80
##
## TotalWorkingYears TrainingTimesLastYear WorkLifeBalance YearsAtCompany
## Min. : 0.00 Min. :0.000 1: 80 Min. : 0.000
## 1st Qu.: 6.00 1st Qu.:2.000 2:344 1st Qu.: 3.000
## Median :10.00 Median :3.000 3:893 Median : 5.000
## Mean :11.28 Mean :2.799 4:153 Mean : 7.008
## 3rd Qu.:15.00 3rd Qu.:3.000 3rd Qu.: 9.000
## Max. :40.00 Max. :6.000 Max. :40.000
##
## YearsInCurrentRole YearsSinceLastPromotion YearsWithCurrManager
## Min. : 0.000 Min. : 0.000 Min. : 0.000
## 1st Qu.: 2.000 1st Qu.: 0.000 1st Qu.: 2.000
## Median : 3.000 Median : 1.000 Median : 3.000
## Mean : 4.229 Mean : 2.188 Mean : 4.123
## 3rd Qu.: 7.000 3rd Qu.: 3.000 3rd Qu.: 7.000
## Max. :18.000 Max. :15.000 Max. :17.000
##
Attrition is our target variable. we would set ‘Yes’ label as our positive value, means our priority target is to know which observation get ‘Yes’ label. We need to set the levels inside the structure data ‘Yes’ level more higher than ‘No’.
employee <- employee %>%
mutate(
Attrition = factor(Attrition, levels = c("Yes", "No"))
)
Check NA or missing value value inside the datase
table(is.na(employee))
##
## FALSE
## 51450
Result is we dont have any NA or Missing value inside our dataset, Finally, we get our final starting dataset
str(employee)
## 'data.frame': 1470 obs. of 35 variables:
## $ Age : int 41 49 37 33 27 32 59 30 38 36 ...
## $ Attrition : Factor w/ 2 levels "Yes","No": 1 2 1 2 2 2 2 2 2 2 ...
## $ BusinessTravel : Factor w/ 3 levels "Non-Travel","Travel_Frequently",..: 3 2 3 2 3 2 3 3 2 3 ...
## $ DailyRate : int 1102 279 1373 1392 591 1005 1324 1358 216 1299 ...
## $ Department : Factor w/ 3 levels "Human Resources",..: 3 2 2 2 2 2 2 2 2 2 ...
## $ DistanceFromHome : int 1 8 2 3 2 2 3 24 23 27 ...
## $ Education : Factor w/ 5 levels "1","2","3","4",..: 2 1 2 4 1 2 3 1 3 3 ...
## $ EducationField : Factor w/ 6 levels "Human Resources",..: 2 2 5 2 4 2 4 2 2 4 ...
## $ EmployeeCount : int 1 1 1 1 1 1 1 1 1 1 ...
## $ EmployeeNumber : int 1 2 4 5 7 8 10 11 12 13 ...
## $ EnvironmentSatisfaction : Factor w/ 4 levels "1","2","3","4": 2 3 4 4 1 4 3 4 4 3 ...
## $ Gender : Factor w/ 2 levels "Female","Male": 1 2 2 1 2 2 1 2 2 2 ...
## $ HourlyRate : int 94 61 92 56 40 79 81 67 44 94 ...
## $ JobInvolvement : Factor w/ 4 levels "1","2","3","4": 3 2 2 3 3 3 4 3 2 3 ...
## $ JobLevel : Factor w/ 5 levels "1","2","3","4",..: 2 2 1 1 1 1 1 1 3 2 ...
## $ JobRole : Factor w/ 9 levels "Healthcare Representative",..: 8 7 3 7 3 3 3 3 5 1 ...
## $ JobSatisfaction : Factor w/ 4 levels "1","2","3","4": 4 2 3 3 2 4 1 3 3 3 ...
## $ MaritalStatus : Factor w/ 3 levels "Divorced","Married",..: 3 2 3 2 2 3 2 1 3 2 ...
## $ MonthlyIncome : int 5993 5130 2090 2909 3468 3068 2670 2693 9526 5237 ...
## $ MonthlyRate : int 19479 24907 2396 23159 16632 11864 9964 13335 8787 16577 ...
## $ NumCompaniesWorked : int 8 1 6 1 9 0 4 1 0 6 ...
## $ Over18 : Factor w/ 1 level "Y": 1 1 1 1 1 1 1 1 1 1 ...
## $ OverTime : Factor w/ 2 levels "No","Yes": 2 1 2 2 1 1 2 1 1 1 ...
## $ PercentSalaryHike : int 11 23 15 11 12 13 20 22 21 13 ...
## $ PerformanceRating : Factor w/ 2 levels "3","4": 1 2 1 1 1 1 2 2 2 1 ...
## $ RelationshipSatisfaction: Factor w/ 4 levels "1","2","3","4": 1 4 2 3 4 3 1 2 2 2 ...
## $ StandardHours : int 80 80 80 80 80 80 80 80 80 80 ...
## $ StockOptionLevel : Factor w/ 4 levels "0","1","2","3": 1 2 1 1 2 1 4 2 1 3 ...
## $ TotalWorkingYears : int 8 10 7 8 6 8 12 1 10 17 ...
## $ TrainingTimesLastYear : int 0 3 3 3 3 2 3 2 2 3 ...
## $ WorkLifeBalance : Factor w/ 4 levels "1","2","3","4": 1 3 3 3 3 2 2 3 3 2 ...
## $ YearsAtCompany : int 6 10 0 8 2 7 1 1 9 7 ...
## $ YearsInCurrentRole : int 4 7 0 7 2 7 0 0 7 7 ...
## $ YearsSinceLastPromotion : int 0 1 0 3 2 3 0 0 1 7 ...
## $ YearsWithCurrManager : int 5 7 0 0 2 6 0 0 8 7 ...
We will observe if there is class imbalance by looking proportion of target variable Attrition
prop.table(table(employee$Attrition))
##
## Yes No
## 0.1612245 0.8387755
Found there is imblance class in our target variable, which ‘Yes’ class is bigger than ‘No’ class and it really have big gap 83% compare 16%.
Based on this findings we need down or upsampling the dataset.
We will split the data into training set, validation set, and testing set. first step we need to split dataset to be training and testing dataset.
set.seed(100)
initial_split <- initial_split(employee, prop = 0.8, strata = "Attrition")
set.seed(100)
train_split <- initial_split(training(initial_split), prop = 0.8, strata = "Attrition")
We will split our training dataset to be training dataset and validation dataset with proportion training dataset around 80% for training dan 20% for validation. Beside it we would downsample data
rec <- recipe(Attrition ~ ., training(train_split)) %>%
step_rm(StandardHours, EmployeeCount, EmployeeNumber, Over18) %>%
step_nzv(all_predictors()) %>%
step_upsample(Attrition, ratio = 1/1, seed = 100) %>%
step_range(all_numeric(), min = 0, max = 1, -Attrition) %>%
# step_center(all_numeric()) %>%
# step_scale(all_numeric()) %>%
step_dummy(all_nominal(), -Attrition, one_hot = FALSE) %>%
prep(strings_as_factors = FALSE)
data_train <- juice(rec)
data_val <- bake(rec, testing(train_split))
data_test <- bake(rec, testing(initial_split))
initial_split
## <1177/293/1470>
prop.table(table(data_train$Attrition))
##
## Yes No
## 0.5 0.5
prop.table(table(data_val$Attrition))
##
## Yes No
## 0.1581197 0.8418803
prop.table(table(data_test$Attrition))
##
## Yes No
## 0.1604096 0.8395904
We can see we downsample only data train and let real data for data validation and data test. We adjust the data to get a proper structure before we fed them into keras.
train_y <- as.numeric(data_train$Attrition)-1
train_x <- data_train %>%
select(-Attrition) %>%
data.matrix()
val_y <- as.numeric(data_val$Attrition)-1
val_x <- data_val %>%
select(-Attrition) %>%
data.matrix()
test_y <- as.numeric(data_test$Attrition)-1
test_x <- data_test %>%
select(-Attrition) %>%
data.matrix()
Neural Network is inspired by the biological neural network system of our brain. It consists of input layer, hidden layer, and output layer. The data will be fed into the input layer, processed through the hidden layer, and converted into specific values, such as probability, in the output layer. The MLP has a back-propagation feature, which means that it will go back and forth to adjust the weight of each connection between neurons in order to minimize the loss function and get better performance.
We will build several layers. There are layer dense which will scale our data using relu activation function in the first and second layer. I will put drouput layer to prevent the model from overfitting. For last layer, we scale back our dat int range [0,1] with sigmoid activation function asthe probability of our data belong to a particular class.
input_n <- ncol(train_x)
model <- keras_model_sequential() %>%
layer_dense(input_shape = input_n,
units = 32,
activation = "relu") %>%
layer_dense(units = 16,
activation = "relu") %>%
# layer_dropout(rate = 0.1) %>%
# layer_batch_normalization() %>%
layer_dense(units = 1,
activation = "sigmoid")
model %>%
compile(optimizer = "adam",
metric = "accuracy",
loss = "binary_crossentropy")
model
## Model
## Model: "sequential"
## ________________________________________________________________________________
## Layer (type) Output Shape Param #
## ================================================================================
## dense (Dense) (None, 32) 2016
## ________________________________________________________________________________
## dense_1 (Dense) (None, 16) 528
## ________________________________________________________________________________
## dense_2 (Dense) (None, 1) 17
## ================================================================================
## Total params: 2,561
## Trainable params: 2,561
## Non-trainable params: 0
## ________________________________________________________________________________
set.seed(100)
history <- model %>%
fit(
x = train_x,
y = train_y,
batch_size = 124,
epochs = 10,
seed = 100,
verbose = 1,
validation_data = list(
val_x,
val_y
)
)
plot(history)
## `geom_smooth()` using formula 'y ~ x'
Our Model get 80% accuracy on training dataset and 67% accuracy on validation dataset. We get the difference between it about 13% it still accaptable and can conclude that our model we made before isnt overfit.
pred_test <- as_tibble(predict(model, test_x)) %>%
set_names("value") %>%
mutate(class = if_else(value > 0.5, "No", "Yes")) %>%
mutate(class = factor(class, levels = levels(data_test$Attrition))) %>%
set_names(paste0("pred_", colnames(.)))
## Warning: `as_tibble.matrix()` requires a matrix with column names or a `.name_repair` argument. Using compatibility `.name_repair`.
## This warning is displayed once per session.
pred_test <- data_test %>%
select(Attrition) %>%
bind_cols(pred_test)
summary(pred_test$pred_class)
## Yes No
## 128 165
pred_test
## # A tibble: 293 x 3
## Attrition pred_value pred_class
## <fct> <dbl> <fct>
## 1 No 0.294 Yes
## 2 No 0.935 No
## 3 No 0.512 No
## 4 No 0.677 No
## 5 Yes 0.552 No
## 6 Yes 0.897 No
## 7 No 0.186 Yes
## 8 No 0.580 No
## 9 No 0.863 No
## 10 No 0.540 No
## # … with 283 more rows
We will check confusion matrix from test dataset.
pred_test %>%
conf_mat(Attrition, pred_class) %>%
autoplot(type = "heatmap")
# metrics summary
pred_test %>%
summarise(
accuracy = accuracy_vec(Attrition, pred_class),
sensitivity = sens_vec(Attrition, pred_class),
specificity = spec_vec(Attrition, pred_class),
precision = precision_vec(Attrition, pred_class)
)
## # A tibble: 1 x 4
## accuracy sensitivity specificity precision
## <dbl> <dbl> <dbl> <dbl>
## 1 0.512 0.340 0.545 0.125
pred_test %>%
roc_curve(Attrition, pred_value) %>%
autoplot()
pred_test %>%
roc_auc(Attrition, pred_value)
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 roc_auc binary 0.545
pred_test_roc <- pred_test %>%
roc_curve(Attrition, pred_value)
p <- pred_test_roc %>%
mutate_if(~ is.numeric(.), ~ round(.,4)) %>%
gather(metric, value, -.threshold) %>%
ggplot(aes(.threshold, value)) +
geom_line(aes(colour = metric)) +
labs(x = "Probability Threshold to be Classified as Positive", y = "Value", colour = "Metrics") +
theme_minimal()
ggplotly(p)
pred_test %>%
pr_curve(Attrition, pred_value) %>%
autoplot()
pred_test_pr <- pred_test %>%
pr_curve(Attrition, pred_value)
p <- pred_test_pr %>%
mutate_if(~ is.numeric(.), ~ round(.,4)) %>%
gather(metric, value, -.threshold) %>%
ggplot(aes(.threshold, value)) +
geom_line(aes(colour = metric)) +
labs(x = "Probability Threshold to be Classified as Positive", y = "Value", colour = "Metrics") +
theme_minimal()
ggplotly(p)
This process target is to understand how our model works to decide classification to our each observant. Why we need it? because Neural networks is “black box” nature meaning these sophisticated models. I will use LIME packages to intepret how it works.
# choose explanation data
data_explain <- testing(initial_split)
get_features <- function(x) {
matrix <- data.matrix(bake(rec, x, -Attrition))
matrix
}
lime_model <- as_classifier(model, labels = levels(data_explain$Attrition))
set.seed(100)
explainer <- lime(
x = data_explain,
model = lime_model,
preprocess = get_features
)
# get lime explanation
explanation <- explain(
x = data_explain[1:4,],
explainer = explainer,
n_labels = 1,
n_features = 4
)
# plot feature explanation
plot_features(explanation) +
labs(title = "LIME Feature Importance Visualization")
Plot above i use LIME package to understand and get which feature are importance to our model decide the classification. This packages allows us to visualize each of the first 4 cases (observations) from the test data. The top four features for each case are shown. Note that they are not the same for each case. The blue bars mean that the feature supports the model conclusion, and the red bars contradict. A few important features based on frequency in first 4 cases:
One thing we need to be careful with the LIME visualization is that we are only doing a sample of the data, in our case the first 4 test observations. Therefore, we are gaining a very localized understanding of how our models work. However, we also want to know on from a global perspective what drives feature importance.
We can perform a correlation analysis on the training set as well to help glean what features correlate globally to “Attrition”. We’ll use the corrr package, which performs tidy correlations:
# Feature correlations to Churn
corrr_analysis <- data.frame(train_x) %>%
mutate(Attrition = train_y) %>%
correlate() %>%
focus(Attrition) %>%
rename(feature = rowname) %>%
arrange(abs(Attrition)) %>%
mutate(feature = as_factor(feature))
##
## Correlation method: 'pearson'
## Missing treated using: 'pairwise.complete.obs'
corrr_analysis <- corrr_analysis %>%
mutate(absAttrition = abs(Attrition)) %>%
arrange(desc(absAttrition)) %>%
slice(1:20) %>%
select(-absAttrition)
corrr_analysis
## # A tibble: 20 x 2
## feature Attrition
## <fct> <dbl>
## 1 TotalWorkingYears 0.304
## 2 StockOptionLevel_X1 0.279
## 3 YearsInCurrentRole 0.277
## 4 Age 0.263
## 5 MonthlyIncome 0.254
## 6 MaritalStatus_Single -0.252
## 7 YearsAtCompany 0.251
## 8 OverTime_Yes -0.246
## 9 YearsWithCurrManager 0.239
## 10 JobRole_Manager 0.180
## 11 JobRole_Research.Director 0.166
## 12 MaritalStatus_Married 0.165
## 13 JobRole_Sales.Representative -0.164
## 14 JobLevel_X2 0.161
## 15 StockOptionLevel_X2 0.156
## 16 BusinessTravel_Travel_Frequently -0.143
## 17 Department_Research...Development 0.142
## 18 WorkLifeBalance_X3 0.138
## 19 JobRole_Laboratory.Technician -0.136
## 20 JobLevel_X5 0.132
# Correlation visualization
corrr_analysis %>%
ggplot(aes(x = Attrition, y = fct_reorder(feature, desc(Attrition)))) +
geom_point() +
# Positive Correlations - Contribute to churn
geom_segment(aes(xend = 0, yend = feature),
color = palette_light()[[2]],
data = corrr_analysis %>% filter(Attrition > 0)) +
geom_point(color = palette_light()[[2]],
data = corrr_analysis %>% filter(Attrition > 0)) +
# Negative Correlations - Prevent churn
geom_segment(aes(xend = 0, yend = feature),
color = palette_light()[[1]],
data = corrr_analysis %>% filter(Attrition < 0)) +
geom_point(color = palette_light()[[1]],
data = corrr_analysis %>% filter(Attrition < 0)) +
# Vertical lines
geom_vline(xintercept = 0, color = palette_light()[[5]], size = 1, linetype = 2) +
geom_vline(xintercept = -0.25, color = palette_light()[[5]], size = 1, linetype = 2) +
geom_vline(xintercept = 0.25, color = palette_light()[[5]], size = 1, linetype = 2) +
# Aesthetics
theme_tq() +
labs(title = "Churn Correlation Analysis",
subtitle = paste("Positive Correlations (contribute to attrition),",
"Negative Correlations (prevent attrition)"),
y = "Feature Importance")
The correlation analysis helps us quickly disseminate which features that the LIME analysis may be excluding. We can see that the following features are highly correlated (magnitude > 0.25):
Increases Likelihood of Attrition (Red): - Total Working Years - Stock Option level = 1 (true) - years in current Role - Age - Monthly Income
Decreases Likelihood of Attrition (Black): - Martial Status = Single - Over Time = Yes
In this case i would use specific question is, How to classify using neural network method and how is result of our model and predict classify and compare with our data test. Beside it we would try to uncover which feature
The result evaluation we use Confussion Matrix as model evaluation and result is :
# metrics summary
pred_test %>%
summarise(
accuracy = accuracy_vec(Attrition, pred_class),
sensitivity = sens_vec(Attrition, pred_class),
specificity = spec_vec(Attrition, pred_class),
precision = precision_vec(Attrition, pred_class)
)
## # A tibble: 1 x 4
## accuracy sensitivity specificity precision
## <dbl> <dbl> <dbl> <dbl>
## 1 0.512 0.340 0.545 0.125
We can conclude our neural network working not good in this case, our model only good to predict “No” label which mean we cant anticipate employee will have Attrition. Several reason make the Neural Network cant running well using this data but major reason is Total observent its too small, and the data we get is imbalance.
Therefore we can uncover which feature or variable have important to HR understand why employee get Attrition. Based on Intepret in model Audit we can conclude:
Is variable or feature important to understand employee attrition