library(rstanarm)
library(ggplot2)
library(bayesplot)
library(broom)
options(dplyr.summarise.inform = FALSE)

#read data
data <- read.csv('churn_data.csv', header=TRUE)
#remove column 'X', an index column
data$X <- NULL

Dataset is from Kaggle: https://www.kaggle.com/blastchar/telco-customer-churn

The data has previously been preprocessed in another exercise: https://github.com/jiajianwoo/telco_churn/blob/main/telco%20churn.ipynb

Train-test split

#convert appropriate columns to factor
factor_cols <- names(data)[-c(39,40,41)]
data[,factor_cols] <- lapply(data[,factor_cols], factor)

#train-test split
set.seed(123)
idx <- sort(sample(nrow(data), nrow(data)*0.7))
train <- data[idx,]
test <- data[-idx,]

Building model

The logistic regression model is built using this code: model <- stan_glm(Churn~., data= train, family = binomial(link = “logit”), seed = 123, chains = 4)

#load model
my_model <- readRDS('stan_churn.rds')
summary(my_model)
## 
## Model Info:
##  function:     stan_glm
##  family:       binomial [logit]
##  formula:      Churn ~ .
##  algorithm:    sampling
##  sample:       4000 (posterior sample size)
##  priors:       see help('prior_summary')
##  observations: 4930
##  predictors:   41
## 
## Estimates:
##                                            mean   sd    10%   50%   90%
## (Intercept)                                0.8   11.2 -13.7   1.0  15.0
## gender1                                    0.0    0.1  -0.1   0.0   0.1
## SeniorCitizen1                             0.2    0.1   0.0   0.2   0.3
## Partner1                                  -0.1    0.1  -0.2  -0.1   0.0
## Dependents1                               -0.2    0.1  -0.3  -0.2   0.0
## PhoneService1                              0.1    6.2  -7.8   0.0   8.0
## PaperlessBilling1                          0.4    0.1   0.2   0.4   0.5
## MultipleLines_No1                         -0.2    3.3  -4.5  -0.2   4.1
## MultipleLines_No.phone.service1           -0.1    6.1  -7.8  -0.1   7.7
## MultipleLines_Yes1                         0.1    3.4  -4.2   0.1   4.5
## InternetService_DSL1                      -0.6    3.6  -5.2  -0.5   4.1
## InternetService_Fiber.optic1               0.9    3.6  -3.8   0.9   5.6
## InternetService_No1                       -0.2    6.0  -8.0  -0.2   7.4
## OnlineSecurity_No1                         0.2    3.6  -4.4   0.3   4.8
## OnlineSecurity_No.internet.service1       -0.3    5.7  -7.6  -0.3   7.1
## OnlineSecurity_Yes1                       -0.1    3.6  -4.7   0.0   4.5
## OnlineBackup_No1                           0.1    3.6  -4.4   0.2   4.8
## OnlineBackup_No.internet.service1         -0.3    5.9  -7.7  -0.3   7.2
## OnlineBackup_Yes1                          0.1    3.6  -4.5   0.1   4.7
## DeviceProtection_No1                       0.1    3.5  -4.4   0.0   4.7
## DeviceProtection_No.internet.service1     -0.1    5.8  -7.5   0.0   7.4
## DeviceProtection_Yes1                      0.1    3.5  -4.3   0.1   4.7
## TechSupport_No1                            0.2    3.6  -4.3   0.3   4.8
## TechSupport_No.internet.service1          -0.2    5.8  -7.7  -0.2   7.4
## TechSupport_Yes1                           0.0    3.6  -4.6   0.0   4.6
## StreamingTV_No1                           -0.1    3.6  -4.9  -0.1   4.5
## StreamingTV_No.internet.service1          -0.4    5.5  -7.5  -0.4   6.7
## StreamingTV_Yes1                           0.4    3.6  -4.3   0.4   5.0
## StreamingMovies_No1                       -0.3    3.5  -4.8  -0.2   4.2
## StreamingMovies_No.internet.service1      -0.3    5.8  -7.7  -0.3   7.2
## StreamingMovies_Yes1                       0.3    3.6  -4.3   0.3   4.8
## Contract_Month.to.month1                   0.6    3.2  -3.6   0.7   4.7
## Contract_One.year1                        -0.1    3.2  -4.3   0.0   4.0
## Contract_Two.year1                        -0.9    3.2  -5.2  -0.8   3.2
## PaymentMethod_Bank.transfer..automatic.1  -0.1    2.9  -4.0  -0.1   3.6
## PaymentMethod_Credit.card..automatic.1    -0.2    2.9  -4.1  -0.2   3.5
## PaymentMethod_Electronic.check1            0.2    2.9  -3.7   0.2   3.9
## PaymentMethod_Mailed.check1               -0.2    2.9  -4.1  -0.2   3.5
## tenure                                    -0.1    0.0  -0.1  -0.1   0.0
## MonthlyCharges                             0.0    0.0  -0.1   0.0   0.0
## TotalCharges                               0.0    0.0   0.0   0.0   0.0
## 
## Fit Diagnostics:
##            mean   sd   10%   50%   90%
## mean_PPD 0.3    0.0  0.3   0.3   0.3  
## 
## The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).
## 
## MCMC diagnostics
##                                          mcse Rhat n_eff
## (Intercept)                              0.2  1.0  4365 
## gender1                                  0.0  1.0  8097 
## SeniorCitizen1                           0.0  1.0  6954 
## Partner1                                 0.0  1.0  6140 
## Dependents1                              0.0  1.0  6728 
## PhoneService1                            0.1  1.0  5552 
## PaperlessBilling1                        0.0  1.0  6716 
## MultipleLines_No1                        0.1  1.0  4314 
## MultipleLines_No.phone.service1          0.1  1.0  4988 
## MultipleLines_Yes1                       0.1  1.0  4367 
## InternetService_DSL1                     0.1  1.0  4455 
## InternetService_Fiber.optic1             0.1  1.0  4468 
## InternetService_No1                      0.1  1.0  8242 
## OnlineSecurity_No1                       0.1  1.0  4690 
## OnlineSecurity_No.internet.service1      0.1  1.0  7421 
## OnlineSecurity_Yes1                      0.1  1.0  4695 
## OnlineBackup_No1                         0.1  1.0  4520 
## OnlineBackup_No.internet.service1        0.1  1.0  7061 
## OnlineBackup_Yes1                        0.1  1.0  4516 
## DeviceProtection_No1                     0.0  1.0  5343 
## DeviceProtection_No.internet.service1    0.1  1.0  7580 
## DeviceProtection_Yes1                    0.0  1.0  5291 
## TechSupport_No1                          0.1  1.0  4546 
## TechSupport_No.internet.service1         0.1  1.0  8880 
## TechSupport_Yes1                         0.1  1.0  4524 
## StreamingTV_No1                          0.1  1.0  4560 
## StreamingTV_No.internet.service1         0.1  1.0  7217 
## StreamingTV_Yes1                         0.1  1.0  4630 
## StreamingMovies_No1                      0.0  1.0  5024 
## StreamingMovies_No.internet.service1     0.1  1.0  8134 
## StreamingMovies_Yes1                     0.0  1.0  5066 
## Contract_Month.to.month1                 0.1  1.0  3437 
## Contract_One.year1                       0.1  1.0  3435 
## Contract_Two.year1                       0.1  1.0  3443 
## PaymentMethod_Bank.transfer..automatic.1 0.1  1.0  2593 
## PaymentMethod_Credit.card..automatic.1   0.1  1.0  2606 
## PaymentMethod_Electronic.check1          0.1  1.0  2604 
## PaymentMethod_Mailed.check1              0.1  1.0  2599 
## tenure                                   0.0  1.0  4654 
## MonthlyCharges                           0.0  1.0  8521 
## TotalCharges                             0.0  1.0  4560 
## mean_PPD                                 0.0  1.0  3853 
## log-posterior                            0.1  1.0  1569 
## 
## For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).

Rhat < 1.1, so the MCMC chains converged

#prior summary
prior_summary(my_model)
## Priors for model 'my_model' 
## ------
## Intercept (after predictors centered)
##  ~ normal(location = 0, scale = 2.5)
## 
## Coefficients
##   Specified prior:
##     ~ normal(location = [0,0,0,...], scale = [2.5,2.5,2.5,...])
##   Adjusted prior:
##     ~ normal(location = [0,0,0,...], scale = [5.00,6.72,5.00,...])
## ------
## See help('prior_summary.stanreg') for more details

The model by default, used a non-informative normal prior for all coefficients, and the prior standard deviation was adjusted according to the corresponding standard deviation of each variable

MCMC Diagnostics

Trace plots

#trace plot
chains <- as.array(my_model)
chains <- as.array(my_model)
mcmc_trace(chains, pars = c("(Intercept)", "gender1", "SeniorCitizen1", "Partner1", "Dependents1", "PaperlessBilling1"))

chains <- as.array(my_model)
mcmc_trace(chains, pars = c("PhoneService1", "MultipleLines_No1", "MultipleLines_No.phone.service1", "MultipleLines_Yes1", "InternetService_DSL1", "InternetService_Fiber.optic1", "InternetService_No1", "OnlineSecurity_No1", "OnlineSecurity_No.internet.service1"))

chains <- as.array(my_model)
mcmc_trace(chains, pars = c("OnlineSecurity_Yes1", "OnlineBackup_No1", "OnlineBackup_No.internet.service1", "OnlineBackup_Yes1"))

chains <- as.array(my_model)
mcmc_trace(chains, pars = c("DeviceProtection_No1", "DeviceProtection_No.internet.service1", "DeviceProtection_Yes1", "TechSupport_No1", "TechSupport_No.internet.service1", "TechSupport_Yes1", "StreamingTV_No1", "StreamingTV_No.internet.service1", "StreamingTV_Yes1"))

chains <- as.array(my_model)
mcmc_trace(chains, pars = c("StreamingMovies_No1", "StreamingMovies_No.internet.service1", "StreamingMovies_Yes1", "Contract_Month.to.month1", "Contract_One.year1", "Contract_Two.year1"))

chains <- as.array(my_model)
mcmc_trace(chains, pars = c("PaymentMethod_Bank.transfer..automatic.1", "PaymentMethod_Credit.card..automatic.1", "PaymentMethod_Electronic.check1", "PaymentMethod_Mailed.check1", "tenure", "MonthlyCharges", "TotalCharges"))

The chains are stationary and have converged.

Effective size and autocorrelation

Plot effective size: total posterior samples for each parameter.

#effective size:total posterior sample size
plot(my_model, "neff")

Checking each chain’s auto-correlation plots for some of the parameters with smaller effective size

plot(my_model, "acf_bar", pars = "PaymentMethod_Bank.transfer..automatic.1")

plot(my_model, "acf_bar", pars = "PaymentMethod_Credit.card..automatic.1")

plot(my_model, "acf_bar", pars = "PaymentMethod_Electronic.check1")

plot(my_model, "acf_bar", pars = "PaymentMethod_Mailed.check1")

Posterior credible intervals

#Posterior credible intervals
posterior_interval(my_model)
##                                                     5%           95%
## (Intercept)                              -1.761988e+01 18.8848840227
## gender1                                  -1.520748e-01  0.1043979439
## SeniorCitizen1                            6.953939e-03  0.3332686049
## Partner1                                 -2.139346e-01  0.0788384355
## Dependents1                              -3.586040e-01 -0.0092779570
## PhoneService1                            -9.985130e+00 10.3180839552
## PaperlessBilling1                         2.145312e-01  0.5050588808
## MultipleLines_No1                        -5.672103e+00  5.1714394375
## MultipleLines_No.phone.service1          -1.023252e+01 10.1007777213
## MultipleLines_Yes1                       -5.404277e+00  5.5258822701
## InternetService_DSL1                     -6.627327e+00  5.3744729927
## InternetService_Fiber.optic1             -5.240770e+00  6.9747135249
## InternetService_No1                      -1.004219e+01  9.8478133590
## OnlineSecurity_No1                       -5.744289e+00  6.1417179354
## OnlineSecurity_No.internet.service1      -9.538161e+00  8.9695882925
## OnlineSecurity_Yes1                      -6.032107e+00  5.8260121074
## OnlineBackup_No1                         -5.787875e+00  5.9764303260
## OnlineBackup_No.internet.service1        -1.000031e+01  9.3502093965
## OnlineBackup_Yes1                        -5.812522e+00  5.9340801724
## DeviceProtection_No1                     -5.644141e+00  6.0126701048
## DeviceProtection_No.internet.service1    -9.647152e+00  9.5277668929
## DeviceProtection_Yes1                    -5.636103e+00  5.9851759990
## TechSupport_No1                          -5.563119e+00  6.0969826041
## TechSupport_No.internet.service1         -9.666904e+00  9.6120238455
## TechSupport_Yes1                         -5.961297e+00  5.7973525552
## StreamingTV_No1                          -6.101489e+00  5.7328332986
## StreamingTV_No.internet.service1         -9.359116e+00  8.6138058706
## StreamingTV_Yes1                         -5.662560e+00  6.2652277642
## StreamingMovies_No1                      -6.151946e+00  5.5307482049
## StreamingMovies_No.internet.service1     -9.899906e+00  9.2596518226
## StreamingMovies_Yes1                     -5.622371e+00  6.1039437236
## Contract_Month.to.month1                 -4.840971e+00  5.8633447986
## Contract_One.year1                       -5.532907e+00  5.2244609588
## Contract_Two.year1                       -6.422275e+00  4.4187688655
## PaymentMethod_Bank.transfer..automatic.1 -4.940747e+00  4.6740863038
## PaymentMethod_Credit.card..automatic.1   -5.094611e+00  4.5548561912
## PaymentMethod_Electronic.check1          -4.679101e+00  4.9691557363
## PaymentMethod_Mailed.check1              -5.039166e+00  4.6047794920
## tenure                                   -6.840312e-02 -0.0449753565
## MonthlyCharges                           -8.454801e-02  0.0285303590
## TotalCharges                              1.663618e-04  0.0004335374

Posterior distribution of logistic regression coefficients

mcmc_areas(
  chains,
  pars = c("gender1", "SeniorCitizen1", "Partner1", "Dependents1", "PaperlessBilling1"),
  prob = 0.95,
  point_est = "mean"
)

mcmc_areas(
  chains,
  pars = c("(Intercept)", "PhoneService1", "MultipleLines_No1", "MultipleLines_No.phone.service1", "MultipleLines_Yes1", "InternetService_DSL1", "InternetService_Fiber.optic1", "InternetService_No1"),
  prob = 0.95,
  point_est = "mean"
)

mcmc_areas(
  chains,
  pars = c("OnlineSecurity_No1", "OnlineSecurity_No.internet.service1", "OnlineSecurity_Yes1", "OnlineBackup_No1", "OnlineBackup_No.internet.service1", "OnlineBackup_Yes1"),
  prob = 0.95,
  point_est = "mean"
)

mcmc_areas(
  chains,
  pars = c("DeviceProtection_No1", "DeviceProtection_No.internet.service1", "DeviceProtection_Yes1", "TechSupport_No1", "TechSupport_No.internet.service1", "TechSupport_Yes1", "StreamingTV_No1", "StreamingTV_No.internet.service1", "StreamingTV_Yes1"),
  prob = 0.95,
  point_est = "mean"
)

mcmc_areas(
  chains,
  pars = c("StreamingMovies_No1", "StreamingMovies_No.internet.service1", "StreamingMovies_Yes1", "Contract_Month.to.month1", "Contract_One.year1", "Contract_Two.year1"),
  prob = 0.95,
  point_est = "mean"
)

mcmc_areas(
  chains,
  pars = c("PaymentMethod_Bank.transfer..automatic.1", "PaymentMethod_Credit.card..automatic.1", "PaymentMethod_Electronic.check1", "PaymentMethod_Mailed.check1"),
  prob = 0.95,
  point_est = "mean"
)

mcmc_areas(
  chains,
  pars = c("tenure", "MonthlyCharges"),
  prob = 0.95,
  point_est = "mean"
)

mcmc_areas(
  chains,
  pars = "TotalCharges",
  prob = 0.95,
  point_est = "mean"
)

Posterior prediction

#prediction
predictions <- posterior_linpred(my_model, transform=TRUE)
## Instead of posterior_linpred(..., transform=TRUE) please call posterior_epred(), which provides equivalent functionality.
#looking at summary of posterior prediction for row 1
summary(predictions[1,])
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
## 0.00243 0.04848 0.19355 0.26749 0.46563 0.85723

This is the posterior distribution of coefficients for row 1 in the train set

We can plot and see the posterior churn probability distribution for the customer corresponding to the first row.

#posterior predictive distribution for row 1
ggplot(data.frame(predictions[1,]), aes(x=predictions.1...)) +
  geom_density()

Posterior churn probability distribution for all customers in training set
pp_check(my_model, "dens_overlay")

y (dark blue) line represents the observed probability distribution in training set, y_rep (light blue) lines represent the predicted distribution.

It seems like everyone has a very similar probability distribution to churn. It is more likely that a customer would stay and keep using the service provided.

Distribution of mean churn probabilities predicted for the train set
pp_check(my_model, "stat_2d")

#distribution of churn probability predicted for train set

Compare it to the frequentist model:

frequentist <- glm(Churn~., data= train, family = binomial)
tidy(frequentist)
## # A tibble: 24 x 5
##    term                         estimate std.error statistic  p.value
##    <chr>                           <dbl>     <dbl>     <dbl>    <dbl>
##  1 (Intercept)                   -2.14      0.298     -7.19  6.70e-13
##  2 gender1                       -0.0247    0.0768    -0.322 7.48e- 1
##  3 SeniorCitizen1                 0.171     0.100      1.70  8.92e- 2
##  4 Partner1                      -0.0694    0.0910    -0.763 4.46e- 1
##  5 Dependents1                   -0.178     0.106     -1.68  9.27e- 2
##  6 PhoneService1                  0.489     0.957      0.510 6.10e- 1
##  7 PaperlessBilling1              0.355     0.0883     4.02  5.71e- 5
##  8 MultipleLines_No1             -0.383     0.210     -1.82  6.81e- 2
##  9 InternetService_DSL1           2.43      2.46       0.986 3.24e- 1
## 10 InternetService_Fiber.optic1   4.03      3.41       1.18  2.37e- 1
## # … with 14 more rows
mean(predict(frequentist, type="response"))
## [1] 0.2744422

The average customer churn rate is quite similar with both approaches.

Prediction using test set

test_pred <- posterior_predict(my_model, newdata =test, transform=TRUE)
summary(test_pred[1,])
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##  0.0000  0.0000  0.0000  0.2631  1.0000  1.0000
#Calculate the mean churn probability for each customer
test_mean <- colMeans(test_pred)

Confusion matrix

table(test$Churn, test_mean > 0.5)
##    
##     FALSE TRUE
##   0  1428  169
##   1   227  289

How do the average churn rates for each customer in test set distribute?

#Distribution of mean probability of each customer in test set
hist(test_mean,
     main = "Distribution of the mean probability each customer \nin test set will churn")

plot(density(test_mean),
     main = "Density curve showing the mean probability each customer \nin test set will churn")

Just like in train set, customers are more likely to continue rather than leave the service, but retaining the customers who are more likely to cut ties is still important to maximise revenue.