Predicting Customer Churn

Introduction

The aim of this project is to build a model that can predict with a high degree of accuracy whether the customer churned. We will also explore which factors increase the probability of customer churning.

We have customers’ data from a telco company including whether the customer churned. We will use this data to build multiple classification models and select the one that does a good job at predicting whether the customer churned.

Loading the required packages and import data

knitr::opts_chunk$set(cache = TRUE)
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(tidymodels))
suppressPackageStartupMessages(library(Metrics))
suppressPackageStartupMessages(library(ggpubr))
suppressPackageStartupMessages(library(janitor))
suppressPackageStartupMessages(library(vip))
theme_set(theme_pubr()) # Setting a global theme for all plots
customer_churn <- read_csv("data/customer churn.csv") |> 
  # clean column names
  janitor::clean_names()
Rows: 1428 Columns: 14
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
dbl (14): Call  Failure, Complains, Subscription  Length, Charge  Amount, Se...

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
glimpse(customer_churn)
Rows: 1,428
Columns: 14
$ call_failure            <dbl> 8, 0, 10, 10, 3, 11, 4, 13, 7, 7, 6, 9, 25, 4,…
$ complains               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ subscription_length     <dbl> 38, 39, 37, 38, 38, 38, 38, 37, 38, 38, 38, 38…
$ charge_amount           <dbl> 0, 0, 0, 0, 0, 1, 0, 2, 0, 1, 0, 0, 3, 1, 0, 1…
$ seconds_of_use          <dbl> 4370, 318, 2453, 4198, 2393, 3775, 2360, 9115,…
$ frequency_of_use        <dbl> 71, 5, 60, 66, 58, 82, 39, 121, 169, 83, 95, 5…
$ frequency_of_sms        <dbl> 5, 7, 359, 1, 2, 32, 285, 144, 0, 2, 7, 8, 54,…
$ distinct_called_numbers <dbl> 17, 4, 24, 35, 33, 28, 18, 43, 44, 25, 12, 17,…
$ age_group               <dbl> 3, 2, 3, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3…
$ tariff_plan             <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
$ status                  <dbl> 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1…
$ age                     <dbl> 30, 25, 30, 15, 15, 30, 30, 30, 30, 30, 30, 30…
$ customer_value          <dbl> 197.640, 46.035, 1536.520, 240.020, 145.805, 2…
$ churn                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…

Data Dictionary

Column Name Description
call_failure number of call failures
compains binary(0: no compalint, 1: complaint)
Subscription Length total months of subscription
Charge Amount ordinal attribute (0: lowest amount, 9: highest amount)
Seconds of Use total seconds of calls
Frequency of use total number of calls
Frequency of SMS total number of text messages
Distinct Called Numbers total number of distinct phone calls
Age Group ordinal attribute (1: younger age, 5: older age)
Tariff Plan binary (1: Pay as you go, 2: contractual)
Status binary (1: active, 2: non-active)
Age age of customer
Customer Value the calculated value of customer
Churn class label (1: churn, 0: non-churn)

Exploratory Data Analysis

In this dataset, our variable of interest in churn. In this section we will explore the relationship between this variable and other features in the dataset

First, lets modify the churn variable from a numeric to a more descriptive factor

customer_churn <- customer_churn |> 
  mutate(
    churn = factor(ifelse(churn == 1, "churn", "non-churn"))
  )

glimpse(customer_churn)
Rows: 1,428
Columns: 14
$ call_failure            <dbl> 8, 0, 10, 10, 3, 11, 4, 13, 7, 7, 6, 9, 25, 4,…
$ complains               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ subscription_length     <dbl> 38, 39, 37, 38, 38, 38, 38, 37, 38, 38, 38, 38…
$ charge_amount           <dbl> 0, 0, 0, 0, 0, 1, 0, 2, 0, 1, 0, 0, 3, 1, 0, 1…
$ seconds_of_use          <dbl> 4370, 318, 2453, 4198, 2393, 3775, 2360, 9115,…
$ frequency_of_use        <dbl> 71, 5, 60, 66, 58, 82, 39, 121, 169, 83, 95, 5…
$ frequency_of_sms        <dbl> 5, 7, 359, 1, 2, 32, 285, 144, 0, 2, 7, 8, 54,…
$ distinct_called_numbers <dbl> 17, 4, 24, 35, 33, 28, 18, 43, 44, 25, 12, 17,…
$ age_group               <dbl> 3, 2, 3, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3…
$ tariff_plan             <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
$ status                  <dbl> 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1…
$ age                     <dbl> 30, 25, 30, 15, 15, 30, 30, 30, 30, 30, 30, 30…
$ customer_value          <dbl> 197.640, 46.035, 1536.520, 240.020, 145.805, 2…
$ churn                   <fct> non-churn, non-churn, non-churn, non-churn, no…

Customer Churn and Subscription Length

Subscription length is the number if months the customer has been with the company.

Customers who have been with the company for a short period of time are like to churn compared to the customers who have been in the company for long.

The graph also shows that customers who have been with the company for the than 1 year are more likely to churn.

Why is the company loosing new customers?

customer_churn |> 
  ggplot(aes(x = subscription_length, color = churn, fill = churn)) +
  geom_density(alpha = 0.7) +
  fill_palette("jco") +
  color_palette("jco")  +
  labs(
    title = "Churn vs Customer Subscription Lenght",
    x = "Customer Subscription Lenght in Months",
    y = "Density"
  ) +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    legend.position = "top"
  )

Customer Value and Customer Churn

Customer value is the calculated value of the customer. Customers with low value churn more compared to customers who have high calculated value.

customer_churn |> 
  ggplot(aes(x = customer_value, color = churn, fill = churn))+
  geom_density(alpha = 0.7) +
  fill_palette("jco") +
  color_palette("jco")  +
  labs(
    x = "Customer Value",
    y = "Density"
  ) +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    legend.position = "top"
  )

Churn & Customer Age

Most of the telco customers are in their 30s years of age. There no much we can say about age and customer’s age, only that younger customers tend to churn more.

customer_churn |> 
  ggplot(aes(x = age, color = churn, fill = churn))+
  geom_histogram(alpha = 0.7, binwidth = 15, position = "dodge") +
  fill_palette("jco") +
  color_palette("jco")  +
  labs(
    x = "Customer age",
    y = "Density"
  ) +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    legend.position = "top"
  )

Frequency of Use, Seconds of Use & Customer Churn

Customers who make less calls are more likely to churn compared to customers who makes calls more frequently.

customer_churn |> 
  ggplot(aes(x = frequency_of_use, seconds_of_use, color = churn, shape = churn)) +
  geom_point(alpha = 0.7, size = 2) +
  color_palette("jco") +
  labs(
    title = "Frequency of Use, Seconds of Use & Customer Churn",
    x = "Total Number of Calls",
    y = "Total Seconds of Calls"
  ) +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    legend.position = "top"
  )

Model Training

We will build several classification models and select the one that does a good job at predicting customers who are likely to churn. We would also like to understand which variables are likely to influence customer churn.

The first step will be preparing our data for modeling. Variables like age_group, tariff plan, status etc are coded as numeric while there are factors. We need to recode these variable prope

customer_churn <- customer_churn |> 
  mutate(
    #age group as factor
    age_group = factor(age_group,
                       ordered = TRUE, 
                       levels = c("1", "2", "3", "4", "5")),
    charge_amount = factor(charge_amount, ordered = TRUE, 
                           levels = c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")),
    tariff_plan = case_when(
      tariff_plan == 1 ~ "Pass As You Go",
      tariff_plan == 2 ~ "Contractual"
    ),
    status = ifelse(status == 1, "active", "non active")
  )
customer_churn
# A tibble: 1,428 × 14
   call_failure complains subscription_length charge_amount seconds_of_use
          <dbl>     <dbl>               <dbl> <ord>                  <dbl>
 1            8         0                  38 0                       4370
 2            0         0                  39 0                        318
 3           10         0                  37 0                       2453
 4           10         0                  38 0                       4198
 5            3         0                  38 0                       2393
 6           11         0                  38 1                       3775
 7            4         0                  38 0                       2360
 8           13         0                  37 2                       9115
 9            7         0                  38 0                      13773
10            7         0                  38 1                       4515
# ℹ 1,418 more rows
# ℹ 9 more variables: frequency_of_use <dbl>, frequency_of_sms <dbl>,
#   distinct_called_numbers <dbl>, age_group <ord>, tariff_plan <chr>,
#   status <chr>, age <dbl>, customer_value <dbl>, churn <fct>

Data Splitting

We will split our data into training and testing. We will further split our training data into cross validation sets for model selection.

We have a class imbalance in our dateset, there are more non churn customers than churn. Churned customers consist 16% of the data.

customer_churn |> 
  count(churn) |> 
  mutate(
    prop = scales::percent(n / sum(n))
  )
# A tibble: 2 × 3
  churn         n prop 
  <fct>     <int> <chr>
1 churn       233 16%  
2 non-churn  1195 84%  

To ensure we have same churn and non churn proportions in both training and testing sets, we will use stratified sampling.

We will reserve 20% of the data for model testing.

set.seed(123) # Fore reporducability
customer_split <- initial_split(na.omit(customer_churn), strata = churn, prop = 0.8)

customer_train <- training(customer_split)
customer_test <- testing(customer_split)

Both training and testing sets have the same 16% churcn proportion.

customer_train |> 
  count(churn) |> 
  mutate(
    prop = scales::percent(n / sum(n))
  )
# A tibble: 2 × 3
  churn         n prop 
  <fct>     <int> <chr>
1 churn       186 16%  
2 non-churn   955 84%  
customer_test |> 
  count(churn) |> 
  mutate(
    prop = scales::percent(n / sum(n))
  )
# A tibble: 2 × 3
  churn         n prop 
  <fct>     <int> <chr>
1 churn        47 16%  
2 non-churn   239 84%  

We then create cross validation sets for modeling evaluations and selection.

set.seed(234)
customer_cv <- vfold_cv(customer_train, strata = churn)

customer_cv
#  10-fold cross-validation using stratification 
# A tibble: 10 × 2
   splits             id    
   <list>             <chr> 
 1 <split [1026/115]> Fold01
 2 <split [1026/115]> Fold02
 3 <split [1026/115]> Fold03
 4 <split [1026/115]> Fold04
 5 <split [1026/115]> Fold05
 6 <split [1027/114]> Fold06
 7 <split [1028/113]> Fold07
 8 <split [1028/113]> Fold08
 9 <split [1028/113]> Fold09
10 <split [1028/113]> Fold10

Logistic Regression Model

The first model to train will be a basic logistic regression model.

logistic_model <- logistic_reg(penalty = tune(), mixture = 1) |> 
  set_engine("glmnet")

logistic_model
Logistic Regression Model Specification (classification)

Main Arguments:
  penalty = tune()
  mixture = 1

Computational engine: glmnet 

We need to define several preprocessing steps for our logistic regression model. For better performance, logistic regression requires character/factor variables to be converted into dummy variables. The numeric variables also need be normalized. We will also remove variables that have zero variance, i.e they only appear once.

logistic_recipe <- 
  recipe(churn ~ ., data = customer_train) |> 
  step_dummy(all_nominal_predictors()) |>
  step_zv(all_predictors()) |> 
  step_normalize(all_numeric_predictors())

logistic_recipe
── Recipe ──────────────────────────────────────────────────────────────────────
── Inputs 
Number of variables by role
outcome:    1
predictor: 13
── Operations 
• Dummy variables from: all_nominal_predictors()
• Zero variance filter on: all_predictors()
• Centering and scaling for: all_numeric_predictors()

We put together the model specification and model recipe with a workflow

logistic_wkflow <- workflow() |> 
  add_model(logistic_model) |> 
  add_recipe(logistic_recipe)


logistic_wkflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps

• step_dummy()
• step_zv()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
Logistic Regression Model Specification (classification)

Main Arguments:
  penalty = tune()
  mixture = 1

Computational engine: glmnet 

Penalty grid

lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 30))

lr_reg_grid
# A tibble: 30 × 1
    penalty
      <dbl>
 1 0.0001  
 2 0.000127
 3 0.000161
 4 0.000204
 5 0.000259
 6 0.000329
 7 0.000418
 8 0.000530
 9 0.000672
10 0.000853
# ℹ 20 more rows
logistic_fit <- logistic_wkflow |> 
  tune_grid(
    grid = lr_reg_grid,
    resamples = customer_cv,
    metrics = metric_set(roc_auc),
    control = control_resamples(save_pred = TRUE)
  )

logistic_fit
# Tuning results
# 10-fold cross-validation using stratification 
# A tibble: 10 × 5
   splits             id     .metrics          .notes           .predictions
   <list>             <chr>  <list>            <list>           <list>      
 1 <split [1026/115]> Fold01 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>    
 2 <split [1026/115]> Fold02 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>    
 3 <split [1026/115]> Fold03 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>    
 4 <split [1026/115]> Fold04 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>    
 5 <split [1026/115]> Fold05 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>    
 6 <split [1027/114]> Fold06 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>    
 7 <split [1028/113]> Fold07 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>    
 8 <split [1028/113]> Fold08 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>    
 9 <split [1028/113]> Fold09 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>    
10 <split [1028/113]> Fold10 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>    
collect_metrics(logistic_fit)
# A tibble: 30 × 7
    penalty .metric .estimator  mean     n std_err .config              
      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
 1 0.0001   roc_auc binary     0.938    10 0.00776 Preprocessor1_Model01
 2 0.000127 roc_auc binary     0.938    10 0.00762 Preprocessor1_Model02
 3 0.000161 roc_auc binary     0.939    10 0.00755 Preprocessor1_Model03
 4 0.000204 roc_auc binary     0.939    10 0.00758 Preprocessor1_Model04
 5 0.000259 roc_auc binary     0.939    10 0.00762 Preprocessor1_Model05
 6 0.000329 roc_auc binary     0.939    10 0.00756 Preprocessor1_Model06
 7 0.000418 roc_auc binary     0.939    10 0.00749 Preprocessor1_Model07
 8 0.000530 roc_auc binary     0.938    10 0.00741 Preprocessor1_Model08
 9 0.000672 roc_auc binary     0.938    10 0.00722 Preprocessor1_Model09
10 0.000853 roc_auc binary     0.938    10 0.00723 Preprocessor1_Model10
# ℹ 20 more rows

Visualizing Logistic Regression metrics

collect_metrics(logistic_fit) |> 
  ggplot(aes(x = penalty, y = mean)) +
  geom_point() +
  geom_line() +
  scale_x_log10(labels = scales::label_number()) +
  theme_minimal()

Which is the best model based on roc curve

best_model <- logistic_fit |> 
  show_best("roc_auc") |> 
  arrange(penalty) |> 
  # high penalty value, less complicated model
  slice(5)

best_model
# A tibble: 1 × 7
   penalty .metric .estimator  mean     n std_err .config              
     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
1 0.000418 roc_auc binary     0.939    10 0.00749 Preprocessor1_Model07

Fewer irrelevant predictors is better. If performance is about the same, we’d prefer to choose a higher penalty value

Pulling predictions for the best model

logistic_fit |> 
  collect_predictions(parameters = best_model)
# A tibble: 1,141 × 7
   id     .pred_churn `.pred_non-churn`  .row  penalty churn .config            
   <chr>        <dbl>             <dbl> <int>    <dbl> <fct> <chr>              
 1 Fold01       0.275           0.725       1 0.000418 churn Preprocessor1_Mode…
 2 Fold01       0.179           0.821      21 0.000418 churn Preprocessor1_Mode…
 3 Fold01       0.989           0.0115     36 0.000418 churn Preprocessor1_Mode…
 4 Fold01       0.991           0.00882    42 0.000418 churn Preprocessor1_Mode…
 5 Fold01       0.985           0.0151     50 0.000418 churn Preprocessor1_Mode…
 6 Fold01       0.897           0.103      57 0.000418 churn Preprocessor1_Mode…
 7 Fold01       0.981           0.0188     81 0.000418 churn Preprocessor1_Mode…
 8 Fold01       0.422           0.578      95 0.000418 churn Preprocessor1_Mode…
 9 Fold01       0.984           0.0164    105 0.000418 churn Preprocessor1_Mode…
10 Fold01       0.437           0.563     128 0.000418 churn Preprocessor1_Mode…
# ℹ 1,131 more rows

Best logistic model roc curve

best_logistic_model <- logistic_fit |> 
  collect_predictions(parameters = best_model) |> 
  roc_curve(churn, .pred_churn) |>
  mutate(
    model = "Logistic Regression"
  )

best_logistic_model
# A tibble: 1,136 × 4
   .threshold specificity sensitivity model              
        <dbl>       <dbl>       <dbl> <chr>              
 1 -Inf           0                 1 Logistic Regression
 2    9.93e-7     0                 1 Logistic Regression
 3    1.51e-6     0.00105           1 Logistic Regression
 4    1.67e-6     0.00209           1 Logistic Regression
 5    2.35e-6     0.00314           1 Logistic Regression
 6    2.61e-6     0.00419           1 Logistic Regression
 7    3.61e-6     0.00524           1 Logistic Regression
 8    3.89e-6     0.00628           1 Logistic Regression
 9    4.20e-6     0.00733           1 Logistic Regression
10    4.88e-6     0.00838           1 Logistic Regression
# ℹ 1,126 more rows
best_logistic_model |> 
  autoplot()

Random Forest Model

The second model to try is a random forest model.

We will tune 3 hyper parameters,

  • trees : Number of trees

  • mtry: No of select predictorrs

  • min_n : Minimum Node Size

rf_model <- rand_forest(
  mtry = tune(),
  trees = tune(),
  min_n = tune()
) |> 
  set_engine("ranger") |> 
  set_mode("classification")

rf_model
Random Forest Model Specification (classification)

Main Arguments:
  mtry = tune()
  trees = tune()
  min_n = tune()

Computational engine: ranger 

Create a new workflow for random forest.

Random forest does not require categorical variables to be dummified and scaling of numerical predictors.

Instead of a normal recipe, we will use formula in our workflow. We will fit train data as is in our random forest model.

rf_wkflow <- workflow() |> 
  add_model(rf_model) |> 
  add_formula(churn ~ .)

rf_wkflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
churn ~ .

── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)

Main Arguments:
  mtry = tune()
  trees = tune()
  min_n = tune()

Computational engine: ranger 

Search grid for random forest model

set.seed(345)
rf_fit <- rf_wkflow |> 
    tune_grid(
    grid = 10,
    resamples = customer_cv,
    control = control_grid(save_pred = TRUE),
    metrics = metric_set(roc_auc)
  )
i Creating pre-processing data to finalize unknown parameter: mtry
rf_fit
# Tuning results
# 10-fold cross-validation using stratification 
# A tibble: 10 × 5
   splits             id     .metrics          .notes           .predictions
   <list>             <chr>  <list>            <list>           <list>      
 1 <split [1026/115]> Fold01 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>    
 2 <split [1026/115]> Fold02 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>    
 3 <split [1026/115]> Fold03 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>    
 4 <split [1026/115]> Fold04 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>    
 5 <split [1026/115]> Fold05 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>    
 6 <split [1027/114]> Fold06 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>    
 7 <split [1028/113]> Fold07 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>    
 8 <split [1028/113]> Fold08 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>    
 9 <split [1028/113]> Fold09 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>    
10 <split [1028/113]> Fold10 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>    
collect_metrics(rf_fit)
# A tibble: 10 × 9
    mtry trees min_n .metric .estimator  mean     n std_err .config             
   <int> <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
 1    11    79     3 roc_auc binary     0.976    10 0.00310 Preprocessor1_Model…
 2     2  1531    35 roc_auc binary     0.971    10 0.00390 Preprocessor1_Model…
 3    11   899    27 roc_auc binary     0.971    10 0.00333 Preprocessor1_Model…
 4    12   370    31 roc_auc binary     0.969    10 0.00338 Preprocessor1_Model…
 5     9  1944    36 roc_auc binary     0.970    10 0.00320 Preprocessor1_Model…
 6     8   594    13 roc_auc binary     0.977    10 0.00296 Preprocessor1_Model…
 7     4  1742    16 roc_auc binary     0.978    10 0.00322 Preprocessor1_Model…
 8     5  1210    18 roc_auc binary     0.977    10 0.00291 Preprocessor1_Model…
 9     7  1084    24 roc_auc binary     0.975    10 0.00303 Preprocessor1_Model…
10     1   798     6 roc_auc binary     0.970    10 0.00510 Preprocessor1_Model…
rf_best_model_parameters <- rf_fit |> 
  select_best(metric = "roc_auc") |> 
  slice(1)
rf_best_model_parameters
# A tibble: 1 × 4
   mtry trees min_n .config              
  <int> <int> <int> <chr>                
1     4  1742    16 Preprocessor1_Model07
best_rf_model <- rf_fit |> 
  collect_predictions(parameters = rf_best_model_parameters) |> 
  roc_curve(churn, .pred_churn) |> 
  mutate(
    model = "Random Forest"
  )

best_rf_model
# A tibble: 1,090 × 4
     .threshold specificity sensitivity model        
          <dbl>       <dbl>       <dbl> <chr>        
 1 -Inf              0                1 Random Forest
 2    0              0                1 Random Forest
 3    0.0000359      0.0346           1 Random Forest
 4    0.0000410      0.0356           1 Random Forest
 5    0.0000442      0.0387           1 Random Forest
 6    0.0000478      0.0408           1 Random Forest
 7    0.0000522      0.0429           1 Random Forest
 8    0.0000574      0.0471           1 Random Forest
 9    0.0000638      0.0492           1 Random Forest
10    0.0000765      0.0503           1 Random Forest
# ℹ 1,080 more rows
autoplot(best_rf_model)

best_logistic_model |> 
  bind_rows(best_rf_model) |> 
  ggplot(aes(x = 1 - specificity, y = sensitivity, col = model)) +
  geom_path(lwd = 1.5, alpha = 0.8) +
  geom_abline(lty = 3) +
  coord_equal() +
  scale_color_viridis_d(option = "plasma", end = .6)

Random Forest performs better than logistic regression

Last Fit.

We have established the best model is Random Forest with hyperparameters as follows

mtry: 4

trees: 1742

min_n: 16

last_rf_model <- rand_forest( mtry = 4, trees = 1742, min_n = 16) |> 
  set_engine("ranger", importance = "impurity") |> 
  set_mode("classification")

last_rf_model
Random Forest Model Specification (classification)

Main Arguments:
  mtry = 4
  trees = 1742
  min_n = 16

Engine-Specific Arguments:
  importance = impurity

Computational engine: ranger 
last_rf_wkflow <- rf_wkflow |> 
  update_model(last_rf_model)

last_rf_wkflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
churn ~ .

── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)

Main Arguments:
  mtry = 4
  trees = 1742
  min_n = 16

Engine-Specific Arguments:
  importance = impurity

Computational engine: ranger 

Fitting the final model

rf_last_fit <- last_rf_wkflow |> 
  last_fit(split = customer_split)

rf_last_fit
# Resampling results
# Manual resampling 
# A tibble: 1 × 6
  splits             id               .metrics .notes   .predictions .workflow 
  <list>             <chr>            <list>   <list>   <list>       <list>    
1 <split [1141/286]> train/test split <tibble> <tibble> <tibble>     <workflow>
collect_metrics(rf_last_fit)
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.944 Preprocessor1_Model1
2 roc_auc  binary         0.987 Preprocessor1_Model1
collect_predictions(rf_last_fit) |> 
  roc_curve(churn, .pred_churn) |> 
  autoplot()

Confusion Matric

collect_predictions(rf_last_fit) |> 
  conf_mat(churn, .pred_class)
           Truth
Prediction  churn non-churn
  churn        38         7
  non-churn     9       232

The model performed well on unseen data (test_data).

Out of 47 customers who churned, the model was able to predict 38.

Which variables are most important in predicting churned customers?

rf_last_fit |> 
  extract_fit_parsnip() |> 
  vi() |> 
  ggplot(aes(x = Importance, y = fct_reorder(Variable, Importance))) +
  geom_col(fill = "#2297FA") +
  labs(
    y = "",
    title = "Variable Importance Plot") +
  theme_minimal()

The most important predictor was customer complains variable, which recorded whether a customer had complained or not.