library(rstanarm)
library(ggplot2)
library(bayesplot)
library(broom)
options(dplyr.summarise.inform = FALSE)
#read data
data <- read.csv('churn_data.csv', header=TRUE)
#remove column 'X', an index column
data$X <- NULL
Dataset is from Kaggle: https://www.kaggle.com/blastchar/telco-customer-churn
The data has previously been preprocessed in another exercise: https://github.com/jiajianwoo/telco_churn/blob/main/telco%20churn.ipynb
#convert appropriate columns to factor
factor_cols <- names(data)[-c(39,40,41)]
data[,factor_cols] <- lapply(data[,factor_cols], factor)
#train-test split
set.seed(123)
idx <- sort(sample(nrow(data), nrow(data)*0.7))
train <- data[idx,]
test <- data[-idx,]
The logistic regression model is built using this code: model <- stan_glm(Churn~., data= train, family = binomial(link = “logit”), seed = 123, chains = 4)
#load model
my_model <- readRDS('stan_churn.rds')
summary(my_model)
##
## Model Info:
## function: stan_glm
## family: binomial [logit]
## formula: Churn ~ .
## algorithm: sampling
## sample: 4000 (posterior sample size)
## priors: see help('prior_summary')
## observations: 4930
## predictors: 41
##
## Estimates:
## mean sd 10% 50% 90%
## (Intercept) 0.8 11.2 -13.7 1.0 15.0
## gender1 0.0 0.1 -0.1 0.0 0.1
## SeniorCitizen1 0.2 0.1 0.0 0.2 0.3
## Partner1 -0.1 0.1 -0.2 -0.1 0.0
## Dependents1 -0.2 0.1 -0.3 -0.2 0.0
## PhoneService1 0.1 6.2 -7.8 0.0 8.0
## PaperlessBilling1 0.4 0.1 0.2 0.4 0.5
## MultipleLines_No1 -0.2 3.3 -4.5 -0.2 4.1
## MultipleLines_No.phone.service1 -0.1 6.1 -7.8 -0.1 7.7
## MultipleLines_Yes1 0.1 3.4 -4.2 0.1 4.5
## InternetService_DSL1 -0.6 3.6 -5.2 -0.5 4.1
## InternetService_Fiber.optic1 0.9 3.6 -3.8 0.9 5.6
## InternetService_No1 -0.2 6.0 -8.0 -0.2 7.4
## OnlineSecurity_No1 0.2 3.6 -4.4 0.3 4.8
## OnlineSecurity_No.internet.service1 -0.3 5.7 -7.6 -0.3 7.1
## OnlineSecurity_Yes1 -0.1 3.6 -4.7 0.0 4.5
## OnlineBackup_No1 0.1 3.6 -4.4 0.2 4.8
## OnlineBackup_No.internet.service1 -0.3 5.9 -7.7 -0.3 7.2
## OnlineBackup_Yes1 0.1 3.6 -4.5 0.1 4.7
## DeviceProtection_No1 0.1 3.5 -4.4 0.0 4.7
## DeviceProtection_No.internet.service1 -0.1 5.8 -7.5 0.0 7.4
## DeviceProtection_Yes1 0.1 3.5 -4.3 0.1 4.7
## TechSupport_No1 0.2 3.6 -4.3 0.3 4.8
## TechSupport_No.internet.service1 -0.2 5.8 -7.7 -0.2 7.4
## TechSupport_Yes1 0.0 3.6 -4.6 0.0 4.6
## StreamingTV_No1 -0.1 3.6 -4.9 -0.1 4.5
## StreamingTV_No.internet.service1 -0.4 5.5 -7.5 -0.4 6.7
## StreamingTV_Yes1 0.4 3.6 -4.3 0.4 5.0
## StreamingMovies_No1 -0.3 3.5 -4.8 -0.2 4.2
## StreamingMovies_No.internet.service1 -0.3 5.8 -7.7 -0.3 7.2
## StreamingMovies_Yes1 0.3 3.6 -4.3 0.3 4.8
## Contract_Month.to.month1 0.6 3.2 -3.6 0.7 4.7
## Contract_One.year1 -0.1 3.2 -4.3 0.0 4.0
## Contract_Two.year1 -0.9 3.2 -5.2 -0.8 3.2
## PaymentMethod_Bank.transfer..automatic.1 -0.1 2.9 -4.0 -0.1 3.6
## PaymentMethod_Credit.card..automatic.1 -0.2 2.9 -4.1 -0.2 3.5
## PaymentMethod_Electronic.check1 0.2 2.9 -3.7 0.2 3.9
## PaymentMethod_Mailed.check1 -0.2 2.9 -4.1 -0.2 3.5
## tenure -0.1 0.0 -0.1 -0.1 0.0
## MonthlyCharges 0.0 0.0 -0.1 0.0 0.0
## TotalCharges 0.0 0.0 0.0 0.0 0.0
##
## Fit Diagnostics:
## mean sd 10% 50% 90%
## mean_PPD 0.3 0.0 0.3 0.3 0.3
##
## The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).
##
## MCMC diagnostics
## mcse Rhat n_eff
## (Intercept) 0.2 1.0 4365
## gender1 0.0 1.0 8097
## SeniorCitizen1 0.0 1.0 6954
## Partner1 0.0 1.0 6140
## Dependents1 0.0 1.0 6728
## PhoneService1 0.1 1.0 5552
## PaperlessBilling1 0.0 1.0 6716
## MultipleLines_No1 0.1 1.0 4314
## MultipleLines_No.phone.service1 0.1 1.0 4988
## MultipleLines_Yes1 0.1 1.0 4367
## InternetService_DSL1 0.1 1.0 4455
## InternetService_Fiber.optic1 0.1 1.0 4468
## InternetService_No1 0.1 1.0 8242
## OnlineSecurity_No1 0.1 1.0 4690
## OnlineSecurity_No.internet.service1 0.1 1.0 7421
## OnlineSecurity_Yes1 0.1 1.0 4695
## OnlineBackup_No1 0.1 1.0 4520
## OnlineBackup_No.internet.service1 0.1 1.0 7061
## OnlineBackup_Yes1 0.1 1.0 4516
## DeviceProtection_No1 0.0 1.0 5343
## DeviceProtection_No.internet.service1 0.1 1.0 7580
## DeviceProtection_Yes1 0.0 1.0 5291
## TechSupport_No1 0.1 1.0 4546
## TechSupport_No.internet.service1 0.1 1.0 8880
## TechSupport_Yes1 0.1 1.0 4524
## StreamingTV_No1 0.1 1.0 4560
## StreamingTV_No.internet.service1 0.1 1.0 7217
## StreamingTV_Yes1 0.1 1.0 4630
## StreamingMovies_No1 0.0 1.0 5024
## StreamingMovies_No.internet.service1 0.1 1.0 8134
## StreamingMovies_Yes1 0.0 1.0 5066
## Contract_Month.to.month1 0.1 1.0 3437
## Contract_One.year1 0.1 1.0 3435
## Contract_Two.year1 0.1 1.0 3443
## PaymentMethod_Bank.transfer..automatic.1 0.1 1.0 2593
## PaymentMethod_Credit.card..automatic.1 0.1 1.0 2606
## PaymentMethod_Electronic.check1 0.1 1.0 2604
## PaymentMethod_Mailed.check1 0.1 1.0 2599
## tenure 0.0 1.0 4654
## MonthlyCharges 0.0 1.0 8521
## TotalCharges 0.0 1.0 4560
## mean_PPD 0.0 1.0 3853
## log-posterior 0.1 1.0 1569
##
## For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
Rhat < 1.1, so the MCMC chains converged
#prior summary
prior_summary(my_model)
## Priors for model 'my_model'
## ------
## Intercept (after predictors centered)
## ~ normal(location = 0, scale = 2.5)
##
## Coefficients
## Specified prior:
## ~ normal(location = [0,0,0,...], scale = [2.5,2.5,2.5,...])
## Adjusted prior:
## ~ normal(location = [0,0,0,...], scale = [5.00,6.72,5.00,...])
## ------
## See help('prior_summary.stanreg') for more details
The model by default, used a non-informative normal prior for all coefficients, and the prior standard deviation was adjusted according to the corresponding standard deviation of each variable
#trace plot
chains <- as.array(my_model)
chains <- as.array(my_model)
mcmc_trace(chains, pars = c("(Intercept)", "gender1", "SeniorCitizen1", "Partner1", "Dependents1", "PaperlessBilling1"))
chains <- as.array(my_model)
mcmc_trace(chains, pars = c("PhoneService1", "MultipleLines_No1", "MultipleLines_No.phone.service1", "MultipleLines_Yes1", "InternetService_DSL1", "InternetService_Fiber.optic1", "InternetService_No1", "OnlineSecurity_No1", "OnlineSecurity_No.internet.service1"))
chains <- as.array(my_model)
mcmc_trace(chains, pars = c("OnlineSecurity_Yes1", "OnlineBackup_No1", "OnlineBackup_No.internet.service1", "OnlineBackup_Yes1"))
chains <- as.array(my_model)
mcmc_trace(chains, pars = c("DeviceProtection_No1", "DeviceProtection_No.internet.service1", "DeviceProtection_Yes1", "TechSupport_No1", "TechSupport_No.internet.service1", "TechSupport_Yes1", "StreamingTV_No1", "StreamingTV_No.internet.service1", "StreamingTV_Yes1"))
chains <- as.array(my_model)
mcmc_trace(chains, pars = c("StreamingMovies_No1", "StreamingMovies_No.internet.service1", "StreamingMovies_Yes1", "Contract_Month.to.month1", "Contract_One.year1", "Contract_Two.year1"))
chains <- as.array(my_model)
mcmc_trace(chains, pars = c("PaymentMethod_Bank.transfer..automatic.1", "PaymentMethod_Credit.card..automatic.1", "PaymentMethod_Electronic.check1", "PaymentMethod_Mailed.check1", "tenure", "MonthlyCharges", "TotalCharges"))
The chains are stationary and have converged.
Plot effective size: total posterior samples for each parameter.
#effective size:total posterior sample size
plot(my_model, "neff")
Checking each chain’s auto-correlation plots for some of the parameters with smaller effective size
plot(my_model, "acf_bar", pars = "PaymentMethod_Bank.transfer..automatic.1")
plot(my_model, "acf_bar", pars = "PaymentMethod_Credit.card..automatic.1")
plot(my_model, "acf_bar", pars = "PaymentMethod_Electronic.check1")
plot(my_model, "acf_bar", pars = "PaymentMethod_Mailed.check1")
#Posterior credible intervals
posterior_interval(my_model)
## 5% 95%
## (Intercept) -1.761988e+01 18.8848840227
## gender1 -1.520748e-01 0.1043979439
## SeniorCitizen1 6.953939e-03 0.3332686049
## Partner1 -2.139346e-01 0.0788384355
## Dependents1 -3.586040e-01 -0.0092779570
## PhoneService1 -9.985130e+00 10.3180839552
## PaperlessBilling1 2.145312e-01 0.5050588808
## MultipleLines_No1 -5.672103e+00 5.1714394375
## MultipleLines_No.phone.service1 -1.023252e+01 10.1007777213
## MultipleLines_Yes1 -5.404277e+00 5.5258822701
## InternetService_DSL1 -6.627327e+00 5.3744729927
## InternetService_Fiber.optic1 -5.240770e+00 6.9747135249
## InternetService_No1 -1.004219e+01 9.8478133590
## OnlineSecurity_No1 -5.744289e+00 6.1417179354
## OnlineSecurity_No.internet.service1 -9.538161e+00 8.9695882925
## OnlineSecurity_Yes1 -6.032107e+00 5.8260121074
## OnlineBackup_No1 -5.787875e+00 5.9764303260
## OnlineBackup_No.internet.service1 -1.000031e+01 9.3502093965
## OnlineBackup_Yes1 -5.812522e+00 5.9340801724
## DeviceProtection_No1 -5.644141e+00 6.0126701048
## DeviceProtection_No.internet.service1 -9.647152e+00 9.5277668929
## DeviceProtection_Yes1 -5.636103e+00 5.9851759990
## TechSupport_No1 -5.563119e+00 6.0969826041
## TechSupport_No.internet.service1 -9.666904e+00 9.6120238455
## TechSupport_Yes1 -5.961297e+00 5.7973525552
## StreamingTV_No1 -6.101489e+00 5.7328332986
## StreamingTV_No.internet.service1 -9.359116e+00 8.6138058706
## StreamingTV_Yes1 -5.662560e+00 6.2652277642
## StreamingMovies_No1 -6.151946e+00 5.5307482049
## StreamingMovies_No.internet.service1 -9.899906e+00 9.2596518226
## StreamingMovies_Yes1 -5.622371e+00 6.1039437236
## Contract_Month.to.month1 -4.840971e+00 5.8633447986
## Contract_One.year1 -5.532907e+00 5.2244609588
## Contract_Two.year1 -6.422275e+00 4.4187688655
## PaymentMethod_Bank.transfer..automatic.1 -4.940747e+00 4.6740863038
## PaymentMethod_Credit.card..automatic.1 -5.094611e+00 4.5548561912
## PaymentMethod_Electronic.check1 -4.679101e+00 4.9691557363
## PaymentMethod_Mailed.check1 -5.039166e+00 4.6047794920
## tenure -6.840312e-02 -0.0449753565
## MonthlyCharges -8.454801e-02 0.0285303590
## TotalCharges 1.663618e-04 0.0004335374
mcmc_areas(
chains,
pars = c("gender1", "SeniorCitizen1", "Partner1", "Dependents1", "PaperlessBilling1"),
prob = 0.95,
point_est = "mean"
)
mcmc_areas(
chains,
pars = c("(Intercept)", "PhoneService1", "MultipleLines_No1", "MultipleLines_No.phone.service1", "MultipleLines_Yes1", "InternetService_DSL1", "InternetService_Fiber.optic1", "InternetService_No1"),
prob = 0.95,
point_est = "mean"
)
mcmc_areas(
chains,
pars = c("OnlineSecurity_No1", "OnlineSecurity_No.internet.service1", "OnlineSecurity_Yes1", "OnlineBackup_No1", "OnlineBackup_No.internet.service1", "OnlineBackup_Yes1"),
prob = 0.95,
point_est = "mean"
)
mcmc_areas(
chains,
pars = c("DeviceProtection_No1", "DeviceProtection_No.internet.service1", "DeviceProtection_Yes1", "TechSupport_No1", "TechSupport_No.internet.service1", "TechSupport_Yes1", "StreamingTV_No1", "StreamingTV_No.internet.service1", "StreamingTV_Yes1"),
prob = 0.95,
point_est = "mean"
)
mcmc_areas(
chains,
pars = c("StreamingMovies_No1", "StreamingMovies_No.internet.service1", "StreamingMovies_Yes1", "Contract_Month.to.month1", "Contract_One.year1", "Contract_Two.year1"),
prob = 0.95,
point_est = "mean"
)
mcmc_areas(
chains,
pars = c("PaymentMethod_Bank.transfer..automatic.1", "PaymentMethod_Credit.card..automatic.1", "PaymentMethod_Electronic.check1", "PaymentMethod_Mailed.check1"),
prob = 0.95,
point_est = "mean"
)
mcmc_areas(
chains,
pars = c("tenure", "MonthlyCharges"),
prob = 0.95,
point_est = "mean"
)
mcmc_areas(
chains,
pars = "TotalCharges",
prob = 0.95,
point_est = "mean"
)
#prediction
predictions <- posterior_linpred(my_model, transform=TRUE)
## Instead of posterior_linpred(..., transform=TRUE) please call posterior_epred(), which provides equivalent functionality.
#looking at summary of posterior prediction for row 1
summary(predictions[1,])
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.00243 0.04848 0.19355 0.26749 0.46563 0.85723
This is the posterior distribution of coefficients for row 1 in the train set
We can plot and see the posterior churn probability distribution for the customer corresponding to the first row.
#posterior predictive distribution for row 1
ggplot(data.frame(predictions[1,]), aes(x=predictions.1...)) +
geom_density()
pp_check(my_model, "dens_overlay")
y (dark blue) line represents the observed probability distribution in training set, y_rep (light blue) lines represent the predicted distribution.
It seems like everyone has a very similar probability distribution to churn. It is more likely that a customer would stay and keep using the service provided.
pp_check(my_model, "stat_2d")
#distribution of churn probability predicted for train set
Compare it to the frequentist model:
frequentist <- glm(Churn~., data= train, family = binomial)
tidy(frequentist)
## # A tibble: 24 x 5
## term estimate std.error statistic p.value
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 (Intercept) -2.14 0.298 -7.19 6.70e-13
## 2 gender1 -0.0247 0.0768 -0.322 7.48e- 1
## 3 SeniorCitizen1 0.171 0.100 1.70 8.92e- 2
## 4 Partner1 -0.0694 0.0910 -0.763 4.46e- 1
## 5 Dependents1 -0.178 0.106 -1.68 9.27e- 2
## 6 PhoneService1 0.489 0.957 0.510 6.10e- 1
## 7 PaperlessBilling1 0.355 0.0883 4.02 5.71e- 5
## 8 MultipleLines_No1 -0.383 0.210 -1.82 6.81e- 2
## 9 InternetService_DSL1 2.43 2.46 0.986 3.24e- 1
## 10 InternetService_Fiber.optic1 4.03 3.41 1.18 2.37e- 1
## # … with 14 more rows
mean(predict(frequentist, type="response"))
## [1] 0.2744422
The average customer churn rate is quite similar with both approaches.
test_pred <- posterior_predict(my_model, newdata =test, transform=TRUE)
summary(test_pred[1,])
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.0000 0.0000 0.0000 0.2631 1.0000 1.0000
#Calculate the mean churn probability for each customer
test_mean <- colMeans(test_pred)
table(test$Churn, test_mean > 0.5)
##
## FALSE TRUE
## 0 1428 169
## 1 227 289
#Distribution of mean probability of each customer in test set
hist(test_mean,
main = "Distribution of the mean probability each customer \nin test set will churn")
plot(density(test_mean),
main = "Density curve showing the mean probability each customer \nin test set will churn")
Just like in train set, customers are more likely to continue rather than leave the service, but retaining the customers who are more likely to cut ties is still important to maximise revenue.