Customer Retention Case Study

Rudy Martinez, Brenda Parnin, Jose Fernandez

11/11/2021


Install Packages

#install.packages(c("SMCRM","dplyr","tidyr","GGally","corrgram", "ggplot2","rpart","rattle","randomForestSRC","purrr"))


Libraries

library('SMCRM')
library("knitr")
library('corrgram')
library('GGally')
library('plyr')    
library('rpart')
library('tidyverse')
library('randomForestSRC')
library('randomForest')
library('caret')

# theme for nice plotting
theme_nice <- theme_classic()+
                theme(
                  axis.line.y.left = element_line(colour = "black"),
                  axis.line.y.right = element_line(colour = "black"),
                  axis.line.x.bottom = element_line(colour = "black"),
                  axis.line.x.top = element_line(colour = "black"),
                  axis.text.y = element_text(colour = "black", size = 12),
                  axis.text.x = element_text(color = "black", size = 12),
                  axis.ticks = element_line(color = "black")) +
                theme(
                  axis.ticks.length = unit(-0.25, "cm"), 
                  axis.text.x = element_text(margin=unit(c(0.5,0.5,0.5,0.5), "cm")), 
                  axis.text.y = element_text(margin=unit(c(0.5,0.5,0.5,0.5), "cm")))


Background

Managing customer retention and acquisition is essential for developing and maintaining customer relationships. The first step to cure customer retention and acquisition is to predict which customers have a high probability of ending their relationship with the firm and the probability of acquiring a new customer. The second step is to target the predicted at-risk current customers or new customers with high likelihood of joining using incentives such as pricing offers or communications such as emails. Models that accurately predict customer retention and acquisition are pivotal in targeting the right customers, thereby decreasing the cost of the marketing campaign and using scarce firm resources more efficiently.


Case Study Tasks

This case study will address the following tasks:

  • Use acquisitionRetention data set to predict which customers will be acquired and for how long (duration) based on a feature set using a random forest.
  • Compute variable importance to detect interactions and optimize hyperparameters for acquired customers.
  • Compare the accuracy of model with a decision trees and logistic regression model for acquiring customers.
  • Extra credit: generate PDP plots for all variables


Data Fields

  • customer: customer number (from 1 to 500)
  • acquisition: 1 if the prospect was acquired, 0 otherwise
  • duration: number of days the customer was a customer of the firm, 0 if acquisition == 0
  • profit: customer lifetime value (CLV) of a given customer, -(Acq_Exp) if the customer is not acquired
  • acq_exp: total dollars spent on trying to acquire this prospect
  • ret_exp: total dollars spent on trying to retain this customer
  • acq_exp_sq: square of the total dollars spent on trying to acquire this prospect
  • ret_exp_sq: square of the total dollars spent on trying to retain this customer
  • freq: number of purchases the customer made during that customer’s lifetime with the firm, 0 if acquisition == 0
  • freq_sq: square of the number of purchases the customer made during that customer’s lifetime with the firm
  • crossbuy: number of product categories the customer purchased from during that customer’s lifetime with the firm, 0 if acquisition = 0
  • sow: Share-of-Wallet; percentage of purchases the customer makes from the given firm given the total amount of purchases across all firms in that category
  • industry: 1 if the customer is in the B2B industry, 0 otherwise
  • revenue: annual sales revenue of the prospect’s firm (in millions of dollar)
  • employees: number of employees in the prospect’s firm


Read in Data

data(acquisitionRetention)
data = acquisitionRetention
str(acquisitionRetention)
## 'data.frame':    500 obs. of  15 variables:
##  $ customer   : num  1 2 3 4 5 6 7 8 9 10 ...
##  $ acquisition: num  1 1 1 0 1 1 1 1 0 0 ...
##  $ duration   : num  1635 1039 1288 0 1631 ...
##  $ profit     : num  6134 3524 4081 -638 5446 ...
##  $ acq_exp    : num  694 460 249 638 589 ...
##  $ ret_exp    : num  972 450 805 0 920 ...
##  $ acq_exp_sq : num  480998 211628 62016 407644 346897 ...
##  $ ret_exp_sq : num  943929 202077 648089 0 846106 ...
##  $ freq       : num  6 11 21 0 2 7 15 13 0 0 ...
##  $ freq_sq    : num  36 121 441 0 4 49 225 169 0 0 ...
##  $ crossbuy   : num  5 6 6 0 9 4 5 5 0 0 ...
##  $ sow        : num  95 22 90 0 80 48 51 23 0 0 ...
##  $ industry   : num  1 0 0 0 0 1 0 1 0 1 ...
##  $ revenue    : num  47.2 45.1 29.1 40.6 48.7 ...
##  $ employees  : num  898 686 1423 181 631 ...
  • There are 500 total observations across 15 variables. It appears that customer is a field that is not necessary; therefore, we will remove this from our modeling process.


Data Exploration

#Summary Statistics
summary(data)
##     customer      acquisition       duration          profit       
##  Min.   :  1.0   Min.   :0.000   Min.   :   0.0   Min.   :-1027.0  
##  1st Qu.:125.8   1st Qu.:0.000   1st Qu.:   0.0   1st Qu.: -316.3  
##  Median :250.5   Median :1.000   Median : 957.5   Median : 3369.9  
##  Mean   :250.5   Mean   :0.676   Mean   : 742.5   Mean   : 2403.8  
##  3rd Qu.:375.2   3rd Qu.:1.000   3rd Qu.:1146.2   3rd Qu.: 3931.6  
##  Max.   :500.0   Max.   :1.000   Max.   :1673.0   Max.   : 6134.3  
##     acq_exp           ret_exp         acq_exp_sq          ret_exp_sq     
##  Min.   :   1.21   Min.   :   0.0   Min.   :      1.5   Min.   :      0  
##  1st Qu.: 384.14   1st Qu.:   0.0   1st Qu.: 147562.0   1st Qu.:      0  
##  Median : 491.66   Median : 398.1   Median : 241729.7   Median : 158480  
##  Mean   : 493.35   Mean   : 336.3   Mean   : 271211.1   Mean   : 184000  
##  3rd Qu.: 600.21   3rd Qu.: 514.3   3rd Qu.: 360246.0   3rd Qu.: 264466  
##  Max.   :1027.04   Max.   :1095.0   Max.   :1054811.2   Max.   :1198937  
##       freq          freq_sq          crossbuy           sow        
##  Min.   : 0.00   Min.   :  0.00   Min.   : 0.000   Min.   :  0.00  
##  1st Qu.: 0.00   1st Qu.:  0.00   1st Qu.: 0.000   1st Qu.:  0.00  
##  Median : 6.00   Median : 36.00   Median : 5.000   Median : 44.00  
##  Mean   : 6.22   Mean   : 69.25   Mean   : 4.052   Mean   : 38.88  
##  3rd Qu.:11.00   3rd Qu.:121.00   3rd Qu.: 7.000   3rd Qu.: 66.00  
##  Max.   :21.00   Max.   :441.00   Max.   :11.000   Max.   :116.00  
##     industry        revenue        employees     
##  Min.   :0.000   Min.   :14.49   Min.   :  18.0  
##  1st Qu.:0.000   1st Qu.:33.53   1st Qu.: 503.0  
##  Median :1.000   Median :41.43   Median : 657.5  
##  Mean   :0.522   Mean   :40.54   Mean   : 671.5  
##  3rd Qu.:1.000   3rd Qu.:47.52   3rd Qu.: 826.0  
##  Max.   :1.000   Max.   :65.10   Max.   :1461.0
#Correlation Visual
ggcorr(data, method = c("everything", "pearson")) 

#Correlation Values
corr_results = cor(data)
round(corr_results, 2)
##             customer acquisition duration profit acq_exp ret_exp acq_exp_sq
## customer        1.00        0.05     0.04   0.04   -0.03    0.02      -0.04
## acquisition     0.05        1.00     0.94   0.96    0.00    0.87      -0.08
## duration        0.04        0.94     1.00   0.98    0.01    0.98      -0.06
## profit          0.04        0.96     0.98   1.00    0.04    0.95      -0.04
## acq_exp        -0.03        0.00     0.01   0.04    1.00    0.01       0.97
## ret_exp         0.02        0.87     0.98   0.95    0.01    1.00      -0.06
## acq_exp_sq     -0.04       -0.08    -0.06  -0.04    0.97   -0.06       1.00
## ret_exp_sq     -0.01        0.63     0.83   0.78    0.03    0.92      -0.02
## freq            0.04        0.78     0.71   0.75    0.00    0.69      -0.06
## freq_sq         0.03        0.57     0.50   0.54   -0.01    0.51      -0.05
## crossbuy        0.06        0.87     0.83   0.86    0.03    0.78      -0.04
## sow             0.01        0.85     0.81   0.83    0.03    0.74      -0.03
## industry        0.10        0.24     0.21   0.23    0.01    0.18       0.03
## revenue         0.00        0.25     0.23   0.24    0.06    0.20       0.04
## employees       0.02        0.48     0.43   0.47   -0.04    0.41      -0.06
##             ret_exp_sq  freq freq_sq crossbuy   sow industry revenue employees
## customer         -0.01  0.04    0.03     0.06  0.01     0.10    0.00      0.02
## acquisition       0.63  0.78    0.57     0.87  0.85     0.24    0.25      0.48
## duration          0.83  0.71    0.50     0.83  0.81     0.21    0.23      0.43
## profit            0.78  0.75    0.54     0.86  0.83     0.23    0.24      0.47
## acq_exp           0.03  0.00   -0.01     0.03  0.03     0.01    0.06     -0.04
## ret_exp           0.92  0.69    0.51     0.78  0.74     0.18    0.20      0.41
## acq_exp_sq       -0.02 -0.06   -0.05    -0.04 -0.03     0.03    0.04     -0.06
## ret_exp_sq        1.00  0.51    0.38     0.58  0.53     0.10    0.13      0.29
## freq              0.51  1.00    0.94     0.69  0.66     0.16    0.15      0.43
## freq_sq           0.38  0.94    1.00     0.52  0.48     0.10    0.10      0.36
## crossbuy          0.58  0.69    0.52     1.00  0.75     0.22    0.19      0.42
## sow               0.53  0.66    0.48     0.75  1.00     0.21    0.23      0.41
## industry          0.10  0.16    0.10     0.22  0.21     1.00    0.03      0.00
## revenue           0.13  0.15    0.10     0.19  0.23     0.03    1.00      0.05
## employees         0.29  0.43    0.36     0.42  0.41     0.00    0.05      1.00
  • The following variables are highly correlated. This finding will be taken into consideration when building our model.
    • acuisition
    • duration
    • profit
    • ret_exp
    • acq_exp_sq
    • ret_exp_sq
    • freq
    • freq_sq
    • crossbuy
    • sow


# create box plots to show if we need to remove any variables
par(mfrow = c(2, 5))
boxplot(duration ~ acquisition, data, xlab = "acquisition", ylab = "duration")
boxplot(profit ~ acquisition, data, xlab = "acquisition", ylab = "profit")
boxplot(ret_exp ~ acquisition, data, xlab = "acquisition", ylab = "ret_exp")
boxplot(acq_exp_sq ~ acquisition, data, xlab = "acquisition", ylab = "acq_exp_sq")
boxplot(ret_exp_sq ~ acquisition, data, xlab = "acquisition", ylab = "ret_exp_sq")
boxplot(freq ~ acquisition, data, xlab = "acquisition", ylab = "freq")
boxplot(freq_sq ~ acquisition, data, xlab = "acquisition", ylab = "freq_sq")
boxplot(crossbuy ~ acquisition, data, xlab = "acquisition", ylab = "crossbuy")
boxplot(sow ~ acquisition, data, xlab = "acquisition", ylab = "sow")


Refactor Variables

data$acquisition = as.factor(data$acquisition)


Data Cleaning

#Check for Null Values
sum(is.na(data))
## [1] 0
#Check for Duplicates
sum(duplicated(data))
## [1] 0
  • There are no null or duplicate values in our dataset


Random Forest - acquisition target

Use acquisitionRetention data set to predict which customers will be acquired (acquisition) and for how long (duration) based on a feature set using a random forest.

set.seed(123)

idx.train = sample(1:nrow(data), size = 0.8 * nrow(data))
train.df = data[idx.train,]
test.df = data[-idx.train,]
set.seed(123)

forest1 = rfsrc(acquisition ~ acq_exp + industry + revenue + employees, #Only include variables based on prior analysis.
                            data = train.df,
                            importance = TRUE, 
                            ntree = 1000)

forest1
##                          Sample size: 400
##            Frequency of class labels: 126, 274
##                      Number of trees: 1000
##            Forest terminal node size: 1
##        Average no. of terminal nodes: 62.454
## No. of variables tried at each split: 2
##               Total no. of variables: 4
##        Resampling used to grow trees: swor
##     Resample size used to grow trees: 253
##                             Analysis: RF-C
##                               Family: class
##                       Splitting rule: gini *random*
##        Number of random split points: 10
##                     Imbalanced ratio: 2.1746
##                    (OOB) Brier score: 0.14316217
##         (OOB) Normalized Brier score: 0.57264866
##                            (OOB) AUC: 0.855347
##                         (OOB) PR-AUC: 0.71186989
##                         (OOB) G-mean: 0.70911114
##    (OOB) Requested performance error: 0.205, 0.44444444, 0.09489051
## 
## Confusion matrix:
## 
##           predicted
##   observed  0   1 class.error
##          0 70  56      0.4444
##          1 26 248      0.0949
## 
##       (OOB) Misclassification rate: 20.5%
  • Based on our initial model results, it appears that we achieved a training accuracy rate of 79.5%.


Random Forest Model Predictions - acquisition

#predicted class labels - used for classification problems
rf1_pred = predict(forest1, test.df)$class

#temporary df to bind prediction to original dataframe
temp1_df = cbind(data,rf1_pred)

#new dataframe to filter the new dataframe for a value of 1, signaling successful customer acquisition
data_new = temp1_df %>% filter(rf1_pred == 1)


Random Forest - duration target

set.seed(123)

forest2 = rfsrc(duration~ acq_exp + industry + revenue + employees, 
                 data = data_new,
                 importance = TRUE,
                 ntree = 1000)

forest2
##                          Sample size: 355
##                      Number of trees: 1000
##            Forest terminal node size: 5
##        Average no. of terminal nodes: 46.168
## No. of variables tried at each split: 2
##               Total no. of variables: 4
##        Resampling used to grow trees: swor
##     Resample size used to grow trees: 224
##                             Analysis: RF-R
##                               Family: regr
##                       Splitting rule: mse *random*
##        Number of random split points: 10
##                      (OOB) R squared: 0.31046814
##    (OOB) Requested performance error: 203944.8272326


Compute variable importance

forest2$importance
##   acq_exp  industry   revenue employees 
## 209115.02  31077.07 111540.77 214058.95
data.frame(importance = forest2$importance) %>%
  tibble::rownames_to_column(var = "variable") %>%
  ggplot(aes(x = reorder(variable,importance), y = importance)) +
    geom_bar(stat = "identity", fill = "steelblue", color = "black")+
    coord_flip() +
     labs(x = "Variables", y = "Variable importance")+
     theme_nice

  • We observe that the above variables are displayed in order of importance in the prediction of duration: employees, acq_exp, revenue, and industry. It is important to note that when analyzing variable importance, the variable with the highest value is the most important.


mindepth = max.subtree(forest2,
                        sub.order = TRUE)

# first order depths
print(round(mindepth$order, 3)[,1])
##   acq_exp  industry   revenue employees 
##     1.398     2.251     1.093     0.671
# visualize minimal depth
data.frame(md = round(mindepth$order, 3)[,1]) %>%
  tibble::rownames_to_column(var = "variable") %>%
  ggplot(aes(x = reorder(variable,desc(md)), y = md)) +
    geom_bar(stat = "identity", fill = "steelblue", color = "black", width = 0.2)+
    coord_flip() +
     labs(x = "Variables", y = "Minimal Depth")+
     theme_nice

  • We observe that the above variables are displayed in order of importance in the prediction of duration: employees, acq_exp, revenue, and industry. It is important to note that when analyzing minimal depth, the variable with the least value is the most important. The most important variable is going to split closest to the tree.


# min depth
mindepth$sub.order
##             acq_exp  industry    revenue  employees
## acq_exp   0.1131823 0.4490810 0.17765787 0.16078760
## industry  0.2335417 0.1788152 0.24009811 0.23241239
## revenue   0.1495405 0.3939603 0.08634847 0.14618496
## employees 0.1344936 0.2746964 0.13511553 0.05449049
as.matrix(mindepth$sub.order) %>%
  reshape2::melt() %>%
  data.frame() %>%
  ggplot(aes(x = Var1, y = Var2, fill = value)) +
    scale_x_discrete(position = "top") +
    geom_tile(color = "white") +
    viridis::scale_fill_viridis("Relative min. depth") +
    labs(x = "", y = "") +
    theme_bw()

  • In the above chart, we can visualize minimal depth.


Detect interactions

# cross-check with vimp
find.interaction(forest2,
                      method = "vimp",
                      importance = "permute")
## Pairing employees with acq_exp 
## Pairing employees with revenue 
## Pairing employees with industry 
## Pairing acq_exp with revenue 
## Pairing acq_exp with industry 
## Pairing revenue with industry 
## 
##                               Method: vimp
##                     No. of variables: 4
##            Variables sorted by VIMP?: TRUE
##    No. of variables used for pairing: 4
##     Total no. of paired interactions: 6
##             Monte Carlo replications: 1
##     Type of noising up used for VIMP: permute
## 
##                        Var 1     Var 2   Paired Additive Difference
## employees:acq_exp  59035.048 24579.654 90978.97 83614.70   7364.264
## employees:revenue  59035.048  8390.153 76497.17 67425.20   9071.967
## employees:industry 59035.048  8519.683 68559.44 67554.73   1004.706
## acq_exp:revenue    23814.565  8390.153 38964.75 32204.72   6760.036
## acq_exp:industry   23814.565  8519.683 33569.86 32334.25   1235.609
## revenue:industry    8715.785  8519.683 19329.09 17235.47   2093.619
  • Based on the interaction results above, we see that several of our interactions achieve a relatively large positive Difference between Paired and Additive. The find.interaction documentation reads that a large positive or negative difference indicates an association worth pursuing if the univariate VIMP for each of the paired variables is reasonably large.


set.seed(123)

idx.train_new = sample(1:nrow(data_new), size = 0.8 * nrow(data_new))
train.df_new = data_new[idx.train_new,]
test.df_new = data_new[-idx.train_new,]
  • To prepare for modeling, we will split our data following an 80/20 approach.


Untuned Random Forest Model without Interactions

forest.no_interaction.untuned = rfsrc(duration ~ acq_exp + industry + revenue + employees, 
                            data = train.df_new, 
                            importance = TRUE, 
                            ntree = 1000)

Optimize hyperparameters for acquired customers without interactions

# Establish a list of possible values for hyper-parameters
mtry.values <- seq(4,6,1)
nodesize.values <- seq(4,8,2)
ntree.values <- seq(4e3,6e3,1e3)

# Create a data frame containing all combinations 
hyper_grid = expand.grid(mtry = mtry.values, nodesize = nodesize.values, ntree = ntree.values)

# Create an empty vector to store OOB error values
oob_err = c()

# Write a loop over the rows of hyper_grid to train the grid of models
for (i in 1:nrow(hyper_grid)) {

   # Train a Random Forest model
   model = rfsrc(duration ~ acq_exp + industry + revenue + employees, 
                 data = train.df_new,
                 mtry = hyper_grid$mtry[i],
                 nodesize = hyper_grid$nodesize[i],
                 ntree = hyper_grid$ntree[i])  
  
                          
    # Store OOB error for the model                      
    oob_err[i] <- model$err.rate[length(model$err.rate)]
}

# Identify optimal set of hyperparmeters based on OOB error
opt_i <- which.min(oob_err)
print(hyper_grid[opt_i,])
##    mtry nodesize ntree
## 27    6        8  6000
  • The above enabled us to identify which combination of hyperparameters achieved the lowest error rate. With these results, we can now proceed to create our final random forest model.
    • mtry: 6
    • nodesize: 8
    • ntree: 6000


Tuned Random Forest Model with Optimal Parameters

set.seed(123)

forest.hyper = rfsrc(duration ~ acq_exp + industry + revenue + employees, 
                     data = train.df_new,
                     mtry = 6,
                     nodesize = 8,
                     ntree = 6000)

Additional Models for duration Prediction

#Logistic Regression
regression.logistic = glm(duration ~ acq_exp + industry + revenue + employees, data = train.df_new)

#Decision Tree Model
dt.model = rpart(duration ~ acq_exp + industry + revenue + employees, 
                             data = train.df_new)


Model Predictions Comparison on Test Set

error.df = 
  data.frame(pred1 = predict(forest.no_interaction.untuned,newdata = test.df_new)$predicted,
             pred2 = predict(forest.hyper, newdata = test.df_new)$predicted,
             pred3 = predict(regression.logistic, newdata = test.df_new),
             pred4 = predict(dt.model, newdata = test.df_new),
             actual = test.df_new$duration, 
             customer = test.df_new$customer) %>%
  mutate_at(.funs = funs(abs.error = abs(actual - .),
                         abs.percent.error = abs(actual - .)/abs(actual)),
            .vars = vars(pred1:pred4))
## Warning: `funs()` was deprecated in dplyr 0.8.0.
## Please use a list of either functions or lambdas: 
## 
##   # Simple named list: 
##   list(mean = mean, median = median)
## 
##   # Auto named with `tibble::lst()`: 
##   tibble::lst(mean, median)
## 
##   # Using lambdas
##   list(~ mean(., trim = .2), ~ median(., na.rm = TRUE))
#mae
error.df %>%
  summarise_at(.funs = funs(mae = mean(.)), 
               .vars = vars(pred1_abs.error:pred4_abs.error))
##   pred1_abs.error_mae pred2_abs.error_mae pred3_abs.error_mae
## 1            344.4241            350.4862            381.9821
##   pred4_abs.error_mae
## 1            405.8812
error.df2 =
  error.df %>%
  left_join(test.df_new, "customer") %>%
  mutate(customer_portfolio = cut(x = rev <- revenue, 
               breaks = qu <- quantile(rev, probs = seq(0, 1, 0.25)),
               labels = names(qu)[-1],
               include.lowest = T)) 

portfolio.mae = 
  error.df2 %>%
  group_by(customer_portfolio) %>%
  summarise_at(.funs = funs(mae = mean(.)), 
               .vars = vars(pred1_abs.error:pred4_abs.error)) %>%
  ungroup()


portfolio.errors = 
  portfolio.mae %>%
  gather(key = error_type, value = error, -customer_portfolio) %>%
  mutate(error_type2 = ifelse(grepl(pattern = "mae", error_type),"MAE","MAE"),
         model_type = ifelse(grepl(pattern = "pred1", error_type),"Untuned Forest",
                        ifelse(grepl(pattern = "pred2", error_type),"Tuned Forest",
                          ifelse(grepl(pattern = "pred3", error_type),"Logistic Regression", "Decision Tree")))) 


ggplot(portfolio.errors, aes(x = customer_portfolio, 
                             y = error, 
                             color = model_type, 
                             group = model_type))+
  geom_line(size = 1.02)+
  geom_point(shape = 15) +
  
  scale_color_brewer(palette = "Set1") +
  labs(y = "Error", x = "Customer portfolios")+
  theme_nice +
  theme(legend.position = "top")+
  guides(color = guide_legend(title = "Model Type", size = 4,nrow = 2,byrow = TRUE))

  • Based on the above results, the Untuned Random Forest model and Tuned Random Forest model both maintained the lowest error rate throughout each quartile of the customer base. The Tuned Random Forest model appears to have performed slightly better. The Decision Tree and Logistic Regression models performed quite similarly; however, they did not match the two Random Forest models.


Compare the accuracy of model with a decision trees and logistic regression model for acquiring customers acquisition

# Establish a list of possible values for hyper-parameters
mtry.values <- seq(4,6,1)
nodesize.values <- seq(4,8,2)
ntree.values <- seq(4e3,6e3,1e3)

# Create a data frame containing all combinations 
hyper_grid = expand.grid(mtry = mtry.values, nodesize = nodesize.values, ntree = ntree.values)

# Create an empty vector to store OOB error values
oob_err = c()

# Write a loop over the rows of hyper_grid to train the grid of models
for (i in 1:nrow(hyper_grid)) {

   # Train a Random Forest model
   model = rfsrc(acquisition ~ acq_exp + industry + revenue + employees, 
                 data = train.df,
                 mtry = hyper_grid$mtry[i],
                 nodesize = hyper_grid$nodesize[i],
                 ntree = hyper_grid$ntree[i])  
  
                          
    # Store OOB error for the model                      
    oob_err[i] <- model$err.rate[length(model$err.rate)]
}

# Identify optimal set of hyperparmeters based on OOB error
opt_i <- which.min(oob_err)
print(hyper_grid[opt_i,])
##   mtry nodesize ntree
## 5    5        6  4000
set.seed(123)

forest_acquisition = rfsrc(acquisition ~ acq_exp + industry + revenue + employees, 
                     data = train.df,
                     mtry = 4,
                     nodesize = 6,
                     ntree = 4000)

logistic.regression.acquisition = glm(acquisition ~ acq_exp + industry + revenue + employees, data = train.df, family = "binomial")

decision.tree.acquisition = rpart(acquisition ~ acq_exp + industry + revenue + employees, data = train.df)
pred1_acq = predict(forest_acquisition,newdata = test.df)$class

pred2_acq = predict(logistic.regression.acquisition, newdata = test.df)
pred2_acq = ifelse(pred2_acq > 0.50, 1, 0)

pred3_acq = predict(decision.tree.acquisition, newdata = test.df, type = "class")


Accuracy Random Forest

confusionMatrix(as.factor(pred1_acq), test.df$acquisition, positive = '1')
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  0  1
##          0 22  6
##          1 14 58
##                                           
##                Accuracy : 0.8             
##                  95% CI : (0.7082, 0.8733)
##     No Information Rate : 0.64            
##     P-Value [Acc > NIR] : 0.0003862       
##                                           
##                   Kappa : 0.5438          
##                                           
##  Mcnemar's Test P-Value : 0.1175249       
##                                           
##             Sensitivity : 0.9062          
##             Specificity : 0.6111          
##          Pos Pred Value : 0.8056          
##          Neg Pred Value : 0.7857          
##              Prevalence : 0.6400          
##          Detection Rate : 0.5800          
##    Detection Prevalence : 0.7200          
##       Balanced Accuracy : 0.7587          
##                                           
##        'Positive' Class : 1               
## 


Accuracy Logistic Regression

confusionMatrix(as.factor(pred2_acq), test.df$acquisition, positive = '1')
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  0  1
##          0 24 10
##          1 12 54
##                                           
##                Accuracy : 0.78            
##                  95% CI : (0.6861, 0.8567)
##     No Information Rate : 0.64            
##     P-Value [Acc > NIR] : 0.001834        
##                                           
##                   Kappa : 0.5167          
##                                           
##  Mcnemar's Test P-Value : 0.831170        
##                                           
##             Sensitivity : 0.8438          
##             Specificity : 0.6667          
##          Pos Pred Value : 0.8182          
##          Neg Pred Value : 0.7059          
##              Prevalence : 0.6400          
##          Detection Rate : 0.5400          
##    Detection Prevalence : 0.6600          
##       Balanced Accuracy : 0.7552          
##                                           
##        'Positive' Class : 1               
## 


Accuracy Decision Tree

confusionMatrix(as.factor(pred3_acq), test.df$acquisition, positive = '1')
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  0  1
##          0 24 14
##          1 12 50
##                                           
##                Accuracy : 0.74            
##                  95% CI : (0.6427, 0.8226)
##     No Information Rate : 0.64            
##     P-Value [Acc > NIR] : 0.02196         
##                                           
##                   Kappa : 0.4425          
##                                           
##  Mcnemar's Test P-Value : 0.84452         
##                                           
##             Sensitivity : 0.7812          
##             Specificity : 0.6667          
##          Pos Pred Value : 0.8065          
##          Neg Pred Value : 0.6316          
##              Prevalence : 0.6400          
##          Detection Rate : 0.5000          
##    Detection Prevalence : 0.6200          
##       Balanced Accuracy : 0.7240          
##                                           
##        'Positive' Class : 1               
## 


Final Results

  • Based on the accuracy scores of the Tuned Random Forest, Logistic Regression, and Decision Tree models, it appears that the Tuned Random Forest model performed the best, followed closely by Logistic Regression and Decision Tree.


PDP Plot - Tuned Random Forest - acquisition

plot.variable(forest_acquisition, partial=TRUE)


PDP Plot - Tuned Random Forest - duration

plot.variable(forest.hyper, partial=TRUE)