knitr::opts_chunk$set(cache = TRUE)
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(tidymodels))
suppressPackageStartupMessages(library(Metrics))
suppressPackageStartupMessages(library(ggpubr))
suppressPackageStartupMessages(library(janitor))
suppressPackageStartupMessages(library(vip))
theme_set(theme_pubr()) # Setting a global theme for all plotsPredicting Customer Churn
Introduction
The aim of this project is to build a model that can predict with a high degree of accuracy whether the customer churned. We will also explore which factors increase the probability of customer churning.
We have customers’ data from a telco company including whether the customer churned. We will use this data to build multiple classification models and select the one that does a good job at predicting whether the customer churned.
Loading the required packages and import data
customer_churn <- read_csv("data/customer churn.csv") |>
# clean column names
janitor::clean_names()Rows: 1428 Columns: 14
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
dbl (14): Call Failure, Complains, Subscription Length, Charge Amount, Se...
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
glimpse(customer_churn)Rows: 1,428
Columns: 14
$ call_failure <dbl> 8, 0, 10, 10, 3, 11, 4, 13, 7, 7, 6, 9, 25, 4,…
$ complains <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ subscription_length <dbl> 38, 39, 37, 38, 38, 38, 38, 37, 38, 38, 38, 38…
$ charge_amount <dbl> 0, 0, 0, 0, 0, 1, 0, 2, 0, 1, 0, 0, 3, 1, 0, 1…
$ seconds_of_use <dbl> 4370, 318, 2453, 4198, 2393, 3775, 2360, 9115,…
$ frequency_of_use <dbl> 71, 5, 60, 66, 58, 82, 39, 121, 169, 83, 95, 5…
$ frequency_of_sms <dbl> 5, 7, 359, 1, 2, 32, 285, 144, 0, 2, 7, 8, 54,…
$ distinct_called_numbers <dbl> 17, 4, 24, 35, 33, 28, 18, 43, 44, 25, 12, 17,…
$ age_group <dbl> 3, 2, 3, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3…
$ tariff_plan <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
$ status <dbl> 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1…
$ age <dbl> 30, 25, 30, 15, 15, 30, 30, 30, 30, 30, 30, 30…
$ customer_value <dbl> 197.640, 46.035, 1536.520, 240.020, 145.805, 2…
$ churn <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
Data Dictionary
| Column Name | Description |
|---|---|
| call_failure | number of call failures |
| compains | binary(0: no compalint, 1: complaint) |
| Subscription Length | total months of subscription |
| Charge Amount | ordinal attribute (0: lowest amount, 9: highest amount) |
| Seconds of Use | total seconds of calls |
| Frequency of use | total number of calls |
| Frequency of SMS | total number of text messages |
| Distinct Called Numbers | total number of distinct phone calls |
| Age Group | ordinal attribute (1: younger age, 5: older age) |
| Tariff Plan | binary (1: Pay as you go, 2: contractual) |
| Status | binary (1: active, 2: non-active) |
| Age | age of customer |
| Customer Value | the calculated value of customer |
| Churn | class label (1: churn, 0: non-churn) |
Exploratory Data Analysis
In this dataset, our variable of interest in churn. In this section we will explore the relationship between this variable and other features in the dataset
First, lets modify the churn variable from a numeric to a more descriptive factor
customer_churn <- customer_churn |>
mutate(
churn = factor(ifelse(churn == 1, "churn", "non-churn"))
)
glimpse(customer_churn)Rows: 1,428
Columns: 14
$ call_failure <dbl> 8, 0, 10, 10, 3, 11, 4, 13, 7, 7, 6, 9, 25, 4,…
$ complains <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ subscription_length <dbl> 38, 39, 37, 38, 38, 38, 38, 37, 38, 38, 38, 38…
$ charge_amount <dbl> 0, 0, 0, 0, 0, 1, 0, 2, 0, 1, 0, 0, 3, 1, 0, 1…
$ seconds_of_use <dbl> 4370, 318, 2453, 4198, 2393, 3775, 2360, 9115,…
$ frequency_of_use <dbl> 71, 5, 60, 66, 58, 82, 39, 121, 169, 83, 95, 5…
$ frequency_of_sms <dbl> 5, 7, 359, 1, 2, 32, 285, 144, 0, 2, 7, 8, 54,…
$ distinct_called_numbers <dbl> 17, 4, 24, 35, 33, 28, 18, 43, 44, 25, 12, 17,…
$ age_group <dbl> 3, 2, 3, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3…
$ tariff_plan <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
$ status <dbl> 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1…
$ age <dbl> 30, 25, 30, 15, 15, 30, 30, 30, 30, 30, 30, 30…
$ customer_value <dbl> 197.640, 46.035, 1536.520, 240.020, 145.805, 2…
$ churn <fct> non-churn, non-churn, non-churn, non-churn, no…
Customer Churn and Subscription Length
Subscription length is the number if months the customer has been with the company.
Customers who have been with the company for a short period of time are like to churn compared to the customers who have been in the company for long.
The graph also shows that customers who have been with the company for the than 1 year are more likely to churn.
Why is the company loosing new customers?
customer_churn |>
ggplot(aes(x = subscription_length, color = churn, fill = churn)) +
geom_density(alpha = 0.7) +
fill_palette("jco") +
color_palette("jco") +
labs(
title = "Churn vs Customer Subscription Lenght",
x = "Customer Subscription Lenght in Months",
y = "Density"
) +
theme_minimal() +
theme(
legend.title = element_blank(),
legend.position = "top"
)Customer Value and Customer Churn
Customer value is the calculated value of the customer. Customers with low value churn more compared to customers who have high calculated value.
customer_churn |>
ggplot(aes(x = customer_value, color = churn, fill = churn))+
geom_density(alpha = 0.7) +
fill_palette("jco") +
color_palette("jco") +
labs(
x = "Customer Value",
y = "Density"
) +
theme_minimal() +
theme(
legend.title = element_blank(),
legend.position = "top"
)Churn & Customer Age
Most of the telco customers are in their 30s years of age. There no much we can say about age and customer’s age, only that younger customers tend to churn more.
customer_churn |>
ggplot(aes(x = age, color = churn, fill = churn))+
geom_histogram(alpha = 0.7, binwidth = 15, position = "dodge") +
fill_palette("jco") +
color_palette("jco") +
labs(
x = "Customer age",
y = "Density"
) +
theme_minimal() +
theme(
legend.title = element_blank(),
legend.position = "top"
)Frequency of Use, Seconds of Use & Customer Churn
Customers who make less calls are more likely to churn compared to customers who makes calls more frequently.
customer_churn |>
ggplot(aes(x = frequency_of_use, seconds_of_use, color = churn, shape = churn)) +
geom_point(alpha = 0.7, size = 2) +
color_palette("jco") +
labs(
title = "Frequency of Use, Seconds of Use & Customer Churn",
x = "Total Number of Calls",
y = "Total Seconds of Calls"
) +
theme_minimal() +
theme(
legend.title = element_blank(),
legend.position = "top"
)Model Training
We will build several classification models and select the one that does a good job at predicting customers who are likely to churn. We would also like to understand which variables are likely to influence customer churn.
The first step will be preparing our data for modeling. Variables like age_group, tariff plan, status etc are coded as numeric while there are factors. We need to recode these variable prope
customer_churn <- customer_churn |>
mutate(
#age group as factor
age_group = factor(age_group,
ordered = TRUE,
levels = c("1", "2", "3", "4", "5")),
charge_amount = factor(charge_amount, ordered = TRUE,
levels = c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")),
tariff_plan = case_when(
tariff_plan == 1 ~ "Pass As You Go",
tariff_plan == 2 ~ "Contractual"
),
status = ifelse(status == 1, "active", "non active")
)
customer_churn# A tibble: 1,428 × 14
call_failure complains subscription_length charge_amount seconds_of_use
<dbl> <dbl> <dbl> <ord> <dbl>
1 8 0 38 0 4370
2 0 0 39 0 318
3 10 0 37 0 2453
4 10 0 38 0 4198
5 3 0 38 0 2393
6 11 0 38 1 3775
7 4 0 38 0 2360
8 13 0 37 2 9115
9 7 0 38 0 13773
10 7 0 38 1 4515
# ℹ 1,418 more rows
# ℹ 9 more variables: frequency_of_use <dbl>, frequency_of_sms <dbl>,
# distinct_called_numbers <dbl>, age_group <ord>, tariff_plan <chr>,
# status <chr>, age <dbl>, customer_value <dbl>, churn <fct>
Data Splitting
We will split our data into training and testing. We will further split our training data into cross validation sets for model selection.
We have a class imbalance in our dateset, there are more non churn customers than churn. Churned customers consist 16% of the data.
customer_churn |>
count(churn) |>
mutate(
prop = scales::percent(n / sum(n))
)# A tibble: 2 × 3
churn n prop
<fct> <int> <chr>
1 churn 233 16%
2 non-churn 1195 84%
To ensure we have same churn and non churn proportions in both training and testing sets, we will use stratified sampling.
We will reserve 20% of the data for model testing.
set.seed(123) # Fore reporducability
customer_split <- initial_split(na.omit(customer_churn), strata = churn, prop = 0.8)
customer_train <- training(customer_split)
customer_test <- testing(customer_split)Both training and testing sets have the same 16% churcn proportion.
customer_train |>
count(churn) |>
mutate(
prop = scales::percent(n / sum(n))
)# A tibble: 2 × 3
churn n prop
<fct> <int> <chr>
1 churn 186 16%
2 non-churn 955 84%
customer_test |>
count(churn) |>
mutate(
prop = scales::percent(n / sum(n))
)# A tibble: 2 × 3
churn n prop
<fct> <int> <chr>
1 churn 47 16%
2 non-churn 239 84%
We then create cross validation sets for modeling evaluations and selection.
set.seed(234)
customer_cv <- vfold_cv(customer_train, strata = churn)
customer_cv# 10-fold cross-validation using stratification
# A tibble: 10 × 2
splits id
<list> <chr>
1 <split [1026/115]> Fold01
2 <split [1026/115]> Fold02
3 <split [1026/115]> Fold03
4 <split [1026/115]> Fold04
5 <split [1026/115]> Fold05
6 <split [1027/114]> Fold06
7 <split [1028/113]> Fold07
8 <split [1028/113]> Fold08
9 <split [1028/113]> Fold09
10 <split [1028/113]> Fold10
Logistic Regression Model
The first model to train will be a basic logistic regression model.
logistic_model <- logistic_reg(penalty = tune(), mixture = 1) |>
set_engine("glmnet")
logistic_modelLogistic Regression Model Specification (classification)
Main Arguments:
penalty = tune()
mixture = 1
Computational engine: glmnet
We need to define several preprocessing steps for our logistic regression model. For better performance, logistic regression requires character/factor variables to be converted into dummy variables. The numeric variables also need be normalized. We will also remove variables that have zero variance, i.e they only appear once.
logistic_recipe <-
recipe(churn ~ ., data = customer_train) |>
step_dummy(all_nominal_predictors()) |>
step_zv(all_predictors()) |>
step_normalize(all_numeric_predictors())
logistic_recipe
── Recipe ──────────────────────────────────────────────────────────────────────
── Inputs
Number of variables by role
outcome: 1
predictor: 13
── Operations
• Dummy variables from: all_nominal_predictors()
• Zero variance filter on: all_predictors()
• Centering and scaling for: all_numeric_predictors()
We put together the model specification and model recipe with a workflow
logistic_wkflow <- workflow() |>
add_model(logistic_model) |>
add_recipe(logistic_recipe)
logistic_wkflow══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()
── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps
• step_dummy()
• step_zv()
• step_normalize()
── Model ───────────────────────────────────────────────────────────────────────
Logistic Regression Model Specification (classification)
Main Arguments:
penalty = tune()
mixture = 1
Computational engine: glmnet
Penalty grid
lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 30))
lr_reg_grid# A tibble: 30 × 1
penalty
<dbl>
1 0.0001
2 0.000127
3 0.000161
4 0.000204
5 0.000259
6 0.000329
7 0.000418
8 0.000530
9 0.000672
10 0.000853
# ℹ 20 more rows
logistic_fit <- logistic_wkflow |>
tune_grid(
grid = lr_reg_grid,
resamples = customer_cv,
metrics = metric_set(roc_auc),
control = control_resamples(save_pred = TRUE)
)
logistic_fit# Tuning results
# 10-fold cross-validation using stratification
# A tibble: 10 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [1026/115]> Fold01 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>
2 <split [1026/115]> Fold02 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>
3 <split [1026/115]> Fold03 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>
4 <split [1026/115]> Fold04 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>
5 <split [1026/115]> Fold05 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>
6 <split [1027/114]> Fold06 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>
7 <split [1028/113]> Fold07 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>
8 <split [1028/113]> Fold08 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>
9 <split [1028/113]> Fold09 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>
10 <split [1028/113]> Fold10 <tibble [30 × 5]> <tibble [0 × 3]> <tibble>
collect_metrics(logistic_fit)# A tibble: 30 × 7
penalty .metric .estimator mean n std_err .config
<dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 0.0001 roc_auc binary 0.938 10 0.00776 Preprocessor1_Model01
2 0.000127 roc_auc binary 0.938 10 0.00762 Preprocessor1_Model02
3 0.000161 roc_auc binary 0.939 10 0.00755 Preprocessor1_Model03
4 0.000204 roc_auc binary 0.939 10 0.00758 Preprocessor1_Model04
5 0.000259 roc_auc binary 0.939 10 0.00762 Preprocessor1_Model05
6 0.000329 roc_auc binary 0.939 10 0.00756 Preprocessor1_Model06
7 0.000418 roc_auc binary 0.939 10 0.00749 Preprocessor1_Model07
8 0.000530 roc_auc binary 0.938 10 0.00741 Preprocessor1_Model08
9 0.000672 roc_auc binary 0.938 10 0.00722 Preprocessor1_Model09
10 0.000853 roc_auc binary 0.938 10 0.00723 Preprocessor1_Model10
# ℹ 20 more rows
Visualizing Logistic Regression metrics
collect_metrics(logistic_fit) |>
ggplot(aes(x = penalty, y = mean)) +
geom_point() +
geom_line() +
scale_x_log10(labels = scales::label_number()) +
theme_minimal()Which is the best model based on roc curve
best_model <- logistic_fit |>
show_best("roc_auc") |>
arrange(penalty) |>
# high penalty value, less complicated model
slice(5)
best_model# A tibble: 1 × 7
penalty .metric .estimator mean n std_err .config
<dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 0.000418 roc_auc binary 0.939 10 0.00749 Preprocessor1_Model07
Fewer irrelevant predictors is better. If performance is about the same, we’d prefer to choose a higher penalty value
Pulling predictions for the best model
logistic_fit |>
collect_predictions(parameters = best_model)# A tibble: 1,141 × 7
id .pred_churn `.pred_non-churn` .row penalty churn .config
<chr> <dbl> <dbl> <int> <dbl> <fct> <chr>
1 Fold01 0.275 0.725 1 0.000418 churn Preprocessor1_Mode…
2 Fold01 0.179 0.821 21 0.000418 churn Preprocessor1_Mode…
3 Fold01 0.989 0.0115 36 0.000418 churn Preprocessor1_Mode…
4 Fold01 0.991 0.00882 42 0.000418 churn Preprocessor1_Mode…
5 Fold01 0.985 0.0151 50 0.000418 churn Preprocessor1_Mode…
6 Fold01 0.897 0.103 57 0.000418 churn Preprocessor1_Mode…
7 Fold01 0.981 0.0188 81 0.000418 churn Preprocessor1_Mode…
8 Fold01 0.422 0.578 95 0.000418 churn Preprocessor1_Mode…
9 Fold01 0.984 0.0164 105 0.000418 churn Preprocessor1_Mode…
10 Fold01 0.437 0.563 128 0.000418 churn Preprocessor1_Mode…
# ℹ 1,131 more rows
Best logistic model roc curve
best_logistic_model <- logistic_fit |>
collect_predictions(parameters = best_model) |>
roc_curve(churn, .pred_churn) |>
mutate(
model = "Logistic Regression"
)
best_logistic_model# A tibble: 1,136 × 4
.threshold specificity sensitivity model
<dbl> <dbl> <dbl> <chr>
1 -Inf 0 1 Logistic Regression
2 9.93e-7 0 1 Logistic Regression
3 1.51e-6 0.00105 1 Logistic Regression
4 1.67e-6 0.00209 1 Logistic Regression
5 2.35e-6 0.00314 1 Logistic Regression
6 2.61e-6 0.00419 1 Logistic Regression
7 3.61e-6 0.00524 1 Logistic Regression
8 3.89e-6 0.00628 1 Logistic Regression
9 4.20e-6 0.00733 1 Logistic Regression
10 4.88e-6 0.00838 1 Logistic Regression
# ℹ 1,126 more rows
best_logistic_model |>
autoplot()Random Forest Model
The second model to try is a random forest model.
We will tune 3 hyper parameters,
trees: Number of treesmtry: No of select predictorrsmin_n: Minimum Node Size
rf_model <- rand_forest(
mtry = tune(),
trees = tune(),
min_n = tune()
) |>
set_engine("ranger") |>
set_mode("classification")
rf_modelRandom Forest Model Specification (classification)
Main Arguments:
mtry = tune()
trees = tune()
min_n = tune()
Computational engine: ranger
Create a new workflow for random forest.
Random forest does not require categorical variables to be dummified and scaling of numerical predictors.
Instead of a normal recipe, we will use formula in our workflow. We will fit train data as is in our random forest model.
rf_wkflow <- workflow() |>
add_model(rf_model) |>
add_formula(churn ~ .)
rf_wkflow══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: rand_forest()
── Preprocessor ────────────────────────────────────────────────────────────────
churn ~ .
── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)
Main Arguments:
mtry = tune()
trees = tune()
min_n = tune()
Computational engine: ranger
Search grid for random forest model
set.seed(345)
rf_fit <- rf_wkflow |>
tune_grid(
grid = 10,
resamples = customer_cv,
control = control_grid(save_pred = TRUE),
metrics = metric_set(roc_auc)
)i Creating pre-processing data to finalize unknown parameter: mtry
rf_fit# Tuning results
# 10-fold cross-validation using stratification
# A tibble: 10 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [1026/115]> Fold01 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>
2 <split [1026/115]> Fold02 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>
3 <split [1026/115]> Fold03 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>
4 <split [1026/115]> Fold04 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>
5 <split [1026/115]> Fold05 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>
6 <split [1027/114]> Fold06 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>
7 <split [1028/113]> Fold07 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>
8 <split [1028/113]> Fold08 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>
9 <split [1028/113]> Fold09 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>
10 <split [1028/113]> Fold10 <tibble [10 × 7]> <tibble [0 × 3]> <tibble>
collect_metrics(rf_fit)# A tibble: 10 × 9
mtry trees min_n .metric .estimator mean n std_err .config
<int> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
1 11 79 3 roc_auc binary 0.976 10 0.00310 Preprocessor1_Model…
2 2 1531 35 roc_auc binary 0.971 10 0.00390 Preprocessor1_Model…
3 11 899 27 roc_auc binary 0.971 10 0.00333 Preprocessor1_Model…
4 12 370 31 roc_auc binary 0.969 10 0.00338 Preprocessor1_Model…
5 9 1944 36 roc_auc binary 0.970 10 0.00320 Preprocessor1_Model…
6 8 594 13 roc_auc binary 0.977 10 0.00296 Preprocessor1_Model…
7 4 1742 16 roc_auc binary 0.978 10 0.00322 Preprocessor1_Model…
8 5 1210 18 roc_auc binary 0.977 10 0.00291 Preprocessor1_Model…
9 7 1084 24 roc_auc binary 0.975 10 0.00303 Preprocessor1_Model…
10 1 798 6 roc_auc binary 0.970 10 0.00510 Preprocessor1_Model…
rf_best_model_parameters <- rf_fit |>
select_best(metric = "roc_auc") |>
slice(1)
rf_best_model_parameters# A tibble: 1 × 4
mtry trees min_n .config
<int> <int> <int> <chr>
1 4 1742 16 Preprocessor1_Model07
best_rf_model <- rf_fit |>
collect_predictions(parameters = rf_best_model_parameters) |>
roc_curve(churn, .pred_churn) |>
mutate(
model = "Random Forest"
)
best_rf_model# A tibble: 1,090 × 4
.threshold specificity sensitivity model
<dbl> <dbl> <dbl> <chr>
1 -Inf 0 1 Random Forest
2 0 0 1 Random Forest
3 0.0000359 0.0346 1 Random Forest
4 0.0000410 0.0356 1 Random Forest
5 0.0000442 0.0387 1 Random Forest
6 0.0000478 0.0408 1 Random Forest
7 0.0000522 0.0429 1 Random Forest
8 0.0000574 0.0471 1 Random Forest
9 0.0000638 0.0492 1 Random Forest
10 0.0000765 0.0503 1 Random Forest
# ℹ 1,080 more rows
autoplot(best_rf_model)best_logistic_model |>
bind_rows(best_rf_model) |>
ggplot(aes(x = 1 - specificity, y = sensitivity, col = model)) +
geom_path(lwd = 1.5, alpha = 0.8) +
geom_abline(lty = 3) +
coord_equal() +
scale_color_viridis_d(option = "plasma", end = .6)Random Forest performs better than logistic regression
Last Fit.
We have established the best model is Random Forest with hyperparameters as follows
mtry: 4
trees: 1742
min_n: 16
last_rf_model <- rand_forest( mtry = 4, trees = 1742, min_n = 16) |>
set_engine("ranger", importance = "impurity") |>
set_mode("classification")
last_rf_modelRandom Forest Model Specification (classification)
Main Arguments:
mtry = 4
trees = 1742
min_n = 16
Engine-Specific Arguments:
importance = impurity
Computational engine: ranger
last_rf_wkflow <- rf_wkflow |>
update_model(last_rf_model)
last_rf_wkflow══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: rand_forest()
── Preprocessor ────────────────────────────────────────────────────────────────
churn ~ .
── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)
Main Arguments:
mtry = 4
trees = 1742
min_n = 16
Engine-Specific Arguments:
importance = impurity
Computational engine: ranger
Fitting the final model
rf_last_fit <- last_rf_wkflow |>
last_fit(split = customer_split)
rf_last_fit# Resampling results
# Manual resampling
# A tibble: 1 × 6
splits id .metrics .notes .predictions .workflow
<list> <chr> <list> <list> <list> <list>
1 <split [1141/286]> train/test split <tibble> <tibble> <tibble> <workflow>
collect_metrics(rf_last_fit)# A tibble: 2 × 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 accuracy binary 0.944 Preprocessor1_Model1
2 roc_auc binary 0.987 Preprocessor1_Model1
collect_predictions(rf_last_fit) |>
roc_curve(churn, .pred_churn) |>
autoplot()Confusion Matric
collect_predictions(rf_last_fit) |>
conf_mat(churn, .pred_class) Truth
Prediction churn non-churn
churn 38 7
non-churn 9 232
The model performed well on unseen data (test_data).
Out of 47 customers who churned, the model was able to predict 38.
Which variables are most important in predicting churned customers?
rf_last_fit |>
extract_fit_parsnip() |>
vi() |>
ggplot(aes(x = Importance, y = fct_reorder(Variable, Importance))) +
geom_col(fill = "#2297FA") +
labs(
y = "",
title = "Variable Importance Plot") +
theme_minimal()The most important predictor was customer complains variable, which recorded whether a customer had complained or not.