Download Libraries Needed
library(keras)
library(lime)
library(tidyquant)
library(rsample)
library(recipes)
library(yardstick)
library(corrr)
library(dplyr)
library(forcats)
library(ggplot2)
library(jsonlite)
churn_data_raw <- read.csv(file = "C:/Users/santi/Documents/R Kaggle Data/IBM Churn Data.csv", header = TRUE, sep= ",")
Viewing the Data
glimpse(churn_data_raw)
## Observations: 7,043
## Variables: 21
## $ customerID <fct> 7590-VHVEG, 5575-GNVDE, 3668-QPYBK, 7795-CFOCW, 9237…
## $ gender <fct> Female, Male, Male, Male, Female, Female, Male, Fema…
## $ SeniorCitizen <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ Partner <fct> Yes, No, No, No, No, No, No, No, Yes, No, Yes, No, Y…
## $ Dependents <fct> No, No, No, No, No, No, Yes, No, No, Yes, Yes, No, N…
## $ tenure <int> 1, 34, 2, 45, 2, 8, 22, 10, 28, 62, 13, 16, 58, 49, …
## $ PhoneService <fct> No, Yes, Yes, No, Yes, Yes, Yes, No, Yes, Yes, Yes, …
## $ MultipleLines <fct> No phone service, No, No, No phone service, No, Yes,…
## $ InternetService <fct> DSL, DSL, DSL, DSL, Fiber optic, Fiber optic, Fiber …
## $ OnlineSecurity <fct> No, Yes, Yes, Yes, No, No, No, Yes, No, Yes, Yes, No…
## $ OnlineBackup <fct> Yes, No, Yes, No, No, No, Yes, No, No, Yes, No, No i…
## $ DeviceProtection <fct> No, Yes, No, Yes, No, Yes, No, No, Yes, No, No, No i…
## $ TechSupport <fct> No, No, No, Yes, No, No, No, No, Yes, No, No, No int…
## $ StreamingTV <fct> No, No, No, No, No, Yes, Yes, No, Yes, No, No, No in…
## $ StreamingMovies <fct> No, No, No, No, No, Yes, No, No, Yes, No, No, No int…
## $ Contract <fct> Month-to-month, One year, Month-to-month, One year, …
## $ PaperlessBilling <fct> Yes, No, Yes, No, Yes, Yes, Yes, No, Yes, No, Yes, N…
## $ PaymentMethod <fct> Electronic check, Mailed check, Mailed check, Bank t…
## $ MonthlyCharges <dbl> 29.85, 56.95, 53.85, 42.30, 70.70, 99.65, 89.10, 29.…
## $ TotalCharges <dbl> 29.85, 1889.50, 108.15, 1840.75, 151.65, 820.50, 194…
## $ Churn <fct> No, No, Yes, No, Yes, Yes, No, No, Yes, No, No, No, …
Removing Unneccsary Data
churn_data_tbl <- churn_data_raw %>%
select(-customerID) %>%
drop_na() %>%
select(Churn, everything())
glimpse(churn_data_tbl)
## Observations: 7,032
## Variables: 20
## $ Churn <fct> No, No, Yes, No, Yes, Yes, No, No, Yes, No, No, No, …
## $ gender <fct> Female, Male, Male, Male, Female, Female, Male, Fema…
## $ SeniorCitizen <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ Partner <fct> Yes, No, No, No, No, No, No, No, Yes, No, Yes, No, Y…
## $ Dependents <fct> No, No, No, No, No, No, Yes, No, No, Yes, Yes, No, N…
## $ tenure <int> 1, 34, 2, 45, 2, 8, 22, 10, 28, 62, 13, 16, 58, 49, …
## $ PhoneService <fct> No, Yes, Yes, No, Yes, Yes, Yes, No, Yes, Yes, Yes, …
## $ MultipleLines <fct> No phone service, No, No, No phone service, No, Yes,…
## $ InternetService <fct> DSL, DSL, DSL, DSL, Fiber optic, Fiber optic, Fiber …
## $ OnlineSecurity <fct> No, Yes, Yes, Yes, No, No, No, Yes, No, Yes, Yes, No…
## $ OnlineBackup <fct> Yes, No, Yes, No, No, No, Yes, No, No, Yes, No, No i…
## $ DeviceProtection <fct> No, Yes, No, Yes, No, Yes, No, No, Yes, No, No, No i…
## $ TechSupport <fct> No, No, No, Yes, No, No, No, No, Yes, No, No, No int…
## $ StreamingTV <fct> No, No, No, No, No, Yes, Yes, No, Yes, No, No, No in…
## $ StreamingMovies <fct> No, No, No, No, No, Yes, No, No, Yes, No, No, No int…
## $ Contract <fct> Month-to-month, One year, Month-to-month, One year, …
## $ PaperlessBilling <fct> Yes, No, Yes, No, Yes, Yes, Yes, No, Yes, No, Yes, N…
## $ PaymentMethod <fct> Electronic check, Mailed check, Mailed check, Bank t…
## $ MonthlyCharges <dbl> 29.85, 56.95, 53.85, 42.30, 70.70, 99.65, 89.10, 29.…
## $ TotalCharges <dbl> 29.85, 1889.50, 108.15, 1840.75, 151.65, 820.50, 194…
Split Test Data & Training Data
set.seed(100)
train_test_split <- initial_split(churn_data_tbl, prop = 0.8)
train_test_split
## <5626/1406/7032>
train_tbl <- training(train_test_split)
test_tbl <- testing(train_test_split)
Create Recipe
rec_obj <- recipe(Churn ~ ., data = train_tbl) %>%
step_discretize(tenure, options = list(cuts = 6)) %>%
step_log(TotalCharges) %>%
step_dummy(all_nominal(), -all_outcomes()) %>%
step_center(all_predictors(), -all_outcomes()) %>%
step_scale(all_predictors(), -all_outcomes()) %>%
prep(data = train_tbl)
Print Recipe
## Data Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 19
##
## Training data contained 5626 data points and no missing data.
##
## Operations:
##
## Dummy variables from tenure [trained]
## Log transformation on TotalCharges [trained]
## Dummy variables from gender, Partner, Dependents, tenure, ... [trained]
## Centering for SeniorCitizen, MonthlyCharges, ... [trained]
## Scaling for SeniorCitizen, MonthlyCharges, ... [trained]
Response variables for training and testing sets
y_train_vec <- ifelse(pull(train_tbl, Churn) == "Yes", 1, 0)
y_test_vec <- ifelse(pull(test_tbl, Churn) == "Yes", 1, 0)
Predictors
x_train_tbl <- bake(rec_obj, new_data = train_tbl) %>% select(-Churn)
x_test_tbl <- bake(rec_obj, new_data = test_tbl) %>% select(-Churn)
glimpse(x_train_tbl)
## Observations: 5,626
## Variables: 35
## $ SeniorCitizen <dbl> -0.434329, -0.434329, -0.434329…
## $ MonthlyCharges <dbl> -0.25311279, -0.35623207, 0.204…
## $ TotalCharges <dbl> 0.38437751, -1.46064716, -1.242…
## $ gender_Male <dbl> 0.9882489, 0.9882489, -1.011710…
## $ Partner_Yes <dbl> -0.9677264, -0.9677264, -0.9677…
## $ Dependents_Yes <dbl> -0.6507747, -0.6507747, -0.6507…
## $ tenure_bin1 <dbl> -0.4549729, 2.1975426, 2.197542…
## $ tenure_bin2 <dbl> -0.4398089, -0.4398089, -0.4398…
## $ tenure_bin3 <dbl> -0.4489849, -0.4489849, -0.4489…
## $ tenure_bin4 <dbl> 2.2339635, -0.4475553, -0.44755…
## $ tenure_bin5 <dbl> -0.4592348, -0.4592348, -0.4592…
## $ tenure_bin6 <dbl> -0.4323038, -0.4323038, -0.4323…
## $ PhoneService_Yes <dbl> 0.3288092, 0.3288092, 0.3288092…
## $ MultipleLines_No.phone.service <dbl> -0.3288092, -0.3288092, -0.3288…
## $ MultipleLines_Yes <dbl> -0.8521545, -0.8521545, -0.8521…
## $ InternetService_Fiber.optic <dbl> -0.8798103, -0.8798103, 1.13640…
## $ InternetService_No <dbl> -0.5280885, -0.5280885, -0.5280…
## $ OnlineSecurity_No.internet.service <dbl> -0.5280885, -0.5280885, -0.5280…
## $ OnlineSecurity_Yes <dbl> 1.578543, 1.578543, -0.633383, …
## $ OnlineBackup_No.internet.service <dbl> -0.5280885, -0.5280885, -0.5280…
## $ OnlineBackup_Yes <dbl> -0.7237004, 1.3815417, -0.72370…
## $ DeviceProtection_No.internet.service <dbl> -0.5280885, -0.5280885, -0.5280…
## $ DeviceProtection_Yes <dbl> 1.3761164, -0.7265536, -0.72655…
## $ TechSupport_No.internet.service <dbl> -0.5280885, -0.5280885, -0.5280…
## $ TechSupport_Yes <dbl> -0.6309046, -0.6309046, -0.6309…
## $ StreamingTV_No.internet.service <dbl> -0.5280885, -0.5280885, -0.5280…
## $ StreamingTV_Yes <dbl> -0.7881721, -0.7881721, -0.7881…
## $ StreamingMovies_No.internet.service <dbl> -0.5280885, -0.5280885, -0.5280…
## $ StreamingMovies_Yes <dbl> -0.7914356, -0.7914356, -0.7914…
## $ Contract_One.year <dbl> 1.9264561, -0.5189956, -0.51899…
## $ Contract_Two.year <dbl> -0.5577284, -0.5577284, -0.5577…
## $ PaperlessBilling_Yes <dbl> -1.2037542, 0.8305868, 0.830586…
## $ PaymentMethod_Credit.card..automatic. <dbl> -0.5236826, -0.5236826, -0.5236…
## $ PaymentMethod_Electronic.check <dbl> -0.7083639, -0.7083639, 1.41145…
## $ PaymentMethod_Mailed.check <dbl> 1.8359758, 1.8359758, -0.544572…
Building our Artificial Neural Network
model_keras <- keras_model_sequential()
model_keras %>%
# First hidden layer
layer_dense(
units = 16,
kernel_initializer = "uniform",
activation = "relu",
input_shape = ncol(x_train_tbl)) %>%
# Dropout to prevent overfitting
layer_dropout(rate = 0.1) %>%
# Second hidden layer
layer_dense(
units = 16,
kernel_initializer = "uniform",
activation = "relu") %>%
# Dropout to prevent overfitting
layer_dropout(rate = 0.1) %>%
# Output layer
layer_dense(
units = 1,
kernel_initializer = "uniform",
activation = "sigmoid") %>%
# Compile ANN
compile(
optimizer = 'adam',
loss = 'binary_crossentropy',
metrics = c('accuracy')
)
keras_model
## function (inputs, outputs = NULL)
## {
## keras$models$Model(inputs = unname(inputs), outputs = unname(outputs))
## }
## <bytecode: 0x00000000259065d0>
## <environment: namespace:keras>
Fit the keras model to the training data
history <- fit(
object = model_keras,
x = as.matrix(x_train_tbl),
y = y_train_vec,
batch_size = 50,
epochs = 35,
validation_split = 0.30
)
Print a summary of the training history
print(history)
## Trained on 3,938 samples, validated on 1,688 samples (batch_size=50, epochs=35)
## Final epoch (plot to see history):
## val_loss: 0.434
## val_acc: 0.8009
## loss: 0.3883
## acc: 0.8189
Plot the training/validation history of our Keras model
plot(history)

Predicted Class
yhat_keras_class_vec <- predict_classes(object = model_keras, x = as.matrix(x_test_tbl)) %>%
as.vector()
Predicted Class Probability
yhat_keras_prob_vec <- predict_proba(object = model_keras, x = as.matrix(x_test_tbl)) %>%
as.vector()
Confusion Table
estimates_keras_tbl %>% conf_mat(truth, estimate)
## Truth
## Prediction no yes
## no 905 188
## yes 110 203
Accuracy
estimates_keras_tbl %>% metrics(truth, estimate)
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.788
## 2 kap binary 0.438
AUC
estimates_keras_tbl %>% roc_auc(truth, class_prob)
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 roc_auc binary 0.837
Precision
tibble(
precision = estimates_keras_tbl %>% precision(truth, estimate),
recall = estimates_keras_tbl %>% recall(truth, estimate)
)
## # A tibble: 1 x 2
## precision$.metric $.estimator $.estimate recall$.metric $.estimator $.estimate
## <chr> <chr> <dbl> <chr> <chr> <dbl>
## 1 precision binary 0.649 recall binary 0.519
Setup lime::model_type() function for keras
model_type.keras.engine.sequential.Sequential <- function(x, ...) {
return("classification")
}
Setup lime::predict_model() function for keras
predict_model.keras.engine.sequential.Sequential <- function(x, newdata, type, ...) {
pred <- predict_proba(object = x, x = as.matrix(newdata))
return(data.frame(Yes = pred, No = 1 - pred))
}
Run lime() on training set
explainer <- lime::lime(
x = x_train_tbl,
model = model_keras,
bin_continuous = FALSE
)
Run explain() on model
explanation <- lime::explain(
x_test_tbl[1:10, ],
explainer = explainer,
n_labels = 1,
n_features = 4,
kernel_width = 0.5
)
plot_features(explanation) +
labs(title = "LIME Feature Importance Visualization",
subtitle = "Hold Out (Test) Set, First 10 Cases Shown")

plot_explanations(explanation) +
labs(title = "LIME Feature Importance Heatmap",
subtitle = "Hold Out (Test) Set, First 10 Cases Shown")

Correlation visualization
corrr_analysis <- x_train_tbl %>%
mutate(Churn = y_train_vec) %>%
correlate() %>%
focus(Churn) %>%
rename(feature = rowname) %>%
arrange(abs(Churn)) %>%
mutate(feature = as_factor(feature))
##
## Correlation method: 'pearson'
## Missing treated using: 'pairwise.complete.obs'
corrr_analysis
## # A tibble: 35 x 2
## feature Churn
## <fct> <dbl>
## 1 gender_Male -0.00539
## 2 MultipleLines_No.phone.service -0.0139
## 3 PhoneService_Yes 0.0139
## 4 tenure_bin3 -0.0162
## 5 MultipleLines_Yes 0.0394
## 6 StreamingTV_Yes 0.0603
## 7 StreamingMovies_Yes 0.0603
## 8 tenure_bin4 -0.0701
## 9 DeviceProtection_Yes -0.0702
## 10 OnlineBackup_Yes -0.0791
## # … with 25 more rows
# Correlation visualization
corrr_analysis %>%
ggplot(aes(x = Churn, y = fct_reorder(feature, desc(Churn)))) +
geom_point() +
# Positive Correlations - Contribute to churn
geom_segment(aes(xend = 0, yend = feature),
color = palette_light()[[2]],
data = corrr_analysis %>% filter(Churn > 0)) +
geom_point(color = palette_light()[[2]],
data = corrr_analysis %>% filter(Churn > 0)) +
# Negative Correlations - Prevent churn
geom_segment(aes(xend = 0, yend = feature),
color = palette_light()[[1]],
data = corrr_analysis %>% filter(Churn < 0)) +
geom_point(color = palette_light()[[1]],
data = corrr_analysis %>% filter(Churn < 0)) +
# Vertical lines
geom_vline(xintercept = 0, color = palette_light()[[5]], size = 1, linetype = 2) +
geom_vline(xintercept = -0.25, color = palette_light()[[5]], size = 1, linetype = 2) +
geom_vline(xintercept = 0.25, color = palette_light()[[5]], size = 1, linetype = 2) +
# Aesthetics
theme_tq() +
labs(title = "Churn Correlation Analysis",
subtitle = paste("Positive Correlations Contribute to Churn",
"Negative Correlations Prevent Churn"),
y = "Feature Importance")
