Introduction

Task
The task of customer churn modeling is to identify customers with high risk of leaving. These predictions can then guide targeted retention program to minimize customer churn.

What Can It Achieve
A targeted retention program guided by such a model can be much more efficient than a non-targeted program, because it allows us to identify and focus on customers who are at risk of leaving. For example, in this demo, among top 20 customers predicted to churn, all indeed churned; among top 10% customers with the highest risk to churn, 80% indeed churned. A company can cover about 60% of leaving customers by reaching out to 25% of customers, which makes it more than twice as effective as a non-targeted retention program.

Overview
Please see as follows the complete workflow of a customer churn model, from Data Cleaning all the way to Model Evaluation. If however you would like to check out its implications for the business directly, please jump to Business Impact Analysis.

Demo Data

Techniques Used

Main techniques used in this demo are as follows

  • Data preprocessing: Tidymodels
  • Model training: H2o.ai, Ensemble methods
  • Model explanation with LIME (local interpretable model-agnostic explanations)
  • Create custom functions for clean code and easy maintenance & updates

Data Cleaning

In this section, we look at the whole dataset for any serious problems before splitting and tucking the test set away for model checking.

Data Fields

There are 21 columns in the dataset.

  • ID field
    • CustomerID
  • Target field
    • Churn
  • Demographic info
    • gender
    • SeniorCitizen
    • Partner
    • Dependents
  • Service used
    • PhoneService
    • MultipleLines
    • InternetService
    • OnlineSecurity
    • OnlienBackup
    • DeviceProtection
    • TechSupport
    • StreamingTV
    • StreamingMovies
  • Customer account info
    • tenure
    • Contract
    • PaperlessBilling
    • PaymentMethod
    • MonthlyCharges
    • TotalCharges
# read data
raw_tbl <- read_csv(here::here("00_DemoData","Telco_Customer_Churn.csv"))

raw_tbl |> 
  colnames()
##  [1] "customerID"       "gender"           "SeniorCitizen"    "Partner"         
##  [5] "Dependents"       "tenure"           "PhoneService"     "MultipleLines"   
##  [9] "InternetService"  "OnlineSecurity"   "OnlineBackup"     "DeviceProtection"
## [13] "TechSupport"      "StreamingTV"      "StreamingMovies"  "Contract"        
## [17] "PaperlessBilling" "PaymentMethod"    "MonthlyCharges"   "TotalCharges"    
## [21] "Churn"

A First Glance

  • Missing
    • Most data is complete
    • The only feature with missing is TotalCharges, with 0.2% missing (i.e., 99.8% complete)
  • Character variables
    • 17 variables were identified as characters
    • CustomerID should be kept as character and won’t participate in analysis
    • The rest 16 variables should be changed to factors, each has 2-4 levels
  • Numeric variables
    • 4 variables were identified as numeric
    • SeniorCitizen should be changed to factor (this matters because we will standardize all numerical variables later)
raw_tbl |> 
  visdat::vis_dat()

miss_var_summary(raw_tbl) |> 
  kbl() |> 
  kable_material_dark()
variable n_miss pct_miss
TotalCharges 11 0.1561834
customerID 0 0.0000000
gender 0 0.0000000
SeniorCitizen 0 0.0000000
Partner 0 0.0000000
Dependents 0 0.0000000
tenure 0 0.0000000
PhoneService 0 0.0000000
MultipleLines 0 0.0000000
InternetService 0 0.0000000
OnlineSecurity 0 0.0000000
OnlineBackup 0 0.0000000
DeviceProtection 0 0.0000000
TechSupport 0 0.0000000
StreamingTV 0 0.0000000
StreamingMovies 0 0.0000000
Contract 0 0.0000000
PaperlessBilling 0 0.0000000
PaymentMethod 0 0.0000000
MonthlyCharges 0 0.0000000
Churn 0 0.0000000

Let’s make the changes suggest above and mark the dataset with these minor changes as a silver_tbl.

# get the bronze, silver and gold idea from Databrick's naming convention. :)

silver_tbl <- raw_tbl |> 
  # recode SeniorCitizen, this will change it to character type
  mutate(SeniorCitizen = recode(SeniorCitizen,
                                "0"="No","1"="Yes")) |> 
  # then change all character variables to factor
  mutate_if(is.character, as.factor)  

Histograms

# use custom functions for concise code and easy maintenance/update
source(here::here("00_Functions","plot_hist_facet.R"))

silver_tbl |> 
  select(-customerID) |> 
  plot_hist_facet()

Numeric variables

We can see from the above histograms that the numeric variables need to be transformed (for skewness) and standardized (for consistent scale). These will be done in the next section Data Preprocessing.

Factors

For factor variables, we need to check

    1. What unique levels/categories do each have,
    1. what proportions of observations in each category.
    1. Are there any issues we should pay attention to (e.g., merge categories, missings)
    1. Is the target class imbalanced?

According to the results, our factor variables look OK. We don’t need to change or merge any categories. Two things to note for considerations during training.

    1. For service variables such as OnlineBackup, in addition to “Yes” and “No”, there is an additional category “No internet service”. This is because some customers that do not use Internet Service naturally won’t use Online Backup. Such additional categories may offer useful information, so we will keep them as they are. However, we might also consider merging them with “No” to improve results.
    1. For the target variable Churn, 27% of the observations fall into the minority category (“Yes”), which is considered a* mild imbalance. We can take this into account in our model training and check whether balancing the classes leads to better results (more on imbalanced classification)
# for technical audience only, run and check the results in console

silver_tbl |> 
  select_if(is.factor) |> 
  summary() 

Data Preprocessing

Initial Split

First of all, let’s split the data and keep the test dataset safely away for a clean test later after training.

initial_split function (TidyModels) allows us to split the data using Stratified Sampling, ensure that the splitted datasets are comparable on the stratified variable. Here, after splitting, the training and test datasets should have equal proportion of churn.

# Split the data into training and test

set.seed(1357)
# set churn as strata in stratified samples
silver_tbl_split <- initial_split(data = silver_tbl, prop = 0.75, strata = Churn)
train_silver_tbl <- training(silver_tbl_split)
test_silver_tbl <- testing(silver_tbl_split)
data size
original 7043
train 5281
test 1762
# Double-check that the stratified sampling worked, that the proportion of churn are consistent across the three datasets
get_churn_pct(silver_tbl) # original
## 
##        No       Yes 
## 0.7346301 0.2653699
get_churn_pct(train_silver_tbl) # train
## 
##        No       Yes 
## 0.7347093 0.2652907
get_churn_pct(test_silver_tbl) # test
## 
##        No       Yes 
## 0.7343927 0.2656073

Preprocessing Steps

Set up recipe

Initiate the recipe, set the roles and the dataset that will be used to train the recipe.

gold_recipe <- recipe(Churn ~ ., data = train_silver_tbl) |> # set Target and Features
  # set the ID field
  update_role(customerID, new_role = "ID") |> 
  # exclude features with no variance at all
  step_zv(all_predictors())

gold_recipe
## Recipe
## 
## Inputs:
## 
##       role #variables
##         ID          1
##    outcome          1
##  predictor         19
## 
## Operations:
## 
## Zero variance filter on all_predictors()

Skewness

Identify variables with serious skewness and transform them accordingly (more on the Yeo-Johnson transformation)

source(here::here("00_Functions","plot_hist_facet.R"))

# get histograms for all numeric variables
train_silver_tbl |> 
  select_if(is.numeric) |>
  plot_hist_facet(ncol=2)

# identify variables with high skewness
skewed_feature_names <- train_silver_tbl |> 
  select_if(is.numeric) |> 
  map_df(skewness) |> 
  pivot_longer(cols=everything(), values_to="skewness") |> 
  filter(abs(skewness) > 0.8) |> # handle both positive & negative skewness
  pull(name) |>  # pull the names to form a vector for later use
  as.character()

print(glue::glue("variables with abs(skewness) > 0.8 identified and transformed: {skewed_feature_names}")) 
## variables with abs(skewness) > 0.8 identified and transformed: TotalCharges
# transform to adjust skewness in the recipe
gold_recipe <- gold_recipe |> 
  step_YeoJohnson(skewed_feature_names)

# bake the recipe on training set to check for performance
gold_recipe |> 
  prep() |> 
  bake(train_silver_tbl) |> 
  select(skewed_feature_names) |> 
  plot_hist_facet()

Standardize: Center and Scale

The numeric variables in the data has different ranges and are standardized.

train_silver_tbl |> 
  select_if(is.numeric) |> 
  summary()
##      tenure      MonthlyCharges    TotalCharges   
##  Min.   : 0.00   Min.   : 18.25   Min.   :  18.8  
##  1st Qu.: 9.00   1st Qu.: 36.85   1st Qu.: 411.2  
##  Median :29.00   Median : 70.50   Median :1424.5  
##  Mean   :32.44   Mean   : 65.12   Mean   :2295.5  
##  3rd Qu.:56.00   3rd Qu.: 89.95   3rd Qu.:3783.6  
##  Max.   :72.00   Max.   :118.75   Max.   :8672.5  
##                                   NA's   :7
# add standardization steps
gold_recipe <- gold_recipe |> 
  step_center(all_numeric(),-all_outcomes()) |> 
  step_scale(all_numeric(), -all_outcomes())

# check performance on the training data
gold_recipe |> 
  prep() |> 
  bake(train_silver_tbl) |> 
  select_if(is.numeric) |> 
  summary()
##      tenure        MonthlyCharges     TotalCharges     
##  Min.   :-1.3247   Min.   :-1.5657   Min.   :-1.94440  
##  1st Qu.:-0.9572   1st Qu.:-0.9444   1st Qu.:-0.78361  
##  Median :-0.1404   Median : 0.1797   Median : 0.02132  
##  Mean   : 0.0000   Mean   : 0.0000   Mean   : 0.00000  
##  3rd Qu.: 0.9621   3rd Qu.: 0.8294   3rd Qu.: 0.86227  
##  Max.   : 1.6155   Max.   : 1.7914   Max.   : 1.76242  
##                                      NA's   :7

Missing & Imputation

We only have one variable with minor missing. If there are more serious concerns over missing, conduct more thorough investigation with naniar (more on k-nearest neighbor imputation).

# identify variables with missing (only one)
missing_feature_names <- miss_var_summary(train_silver_tbl) |>
  filter(n_miss > 0) |> 
  pull(variable) |> 
  as.character()

missing_feature_names
## [1] "TotalCharges"
# check the summary statistics of these variables before applying imputation
gold_recipe |> 
  prep() |> 
  bake(train_silver_tbl) |> 
  select(missing_feature_names) |> 
  summary()
##   TotalCharges     
##  Min.   :-1.94440  
##  1st Qu.:-0.78361  
##  Median : 0.02132  
##  Mean   : 0.00000  
##  3rd Qu.: 0.86227  
##  Max.   : 1.76242  
##  NA's   :7
# apply imputations to the recipe
gold_recipe <- gold_recipe |> 
  step_impute_knn(all_predictors(),neighbors = 5)

# check the performance of imputations
gold_recipe |> 
  prep() |> 
  bake(train_silver_tbl) |> 
  select(missing_feature_names) |> 
  summary()
##   TotalCharges      
##  Min.   :-1.944398  
##  1st Qu.:-0.783515  
##  Median : 0.021291  
##  Mean   :-0.000201  
##  3rd Qu.: 0.861957  
##  Max.   : 1.762422

Dummy

To analyze factor variables in the model, they need to be transformed into dummy variables with (0=No, 1=Yes) values.

gold_recipe <- gold_recipe |> 
  step_dummy(all_nominal_predictors()) # careful, using all_nominal() will dummy CustomerID

gold_recipe |> 
  prep() |> 
  bake(train_silver_tbl) |> 
  head(10) |> 
  kbl() |> 
  kable_classic_2(full_width=F)
customerID tenure MonthlyCharges TotalCharges Churn gender_Male SeniorCitizen_Yes Partner_Yes Dependents_Yes PhoneService_Yes MultipleLines_No.phone.service MultipleLines_Yes InternetService_Fiber.optic InternetService_No OnlineSecurity_No.internet.service OnlineSecurity_Yes OnlineBackup_No.internet.service OnlineBackup_Yes DeviceProtection_No.internet.service DeviceProtection_Yes TechSupport_No.internet.service TechSupport_Yes StreamingTV_No.internet.service StreamingTV_Yes StreamingMovies_No.internet.service StreamingMovies_Yes Contract_One.year Contract_Two.year PaperlessBilling_Yes PaymentMethod_Credit.card..automatic. PaymentMethod_Electronic.check PaymentMethod_Mailed.check
5575-GNVDE 0.0637314 -0.2729744 0.2432016 No 1 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 1
7795-CFOCW 0.5129207 -0.7623491 0.2219888 No 1 0 0 0 0 1 0 0 0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0 0 0 0
1452-KIOVK -0.4262934 0.8009777 0.2687167 No 1 0 0 1 1 0 1 1 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 1 1 0 0
6713-OKOMC -0.9163181 -1.1815746 -0.9467320 No 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1
6388-TABGU 1.2071224 -0.2996979 0.7838781 No 1 0 0 1 1 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0
9763-GRSKD -0.7938119 -0.5068053 -0.5782664 No 1 0 1 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1
7469-LKBCI -0.6713057 -1.5423423 -0.9061227 No 1 0 0 0 1 0 0 0 1 1 0 1 0 1 0 1 0 1 0 1 0 0 1 0 1 0 0
8091-TTVAX 1.0437808 1.1767774 1.2794175 No 1 0 1 0 1 0 1 1 0 0 0 0 0 0 1 0 0 0 1 0 1 1 0 0 1 0 0
3655-SNQYZ 1.4929702 1.6076944 1.6506188 No 0 0 1 1 1 0 1 1 0 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 0
8191-XWSZG 0.7987685 -1.4855548 -0.2189716 No 0 0 0 0 1 0 0 0 1 1 0 1 0 1 0 1 0 1 0 1 0 1 0 0 0 0 1

Multicollinearity

Check for multicollinearity

gold_recipe <- gold_recipe |> 
  step_corr(all_predictors(),threshold = 0.9) 

# train the recipe
gold_recipe_trained <-gold_recipe |> 
  prep() 

# check what variables would be removed due to multicollinearity
# PCA: model will be hard to interpret; Feature transformation can be considered for real data if necessary and suitable  
print(glue::glue("Predictors with > 0.9 correlation and will be removed: {gold_recipe_trained$steps[[7]]$removals}"))
## Predictors with > 0.9 correlation and will be removed: PhoneService_Yes
## Predictors with > 0.9 correlation and will be removed: InternetService_No
## Predictors with > 0.9 correlation and will be removed: OnlineSecurity_No.internet.service
## Predictors with > 0.9 correlation and will be removed: OnlineBackup_No.internet.service
## Predictors with > 0.9 correlation and will be removed: DeviceProtection_No.internet.service
## Predictors with > 0.9 correlation and will be removed: TechSupport_No.internet.service
## Predictors with > 0.9 correlation and will be removed: StreamingTV_No.internet.service

Preprocess Data

The recipe is trained on the training dataset, and applied to both training and test data to create the dataset ready for modeling.

train_gold_tbl<- gold_recipe |> 
  prep() |> 
  bake(train_silver_tbl)

test_gold_tbl <- gold_recipe |> 
  prep() |> 
  bake(test_silver_tbl)
# for technical audience only, skim the processed data, run and check the results in console
train_gold_tbl |> 
  skimr::skim()

# notice that all process, dummy, standardization, imputation, etc is done successfully in the test set too
test_gold_tbl |> 
  skimr::skim()

EDA on preprocessed data

In this section, we take a quick look at all features in light of their relationship with the target (Churn). This is done for two purposes

    1. Feature quality: Get an idea whether we need to collect more data for model training.
    1. Initial insights: Discuss with business stakeholders about features with the highest correlations.

Correlation plot

The features are sorted based on its correlation with the target (Churn). The features at the top are negatively correlated with (i.e., preventing) Churn, while the features at the bottom are positively correlated with (i.e., leading to) Churn.

source(here::here("00_Functions","plot_cor.R"))

train_gold_tbl |> 
  select(-customerID) |> 
  plot_cor(target=Churn,fct_reorder = TRUE, fct_rev = TRUE)

Distribution & Correlation

source(here::here("00_Functions","plot_ggpairs.R"))

# the variable names come from silver tbl, coz factors in gold tbl are dummies
# this may lead to errors if any numeric variables were deleted in data preprocessing due to multicollinearity
numeric_variables <- train_silver_tbl |> 
  select_if(is.numeric) |> 
  colnames()
  
train_gold_tbl |> 
  select(all_of(numeric_variables), Churn) |> 
  plot_ggpairs(color=Churn)

Cross Table

We can also make a cross table to explore any potential differences across the churn vs not-churn group over all numeric and categorical features. For each feature, a test result is available for EDA purposes.

library(crosstable)
train_gold_tbl %>% 
  crosstable(train_gold_tbl %>% select(-customerID,-Churn) %>% names(),
             by = Churn,
             percent_digits = 0, percent_pattern = "{n} ({p_col})",
             test = TRUE) %>% 
  as_flextable(show_id = FALSE)

Model Training

Here, h2o.ai is used for model training. 96 Models (including ensemble models) were trained within 30 minutes training time (see the list here. The top performing model was checked with diagnostic plots first, then used for prediction.

h2o.init()
##  Connection successful!
## 
## R is connected to the H2O cluster: 
##     H2O cluster uptime:         2 hours 15 minutes 
##     H2O cluster timezone:       Australia/Sydney 
##     H2O data parsing timezone:  UTC 
##     H2O cluster version:        3.34.0.3 
##     H2O cluster version age:    7 months and 20 days !!! 
##     H2O cluster name:           H2O_started_from_R_ningw_qda835 
##     H2O cluster total nodes:    1 
##     H2O cluster total memory:   3.21 GB 
##     H2O cluster total cores:    8 
##     H2O cluster allowed cores:  8 
##     H2O cluster healthy:        TRUE 
##     H2O Connection ip:          localhost 
##     H2O Connection port:        54321 
##     H2O Connection proxy:       NA 
##     H2O Internal Security:      FALSE 
##     H2O API Extensions:         Amazon S3, Algos, AutoML, Core V3, TargetEncoder, Core V4 
##     R Version:                  R version 4.1.2 (2021-11-01)
h2o.no_progress() #not to show progress in RMarkdown results
# to reverse this, use:
# h2o.show_progress()

# Here in order to keep lime::explain() understandable (e.g., original values before scaling; categorical before dummy), we will use SILVER tables for training
# make sure to double check gold_receipe steps in advance to make sure that all can be done my h2o.automl()


# Get the three dataset ready for h2o: Training, Validation, and Test
h2o_split <- h2o.splitFrame(data=as.h2o(train_silver_tbl), ratio = 0.75, seed = 1234) #silver
train_h2o_tbl <- h2o_split[[1]]
valid_h2o_tbl <- h2o_split[[2]]
test_h2o_tbl <- as.h2o(test_silver_tbl) #silver


# Set target, features and ID
target = "Churn"
ID = "customerID"
features = setdiff(names(train_silver_tbl),c(target,ID))
# train the models
h2o_models <- h2o.automl(x=features,y=target,training_frame=train_h2o_tbl,
                         validation_frame=valid_h2o_tbl,leaderboard_frame=test_h2o_tbl,
                         balance_classes=TRUE,  # for mild imbalanced class
                         nfolds = 5, max_runtime_secs = 1800) 

# check model performance, sorted by AUC
all_models_tbl <- h2o_models@leaderboard |> 
  as_tibble() |> 
  arrange(desc(auc)) 

# save the model list for future reference
write_csv(all_models_tbl, file=here::here("00_Models_auto","all_models_list.csv"))

# ------------ check out the top3 models before saving
source(here::here("00_Functions","h2o_performance_plot.R"))
h2o_models@leaderboard |> 
  plot_h2o_performance(newdata = test_silver_tbl)

# ----------- save the top3 models

# take the top performing models
top3 <- h2o_models@leaderboard |> 
  get_model_name_by_position(1:3) 

saveRDS(top3, file=here::here("00_Models_auto","top3.RData"))

# save the top 3 models
for (i in 1:3) {
  model_name <- h2o_models@leaderboard |> 
                get_model_name_by_position(i)
      
  model <- h2o.getModel(model_name)
  
  h2o.saveModel(model,path=here::here("00_Models_auto"))
}

Model Evaluation

96 models were trained, followed please see their performance sorted by AUC (descending).

Metrics

# load the model list
all_models_tbl <- read_csv(file=here::here("00_Models_auto","all_models_list.csv"))
top3 <- readRDS(file=here::here("00_Models_auto","top3.RData"))

all_models_tbl |> 
  kbl() |> 
  kable_material_dark()
model_id auc logloss aucpr mean_per_class_error rmse mse
StackedEnsemble_BestOfFamily_6_AutoML_1_20220122_200013 0.8429603 0.4186760 0.6860444 0.2487417 0.3678451 0.1353100
StackedEnsemble_BestOfFamily_4_AutoML_1_20220122_200013 0.8429050 0.4184483 0.6874397 0.2499009 0.3675854 0.1351190
DeepLearning_grid_1_AutoML_1_20220122_200013_model_10 0.8426284 0.7858194 0.6785094 0.2469683 0.5179997 0.2683237
DeepLearning_grid_3_AutoML_1_20220122_200013_model_6 0.8424954 0.7665510 0.6795821 0.2365091 0.4976893 0.2476947
GBM_lr_annealing_selection_AutoML_1_20220122_200013_select_model 0.8417540 0.4218677 0.6813518 0.2360781 0.3695331 0.1365547
StackedEnsemble_Best1000_1_AutoML_1_20220122_200013 0.8415162 0.4208542 0.6860456 0.2485370 0.3685001 0.1357923
StackedEnsemble_AllModels_6_AutoML_1_20220122_200013 0.8413602 0.4192623 0.6849275 0.2512880 0.3683571 0.1356870
StackedEnsemble_AllModels_4_AutoML_1_20220122_200013 0.8413255 0.4192549 0.6849544 0.2495145 0.3683081 0.1356509
GBM_grid_1_AutoML_1_20220122_200013_model_31 0.8412793 0.4215656 0.6820301 0.2426716 0.3689393 0.1361162
StackedEnsemble_AllModels_3_AutoML_1_20220122_200013 0.8410456 0.4197545 0.6844148 0.2505152 0.3688374 0.1360410
DeepLearning_grid_3_AutoML_1_20220122_200013_model_4 0.8410151 2.7808195 0.6633470 0.2392139 0.7593069 0.5765470
StackedEnsemble_AllModels_2_AutoML_1_20220122_200013 0.8409300 0.4204980 0.6830169 0.2497407 0.3683970 0.1357164
StackedEnsemble_BestOfFamily_3_AutoML_1_20220122_200013 0.8407988 0.4203761 0.6828417 0.2497407 0.3684030 0.1357208
GBM_grid_1_AutoML_1_20220122_200013_model_12 0.8407971 0.4200798 0.6854729 0.2410798 0.3690265 0.1361805
GBM_grid_1_AutoML_1_20220122_200013_model_22 0.8405990 0.4205572 0.6793044 0.2522424 0.3682675 0.1356209
StackedEnsemble_BestOfFamily_2_AutoML_1_20220122_200013 0.8404660 0.4207655 0.6809188 0.2507183 0.3684894 0.1357844
StackedEnsemble_AllModels_1_AutoML_1_20220122_200013 0.8404536 0.4209126 0.6815473 0.2507183 0.3685602 0.1358366
StackedEnsemble_BestOfFamily_1_AutoML_1_20220122_200013 0.8403752 0.4208745 0.6803333 0.2511047 0.3685890 0.1358578
GLM_1_AutoML_1_20220122_200013 0.8393712 0.4227888 0.6767564 0.2440554 0.3694661 0.1365052
DeepLearning_grid_2_AutoML_1_20220122_200013_model_6 0.8393679 0.6618479 0.6703972 0.2457628 0.4813143 0.2316635
GBM_grid_1_AutoML_1_20220122_200013_model_32 0.8391665 0.4249049 0.6833515 0.2527428 0.3703698 0.1371738
GBM_grid_1_AutoML_1_20220122_200013_model_36 0.8390013 0.4238496 0.6792855 0.2410996 0.3698086 0.1367584
DeepLearning_grid_3_AutoML_1_20220122_200013_model_3 0.8388230 1.3903065 0.6698004 0.2433965 0.6723960 0.4521164
DeepLearning_grid_2_AutoML_1_20220122_200013_model_10 0.8388131 0.5510382 0.6698477 0.2410319 0.4363939 0.1904397
GBM_grid_1_AutoML_1_20220122_200013_model_10 0.8385976 0.4229516 0.6805662 0.2514911 0.3699705 0.1368782
DeepLearning_grid_2_AutoML_1_20220122_200013_model_7 0.8382640 2.9501760 0.6606921 0.2427608 0.7659212 0.5866354
GBM_grid_1_AutoML_1_20220122_200013_model_26 0.8382219 0.4227443 0.6814261 0.2418741 0.3696969 0.1366758
GBM_grid_1_AutoML_1_20220122_200013_model_28 0.8381278 0.4260265 0.6750600 0.2472622 0.3712230 0.1378065
DeepLearning_grid_3_AutoML_1_20220122_200013_model_5 0.8381097 0.9286804 0.6685972 0.2449207 0.5492471 0.3016724
GBM_grid_1_AutoML_1_20220122_200013_model_40 0.8380882 0.4236667 0.6794013 0.2484461 0.3702305 0.1370706
GBM_grid_1_AutoML_1_20220122_200013_model_2 0.8378405 0.4261121 0.6694961 0.2467387 0.3712516 0.1378277
DeepLearning_grid_2_AutoML_1_20220122_200013_model_4 0.8377530 1.5358004 0.6719779 0.2430118 0.6548181 0.4287868
GBM_grid_1_AutoML_1_20220122_200013_model_20 0.8368084 0.4250418 0.6753893 0.2503765 0.3705319 0.1372939
GBM_grid_1_AutoML_1_20220122_200013_model_15 0.8365921 0.4256363 0.6764487 0.2544700 0.3706331 0.1373689
GBM_grid_1_AutoML_1_20220122_200013_model_18 0.8365327 0.4271538 0.6718058 0.2519006 0.3719059 0.1383140
DeepLearning_grid_1_AutoML_1_20220122_200013_model_13 0.8355130 0.4371138 0.6746673 0.2441016 0.3741200 0.1399658
StackedEnsemble_BestOfFamily_5_AutoML_1_20220122_200013 0.8354280 0.4248692 0.6735632 0.2503996 0.3706099 0.1373517
DeepLearning_grid_2_AutoML_1_20220122_200013_model_2 0.8353083 3.7995294 0.6225588 0.2381917 0.8028009 0.6444893
GBM_1_AutoML_1_20220122_200013 0.8340062 0.4285271 0.6680368 0.2447836 0.3718222 0.1382517
GBM_grid_1_AutoML_1_20220122_200013_model_41 0.8338692 0.4312231 0.6701641 0.2464184 0.3739051 0.1398050
DeepLearning_grid_1_AutoML_1_20220122_200013_model_5 0.8338106 1.5570297 0.6580370 0.2528782 0.5751254 0.3307693
StackedEnsemble_AllModels_5_AutoML_1_20220122_200013 0.8333953 0.4284951 0.6590979 0.2577428 0.3730240 0.1391469
GBM_grid_1_AutoML_1_20220122_200013_model_42 0.8332136 0.4412795 0.6758361 0.2395557 0.3780005 0.1428844
DeepLearning_grid_1_AutoML_1_20220122_200013_model_11 0.8331352 0.6551105 0.6674346 0.2549687 0.4698987 0.2208047
GBM_grid_1_AutoML_1_20220122_200013_model_11 0.8330493 0.4346830 0.6671945 0.2498283 0.3740896 0.1399430
DeepLearning_grid_1_AutoML_1_20220122_200013_model_12 0.8326604 3.0211977 0.6321095 0.2439894 0.7589914 0.5760679
GBM_grid_1_AutoML_1_20220122_200013_model_5 0.8326481 0.4324501 0.6715337 0.2527180 0.3743213 0.1401164
GBM_grid_1_AutoML_1_20220122_200013_model_33 0.8325448 0.4313642 0.6675994 0.2483520 0.3723110 0.1386154
GBM_grid_1_AutoML_1_20220122_200013_model_39 0.8318100 0.4326108 0.6657893 0.2421696 0.3727869 0.1389701
DeepLearning_grid_2_AutoML_1_20220122_200013_model_5 0.8313146 4.7937817 0.6382740 0.2512186 0.8471850 0.7177224
GBM_grid_1_AutoML_1_20220122_200013_model_24 0.8308383 0.4357828 0.6652116 0.2601520 0.3758469 0.1412609
GBM_grid_1_AutoML_1_20220122_200013_model_21 0.8305914 0.4366144 0.6622726 0.2530565 0.3747984 0.1404738
GBM_grid_1_AutoML_1_20220122_200013_model_25 0.8301984 0.4344138 0.6627115 0.2529905 0.3741087 0.1399573
DeepLearning_grid_2_AutoML_1_20220122_200013_model_9 0.8299656 0.7057789 0.6328349 0.2545146 0.4889228 0.2390455
GBM_grid_1_AutoML_1_20220122_200013_model_9 0.8287841 0.4381046 0.6655288 0.2552164 0.3754303 0.1409479
GBM_5_AutoML_1_20220122_200013 0.8285851 0.4386119 0.6676289 0.2632218 0.3755338 0.1410257
GBM_grid_1_AutoML_1_20220122_200013_model_35 0.8284406 0.4359689 0.6655123 0.2556721 0.3748381 0.1405036
DeepLearning_1_AutoML_1_20220122_200013 0.8275018 1.2476622 0.6470500 0.2677925 0.6213484 0.3860738
GBM_grid_1_AutoML_1_20220122_200013_model_17 0.8264401 0.4530946 0.6537547 0.2395309 0.3819990 0.1459232
GBM_grid_1_AutoML_1_20220122_200013_model_1 0.8262692 0.4436528 0.6637998 0.2584446 0.3778539 0.1427736
GBM_grid_1_AutoML_1_20220122_200013_model_3 0.8249985 0.4468257 0.6619466 0.2547375 0.3780430 0.1429165
DeepLearning_grid_2_AutoML_1_20220122_200013_model_1 0.8246633 1.3766840 0.6332877 0.2539466 0.5874531 0.3451012
GBM_2_AutoML_1_20220122_200013 0.8245477 0.4441974 0.6630663 0.2604014 0.3771702 0.1422573
GBM_4_AutoML_1_20220122_200013 0.8244164 0.4487594 0.6564965 0.2639004 0.3791206 0.1437324
GBM_grid_1_AutoML_1_20220122_200013_model_29 0.8242967 0.4485780 0.6609483 0.2580582 0.3789506 0.1436036
GBM_3_AutoML_1_20220122_200013 0.8237881 0.4471456 0.6544126 0.2593528 0.3780531 0.1429241
GBM_grid_1_AutoML_1_20220122_200013_model_34 0.8235380 0.4488503 0.6482430 0.2635405 0.3782595 0.1430802
XRT_1_AutoML_1_20220122_200013 0.8234017 0.4499874 0.6476233 0.2572177 0.3822107 0.1460850
GBM_grid_1_AutoML_1_20220122_200013_model_19 0.8223870 0.4512313 0.6373500 0.2535569 0.3824016 0.1462310
GBM_grid_1_AutoML_1_20220122_200013_model_6 0.8215936 0.4451277 0.6542411 0.2660636 0.3771840 0.1422678
GBM_grid_1_AutoML_1_20220122_200013_model_30 0.8212914 0.4508938 0.6472986 0.2595114 0.3803616 0.1446750
DeepLearning_grid_2_AutoML_1_20220122_200013_model_11 0.8207151 3.4537916 0.6015543 0.2540143 0.7656181 0.5861711
GBM_grid_1_AutoML_1_20220122_200013_model_16 0.8185767 0.4876131 0.6221230 0.2600100 0.3956478 0.1565372
DeepLearning_grid_1_AutoML_1_20220122_200013_model_9 0.8184809 3.3531349 0.5776392 0.2575381 0.7432531 0.5524252
GBM_grid_1_AutoML_1_20220122_200013_model_13 0.8177478 0.4557151 0.6298769 0.2707912 0.3841826 0.1475963
GBM_grid_1_AutoML_1_20220122_200013_model_37 0.8173655 0.4690020 0.6341797 0.2671518 0.3868436 0.1496480
GBM_grid_1_AutoML_1_20220122_200013_model_14 0.8162930 0.4539915 0.6410161 0.2699689 0.3815077 0.1455481
DRF_1_AutoML_1_20220122_200013 0.8154607 0.5044749 0.6323613 0.2674226 0.3876678 0.1502863
DeepLearning_grid_1_AutoML_1_20220122_200013_model_4 0.8154459 0.5299915 0.6124203 0.2650415 0.3913375 0.1531450
GBM_grid_1_AutoML_1_20220122_200013_model_23 0.8124777 0.4658437 0.6502248 0.2669718 0.3835722 0.1471276
GBM_grid_1_AutoML_1_20220122_200013_model_7 0.8077047 0.4908590 0.6258822 0.2658324 0.3952258 0.1562034
DeepLearning_grid_1_AutoML_1_20220122_200013_model_2 0.8027145 0.9314588 0.6095837 0.2795182 0.4982648 0.2482679
GBM_grid_1_AutoML_1_20220122_200013_model_38 0.7985921 0.5026499 0.6012272 0.2830173 0.3984007 0.1587231
GBM_grid_1_AutoML_1_20220122_200013_model_4 0.7946943 0.5013130 0.6091147 0.2782897 0.3969791 0.1575924
DeepLearning_grid_2_AutoML_1_20220122_200013_model_3 0.7895390 3.4013958 0.5965377 0.2778620 0.5151853 0.2654159
GBM_grid_1_AutoML_1_20220122_200013_model_27 0.7879579 0.5209375 0.6018495 0.2914999 0.4030699 0.1624654
DeepLearning_grid_1_AutoML_1_20220122_200013_model_7 0.7878687 2.7998577 0.5373631 0.2766516 0.6072039 0.3686966
DeepLearning_grid_1_AutoML_1_20220122_200013_model_3 0.7874658 2.4600664 0.5309107 0.2815146 0.5832148 0.3401395
GBM_grid_1_AutoML_1_20220122_200013_model_43 0.7864594 0.5589513 0.5357356 0.2829430 0.4329170 0.1874171
DeepLearning_grid_1_AutoML_1_20220122_200013_model_6 0.7843606 3.1758276 0.5519104 0.2665128 0.5104468 0.2605560
DeepLearning_grid_1_AutoML_1_20220122_200013_model_1 0.7830875 2.4934803 0.5505257 0.2912918 0.5143025 0.2645071
DeepLearning_grid_1_AutoML_1_20220122_200013_model_8 0.7671931 7.7873074 0.4526219 0.2585767 0.7439776 0.5535026
DeepLearning_grid_2_AutoML_1_20220122_200013_model_8 0.7616877 4.6353891 0.5300179 0.2752431 0.5133562 0.2635346
DeepLearning_grid_3_AutoML_1_20220122_200013_model_2 0.7503071 5.1601774 0.4794421 0.2961301 0.6206716 0.3852332
GBM_grid_1_AutoML_1_20220122_200013_model_8 0.7193317 0.6642304 0.4777992 0.2939041 0.4442664 0.1973727
DeepLearning_grid_3_AutoML_1_20220122_200013_model_1 0.6652829 3.9819507 0.3566597 0.3437331 0.8526132 0.7269493

Diagnostic plots

Let’s check the top model first.

Confusion Matrix

# get the top model
top_model <- h2o.loadModel(path=here::here("00_Models_auto",top3[[1]]))

# check its performance on the test data
top_model_perf <- h2o.performance(top_model,newdata=test_h2o_tbl)

# check a particular metric
h2o.confusionMatrix(top_model_perf) 
## Confusion Matrix (vertical: actual; across: predicted)  for max f1 @ threshold = 0.292101226561922:
##          No Yes    Error       Rate
## No     1035 259 0.200155  =259/1294
## Yes      97 371 0.207265    =97/468
## Totals 1132 630 0.202043  =356/1762

Precision & Recall

The following graph shows how precision and recall change with threshold. There is a tradeoff between the two. The threshold that achieves the highest F1 is 0.29 (marked by the dashed line).

top_model_metrics <- top_model_perf |> 
  h2o.metric() |> 
  as_tibble()

top_model_metrics |> 
  ggplot(aes(x=threshold)) +
  geom_line(aes(y=precision), color="blue",size=1) +
  geom_line(aes(y=recall), color="red3",size=1) +
  geom_vline(aes(xintercept = h2o.find_threshold_by_max_metric(top_model_perf,metric="f1")), color="grey30",size=1.1, linetype="dashed") +
  theme_tq() +
  labs(y="", title="Precision(blue) and Recall(red)", 
       subtitle = "Threshold for maximum F1 marked by the dashed line")

top_model_metrics |> 
  ggplot(aes(x=precision, y= recall)) +
  geom_line(alpha=0.7, size=1) +
  theme_tq()+
  scale_color_tq() +
  labs(
    title = "Precision vs Recall"
  )

ROC Curve

top_model_metrics |> 
  ggplot(aes(x=fpr, y=tpr)) +
  geom_line() +
  theme_tq() +
  scale_color_tq() +
  labs(
    title = "ROC Curve"
  )

test_predictions <- top_model |> 
  h2o.predict(newdata = test_h2o_tbl) |> 
  as_tibble()

# combine predictions with Churn result
test_predicted <- test_predictions |> 
  bind_cols(test_gold_tbl |> select(Churn, customerID)) 

Business Impact Analysis

The Gain Tables

The following table shows the churn situation in the test data. Among all 1762 customers in this dataset, 26.6% left the company. That is to say, if we conduct retention program at random, for every 4 customers targeted, we are likely to cover around 1 customer that is actually likely to churn, which is probably not very cost-effective.

Total churn Total customers % churn
468 1762 26.6


The following table shows the first 20 customers the model guessed to be the most likely to leave. Without the guidance of this model, guesses based on the overall 26.6% proportion would be that 5 out of each 20 customers are likely to leave. However, among these 20 customers identified by the model, 20 out of 20 indeed left (note: normally the model makes some mistakes, we are just very lucky this time).

Main takeaway: a focused customer retention program guided by such a model is much more likely to reach customers who are at high risk of churn.

Predicted probability of churn Churn customerID
0.8659771 Yes 5419-JPRRN
0.8470649 Yes 1400-MMYXY
0.8433873 Yes 1875-QIVME
0.8395100 Yes 7181-BQYBV
0.8360232 Yes 4822-RVYBB
0.8276673 Yes 5299-SJCZT
0.8274601 Yes 5192-EBGOV
0.8254721 Yes 0781-LKXBR
0.8249862 Yes 3389-YGYAI
0.8206488 Yes 8098-LLAZX
0.8196879 Yes 9305-CDSKC
0.8183148 Yes 1069-XAIEM
0.8165900 Yes 4587-VVTOX
0.8154116 Yes 0655-RBDUG
0.8106557 Yes 7274-RTAPZ
0.8097051 Yes 8603-IJWDN
0.8053700 Yes 6108-OQZDQ
0.8046836 Yes 3811-VBYBZ
0.8037927 Yes 2968-SSGAA
0.8023322 Yes 4078-SAYYN


If we keep going down the list till we have covered top 10% of customers in the test dataset. The model achieved 80% precision among these 176 customers, that is, 80% of them indeed left the company.

Total Churned % Churned
176 141 80%

The Gain Plot

As shown in the following graph, guided by the model,

  • by reaching out to 25% of the customers, a retention program will cover about 60% of all churning customers;
  • by reaching out to 50% of the customers, about 90% of all churning customers will be covered.

Prediction Explained

Machine learning models sometimes feel like a black-box. However, if we are able to provide a clear explanation to the end users about why certain customer has been predicted to churn, this will be very helpful at least in two ways

  • It helps build trust toward the prediction because the end users can then validate or help improve the model.
  • It helps end users decide how best to act on the prediction. For example, if the type of contract has an impact, would I be able to discuss with the customer and negotiate a different type of contract?

Please see the following graph for an example where the prediction about a customer (ID: 5419-JPRRN) is explained by its 10 most influential predictors. In this graph, the features/predictors are ranked by weight, with the feature at top being the most influential. Features that support the prediction are colored darkblue, while features contradict the prediction colored red. (more on the analysis behind the explanation).