Download Libraries Needed

library(keras)
library(lime)
library(tidyquant)
library(rsample)
library(recipes)
library(yardstick)
library(corrr)
library(dplyr)
library(forcats)
library(ggplot2)
library(jsonlite)
churn_data_raw <- read.csv(file = "C:/Users/santi/Documents/R Kaggle Data/IBM Churn Data.csv", header = TRUE, sep=  ",")

Viewing the Data

glimpse(churn_data_raw)
## Observations: 7,043
## Variables: 21
## $ customerID       <fct> 7590-VHVEG, 5575-GNVDE, 3668-QPYBK, 7795-CFOCW, 9237…
## $ gender           <fct> Female, Male, Male, Male, Female, Female, Male, Fema…
## $ SeniorCitizen    <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ Partner          <fct> Yes, No, No, No, No, No, No, No, Yes, No, Yes, No, Y…
## $ Dependents       <fct> No, No, No, No, No, No, Yes, No, No, Yes, Yes, No, N…
## $ tenure           <int> 1, 34, 2, 45, 2, 8, 22, 10, 28, 62, 13, 16, 58, 49, …
## $ PhoneService     <fct> No, Yes, Yes, No, Yes, Yes, Yes, No, Yes, Yes, Yes, …
## $ MultipleLines    <fct> No phone service, No, No, No phone service, No, Yes,…
## $ InternetService  <fct> DSL, DSL, DSL, DSL, Fiber optic, Fiber optic, Fiber …
## $ OnlineSecurity   <fct> No, Yes, Yes, Yes, No, No, No, Yes, No, Yes, Yes, No…
## $ OnlineBackup     <fct> Yes, No, Yes, No, No, No, Yes, No, No, Yes, No, No i…
## $ DeviceProtection <fct> No, Yes, No, Yes, No, Yes, No, No, Yes, No, No, No i…
## $ TechSupport      <fct> No, No, No, Yes, No, No, No, No, Yes, No, No, No int…
## $ StreamingTV      <fct> No, No, No, No, No, Yes, Yes, No, Yes, No, No, No in…
## $ StreamingMovies  <fct> No, No, No, No, No, Yes, No, No, Yes, No, No, No int…
## $ Contract         <fct> Month-to-month, One year, Month-to-month, One year, …
## $ PaperlessBilling <fct> Yes, No, Yes, No, Yes, Yes, Yes, No, Yes, No, Yes, N…
## $ PaymentMethod    <fct> Electronic check, Mailed check, Mailed check, Bank t…
## $ MonthlyCharges   <dbl> 29.85, 56.95, 53.85, 42.30, 70.70, 99.65, 89.10, 29.…
## $ TotalCharges     <dbl> 29.85, 1889.50, 108.15, 1840.75, 151.65, 820.50, 194…
## $ Churn            <fct> No, No, Yes, No, Yes, Yes, No, No, Yes, No, No, No, …

Removing Unneccsary Data

churn_data_tbl <- churn_data_raw %>%
  select(-customerID) %>%
  drop_na() %>%
  select(Churn, everything())

glimpse(churn_data_tbl)
## Observations: 7,032
## Variables: 20
## $ Churn            <fct> No, No, Yes, No, Yes, Yes, No, No, Yes, No, No, No, …
## $ gender           <fct> Female, Male, Male, Male, Female, Female, Male, Fema…
## $ SeniorCitizen    <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ Partner          <fct> Yes, No, No, No, No, No, No, No, Yes, No, Yes, No, Y…
## $ Dependents       <fct> No, No, No, No, No, No, Yes, No, No, Yes, Yes, No, N…
## $ tenure           <int> 1, 34, 2, 45, 2, 8, 22, 10, 28, 62, 13, 16, 58, 49, …
## $ PhoneService     <fct> No, Yes, Yes, No, Yes, Yes, Yes, No, Yes, Yes, Yes, …
## $ MultipleLines    <fct> No phone service, No, No, No phone service, No, Yes,…
## $ InternetService  <fct> DSL, DSL, DSL, DSL, Fiber optic, Fiber optic, Fiber …
## $ OnlineSecurity   <fct> No, Yes, Yes, Yes, No, No, No, Yes, No, Yes, Yes, No…
## $ OnlineBackup     <fct> Yes, No, Yes, No, No, No, Yes, No, No, Yes, No, No i…
## $ DeviceProtection <fct> No, Yes, No, Yes, No, Yes, No, No, Yes, No, No, No i…
## $ TechSupport      <fct> No, No, No, Yes, No, No, No, No, Yes, No, No, No int…
## $ StreamingTV      <fct> No, No, No, No, No, Yes, Yes, No, Yes, No, No, No in…
## $ StreamingMovies  <fct> No, No, No, No, No, Yes, No, No, Yes, No, No, No int…
## $ Contract         <fct> Month-to-month, One year, Month-to-month, One year, …
## $ PaperlessBilling <fct> Yes, No, Yes, No, Yes, Yes, Yes, No, Yes, No, Yes, N…
## $ PaymentMethod    <fct> Electronic check, Mailed check, Mailed check, Bank t…
## $ MonthlyCharges   <dbl> 29.85, 56.95, 53.85, 42.30, 70.70, 99.65, 89.10, 29.…
## $ TotalCharges     <dbl> 29.85, 1889.50, 108.15, 1840.75, 151.65, 820.50, 194…

Split Test Data & Training Data

set.seed(100)
train_test_split <- initial_split(churn_data_tbl, prop = 0.8)
train_test_split
## <5626/1406/7032>
train_tbl <- training(train_test_split)
test_tbl  <- testing(train_test_split)

Determine if log transformation improves correlation

train_tbl %>%
  select(Churn, TotalCharges) %>%
  mutate(
    Churn = Churn %>% as.factor() %>% as.numeric(),
    LogTotalCharges = log(TotalCharges)
  ) %>%
  correlate() %>%
  focus(Churn) %>%
  fashion()
## 
## Correlation method: 'pearson'
## Missing treated using: 'pairwise.complete.obs'
##           rowname Churn
## 1    TotalCharges  -.20
## 2 LogTotalCharges  -.25

Create Recipe

rec_obj <- recipe(Churn ~ ., data = train_tbl) %>%
  step_discretize(tenure, options = list(cuts = 6)) %>%
  step_log(TotalCharges) %>%
  step_dummy(all_nominal(), -all_outcomes()) %>%
  step_center(all_predictors(), -all_outcomes()) %>%
  step_scale(all_predictors(), -all_outcomes()) %>%
  prep(data = train_tbl)

Response variables for training and testing sets

y_train_vec <- ifelse(pull(train_tbl, Churn) == "Yes", 1, 0)
y_test_vec  <- ifelse(pull(test_tbl, Churn) == "Yes", 1, 0)

Predictors

x_train_tbl <- bake(rec_obj, new_data = train_tbl) %>% select(-Churn)
x_test_tbl  <- bake(rec_obj, new_data = test_tbl) %>% select(-Churn)

glimpse(x_train_tbl)
## Observations: 5,626
## Variables: 35
## $ SeniorCitizen                         <dbl> -0.434329, -0.434329, -0.434329…
## $ MonthlyCharges                        <dbl> -0.25311279, -0.35623207, 0.204…
## $ TotalCharges                          <dbl> 0.38437751, -1.46064716, -1.242…
## $ gender_Male                           <dbl> 0.9882489, 0.9882489, -1.011710…
## $ Partner_Yes                           <dbl> -0.9677264, -0.9677264, -0.9677…
## $ Dependents_Yes                        <dbl> -0.6507747, -0.6507747, -0.6507…
## $ tenure_bin1                           <dbl> -0.4549729, 2.1975426, 2.197542…
## $ tenure_bin2                           <dbl> -0.4398089, -0.4398089, -0.4398…
## $ tenure_bin3                           <dbl> -0.4489849, -0.4489849, -0.4489…
## $ tenure_bin4                           <dbl> 2.2339635, -0.4475553, -0.44755…
## $ tenure_bin5                           <dbl> -0.4592348, -0.4592348, -0.4592…
## $ tenure_bin6                           <dbl> -0.4323038, -0.4323038, -0.4323…
## $ PhoneService_Yes                      <dbl> 0.3288092, 0.3288092, 0.3288092…
## $ MultipleLines_No.phone.service        <dbl> -0.3288092, -0.3288092, -0.3288…
## $ MultipleLines_Yes                     <dbl> -0.8521545, -0.8521545, -0.8521…
## $ InternetService_Fiber.optic           <dbl> -0.8798103, -0.8798103, 1.13640…
## $ InternetService_No                    <dbl> -0.5280885, -0.5280885, -0.5280…
## $ OnlineSecurity_No.internet.service    <dbl> -0.5280885, -0.5280885, -0.5280…
## $ OnlineSecurity_Yes                    <dbl> 1.578543, 1.578543, -0.633383, …
## $ OnlineBackup_No.internet.service      <dbl> -0.5280885, -0.5280885, -0.5280…
## $ OnlineBackup_Yes                      <dbl> -0.7237004, 1.3815417, -0.72370…
## $ DeviceProtection_No.internet.service  <dbl> -0.5280885, -0.5280885, -0.5280…
## $ DeviceProtection_Yes                  <dbl> 1.3761164, -0.7265536, -0.72655…
## $ TechSupport_No.internet.service       <dbl> -0.5280885, -0.5280885, -0.5280…
## $ TechSupport_Yes                       <dbl> -0.6309046, -0.6309046, -0.6309…
## $ StreamingTV_No.internet.service       <dbl> -0.5280885, -0.5280885, -0.5280…
## $ StreamingTV_Yes                       <dbl> -0.7881721, -0.7881721, -0.7881…
## $ StreamingMovies_No.internet.service   <dbl> -0.5280885, -0.5280885, -0.5280…
## $ StreamingMovies_Yes                   <dbl> -0.7914356, -0.7914356, -0.7914…
## $ Contract_One.year                     <dbl> 1.9264561, -0.5189956, -0.51899…
## $ Contract_Two.year                     <dbl> -0.5577284, -0.5577284, -0.5577…
## $ PaperlessBilling_Yes                  <dbl> -1.2037542, 0.8305868, 0.830586…
## $ PaymentMethod_Credit.card..automatic. <dbl> -0.5236826, -0.5236826, -0.5236…
## $ PaymentMethod_Electronic.check        <dbl> -0.7083639, -0.7083639, 1.41145…
## $ PaymentMethod_Mailed.check            <dbl> 1.8359758, 1.8359758, -0.544572…

Building our Artificial Neural Network

model_keras <- keras_model_sequential()
model_keras %>% 

# First hidden layer
layer_dense(
  units              = 16, 
  kernel_initializer = "uniform", 
  activation         = "relu", 
  input_shape        = ncol(x_train_tbl)) %>% 
  
  # Dropout to prevent overfitting
  layer_dropout(rate = 0.1) %>%
  
  # Second hidden layer
  layer_dense(
    units              = 16, 
    kernel_initializer = "uniform", 
    activation         = "relu") %>% 
  
  # Dropout to prevent overfitting
  layer_dropout(rate = 0.1) %>%
  
  # Output layer
  layer_dense(
    units              = 1, 
    kernel_initializer = "uniform", 
    activation         = "sigmoid") %>% 
  
  # Compile ANN
  compile(
    optimizer = 'adam',
    loss      = 'binary_crossentropy',
    metrics   = c('accuracy')
  )

keras_model
## function (inputs, outputs = NULL) 
## {
##     keras$models$Model(inputs = unname(inputs), outputs = unname(outputs))
## }
## <bytecode: 0x00000000259065d0>
## <environment: namespace:keras>

Fit the keras model to the training data

history <- fit(
  object           = model_keras, 
  x                = as.matrix(x_train_tbl), 
  y                = y_train_vec,
  batch_size       = 50, 
  epochs           = 35,
  validation_split = 0.30
)

Plot the training/validation history of our Keras model

plot(history)

Predicted Class

yhat_keras_class_vec <- predict_classes(object = model_keras, x = as.matrix(x_test_tbl)) %>%
       as.vector()

Predicted Class Probability

yhat_keras_prob_vec  <- predict_proba(object = model_keras, x = as.matrix(x_test_tbl)) %>%
  as.vector() 

Format test data and predictions for yardstick metrics

estimates_keras_tbl <- tibble(
  truth      = as.factor(y_test_vec) %>% fct_recode(yes = "1", no = "0"),
  estimate   = as.factor(yhat_keras_class_vec) %>% fct_recode(yes = "1", no = "0"),
  class_prob = yhat_keras_prob_vec
)

estimates_keras_tbl
## # A tibble: 1,406 x 3
##    truth estimate class_prob
##    <fct> <fct>         <dbl>
##  1 no    yes          0.692 
##  2 no    no           0.0444
##  3 yes   yes          0.707 
##  4 no    no           0.0266
##  5 no    no           0.0507
##  6 yes   no           0.264 
##  7 no    no           0.0871
##  8 yes   no           0.257 
##  9 yes   yes          0.707 
## 10 no    no           0.299 
## # … with 1,396 more rows
options(yardstick.event_first = FALSE)

Confusion Table

estimates_keras_tbl %>% conf_mat(truth, estimate)
##           Truth
## Prediction  no yes
##        no  905 188
##        yes 110 203

Accuracy

estimates_keras_tbl %>% metrics(truth, estimate)
## # A tibble: 2 x 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.788
## 2 kap      binary         0.438

AUC

estimates_keras_tbl %>% roc_auc(truth, class_prob)
## # A tibble: 1 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.837

Precision

tibble(
  precision = estimates_keras_tbl %>% precision(truth, estimate),
  recall    = estimates_keras_tbl %>% recall(truth, estimate)
)
## # A tibble: 1 x 2
##   precision$.metric $.estimator $.estimate recall$.metric $.estimator $.estimate
##   <chr>             <chr>            <dbl> <chr>          <chr>            <dbl>
## 1 precision         binary           0.649 recall         binary           0.519

Setup lime::model_type() function for keras

model_type.keras.engine.sequential.Sequential <- function(x, ...) {
  return("classification")
}

Setup lime::predict_model() function for keras

predict_model.keras.engine.sequential.Sequential <- function(x, newdata, type, ...) {
  pred <- predict_proba(object = x, x = as.matrix(newdata))
  return(data.frame(Yes = pred, No = 1 - pred))
}

Run lime() on training set

explainer <- lime::lime(
  x              = x_train_tbl,
  model          = model_keras,
  bin_continuous = FALSE
)

Run explain() on model

explanation <- lime::explain(
   x_test_tbl[1:10, ], 
  explainer    = explainer, 
  n_labels     = 1, 
  n_features   = 4,
  kernel_width = 0.5
)



plot_features(explanation) +
  labs(title = "LIME Feature Importance Visualization",
       subtitle = "Hold Out (Test) Set, First 10 Cases Shown")

plot_explanations(explanation) +
  labs(title = "LIME Feature Importance Heatmap",
       subtitle = "Hold Out (Test) Set, First 10 Cases Shown")

Correlation visualization

corrr_analysis <- x_train_tbl %>%
  mutate(Churn = y_train_vec) %>%
  correlate() %>%
  focus(Churn) %>%
  rename(feature = rowname) %>%
  arrange(abs(Churn)) %>%
  mutate(feature = as_factor(feature))
## 
## Correlation method: 'pearson'
## Missing treated using: 'pairwise.complete.obs'
corrr_analysis
## # A tibble: 35 x 2
##    feature                           Churn
##    <fct>                             <dbl>
##  1 gender_Male                    -0.00539
##  2 MultipleLines_No.phone.service -0.0139 
##  3 PhoneService_Yes                0.0139 
##  4 tenure_bin3                    -0.0162 
##  5 MultipleLines_Yes               0.0394 
##  6 StreamingTV_Yes                 0.0603 
##  7 StreamingMovies_Yes             0.0603 
##  8 tenure_bin4                    -0.0701 
##  9 DeviceProtection_Yes           -0.0702 
## 10 OnlineBackup_Yes               -0.0791 
## # … with 25 more rows
# Correlation visualization
corrr_analysis %>%
  ggplot(aes(x = Churn, y = fct_reorder(feature, desc(Churn)))) +
  geom_point() +
  # Positive Correlations - Contribute to churn
  geom_segment(aes(xend = 0, yend = feature),
               color = palette_light()[[2]],
               data = corrr_analysis %>% filter(Churn > 0)) +
  geom_point(color = palette_light()[[2]],
             data = corrr_analysis %>% filter(Churn > 0)) +
  # Negative Correlations - Prevent churn
  geom_segment(aes(xend = 0, yend = feature),
               color = palette_light()[[1]],
               data = corrr_analysis %>% filter(Churn < 0)) +
  geom_point(color = palette_light()[[1]],
             data = corrr_analysis %>% filter(Churn < 0)) +
  # Vertical lines
  geom_vline(xintercept = 0, color = palette_light()[[5]], size = 1, linetype = 2) +
  geom_vline(xintercept = -0.25, color = palette_light()[[5]], size = 1, linetype = 2) +
  geom_vline(xintercept = 0.25, color = palette_light()[[5]], size = 1, linetype = 2) +
  # Aesthetics
  theme_tq() +
  labs(title = "Churn Correlation Analysis",
       subtitle = paste("Positive Correlations Contribute to Churn",
                        "Negative Correlations Prevent Churn"),
       y = "Feature Importance")