Intro

Background

This is a fictional data set created by IBM data scientists. The target why they create this fictial data set, is to uncover the factors that lead to employee attrition.

In this case i would use specific question is, How to classify using neural network method and how is result of our model and predict classify and compare with our data test. Beside it we would try to uncover which feature is important to understand employee attrition

Dataset

Dataset we get from this kaggle, this is an public dataset which mean everyone can access dataset.

Data Preparation

Import library

In this step, we would like to preparing and wrangling data to help do this process and to help our modelling we need to import some package. We will use some wrangling package, some data processing and modelling package, and we need some visualize package.

library(tidyverse)
library(ggplot2)
library(GGally)
library(reshape)
library(keras)
library(rsample)
library(recipes)
library(yardstick)
library(caret)
library(plotly)
library(lime)
library(scales)
library(corrr)
library(tidyquant)

options(scipen = 100)

Read Data

Import the dataset from CSV

employee <- read.csv("data/WA_Fn-UseC_-HR-Employee-Attrition.csv")
glimpse(employee)
## Observations: 1,470
## Variables: 35
## $ Age                      <int> 41, 49, 37, 33, 27, 32, 59, 30, 38, 36, 35, …
## $ Attrition                <fct> Yes, No, Yes, No, No, No, No, No, No, No, No…
## $ BusinessTravel           <fct> Travel_Rarely, Travel_Frequently, Travel_Rar…
## $ DailyRate                <int> 1102, 279, 1373, 1392, 591, 1005, 1324, 1358…
## $ Department               <fct> Sales, Research & Development, Research & De…
## $ DistanceFromHome         <int> 1, 8, 2, 3, 2, 2, 3, 24, 23, 27, 16, 15, 26,…
## $ Education                <int> 2, 1, 2, 4, 1, 2, 3, 1, 3, 3, 3, 2, 1, 2, 3,…
## $ EducationField           <fct> Life Sciences, Life Sciences, Other, Life Sc…
## $ EmployeeCount            <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,…
## $ EmployeeNumber           <int> 1, 2, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16…
## $ EnvironmentSatisfaction  <int> 2, 3, 4, 4, 1, 4, 3, 4, 4, 3, 1, 4, 1, 2, 3,…
## $ Gender                   <fct> Female, Male, Male, Female, Male, Male, Fema…
## $ HourlyRate               <int> 94, 61, 92, 56, 40, 79, 81, 67, 44, 94, 84, …
## $ JobInvolvement           <int> 3, 2, 2, 3, 3, 3, 4, 3, 2, 3, 4, 2, 3, 3, 2,…
## $ JobLevel                 <int> 2, 2, 1, 1, 1, 1, 1, 1, 3, 2, 1, 2, 1, 1, 1,…
## $ JobRole                  <fct> Sales Executive, Research Scientist, Laborat…
## $ JobSatisfaction          <int> 4, 2, 3, 3, 2, 4, 1, 3, 3, 3, 2, 3, 3, 4, 3,…
## $ MaritalStatus            <fct> Single, Married, Single, Married, Married, S…
## $ MonthlyIncome            <int> 5993, 5130, 2090, 2909, 3468, 3068, 2670, 26…
## $ MonthlyRate              <int> 19479, 24907, 2396, 23159, 16632, 11864, 996…
## $ NumCompaniesWorked       <int> 8, 1, 6, 1, 9, 0, 4, 1, 0, 6, 0, 0, 1, 0, 5,…
## $ Over18                   <fct> Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y,…
## $ OverTime                 <fct> Yes, No, Yes, Yes, No, No, Yes, No, No, No, …
## $ PercentSalaryHike        <int> 11, 23, 15, 11, 12, 13, 20, 22, 21, 13, 13, …
## $ PerformanceRating        <int> 3, 4, 3, 3, 3, 3, 4, 4, 4, 3, 3, 3, 3, 3, 3,…
## $ RelationshipSatisfaction <int> 1, 4, 2, 3, 4, 3, 1, 2, 2, 2, 3, 4, 4, 3, 2,…
## $ StandardHours            <int> 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, …
## $ StockOptionLevel         <int> 0, 1, 0, 0, 1, 0, 3, 1, 0, 2, 1, 0, 1, 1, 0,…
## $ TotalWorkingYears        <int> 8, 10, 7, 8, 6, 8, 12, 1, 10, 17, 6, 10, 5, …
## $ TrainingTimesLastYear    <int> 0, 3, 3, 3, 3, 2, 3, 2, 2, 3, 5, 3, 1, 2, 4,…
## $ WorkLifeBalance          <int> 1, 3, 3, 3, 3, 2, 2, 3, 3, 2, 3, 3, 2, 3, 3,…
## $ YearsAtCompany           <int> 6, 10, 0, 8, 2, 7, 1, 1, 9, 7, 5, 9, 5, 2, 4…
## $ YearsInCurrentRole       <int> 4, 7, 0, 7, 2, 7, 0, 0, 7, 7, 4, 5, 2, 2, 2,…
## $ YearsSinceLastPromotion  <int> 0, 1, 0, 3, 2, 3, 0, 0, 1, 7, 0, 0, 4, 1, 0,…
## $ YearsWithCurrManager     <int> 5, 7, 0, 0, 2, 6, 0, 0, 8, 7, 3, 8, 3, 2, 3,…

Data Wrangling

After import data we try to preview the data we have now. I use summary function to get summary from each feature

summary(employee)
##       Age        Attrition            BusinessTravel   DailyRate     
##  Min.   :18.00   No :1233   Non-Travel       : 150   Min.   : 102.0  
##  1st Qu.:30.00   Yes: 237   Travel_Frequently: 277   1st Qu.: 465.0  
##  Median :36.00              Travel_Rarely    :1043   Median : 802.0  
##  Mean   :36.92                                       Mean   : 802.5  
##  3rd Qu.:43.00                                       3rd Qu.:1157.0  
##  Max.   :60.00                                       Max.   :1499.0  
##                                                                      
##                   Department  DistanceFromHome   Education    
##  Human Resources       : 63   Min.   : 1.000   Min.   :1.000  
##  Research & Development:961   1st Qu.: 2.000   1st Qu.:2.000  
##  Sales                 :446   Median : 7.000   Median :3.000  
##                               Mean   : 9.193   Mean   :2.913  
##                               3rd Qu.:14.000   3rd Qu.:4.000  
##                               Max.   :29.000   Max.   :5.000  
##                                                               
##           EducationField EmployeeCount EmployeeNumber   EnvironmentSatisfaction
##  Human Resources : 27    Min.   :1     Min.   :   1.0   Min.   :1.000          
##  Life Sciences   :606    1st Qu.:1     1st Qu.: 491.2   1st Qu.:2.000          
##  Marketing       :159    Median :1     Median :1020.5   Median :3.000          
##  Medical         :464    Mean   :1     Mean   :1024.9   Mean   :2.722          
##  Other           : 82    3rd Qu.:1     3rd Qu.:1555.8   3rd Qu.:4.000          
##  Technical Degree:132    Max.   :1     Max.   :2068.0   Max.   :4.000          
##                                                                                
##     Gender      HourlyRate     JobInvolvement    JobLevel    
##  Female:588   Min.   : 30.00   Min.   :1.00   Min.   :1.000  
##  Male  :882   1st Qu.: 48.00   1st Qu.:2.00   1st Qu.:1.000  
##               Median : 66.00   Median :3.00   Median :2.000  
##               Mean   : 65.89   Mean   :2.73   Mean   :2.064  
##               3rd Qu.: 83.75   3rd Qu.:3.00   3rd Qu.:3.000  
##               Max.   :100.00   Max.   :4.00   Max.   :5.000  
##                                                              
##                       JobRole    JobSatisfaction  MaritalStatus MonthlyIncome  
##  Sales Executive          :326   Min.   :1.000   Divorced:327   Min.   : 1009  
##  Research Scientist       :292   1st Qu.:2.000   Married :673   1st Qu.: 2911  
##  Laboratory Technician    :259   Median :3.000   Single  :470   Median : 4919  
##  Manufacturing Director   :145   Mean   :2.729                  Mean   : 6503  
##  Healthcare Representative:131   3rd Qu.:4.000                  3rd Qu.: 8379  
##  Manager                  :102   Max.   :4.000                  Max.   :19999  
##  (Other)                  :215                                                 
##   MonthlyRate    NumCompaniesWorked Over18   OverTime   PercentSalaryHike
##  Min.   : 2094   Min.   :0.000      Y:1470   No :1054   Min.   :11.00    
##  1st Qu.: 8047   1st Qu.:1.000               Yes: 416   1st Qu.:12.00    
##  Median :14236   Median :2.000                          Median :14.00    
##  Mean   :14313   Mean   :2.693                          Mean   :15.21    
##  3rd Qu.:20462   3rd Qu.:4.000                          3rd Qu.:18.00    
##  Max.   :26999   Max.   :9.000                          Max.   :25.00    
##                                                                          
##  PerformanceRating RelationshipSatisfaction StandardHours StockOptionLevel
##  Min.   :3.000     Min.   :1.000            Min.   :80    Min.   :0.0000  
##  1st Qu.:3.000     1st Qu.:2.000            1st Qu.:80    1st Qu.:0.0000  
##  Median :3.000     Median :3.000            Median :80    Median :1.0000  
##  Mean   :3.154     Mean   :2.712            Mean   :80    Mean   :0.7939  
##  3rd Qu.:3.000     3rd Qu.:4.000            3rd Qu.:80    3rd Qu.:1.0000  
##  Max.   :4.000     Max.   :4.000            Max.   :80    Max.   :3.0000  
##                                                                           
##  TotalWorkingYears TrainingTimesLastYear WorkLifeBalance YearsAtCompany  
##  Min.   : 0.00     Min.   :0.000         Min.   :1.000   Min.   : 0.000  
##  1st Qu.: 6.00     1st Qu.:2.000         1st Qu.:2.000   1st Qu.: 3.000  
##  Median :10.00     Median :3.000         Median :3.000   Median : 5.000  
##  Mean   :11.28     Mean   :2.799         Mean   :2.761   Mean   : 7.008  
##  3rd Qu.:15.00     3rd Qu.:3.000         3rd Qu.:3.000   3rd Qu.: 9.000  
##  Max.   :40.00     Max.   :6.000         Max.   :4.000   Max.   :40.000  
##                                                                          
##  YearsInCurrentRole YearsSinceLastPromotion YearsWithCurrManager
##  Min.   : 0.000     Min.   : 0.000          Min.   : 0.000      
##  1st Qu.: 2.000     1st Qu.: 0.000          1st Qu.: 2.000      
##  Median : 3.000     Median : 1.000          Median : 3.000      
##  Mean   : 4.229     Mean   : 2.188          Mean   : 4.123      
##  3rd Qu.: 7.000     3rd Qu.: 3.000          3rd Qu.: 7.000      
##  Max.   :18.000     Max.   :15.000          Max.   :17.000      
## 

We have refrences from sourc that some feature containt leveling classification:

Education
1. ‘Below College’
2. ‘College’
3. ‘Bachelor’
4. ‘Master’
5. ‘Doctor’

EnvironmentSatisfaction 1. ‘Low’
2. ‘Medium’
3. ‘High’
4. ‘Very High’

JobInvolvement 1. ‘Low’
2. ‘Medium’
3. ‘High’
4. ‘Very High’

JobSatisfaction
1. ‘Low’
2. ‘Medium’
3. ‘High’
4. ‘Very High’

PerformanceRating
1. ‘Low’
2. ‘Good’
3. ‘Excellent’
4. ‘Outstanding’

RelationshipSatisfaction
1. ‘Low’
2. ‘Medium’
3. ‘High’
4. ‘Very High’

WorkLifeBalance
1. ‘Bad’
2. ‘Good’
3. ‘Better’
4. ‘Best’

Based on refrence abouve some of our feature have missmatch data type, so we will tro to convert some feature to be correct data type.

employee <- employee %>%
  mutate(
    Education = as.factor(Education),
    EnvironmentSatisfaction = as.factor(EnvironmentSatisfaction),
    JobInvolvement = as.factor(JobInvolvement),
    JobLevel = as.factor(JobLevel),
    JobSatisfaction = as.factor(JobSatisfaction),
    PerformanceRating = as.factor(PerformanceRating),
    RelationshipSatisfaction = as.factor(RelationshipSatisfaction),
    StockOptionLevel = as.factor(StockOptionLevel),
    WorkLifeBalance = as.factor(WorkLifeBalance)
  )

summary(employee)
##       Age        Attrition            BusinessTravel   DailyRate     
##  Min.   :18.00   No :1233   Non-Travel       : 150   Min.   : 102.0  
##  1st Qu.:30.00   Yes: 237   Travel_Frequently: 277   1st Qu.: 465.0  
##  Median :36.00              Travel_Rarely    :1043   Median : 802.0  
##  Mean   :36.92                                       Mean   : 802.5  
##  3rd Qu.:43.00                                       3rd Qu.:1157.0  
##  Max.   :60.00                                       Max.   :1499.0  
##                                                                      
##                   Department  DistanceFromHome Education
##  Human Resources       : 63   Min.   : 1.000   1:170    
##  Research & Development:961   1st Qu.: 2.000   2:282    
##  Sales                 :446   Median : 7.000   3:572    
##                               Mean   : 9.193   4:398    
##                               3rd Qu.:14.000   5: 48    
##                               Max.   :29.000            
##                                                         
##           EducationField EmployeeCount EmployeeNumber   EnvironmentSatisfaction
##  Human Resources : 27    Min.   :1     Min.   :   1.0   1:284                  
##  Life Sciences   :606    1st Qu.:1     1st Qu.: 491.2   2:287                  
##  Marketing       :159    Median :1     Median :1020.5   3:453                  
##  Medical         :464    Mean   :1     Mean   :1024.9   4:446                  
##  Other           : 82    3rd Qu.:1     3rd Qu.:1555.8                          
##  Technical Degree:132    Max.   :1     Max.   :2068.0                          
##                                                                                
##     Gender      HourlyRate     JobInvolvement JobLevel
##  Female:588   Min.   : 30.00   1: 83          1:543   
##  Male  :882   1st Qu.: 48.00   2:375          2:534   
##               Median : 66.00   3:868          3:218   
##               Mean   : 65.89   4:144          4:106   
##               3rd Qu.: 83.75                  5: 69   
##               Max.   :100.00                          
##                                                       
##                       JobRole    JobSatisfaction  MaritalStatus MonthlyIncome  
##  Sales Executive          :326   1:289           Divorced:327   Min.   : 1009  
##  Research Scientist       :292   2:280           Married :673   1st Qu.: 2911  
##  Laboratory Technician    :259   3:442           Single  :470   Median : 4919  
##  Manufacturing Director   :145   4:459                          Mean   : 6503  
##  Healthcare Representative:131                                  3rd Qu.: 8379  
##  Manager                  :102                                  Max.   :19999  
##  (Other)                  :215                                                 
##   MonthlyRate    NumCompaniesWorked Over18   OverTime   PercentSalaryHike
##  Min.   : 2094   Min.   :0.000      Y:1470   No :1054   Min.   :11.00    
##  1st Qu.: 8047   1st Qu.:1.000               Yes: 416   1st Qu.:12.00    
##  Median :14236   Median :2.000                          Median :14.00    
##  Mean   :14313   Mean   :2.693                          Mean   :15.21    
##  3rd Qu.:20462   3rd Qu.:4.000                          3rd Qu.:18.00    
##  Max.   :26999   Max.   :9.000                          Max.   :25.00    
##                                                                          
##  PerformanceRating RelationshipSatisfaction StandardHours StockOptionLevel
##  3:1244            1:276                    Min.   :80    0:631           
##  4: 226            2:303                    1st Qu.:80    1:596           
##                    3:459                    Median :80    2:158           
##                    4:432                    Mean   :80    3: 85           
##                                             3rd Qu.:80                    
##                                             Max.   :80                    
##                                                                           
##  TotalWorkingYears TrainingTimesLastYear WorkLifeBalance YearsAtCompany  
##  Min.   : 0.00     Min.   :0.000         1: 80           Min.   : 0.000  
##  1st Qu.: 6.00     1st Qu.:2.000         2:344           1st Qu.: 3.000  
##  Median :10.00     Median :3.000         3:893           Median : 5.000  
##  Mean   :11.28     Mean   :2.799         4:153           Mean   : 7.008  
##  3rd Qu.:15.00     3rd Qu.:3.000                         3rd Qu.: 9.000  
##  Max.   :40.00     Max.   :6.000                         Max.   :40.000  
##                                                                          
##  YearsInCurrentRole YearsSinceLastPromotion YearsWithCurrManager
##  Min.   : 0.000     Min.   : 0.000          Min.   : 0.000      
##  1st Qu.: 2.000     1st Qu.: 0.000          1st Qu.: 2.000      
##  Median : 3.000     Median : 1.000          Median : 3.000      
##  Mean   : 4.229     Mean   : 2.188          Mean   : 4.123      
##  3rd Qu.: 7.000     3rd Qu.: 3.000          3rd Qu.: 7.000      
##  Max.   :18.000     Max.   :15.000          Max.   :17.000      
## 

Attrition is our target variable. we would set ‘Yes’ label as our positive value, means our priority target is to know which observation get ‘Yes’ label. We need to set the levels inside the structure data ‘Yes’ level more higher than ‘No’.

employee <- employee %>% 
  mutate(
    Attrition = factor(Attrition, levels = c("Yes", "No"))
  )

Check NA or missing value value inside the datase

table(is.na(employee))
## 
## FALSE 
## 51450

Result is we dont have any NA or Missing value inside our dataset, Finally, we get our final starting dataset

str(employee)
## 'data.frame':    1470 obs. of  35 variables:
##  $ Age                     : int  41 49 37 33 27 32 59 30 38 36 ...
##  $ Attrition               : Factor w/ 2 levels "Yes","No": 1 2 1 2 2 2 2 2 2 2 ...
##  $ BusinessTravel          : Factor w/ 3 levels "Non-Travel","Travel_Frequently",..: 3 2 3 2 3 2 3 3 2 3 ...
##  $ DailyRate               : int  1102 279 1373 1392 591 1005 1324 1358 216 1299 ...
##  $ Department              : Factor w/ 3 levels "Human Resources",..: 3 2 2 2 2 2 2 2 2 2 ...
##  $ DistanceFromHome        : int  1 8 2 3 2 2 3 24 23 27 ...
##  $ Education               : Factor w/ 5 levels "1","2","3","4",..: 2 1 2 4 1 2 3 1 3 3 ...
##  $ EducationField          : Factor w/ 6 levels "Human Resources",..: 2 2 5 2 4 2 4 2 2 4 ...
##  $ EmployeeCount           : int  1 1 1 1 1 1 1 1 1 1 ...
##  $ EmployeeNumber          : int  1 2 4 5 7 8 10 11 12 13 ...
##  $ EnvironmentSatisfaction : Factor w/ 4 levels "1","2","3","4": 2 3 4 4 1 4 3 4 4 3 ...
##  $ Gender                  : Factor w/ 2 levels "Female","Male": 1 2 2 1 2 2 1 2 2 2 ...
##  $ HourlyRate              : int  94 61 92 56 40 79 81 67 44 94 ...
##  $ JobInvolvement          : Factor w/ 4 levels "1","2","3","4": 3 2 2 3 3 3 4 3 2 3 ...
##  $ JobLevel                : Factor w/ 5 levels "1","2","3","4",..: 2 2 1 1 1 1 1 1 3 2 ...
##  $ JobRole                 : Factor w/ 9 levels "Healthcare Representative",..: 8 7 3 7 3 3 3 3 5 1 ...
##  $ JobSatisfaction         : Factor w/ 4 levels "1","2","3","4": 4 2 3 3 2 4 1 3 3 3 ...
##  $ MaritalStatus           : Factor w/ 3 levels "Divorced","Married",..: 3 2 3 2 2 3 2 1 3 2 ...
##  $ MonthlyIncome           : int  5993 5130 2090 2909 3468 3068 2670 2693 9526 5237 ...
##  $ MonthlyRate             : int  19479 24907 2396 23159 16632 11864 9964 13335 8787 16577 ...
##  $ NumCompaniesWorked      : int  8 1 6 1 9 0 4 1 0 6 ...
##  $ Over18                  : Factor w/ 1 level "Y": 1 1 1 1 1 1 1 1 1 1 ...
##  $ OverTime                : Factor w/ 2 levels "No","Yes": 2 1 2 2 1 1 2 1 1 1 ...
##  $ PercentSalaryHike       : int  11 23 15 11 12 13 20 22 21 13 ...
##  $ PerformanceRating       : Factor w/ 2 levels "3","4": 1 2 1 1 1 1 2 2 2 1 ...
##  $ RelationshipSatisfaction: Factor w/ 4 levels "1","2","3","4": 1 4 2 3 4 3 1 2 2 2 ...
##  $ StandardHours           : int  80 80 80 80 80 80 80 80 80 80 ...
##  $ StockOptionLevel        : Factor w/ 4 levels "0","1","2","3": 1 2 1 1 2 1 4 2 1 3 ...
##  $ TotalWorkingYears       : int  8 10 7 8 6 8 12 1 10 17 ...
##  $ TrainingTimesLastYear   : int  0 3 3 3 3 2 3 2 2 3 ...
##  $ WorkLifeBalance         : Factor w/ 4 levels "1","2","3","4": 1 3 3 3 3 2 2 3 3 2 ...
##  $ YearsAtCompany          : int  6 10 0 8 2 7 1 1 9 7 ...
##  $ YearsInCurrentRole      : int  4 7 0 7 2 7 0 0 7 7 ...
##  $ YearsSinceLastPromotion : int  0 1 0 3 2 3 0 0 1 7 ...
##  $ YearsWithCurrManager    : int  5 7 0 0 2 6 0 0 8 7 ...

Exploratory Data Analysis

We will observe if there is class imbalance by looking proportion of target variable Attrition

prop.table(table(employee$Attrition))
## 
##       Yes        No 
## 0.1612245 0.8387755

Found there is imblance class in our target variable, which ‘Yes’ class is bigger than ‘No’ class and it really have big gap 83% compare 16%.

Based on this findings we need down or upsampling the dataset.

Modelling

Cross-Validation

We will split the data into training set, validation set, and testing set. first step we need to split dataset to be training and testing dataset.

set.seed(100)
initial_split <- initial_split(employee, prop = 0.8, strata = "Attrition")

set.seed(100)
train_split <- initial_split(training(initial_split), prop = 0.8, strata = "Attrition")

We will split our training dataset to be training dataset and validation dataset with proportion training dataset around 80% for training dan 20% for validation. Beside it we would downsample data

rec <- recipe(Attrition ~ ., training(train_split)) %>% 
  step_rm(StandardHours, EmployeeCount, EmployeeNumber, Over18) %>% 
  step_nzv(all_predictors()) %>% 
  step_upsample(Attrition, ratio = 1/1, seed = 100) %>% 
  step_range(all_numeric(), min = 0, max = 1, -Attrition) %>%
  # step_center(all_numeric()) %>%
  # step_scale(all_numeric()) %>%
  step_dummy(all_nominal(), -Attrition, one_hot = FALSE) %>% 
  prep(strings_as_factors = FALSE)

data_train <- juice(rec)
data_val <- bake(rec, testing(train_split))
data_test <- bake(rec, testing(initial_split))
initial_split
## <1177/293/1470>
prop.table(table(data_train$Attrition))
## 
## Yes  No 
## 0.5 0.5
prop.table(table(data_val$Attrition))
## 
##       Yes        No 
## 0.1581197 0.8418803
prop.table(table(data_test$Attrition))
## 
##       Yes        No 
## 0.1604096 0.8395904

We can see we downsample only data train and let real data for data validation and data test. We adjust the data to get a proper structure before we fed them into keras.

train_y <- as.numeric(data_train$Attrition)-1
train_x <- data_train %>% 
  select(-Attrition) %>% 
  data.matrix()

val_y <- as.numeric(data_val$Attrition)-1
val_x <- data_val %>% 
  select(-Attrition) %>% 
  data.matrix()

test_y <- as.numeric(data_test$Attrition)-1
test_x <- data_test %>% 
  select(-Attrition) %>% 
  data.matrix()

Neural Network

Architecture: Multilayer Perceptron

Neural Network is inspired by the biological neural network system of our brain. It consists of input layer, hidden layer, and output layer. The data will be fed into the input layer, processed through the hidden layer, and converted into specific values, such as probability, in the output layer. The MLP has a back-propagation feature, which means that it will go back and forth to adjust the weight of each connection between neurons in order to minimize the loss function and get better performance.

We will build several layers. There are layer dense which will scale our data using relu activation function in the first and second layer. I will put drouput layer to prevent the model from overfitting. For last layer, we scale back our dat int range [0,1] with sigmoid activation function asthe probability of our data belong to a particular class.

input_n <- ncol(train_x)

model <- keras_model_sequential() %>%
  layer_dense(input_shape = input_n,
              units = 32,
              activation = "relu") %>%
  layer_dense(units = 16,
              activation = "relu") %>%
  # layer_dropout(rate = 0.1) %>%
  # layer_batch_normalization() %>%
  layer_dense(units = 1,
              activation = "sigmoid")

model %>%
  compile(optimizer = "adam",
          metric = "accuracy",
          loss = "binary_crossentropy")

model
## Model
## Model: "sequential"
## ________________________________________________________________________________
## Layer (type)                        Output Shape                    Param #     
## ================================================================================
## dense (Dense)                       (None, 32)                      2016        
## ________________________________________________________________________________
## dense_1 (Dense)                     (None, 16)                      528         
## ________________________________________________________________________________
## dense_2 (Dense)                     (None, 1)                       17          
## ================================================================================
## Total params: 2,561
## Trainable params: 2,561
## Non-trainable params: 0
## ________________________________________________________________________________

Model Fitting

set.seed(100)

history <- model %>%
  fit(
    x = train_x,
    y = train_y,
    batch_size = 124,
    epochs = 10,
    seed = 100,
    verbose = 1,
    validation_data = list(
      val_x,
      val_y
    )
  )

plot(history)
## `geom_smooth()` using formula 'y ~ x'

Our Model get 80% accuracy on training dataset and 67% accuracy on validation dataset. We get the difference between it about 13% it still accaptable and can conclude that our model we made before isnt overfit.

Model Evaluation

Performance

pred_test <- as_tibble(predict(model, test_x)) %>%
  set_names("value") %>%
  mutate(class = if_else(value > 0.5, "No", "Yes")) %>%
  mutate(class = factor(class, levels = levels(data_test$Attrition))) %>%
  set_names(paste0("pred_", colnames(.)))
## Warning: `as_tibble.matrix()` requires a matrix with column names or a `.name_repair` argument. Using compatibility `.name_repair`.
## This warning is displayed once per session.
pred_test <- data_test %>%
  select(Attrition) %>%
  bind_cols(pred_test)

summary(pred_test$pred_class)
## Yes  No 
## 128 165
pred_test
## # A tibble: 293 x 3
##    Attrition pred_value pred_class
##    <fct>          <dbl> <fct>     
##  1 No             0.294 Yes       
##  2 No             0.935 No        
##  3 No             0.512 No        
##  4 No             0.677 No        
##  5 Yes            0.552 No        
##  6 Yes            0.897 No        
##  7 No             0.186 Yes       
##  8 No             0.580 No        
##  9 No             0.863 No        
## 10 No             0.540 No        
## # … with 283 more rows

We will check confusion matrix from test dataset.

pred_test %>%
  conf_mat(Attrition, pred_class) %>%
  autoplot(type = "heatmap")

# metrics summary
pred_test %>%
  summarise(
    accuracy = accuracy_vec(Attrition, pred_class),
    sensitivity = sens_vec(Attrition, pred_class),
    specificity = spec_vec(Attrition, pred_class),
    precision = precision_vec(Attrition, pred_class)
  )
## # A tibble: 1 x 4
##   accuracy sensitivity specificity precision
##      <dbl>       <dbl>       <dbl>     <dbl>
## 1    0.512       0.340       0.545     0.125

Roc Curve

pred_test %>%
  roc_curve(Attrition, pred_value) %>%
  autoplot()

pred_test %>% 
  roc_auc(Attrition, pred_value)
## # A tibble: 1 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.545

Sensitivity - Specificty Curve

pred_test_roc <- pred_test %>%
  roc_curve(Attrition, pred_value)

p <- pred_test_roc %>%
  mutate_if(~ is.numeric(.), ~ round(.,4)) %>%
  gather(metric, value, -.threshold) %>%
  ggplot(aes(.threshold, value)) +
  geom_line(aes(colour = metric)) +
  labs(x = "Probability Threshold to be Classified as Positive", y = "Value", colour = "Metrics") +
  theme_minimal()

ggplotly(p)

Precision - Recall Curve

pred_test %>%
  pr_curve(Attrition, pred_value) %>%
  autoplot()

pred_test_pr <- pred_test %>%
  pr_curve(Attrition, pred_value)

p <- pred_test_pr %>%
  mutate_if(~ is.numeric(.), ~ round(.,4)) %>%
  gather(metric, value, -.threshold) %>%
  ggplot(aes(.threshold, value)) +
  geom_line(aes(colour = metric)) +
  labs(x = "Probability Threshold to be Classified as Positive", y = "Value", colour = "Metrics") +
  theme_minimal()

ggplotly(p)

Model Audit

This process target is to understand how our model works to decide classification to our each observant. Why we need it? because Neural networks is “black box” nature meaning these sophisticated models. I will use LIME packages to intepret how it works.

# choose explanation data
data_explain <- testing(initial_split)

get_features <- function(x) {
  matrix <- data.matrix(bake(rec, x, -Attrition))
  matrix
}

lime_model <- as_classifier(model, labels = levels(data_explain$Attrition))

set.seed(100)
explainer <- lime(
  x = data_explain,
  model = lime_model,
  preprocess = get_features
)

# get lime explanation
explanation <- explain(
  x = data_explain[1:4,],
  explainer = explainer,
  n_labels = 1,
  n_features = 4
)

# plot feature explanation

plot_features(explanation) + 
  labs(title = "LIME Feature Importance Visualization")

Plot above i use LIME package to understand and get which feature are importance to our model decide the classification. This packages allows us to visualize each of the first 4 cases (observations) from the test data. The top four features for each case are shown. Note that they are not the same for each case. The blue bars mean that the feature supports the model conclusion, and the red bars contradict. A few important features based on frequency in first 4 cases:

  • Work Life Balance
  • Business Travel Frequency
  • Martial Status

One thing we need to be careful with the LIME visualization is that we are only doing a sample of the data, in our case the first 4 test observations. Therefore, we are gaining a very localized understanding of how our models work. However, we also want to know on from a global perspective what drives feature importance.

We can perform a correlation analysis on the training set as well to help glean what features correlate globally to “Attrition”. We’ll use the corrr package, which performs tidy correlations:

# Feature correlations to Churn
corrr_analysis <- data.frame(train_x) %>%
  mutate(Attrition = train_y) %>%
  correlate() %>%
  focus(Attrition) %>%
  rename(feature = rowname) %>%
  arrange(abs(Attrition)) %>%
  mutate(feature = as_factor(feature)) 
## 
## Correlation method: 'pearson'
## Missing treated using: 'pairwise.complete.obs'
corrr_analysis <- corrr_analysis %>% 
  mutate(absAttrition = abs(Attrition)) %>% 
  arrange(desc(absAttrition)) %>% 
  slice(1:20) %>% 
  select(-absAttrition)

corrr_analysis
## # A tibble: 20 x 2
##    feature                           Attrition
##    <fct>                                 <dbl>
##  1 TotalWorkingYears                     0.304
##  2 StockOptionLevel_X1                   0.279
##  3 YearsInCurrentRole                    0.277
##  4 Age                                   0.263
##  5 MonthlyIncome                         0.254
##  6 MaritalStatus_Single                 -0.252
##  7 YearsAtCompany                        0.251
##  8 OverTime_Yes                         -0.246
##  9 YearsWithCurrManager                  0.239
## 10 JobRole_Manager                       0.180
## 11 JobRole_Research.Director             0.166
## 12 MaritalStatus_Married                 0.165
## 13 JobRole_Sales.Representative         -0.164
## 14 JobLevel_X2                           0.161
## 15 StockOptionLevel_X2                   0.156
## 16 BusinessTravel_Travel_Frequently     -0.143
## 17 Department_Research...Development     0.142
## 18 WorkLifeBalance_X3                    0.138
## 19 JobRole_Laboratory.Technician        -0.136
## 20 JobLevel_X5                           0.132
# Correlation visualization
corrr_analysis %>%
  ggplot(aes(x = Attrition, y = fct_reorder(feature, desc(Attrition)))) +
  geom_point() +
  # Positive Correlations - Contribute to churn
  geom_segment(aes(xend = 0, yend = feature), 
               color = palette_light()[[2]], 
               data = corrr_analysis %>% filter(Attrition > 0)) +
  geom_point(color = palette_light()[[2]], 
             data = corrr_analysis %>% filter(Attrition > 0)) +
  # Negative Correlations - Prevent churn
  geom_segment(aes(xend = 0, yend = feature), 
               color = palette_light()[[1]], 
               data = corrr_analysis %>% filter(Attrition < 0)) +
  geom_point(color = palette_light()[[1]], 
             data = corrr_analysis %>% filter(Attrition < 0)) +
  # Vertical lines
  geom_vline(xintercept = 0, color = palette_light()[[5]], size = 1, linetype = 2) +
  geom_vline(xintercept = -0.25, color = palette_light()[[5]], size = 1, linetype = 2) +
  geom_vline(xintercept = 0.25, color = palette_light()[[5]], size = 1, linetype = 2) +
  # Aesthetics
  theme_tq() +
  labs(title = "Churn Correlation Analysis",
       subtitle = paste("Positive Correlations (contribute to attrition),",
                        "Negative Correlations (prevent attrition)"),
       y = "Feature Importance")

The correlation analysis helps us quickly disseminate which features that the LIME analysis may be excluding. We can see that the following features are highly correlated (magnitude > 0.25):

Increases Likelihood of Attrition (Red): - Total Working Years - Stock Option level = 1 (true) - years in current Role - Age - Monthly Income

Decreases Likelihood of Attrition (Black): - Martial Status = Single - Over Time = Yes

Conclusion

In this case i would use specific question is, How to classify using neural network method and how is result of our model and predict classify and compare with our data test. Beside it we would try to uncover which feature

The result evaluation we use Confussion Matrix as model evaluation and result is :

# metrics summary
pred_test %>%
  summarise(
    accuracy = accuracy_vec(Attrition, pred_class),
    sensitivity = sens_vec(Attrition, pred_class),
    specificity = spec_vec(Attrition, pred_class),
    precision = precision_vec(Attrition, pred_class)
  )
## # A tibble: 1 x 4
##   accuracy sensitivity specificity precision
##      <dbl>       <dbl>       <dbl>     <dbl>
## 1    0.512       0.340       0.545     0.125

We can conclude our neural network working not good in this case, our model only good to predict “No” label which mean we cant anticipate employee will have Attrition. Several reason make the Neural Network cant running well using this data but major reason is Total observent its too small, and the data we get is imbalance.

Therefore we can uncover which feature or variable have important to HR understand why employee get Attrition. Based on Intepret in model Audit we can conclude:

  • Total Working Years
  • Stock Option level
  • years in current Role
  • Age
  • Monthly Income
  • Martial Status
  • Over Time

Is variable or feature important to understand employee attrition