Predicting a qualitative response for observation can be referred to as classifying that observation since it involves assigning the observation to a category or class. Classification forms the basis for Logistic Regression. Logistic Regression is a supervised algorithm used to predict a dependent variable that is categorical or discrete. Logistic regression models the data using the sigmoid function. Churned Customers are those who have decided to end their relationship with their existing company. In our case study, we will be working on a churn dataset.
XYZ is a service-providing company that provides customers with a one-year subscription plan for their product. The company wants to know if the customers will renew the subscription for the coming year or not.
Build a logistics regression learning model on the given dataset to determine whether the customer will churn or not. Tech stack * Language - R * Libraries - * Tidy Models, * Random Forests * Xgboost
source("00_Scripts/plot_cor.R")## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.2 ──
## ✔ stringr 1.4.1 ✔ forcats 0.5.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ foreach::accumulate() masks purrr::accumulate()
## ✖ lubridate::as.difftime() masks base::as.difftime()
## ✖ readr::col_factor() masks scales::col_factor()
## ✖ lubridate::date() masks base::date()
## ✖ purrr::discard() masks scales::discard()
## ✖ magrittr::extract() masks tidyr::extract()
## ✖ plotly::filter() masks dplyr::filter(), stats::filter()
## ✖ xts::first() masks dplyr::first()
## ✖ stringr::fixed() masks recipes::fixed()
## ✖ lubridate::intersect() masks base::intersect()
## ✖ dplyr::lag() masks stats::lag()
## ✖ xts::last() masks dplyr::last()
## ✖ magrittr::set_names() masks purrr::set_names()
## ✖ lubridate::setdiff() masks base::setdiff()
## ✖ readr::spec() masks yardstick::spec()
## ✖ lubridate::union() masks base::union()
## ✖ foreach::when() masks purrr::when()
source("00_Scripts/plot_hist_facet.R")
source("00_Scripts/plot_ggpairs.R")## Registered S3 method overwritten by 'GGally':
## method from
## +.gg ggplot2
Wehave been provided with data from XYZ a service-providing company that provides customers with a one-year subscription plan for their product. The company wants to know if the customers will renew the subscription for the coming year or not. We use readr to import this data forn us to analyse . Our data has 16 variables and 2000 entries.
# Read in the data
data <- read_csv("00_data/data_regression.csv")## Rows: 2000 Columns: 16
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (4): phone_no, gender, multi_screen, mail_subscribed
## dbl (12): year, customer_id, age, no_of_days_subscribed, weekly_mins_watched...
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
The CSV consists of around 2000 rows and 16 columns == Features: 1. Year 2. Customer_id - unique id 3. Phone_no - customer phone no 4. Gender -Male/Female 5. Age 6. No of days subscribed - the number of days since the subscription 7. Multi-screen - does the customer have a single/ multiple screen subscription 8. Mail subscription - customer receive mails or not 9. Weekly mins watched - number of minutes watched weekly 10. Minimum daily mins - minimum minutes watched 11. Maximum daily mins - maximum minutes watched 12. Weekly nights max mins - number of minutes watched at night time 13. Videos watched - total number of videos watched 14. Maximum_days_inactive - days since inactive 15. Customer support calls - number of customer support calls 16. Churn - ● 1- Yes ● 0 - No
In addition to knowing the features in this data set it is important to inspect and clean this data. Cleaning the data involves among others taking care of missing variables, and transforming our depend variable into factors to help us with analysis.
At this stage lets drop missing variables and reduce the customer id, phone number and year since we wont be using this- as they are not important for our analysis. Our final dataset has 1918 records and 13 features
churn_clean <- data |> mutate(churn = factor(churn)) %>%
drop_na() |>
select(- customer_id,-phone_no, -year)
print(head(churn_clean))## # A tibble: 6 × 13
## gender age no_of_d…¹ multi…² mail_…³ weekl…⁴ minim…⁵ maxim…⁶ weekl…⁷ video…⁸
## <chr> <dbl> <dbl> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 Female 36 62 no no 148. 12.2 16.8 82 1
## 2 Female 39 149 no no 294. 7.7 33.4 87 3
## 3 Female 65 126 no no 87.3 11.9 9.89 91 1
## 4 Female 24 131 no yes 321. 9.5 36.4 102 4
## 5 Female 40 191 no no 243 10.9 27.5 83 7
## 6 Male 61 205 no yes 264. 7.8 29.9 64 5
## # … with 3 more variables: maximum_days_inactive <dbl>,
## # customer_support_calls <dbl>, churn <fct>, and abbreviated variable names
## # ¹no_of_days_subscribed, ²multi_screen, ³mail_subscribed,
## # ⁴weekly_mins_watched, ⁵minimum_daily_mins, ⁶maximum_daily_mins,
## # ⁷weekly_max_night_mins, ⁸videos_watched
# Visualize data
To understand this data we use our histogram function to have more insights on the remaining variables. From this histogram, it appears gender, email subscribed and churn are categorical variables with the rest of the features being numeric features
plot_hist_facet(churn_clean)Lets use our corre;lation function to look at the correlation between our features
plot_ggpairs( churn_clean)
This chart shows that there are three categorical variables namely
gender, Multi-screen, and Mail subscription.There are more males than
females in our data set.Also most respondents did not have a multi
screen- only 194 of the 1984 had a multi screen, and majority of our
clients did not subscribe to mail - only 547 of the 1918.
For numeric data we see that Age us skewed to the right meaning there could be outliers with higher ages. The mean age is 38.6 and median age of 37. The mean number of days subscribed is around 100 days and median of 99 looks like the number of days subscribed is normally distributed( mean and median seems close). The weekly minutes watched is normally distributed with a median of 269.9 minutes. The minimum daily minutes watched and maximum daily minutes watched and weekly maximum night minutes watched is also equally distributed with median of 10.20 ,30.59 and 101 minutes respectively
hist(churn_clean$videos_watched, main ="Histogram of Number of Videos watched", col= "#32CD32",xlab = "Number of Videos")+
theme_tq()## NULL
Lets look at average number of days of inactivity. It looks like the maximum number of days inactive is 3 days
hist(churn_clean$maximum_days_inactive, main ="Histogram of Maximum days of Inactivity", col= "#85FF00",xlab = "Number of Inactive days")+
theme_tq()## NULL
Whether teh customer churned or not. It looks like most of our clients did not churn as only 253 clients churned and therefore our data is imbalanced.we may need to address the class imbalance in section below
plot(churn_clean$churn)+ theme_tq_dark()## NULL
The following code looks at the class imbalance as a volume and proportion and then I am going to use the second index from the class balance table i.e. the number of people who churned and those that did not churn.
class_bal_table <- table(churn_clean$churn)
prop_tab <- prop.table(class_bal_table)
upsample_ratio <- (class_bal_table[2] / sum(class_bal_table))
print(prop_tab)##
## 0 1
## 0.8680918 0.1319082
print(class_bal_table)##
## 0 1
## 1665 253
It is always a good idea to observe the data structures of the data items we are trying to predict. I generally separate the names of the variables out into factors, integer / numerics and character vectors:
As described above we have one factor variable which is churn, while age,no_of_days_subscribed, weekly_mins_watched, minimum_daily_mins,maximum_daily_mins, weekly_max_night_mins, videos_watched, maximum_days_inactive and customer_support_calls are numeric variables. Gender,multi_screen, and mail_subscribed are character variables
factors <- names(select_if(churn_clean, is.factor))
numbers <- names(select_if(churn_clean, is.numeric))
characters <- names(select_if(churn_clean, is.character))
print(factors)## [1] "churn"
print(numbers)## [1] "age" "no_of_days_subscribed" "weekly_mins_watched"
## [4] "minimum_daily_mins" "maximum_daily_mins" "weekly_max_night_mins"
## [7] "videos_watched" "maximum_days_inactive" "customer_support_calls"
print(characters)## [1] "gender" "multi_screen" "mail_subscribed"
We start developing our machine learning model and we start by dividing the data into a training and test sample. This approach is the simplest method to testing our models accuracy and future performance on unseen data. Here we are going to treat the test data as the unseen data to allow us to evaluate if the model is fit for being released into the wild, or not. We will use 80-20% rule to split our data
# Partition into training and hold out test / validation sample
set.seed(123)
split <- rsample::initial_split(churn_clean, prop=0.8)
train_data <- rsample::training(split)
test_data <- rsample::testing(split)Let us use recipes package to prepare and our model. The first part of the recipe is to fit our model and then we add recipe steps, this is supposed to replicate baking adding the specific ingredients. We will start by up sampling our data given our class imbalance described above. We will then convert categorical variables to dummy- 0 and 1’s and then get rid of features that have zero variances using step-zero variance. We will then scale and center our data using step normalize to make sure our model is normalized.
churn_rec <-
recipe(churn ~ ., data=train_data) %>%
themis::step_upsample(churn, over_ratio = as.numeric(upsample_ratio)) %>%
#SMOTE recipe step to upsample the minority class i.e. those that churned
step_dummy(all_nominal(), -all_outcomes()) %>%
#Automatically created dummy variables for all categorical variables (nominal)
step_zv(all_predictors()) %>%
#Get rid of features that have zero variance
step_normalize(all_predictors()) #ML models train better when the data is centered and scaled
print(churn_rec) #Terminology is to use recipe## Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 12
##
## Operations:
##
## Up-sampling based on churn
## Dummy variables from all_nominal(), -all_outcomes()
## Zero variance filter on all_predictors()
## Centering and scaling for all_predictors()
We then use the Parsnip package to create a basic logistic regression as our baseline model. We are using logistic regression as a choice a nice generalized linear model that most people have encountered anbd very relevant to our case. Now that we have a recipe and have parsnabed our model it is important we create a workflow structure for our modelling . We use the steps below :
In TidyModels you have to create an instance of the model in memory before working with it:
lr_mod <-
parsnip::logistic_reg() %>%
set_engine("glm")
print(lr_mod)## Logistic Regression Model Specification (classification)
##
## Computational engine: glm
The next step is to create the model workflow.
Now it is time to do the workflow to connect the newly instantiated model together:
# Create model workflow
churn_wf <-
workflow() %>%
add_model(lr_mod) %>%
add_recipe(churn_rec)
print(churn_wf)## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 4 Recipe Steps
##
## • step_upsample()
## • step_dummy()
## • step_zv()
## • step_normalize()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Logistic Regression Model Specification (classification)
##
## Computational engine: glm
The next step is fitting the model to our data:
# Create the model fit
churn_fit <-
churn_wf %>%
fit(data = train_data)The final step is to use the pull_workflow_fit() parameter to retrieve the fit on the workflow:
churn_fitted <- churn_fit %>%
extract_fit_parsnip() %>%
tidy()
print(churn_fitted)## # A tibble: 13 × 5
## term estimate std.error statistic p.value
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 (Intercept) -2.29 0.100 -22.8 3.83e-115
## 2 age 0.146 0.0781 1.87 6.09e- 2
## 3 no_of_days_subscribed -0.0405 0.0841 -0.482 6.30e- 1
## 4 weekly_mins_watched 276. 270. 1.02 3.06e- 1
## 5 minimum_daily_mins 0.575 0.239 2.40 1.62e- 2
## 6 maximum_daily_mins -276. 270. -1.02 3.07e- 1
## 7 weekly_max_night_mins 0.0366 0.0849 0.431 6.67e- 1
## 8 videos_watched -0.117 0.0908 -1.29 1.97e- 1
## 9 maximum_days_inactive -0.276 0.240 -1.15 2.52e- 1
## 10 customer_support_calls 0.612 0.0764 8.01 1.12e- 15
## 11 gender_Male 0.105 0.0850 1.23 2.17e- 1
## 12 multi_screen_yes 0.520 0.0631 8.24 1.69e- 16
## 13 mail_subscribed_yes -0.341 0.0918 -3.72 2.01e- 4
As an optional step I have created a plot to visualise the significance. This will only work with linear, and generalized linear models, that analyse p values from t tests and finding the probability value from the t distribution. The visualisation code is contained hereunder:
# Add significance column to tibble using mutate
churn_fitted <- churn_fitted %>%
mutate(Significance = ifelse(p.value < 0.05, "Significant", "Insignificant")) %>%
arrange(desc(p.value))
#Create a ggplot object to visualise significance
plot <- churn_fitted %>%
ggplot(data = churn_fitted, mapping = aes(x=term, y=p.value, fill=Significance)) +
geom_col() + theme(axis.text.x = element_text(
face="bold", color="#0070BA",
size=8, angle=90)
) + labs(y="P value", x="Terms",
title="P value significance chart",
subtitle="A chart to represent the significant variables in the model",
caption="Produced by Erick Yegon 10th Dec 2022")
print("Creating plot of P values")## [1] "Creating plot of P values"
#print(plot) + theme_tq_green()
plotly::ggplotly(plot)+ theme_tq_green()## NULL
#print(ggplotly(plot))
#ggsave("Figures/p_val_plot.png", plot) #Save the plotNow we will assess how well the model predicts on the test (holdout) data to evaluate if we want to productionise the model, or abandon it at this stage. This is implemented below:
class_pred <- predict(churn_fit, test_data) #Get the class label predictions
prob_pred <- predict(churn_fit, test_data, type="prob") #Get the probability predictions
lr_predictions <- data.frame(class_pred, prob_pred) %>%
setNames(c("LR_Class", "LR_NotChurnedProb", "LR_ChurnedProb")) #Combined into tibble and rename
churned_preds <- test_data %>%
bind_cols(lr_predictions)
print(tail(lr_predictions))## LR_Class LR_NotChurnedProb LR_ChurnedProb
## 379 0 0.8421420 0.15785802
## 380 0 0.8862254 0.11377460
## 381 0 0.8418378 0.15816217
## 382 0 0.9419147 0.05808528
## 383 0 0.8348782 0.16512177
## 384 0 0.6377448 0.36225517
Lets use the Yardstick tools in our TidyModels arsenal. It is useful for generating quick summary statistics and evaluation metrics. I will grab the area under the curve estimates to show how well the model fits:
roc_plot <-
churned_preds %>%
roc_curve(truth = churn, LR_NotChurnedProb) %>%
autoplot
roc_plot+
theme_tq_green()RoC plors are great but they only show you sensitivity how well it is at predicting stranded and the inverse how good it is at predicting not churned. It would be great to look at the overall accuracy and balanced accuracy on a confusion matrix, for binomial classification problems.
I use the CARET package and utilise the confusion matrix functions to perform this:
cm <- caret::confusionMatrix(churned_preds$churn,
churned_preds$LR_Class)
print(cm)## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 324 9
## 1 40 11
##
## Accuracy : 0.8724
## 95% CI : (0.8348, 0.9041)
## No Information Rate : 0.9479
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.254
##
## Mcnemar's Test P-Value : 1.822e-05
##
## Sensitivity : 0.8901
## Specificity : 0.5500
## Pos Pred Value : 0.9730
## Neg Pred Value : 0.2157
## Prevalence : 0.9479
## Detection Rate : 0.8438
## Detection Prevalence : 0.8672
## Balanced Accuracy : 0.7201
##
## 'Positive' Class : 0
##
Lets use the Confusion Table function to flatten binary and multi-class confusion matrix results.
library(ConfusionTableR)
cm <- ConfusionTableR::binary_visualiseR(
train_labels = churned_preds$churn,
truth_labels = churned_preds$LR_Class,
class_label1 = "Not Churned",
class_label2 = "Churned",
quadrant_col1 = "#28ACB4",
quadrant_col2 = "#4397D2",
custom_title = "Churned Customer Confusion Matrix",
text_col= "black",
cm_stat_size = 1.2,
round_dig = 2)
## Save the information in database As this is a binary classification
problem, then there is the potential to store the outputs of the model
into a database. The ConfusionTableR package can do this for us and to
implement this in a binary classification model, we would use the
binary_class_cm to output this to a list whereby you can
then flatten the table. This works very much like the broom package for
linear regression outputs.
cm_binary_class <- ConfusionTableR::binary_class_cm(
train_labels = churned_preds$churn,
truth_labels = churned_preds$LR_Class)## [INFO] Building a record level confusion matrix to store in dataset
## [INFO] Build finished and to expose record level cm use the record_level_cm list item
# Expose the record level confusion matrix
glimpse(cm_binary_class$record_level_cm)## Rows: 1
## Columns: 23
## $ Pred_0_Ref_0 <int> 324
## $ Pred_1_Ref_0 <int> 40
## $ Pred_0_Ref_1 <int> 9
## $ Pred_1_Ref_1 <int> 11
## $ Accuracy <dbl> 0.8723958
## $ Kappa <dbl> 0.2540438
## $ AccuracyLower <dbl> 0.8348301
## $ AccuracyUpper <dbl> 0.9040866
## $ AccuracyNull <dbl> 0.9479167
## $ AccuracyPValue <dbl> 1
## $ McnemarPValue <dbl> 1.82153e-05
## $ Sensitivity <dbl> 0.8901099
## $ Specificity <dbl> 0.55
## $ Pos.Pred.Value <dbl> 0.972973
## $ Neg.Pred.Value <dbl> 0.2156863
## $ Precision <dbl> 0.972973
## $ Recall <dbl> 0.8901099
## $ F1 <dbl> 0.9296987
## $ Prevalence <dbl> 0.9479167
## $ Detection.Rate <dbl> 0.84375
## $ Detection.Prevalence <dbl> 0.8671875
## $ Balanced.Accuracy <dbl> 0.7200549
## $ cm_ts <dttm> 2022-12-13 17:18:31
We will next look at how to improve our models with model selection, K-fold cross validation and hyperparameter tuning.
Lets save our model for later use
save.image(file="00_data/churned_data.rdata")Lets read in our model that we saved
load(file="00_data/churned_data.rdata")The first step will involve cross validation. The essence of cross validation is that we take sub samples of the training dataset to emulate how well the model will perform on unseen data samples when out in the wild (production):
The folds take a sample of the training set and each randomly selected fold acts as the test sample. We then use a final hold out validation set to finally test the model. This will be shown in the following section.
set.seed(123)
#Set a random seed for replication of results
ten_fold <- vfold_cv(train_data, v=10)We will use the previous trained logistic regression model with resamples to improve the results of the cross validation:
set.seed(123)
lr_fit_rs <-
churn_wf %>%
fit_resamples(ten_fold)We will now collect the metrics using the tune package and the collect_metrics function:
# To collect the resmaples you need to call collect_metrics to average out the accuracy for that model
collected_mets <- tune::collect_metrics(lr_fit_rs)
print(collected_mets)## # A tibble: 2 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.870 10 0.00785 Preprocessor1_Model1
## 2 roc_auc binary 0.788 10 0.0164 Preprocessor1_Model1
# Now I can compare the accuracy from the previous test set I had already generated a confusion matrix for
accuracy_resamples <- collected_mets$mean[1] * 100
accuracy_validation_set <- as.numeric(cm$overall[1] * 100)
print(cat(paste0("The true accuracy of the model is between the resample testing:",
round(accuracy_resamples,2), "\nThe validation sample: ",
round(accuracy_validation_set,2), ".")))## The true accuracy of the model is between the resample testing:86.97
## The validation sample: .NULL
This shows that the true accuracy value is somewhere between the reported results from the resampling method and those in our validation sample.
Let us move on from the logistic regression and aim to build a random forest, and later a decision tree. Other options in Parnsip would be to use a gradient boosted tree to amp up the results further.
The first step, as with the logistic regression example, if to define and instantiate the model:
rf_mod <-
rand_forest(trees=500) %>%
set_engine("ranger") %>%
set_mode("classification")
print(rf_mod)## Random Forest Model Specification (classification)
##
## Main Arguments:
## trees = 500
##
## Computational engine: ranger
Then we are going to fit the model to the previous training data:
rf_fit <-
rf_mod %>%
fit(churn ~ ., data = train_data)
print(rf_fit)## parsnip model object
##
## Ranger result
##
## Call:
## ranger::ranger(x = maybe_data_frame(x), y = y, num.trees = ~500, num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1), probability = TRUE)
##
## Type: Probability estimation
## Number of trees: 500
## Sample size: 1534
## Number of independent variables: 12
## Mtry: 3
## Target node size: 10
## Variable importance mode: none
## Splitrule: gini
## OOB prediction error (Brier s.): 0.06541508
We will aim to increase the sample representation in this model by fitting it to a resamples object, in parsnip and rsample:
#Create workflow step
rf_wf <-
workflow() %>%
add_model(rf_mod) %>%
add_formula(churn ~ .) #The predictor is contained in add_formula method
set.seed(123)
rf_fit_rs <-
rf_wf %>%
fit_resamples(ten_fold)
print(rf_fit_rs)## # Resampling results
## # 10-fold cross-validation
## # A tibble: 10 × 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [1380/154]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]>
## 2 <split [1380/154]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]>
## 3 <split [1380/154]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]>
## 4 <split [1380/154]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]>
## 5 <split [1381/153]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]>
## 6 <split [1381/153]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]>
## 7 <split [1381/153]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]>
## 8 <split [1381/153]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]>
## 9 <split [1381/153]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]>
## 10 <split [1381/153]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]>
The next step is to collect the resample metrics:
# Collect the metrics using another model with resampling
rf_resample_mean_preds <- tune::collect_metrics(rf_fit_rs)
print(rf_resample_mean_preds)## # A tibble: 2 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.922 10 0.00462 Preprocessor1_Model1
## 2 roc_auc binary 0.870 10 0.0116 Preprocessor1_Model1
we have improved our accuracy to 92% with RF. The model predictive power is maxing out at about 92%. I know this is due to the fact that the data is dummy data and most of the features that are contained in the model have a weak association to the outcome variable.
We are going to now create a decision tree and we are going to tune the hyperparameters using the dials package. The dials package contains a list of hyperparameter tuning methods and is useful for creating quick hyperparameter grids and aiming to optimise them.
Like all the other steps, the first thing to do is build the decision tree. Note - the reason set_model(“classification”) is because the thing we are predicting is a factor. If this was a continuous variable, then you would need to switch this to regression. However, the model development for regression is identical to classification.
tune_tree <-
decision_tree(
cost_complexity = tune(), #tune() is a placeholder for an empty grid
tree_depth = tune() #we will fill these in the next section
) %>%
set_engine("rpart") %>%
set_mode("classification")
print(tune_tree)## Decision Tree Model Specification (classification)
##
## Main Arguments:
## cost_complexity = tune()
## tree_depth = tune()
##
## Computational engine: rpart
The next step is to fill these blank values for cost complexity and tree depth - see the documentation for parsnip about these meaning, but decision trees have a cost value which minimises the splits and the depth of the tree is how far down you go.
We will now create the object:
grid_tree_tune <- grid_regular(dials::cost_complexity(),
dials::tree_depth(),
levels = 10)
print(head(grid_tree_tune,20))## # A tibble: 20 × 2
## cost_complexity tree_depth
## <dbl> <int>
## 1 0.0000000001 1
## 2 0.000000001 1
## 3 0.00000001 1
## 4 0.0000001 1
## 5 0.000001 1
## 6 0.00001 1
## 7 0.0001 1
## 8 0.001 1
## 9 0.01 1
## 10 0.1 1
## 11 0.0000000001 2
## 12 0.000000001 2
## 13 0.00000001 2
## 14 0.0000001 2
## 15 0.000001 2
## 16 0.00001 2
## 17 0.0001 2
## 18 0.001 2
## 19 0.01 2
## 20 0.1 2
The tuning process, and modelling process, normally needs the ML engineer to access the full potential of your machine. The next steps show how to register the cores on your machine and max them out for training the model and doing grid searching:
all_cores <- parallel::detectCores(logical = FALSE)-1
print(all_cores)## [1] 3
#Registers all cores and subtracts one, so you have some time to work
cl <- makePSOCKcluster(all_cores)
print(cl)## socket cluster with 3 nodes on host 'localhost'
#Makes an in memory cluster to utilise your cores
registerDoParallel(cl)
#Registers that we want to do parallel processingNext, I will create the model workflow, as we have done a few times before:
set.seed(123)
tree_wf <- workflow() %>%
add_model(tune_tree) %>%
add_formula(churn ~ .)
# Make the decision tree workflow - always postfix with wf for convention
# Add the registered model
# Add the formula of the outcome class you are predicting against all IVs
tree_pred_tuned <-
tree_wf %>%
tune::tune_grid(
resamples = ten_fold, #This is the 10 fold cross validation variable we created earlier
grid = grid_tree_tune #This is the tuning grid
)This ggplot helps to visualise how the manual tuning has gone on and will show where the best tree depth occurs in terms of the cost complexity (the number of terminal or leaf nodes):
tune_plot <- tree_pred_tuned %>%
collect_metrics() %>% #Collect metrics from tuning
mutate(tree_depth = factor(tree_depth)) %>%
ggplot(aes(cost_complexity, mean, color = tree_depth)) +
geom_line(size = 1, alpha = 0.7) +
geom_point(size = 1.5) +
facet_wrap(~ .metric, scales = "free", nrow = 2) +
scale_x_log10(labels = scales::label_number()) +
scale_color_viridis_d(option = "plasma", begin = .9, end = 0) + theme_tq()## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
print(tune_plot)ggsave(filename="Figures/hyperparameter_tree.png", tune_plot)## Saving 7 x 5 in image
This shows that you only need a depth of 4 to get the optimal accuracy. However, the tune package helps us out with this as well.
The tune package allows us to select the best candidate model, with the most optimal set of hyperparameters:
# To get the best ROC - area under the curve value we will use the following:
tree_pred_tuned %>%
tune::show_best("roc_auc")## # A tibble: 5 × 8
## cost_complexity tree_depth .metric .estimator mean n std_err .config
## <dbl> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.0000000001 7 roc_auc binary 0.857 10 0.00960 Preprocesso…
## 2 0.000000001 7 roc_auc binary 0.857 10 0.00960 Preprocesso…
## 3 0.00000001 7 roc_auc binary 0.857 10 0.00960 Preprocesso…
## 4 0.0000001 7 roc_auc binary 0.857 10 0.00960 Preprocesso…
## 5 0.000001 7 roc_auc binary 0.857 10 0.00960 Preprocesso…
# Select the best tree
best_tree <- tree_pred_tuned %>%
tune::select_best("roc_auc")
print(best_tree)## # A tibble: 1 × 3
## cost_complexity tree_depth .config
## <dbl> <int> <chr>
## 1 0.0000000001 7 Preprocessor1_Model041
The next step is to us the best tree to make our predictions.
final_wf <-
tree_wf %>%
finalize_workflow(best_tree) #Finalise workflow passes in our best tree
print(final_wf)## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Formula
## Model: decision_tree()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## churn ~ .
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Decision Tree Model Specification (classification)
##
## Main Arguments:
## cost_complexity = 1e-10
## tree_depth = 7
##
## Computational engine: rpart
Make a prediction against this finalised tree:
final_tree_pred <-
final_wf %>%
fit(data = train_data)
print(final_tree_pred)## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Formula
## Model: decision_tree()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## churn ~ .
##
## ── Model ───────────────────────────────────────────────────────────────────────
## n= 1534
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 1534 202 0 (0.86831812 0.13168188)
## 2) customer_support_calls< 3.5 1417 142 0 (0.89978829 0.10021171)
## 4) weekly_mins_watched< 385.425 1314 92 0 (0.92998478 0.07001522)
## 8) multi_screen=no 1200 50 0 (0.95833333 0.04166667)
## 16) weekly_mins_watched< 348.975 1098 31 0 (0.97176685 0.02823315) *
## 17) weekly_mins_watched>=348.975 102 19 0 (0.81372549 0.18627451)
## 34) mail_subscribed=yes 31 0 0 (1.00000000 0.00000000) *
## 35) mail_subscribed=no 71 19 0 (0.73239437 0.26760563)
## 70) minimum_daily_mins< 12.85 59 12 0 (0.79661017 0.20338983) *
## 71) minimum_daily_mins>=12.85 12 5 1 (0.41666667 0.58333333) *
## 9) multi_screen=yes 114 42 0 (0.63157895 0.36842105)
## 18) minimum_daily_mins< 13.05 94 22 0 (0.76595745 0.23404255)
## 36) videos_watched>=2.5 75 3 0 (0.96000000 0.04000000) *
## 37) videos_watched< 2.5 19 0 1 (0.00000000 1.00000000) *
## 19) minimum_daily_mins>=13.05 20 0 1 (0.00000000 1.00000000) *
## 5) weekly_mins_watched>=385.425 103 50 0 (0.51456311 0.48543689)
## 10) mail_subscribed=yes 33 4 0 (0.87878788 0.12121212) *
## 11) mail_subscribed=no 70 24 1 (0.34285714 0.65714286)
## 22) weekly_mins_watched< 454.2 60 24 1 (0.40000000 0.60000000)
## 44) minimum_daily_mins< 7.3 11 3 0 (0.72727273 0.27272727) *
## 45) minimum_daily_mins>=7.3 49 16 1 (0.32653061 0.67346939)
## 90) weekly_mins_watched< 413.4 31 14 1 (0.45161290 0.54838710)
## 180) weekly_max_night_mins>=108.5 8 2 0 (0.75000000 0.25000000) *
## 181) weekly_max_night_mins< 108.5 23 8 1 (0.34782609 0.65217391) *
## 91) weekly_mins_watched>=413.4 18 2 1 (0.11111111 0.88888889) *
## 23) weekly_mins_watched>=454.2 10 0 1 (0.00000000 1.00000000) *
## 3) customer_support_calls>=3.5 117 57 1 (0.48717949 0.51282051)
## 6) weekly_mins_watched>=248.775 64 18 0 (0.71875000 0.28125000)
## 12) minimum_daily_mins< 13.05 54 12 0 (0.77777778 0.22222222) *
## 13) minimum_daily_mins>=13.05 10 4 1 (0.40000000 0.60000000) *
## 7) weekly_mins_watched< 248.775 53 11 1 (0.20754717 0.79245283) *
We will look at global variable importance. As mentioned prior, to look at local patient level importance, use the LIME package.
plot <- final_tree_pred %>%
extract_fit_parsnip() %>%
vip(aesthetics = list(color = "black", fill = "#26ACB5")) + theme_tq()
print(ggplotly(plot))
ggsave("Figures/VarImp.png", plot)## Saving 7 x 5 in image
This was derived when we looked at the logistic regression significance that these would be the important variables, due to their linear significance. Weekly minutes watched maximum daily_ minutes anbd customer support calls are very important variavles
The last step is to create the final predictions from the tuned decision tree:
# Create the final prediction
final_fit <-
final_wf %>%
last_fit(split)
final_fit_fitted_metrics <- final_fit %>%
collect_metrics()
print(final_fit_fitted_metrics)## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.935 Preprocessor1_Model1
## 2 roc_auc binary 0.871 Preprocessor1_Model1
#Create the final predictions
final_fit_predictions <- final_fit %>%
collect_predictions()
print(final_fit_predictions)## # A tibble: 384 × 7
## id .pred_0 .pred_1 .row .pred_class churn .config
## <chr> <dbl> <dbl> <int> <fct> <fct> <chr>
## 1 train/test split 0.208 0.792 3 1 1 Preprocessor1_Model1
## 2 train/test split 0.972 0.0282 7 0 0 Preprocessor1_Model1
## 3 train/test split 0.972 0.0282 15 0 0 Preprocessor1_Model1
## 4 train/test split 0.417 0.583 21 1 0 Preprocessor1_Model1
## 5 train/test split 0.208 0.792 22 1 1 Preprocessor1_Model1
## 6 train/test split 0.778 0.222 25 0 1 Preprocessor1_Model1
## 7 train/test split 0 1 36 1 1 Preprocessor1_Model1
## 8 train/test split 0.972 0.0282 43 0 0 Preprocessor1_Model1
## 9 train/test split 0.972 0.0282 47 0 0 Preprocessor1_Model1
## 10 train/test split 0.972 0.0282 50 0 0 Preprocessor1_Model1
## # … with 374 more rows
You could do similar with viewing this object in the confusion matrix add in, but I will view this on a plot:
roc_plot <- final_fit_predictions %>%
roc_curve(churn, `.pred_0`) %>%
autoplot()
print(roc_plot)ggsave(filename = "Figures/tuned_tree.png", plot=roc_plot)## Saving 7 x 5 in image
One last point to note - to inspect any of the tuning parameters and hyperparameters for the models you can use the args function to return these - examples below:
args(decision_tree)## function (mode = "unknown", engine = "rpart", cost_complexity = NULL,
## tree_depth = NULL, min_n = NULL)
## NULL
args(logistic_reg)## function (mode = "classification", engine = "glm", penalty = NULL,
## mixture = NULL)
## NULL
args(rand_forest)## function (mode = "unknown", engine = "ranger", mtry = NULL, trees = NULL,
## min_n = NULL)
## NULL
Finally we will look at implementing a really powerful model Extrem Gradient Boosting
From the previous recipe we created applying the preprocessing steps to the churn data we are going to try and improve our model using tuning, resampling and powerful model training techniques, using bagged trees and gradient boosting.
# Use churned recipe
# Prepare the recipe
churn_rec_preped <- churn_rec %>%
prep()
#Bake the recipe
strand_folds <-
recipes::bake(
churn_rec_preped,
new_data = training(split)
) %>%
rsample::vfold_cv(v = 10)xgboost_model <-
parsnip::boost_tree(
mode = "classification",
trees=1000,
min_n = tune(),
tree_depth = tune(),
learn_rate = tune(),
loss_reduction = tune()
) %>%
set_engine("xgboost")The next step, as we covered earlier, is to create a grid search in dials to go over each one of these hyperparameters to tune the model.
xgboost_params <-
dials::parameters(
min_n(),
tree_depth(),
learn_rate(),
loss_reduction()
)
xgboost_grid <-
dials::grid_max_entropy(
xgboost_params, size = 100 #Indicates the size of the search space
)
xgboost_grid## # A tibble: 100 × 4
## min_n tree_depth learn_rate loss_reduction
## <int> <int> <dbl> <dbl>
## 1 24 2 2.65e- 8 4.05e-6
## 2 16 15 5.79e-10 1.64e-2
## 3 19 10 3.90e- 6 2.90e-1
## 4 38 1 5.27e- 5 1.02e-8
## 5 6 3 4.21e- 7 7.09e-3
## 6 38 2 1.18e- 2 3.25e+0
## 7 15 4 4.13e- 3 6.12e-9
## 8 12 9 9.50e- 5 1.48e+1
## 9 27 3 9.47e- 6 6.24e-9
## 10 28 2 4.84e- 5 3.07e-4
## # … with 90 more rows
The next step is to set up the workflow for the grid.
xgboost_wf <-
workflows::workflow() %>%
add_model(xgboost_model) %>%
add_formula(churn ~ .)The next step is to now tune the model using your tuning grid:
xgboost_tuned <- tune::tune_grid(
object = xgboost_wf,
resamples = strand_folds,
grid = xgboost_grid,
metrics = yardstick::metric_set(accuracy, roc_auc),
control = tune::control_grid(verbose=TRUE)
)Let’s now get the best hyperparameters for the model:
xgboost_tuned %>%
tune::show_best(metric="roc_auc")## # A tibble: 5 × 10
## min_n tree_depth learn_…¹ loss_r…² .metric .esti…³ mean n std_err .config
## <int> <int> <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 4 14 2.42e-3 1.29e- 2 roc_auc binary 0.879 10 0.0152 Prepro…
## 2 4 9 3.52e-9 8.64e- 2 roc_auc binary 0.876 10 0.0140 Prepro…
## 3 4 12 5.95e-9 4.13e- 3 roc_auc binary 0.876 10 0.0139 Prepro…
## 4 3 7 1.06e-5 4.51e+ 0 roc_auc binary 0.875 10 0.0128 Prepro…
## 5 4 5 3.67e-5 6.80e-10 roc_auc binary 0.873 10 0.0128 Prepro…
## # … with abbreviated variable names ¹learn_rate, ²loss_reduction, ³.estimator
Now let’s select the most performant from the model:
xgboost_best_params <- xgboost_tuned %>%
tune::select_best(metric="roc_auc")The final stage is finalizing the model to use the best parameters:
xgboost_model_final <- xgboost_model %>%
tune::finalize_model(xgboost_best_params)# Create training set
train_proc <- bake(churn_rec_preped,
new_data = training(split))
train_prediction <- xgboost_model_final %>%
fit(
formula = churn ~ .,
data = train_proc
) %>%
predict(new_data = train_proc) %>%
bind_cols(training(split))
xgboost_score_train <-
train_prediction %>%
yardstick::metrics(churn, .pred_class)
# Create testing set
test_proc <- bake(churn_rec_preped,
new_data = testing(split))
test_prediction <- xgboost_model_final %>%
fit(
formula = churn ~ .,
data = train_proc
) %>%
predict(new_data = test_proc)
# Bind test predictions to labels
test_prediction <- cbind(test_prediction, testing(split))
xgboost_score <-
test_prediction %>%
yardstick::metrics(churn, .pred_class)
print(xgboost_score)## # A tibble: 2 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.935
## 2 kap binary 0.700
The final step we will use ConfusionTableR to store the classification results in a csv file:
library(data.table)##
## Attaching package: 'data.table'
## The following objects are masked from 'package:xts':
##
## first, last
## The following objects are masked from 'package:lubridate':
##
## hour, isoweek, mday, minute, month, quarter, second, wday, week,
## yday, year
## The following object is masked from 'package:purrr':
##
## transpose
## The following objects are masked from 'package:dplyr':
##
## between, first, last
library(ConfusionTableR)
cm_outputs <- ConfusionTableR::binary_class_cm(train_labels = test_prediction$.pred_class,
truth_labels = test_prediction$churn)## [INFO] Building a record level confusion matrix to store in dataset
## [INFO] Build finished and to expose record level cm use the record_level_cm list item
cm_outputs ## $confusion_matrix
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 324 16
## 1 9 35
##
## Accuracy : 0.9349
## 95% CI : (0.9054, 0.9574)
## No Information Rate : 0.8672
## P-Value [Acc > NIR] : 1.533e-05
##
## Kappa : 0.6999
##
## Mcnemar's Test P-Value : 0.2301
##
## Sensitivity : 0.9730
## Specificity : 0.6863
## Pos Pred Value : 0.9529
## Neg Pred Value : 0.7955
## Prevalence : 0.8672
## Detection Rate : 0.8438
## Detection Prevalence : 0.8854
## Balanced Accuracy : 0.8296
##
## 'Positive' Class : 0
##
##
## $record_level_cm
## Pred_0_Ref_0 Pred_1_Ref_0 Pred_0_Ref_1 Pred_1_Ref_1 Accuracy Kappa
## 1 324 9 16 35 0.9348958 0.699925
## AccuracyLower AccuracyUpper AccuracyNull AccuracyPValue McnemarPValue
## 1 0.9053911 0.9574284 0.8671875 1.533058e-05 0.2301393
## Sensitivity Specificity Pos.Pred.Value Neg.Pred.Value Precision Recall
## 1 0.972973 0.6862745 0.9529412 0.7954545 0.9529412 0.972973
## F1 Prevalence Detection.Rate Detection.Prevalence Balanced.Accuracy
## 1 0.9628529 0.8671875 0.84375 0.8854167 0.8296237
## cm_ts
## 1 2022-12-13 18:11:32
##
## $cm_tbl
## PredLabel Freq
## 1 Pred_0_Ref_0 324
## 2 Pred_1_Ref_0 9
## 3 Pred_0_Ref_1 16
## 4 Pred_1_Ref_1 35
##
## $last_run
## [1] "2022-12-13 18:11:32 EAT"
#data.table::fwrite(cm_outputs$record_level_cm,
# file=paste0("xgboost_results_", as.character(Sys.time()),
# ".csv"))