#install.packages(c("SMCRM","dplyr","tidyr","GGally","corrgram", "ggplot2","rpart","rattle","randomForestSRC","purrr"))
library('SMCRM')
## Warning: package 'SMCRM' was built under R version 4.0.3
library("knitr")
library('corrgram')
## Warning: package 'corrgram' was built under R version 4.0.5
library('GGally')
## Warning: package 'GGally' was built under R version 4.0.5
library('plyr')
library('rpart')
library('tidyverse')
library('randomForestSRC')
## Warning: package 'randomForestSRC' was built under R version 4.0.5
library('randomForest')
## Warning: package 'randomForest' was built under R version 4.0.5
library('caret')
## Warning: package 'caret' was built under R version 4.0.5
# theme for nice plotting
theme_nice <- theme_classic()+
theme(
axis.line.y.left = element_line(colour = "black"),
axis.line.y.right = element_line(colour = "black"),
axis.line.x.bottom = element_line(colour = "black"),
axis.line.x.top = element_line(colour = "black"),
axis.text.y = element_text(colour = "black", size = 12),
axis.text.x = element_text(color = "black", size = 12),
axis.ticks = element_line(color = "black")) +
theme(
axis.ticks.length = unit(-0.25, "cm"),
axis.text.x = element_text(margin=unit(c(0.5,0.5,0.5,0.5), "cm")),
axis.text.y = element_text(margin=unit(c(0.5,0.5,0.5,0.5), "cm")))
Managing customer retention and acquisition is essential for developing and maintaining customer relationships. The first step to cure customer retention and acquisition is to predict which customers have a high probability of ending their relationship with the firm and the probability of acquiring a new customer. The second step is to target the predicted at-risk current customers or new customers with high likelihood of joining using incentives such as pricing offers or communications such as emails. Models that accurately predict customer retention and acquisition are pivotal in targeting the right customers, thereby decreasing the cost of the marketing campaign and using scarce firm resources more efficiently.
This case study will address the following tasks:
acquisitionRetention data set to predict which customers will be acquired and for how long (duration) based on a feature set using a random forest.customer: customer number (from 1 to 500)acquisition: 1 if the prospect was acquired, 0 otherwiseduration: number of days the customer was a customer of the firm, 0 if acquisition == 0profit: customer lifetime value (CLV) of a given customer, -(Acq_Exp) if the customer is not acquiredacq_exp: total dollars spent on trying to acquire this prospectret_exp: total dollars spent on trying to retain this customeracq_exp_sq: square of the total dollars spent on trying to acquire this prospectret_exp_sq: square of the total dollars spent on trying to retain this customerfreq: number of purchases the customer made during that customer’s lifetime with the firm, 0 if acquisition == 0freq_sq: square of the number of purchases the customer made during that customer’s lifetime with the firmcrossbuy: number of product categories the customer purchased from during that customer’s lifetime with the firm, 0 if acquisition = 0sow: Share-of-Wallet; percentage of purchases the customer makes from the given firm given the total amount of purchases across all firms in that categoryindustry: 1 if the customer is in the B2B industry, 0 otherwiserevenue: annual sales revenue of the prospect’s firm (in millions of dollar)employees: number of employees in the prospect’s firmdata(acquisitionRetention)
data = acquisitionRetention
str(acquisitionRetention)
## 'data.frame': 500 obs. of 15 variables:
## $ customer : num 1 2 3 4 5 6 7 8 9 10 ...
## $ acquisition: num 1 1 1 0 1 1 1 1 0 0 ...
## $ duration : num 1635 1039 1288 0 1631 ...
## $ profit : num 6134 3524 4081 -638 5446 ...
## $ acq_exp : num 694 460 249 638 589 ...
## $ ret_exp : num 972 450 805 0 920 ...
## $ acq_exp_sq : num 480998 211628 62016 407644 346897 ...
## $ ret_exp_sq : num 943929 202077 648089 0 846106 ...
## $ freq : num 6 11 21 0 2 7 15 13 0 0 ...
## $ freq_sq : num 36 121 441 0 4 49 225 169 0 0 ...
## $ crossbuy : num 5 6 6 0 9 4 5 5 0 0 ...
## $ sow : num 95 22 90 0 80 48 51 23 0 0 ...
## $ industry : num 1 0 0 0 0 1 0 1 0 1 ...
## $ revenue : num 47.2 45.1 29.1 40.6 48.7 ...
## $ employees : num 898 686 1423 181 631 ...
customer is a field that is not necessary; therefore, we will remove this from our modeling process.#Summary Statistics
summary(data)
## customer acquisition duration profit
## Min. : 1.0 Min. :0.000 Min. : 0.0 Min. :-1027.0
## 1st Qu.:125.8 1st Qu.:0.000 1st Qu.: 0.0 1st Qu.: -316.3
## Median :250.5 Median :1.000 Median : 957.5 Median : 3369.9
## Mean :250.5 Mean :0.676 Mean : 742.5 Mean : 2403.8
## 3rd Qu.:375.2 3rd Qu.:1.000 3rd Qu.:1146.2 3rd Qu.: 3931.6
## Max. :500.0 Max. :1.000 Max. :1673.0 Max. : 6134.3
## acq_exp ret_exp acq_exp_sq ret_exp_sq
## Min. : 1.21 Min. : 0.0 Min. : 1.5 Min. : 0
## 1st Qu.: 384.14 1st Qu.: 0.0 1st Qu.: 147562.0 1st Qu.: 0
## Median : 491.66 Median : 398.1 Median : 241729.7 Median : 158480
## Mean : 493.35 Mean : 336.3 Mean : 271211.1 Mean : 184000
## 3rd Qu.: 600.21 3rd Qu.: 514.3 3rd Qu.: 360246.0 3rd Qu.: 264466
## Max. :1027.04 Max. :1095.0 Max. :1054811.2 Max. :1198937
## freq freq_sq crossbuy sow
## Min. : 0.00 Min. : 0.00 Min. : 0.000 Min. : 0.00
## 1st Qu.: 0.00 1st Qu.: 0.00 1st Qu.: 0.000 1st Qu.: 0.00
## Median : 6.00 Median : 36.00 Median : 5.000 Median : 44.00
## Mean : 6.22 Mean : 69.25 Mean : 4.052 Mean : 38.88
## 3rd Qu.:11.00 3rd Qu.:121.00 3rd Qu.: 7.000 3rd Qu.: 66.00
## Max. :21.00 Max. :441.00 Max. :11.000 Max. :116.00
## industry revenue employees
## Min. :0.000 Min. :14.49 Min. : 18.0
## 1st Qu.:0.000 1st Qu.:33.53 1st Qu.: 503.0
## Median :1.000 Median :41.43 Median : 657.5
## Mean :0.522 Mean :40.54 Mean : 671.5
## 3rd Qu.:1.000 3rd Qu.:47.52 3rd Qu.: 826.0
## Max. :1.000 Max. :65.10 Max. :1461.0
#Correlation Visual
ggcorr(data, method = c("everything", "pearson"))
#Correlation Values
corr_results = cor(data)
round(corr_results, 2)
## customer acquisition duration profit acq_exp ret_exp acq_exp_sq
## customer 1.00 0.05 0.04 0.04 -0.03 0.02 -0.04
## acquisition 0.05 1.00 0.94 0.96 0.00 0.87 -0.08
## duration 0.04 0.94 1.00 0.98 0.01 0.98 -0.06
## profit 0.04 0.96 0.98 1.00 0.04 0.95 -0.04
## acq_exp -0.03 0.00 0.01 0.04 1.00 0.01 0.97
## ret_exp 0.02 0.87 0.98 0.95 0.01 1.00 -0.06
## acq_exp_sq -0.04 -0.08 -0.06 -0.04 0.97 -0.06 1.00
## ret_exp_sq -0.01 0.63 0.83 0.78 0.03 0.92 -0.02
## freq 0.04 0.78 0.71 0.75 0.00 0.69 -0.06
## freq_sq 0.03 0.57 0.50 0.54 -0.01 0.51 -0.05
## crossbuy 0.06 0.87 0.83 0.86 0.03 0.78 -0.04
## sow 0.01 0.85 0.81 0.83 0.03 0.74 -0.03
## industry 0.10 0.24 0.21 0.23 0.01 0.18 0.03
## revenue 0.00 0.25 0.23 0.24 0.06 0.20 0.04
## employees 0.02 0.48 0.43 0.47 -0.04 0.41 -0.06
## ret_exp_sq freq freq_sq crossbuy sow industry revenue employees
## customer -0.01 0.04 0.03 0.06 0.01 0.10 0.00 0.02
## acquisition 0.63 0.78 0.57 0.87 0.85 0.24 0.25 0.48
## duration 0.83 0.71 0.50 0.83 0.81 0.21 0.23 0.43
## profit 0.78 0.75 0.54 0.86 0.83 0.23 0.24 0.47
## acq_exp 0.03 0.00 -0.01 0.03 0.03 0.01 0.06 -0.04
## ret_exp 0.92 0.69 0.51 0.78 0.74 0.18 0.20 0.41
## acq_exp_sq -0.02 -0.06 -0.05 -0.04 -0.03 0.03 0.04 -0.06
## ret_exp_sq 1.00 0.51 0.38 0.58 0.53 0.10 0.13 0.29
## freq 0.51 1.00 0.94 0.69 0.66 0.16 0.15 0.43
## freq_sq 0.38 0.94 1.00 0.52 0.48 0.10 0.10 0.36
## crossbuy 0.58 0.69 0.52 1.00 0.75 0.22 0.19 0.42
## sow 0.53 0.66 0.48 0.75 1.00 0.21 0.23 0.41
## industry 0.10 0.16 0.10 0.22 0.21 1.00 0.03 0.00
## revenue 0.13 0.15 0.10 0.19 0.23 0.03 1.00 0.05
## employees 0.29 0.43 0.36 0.42 0.41 0.00 0.05 1.00
acuisitiondurationprofitret_expacq_exp_sqret_exp_sqfreqfreq_sqcrossbuysow# create box plots to show if we need to remove any variables
par(mfrow = c(2, 5))
boxplot(duration ~ acquisition, data, xlab = "acquisition", ylab = "duration")
boxplot(profit ~ acquisition, data, xlab = "acquisition", ylab = "profit")
boxplot(ret_exp ~ acquisition, data, xlab = "acquisition", ylab = "ret_exp")
boxplot(acq_exp_sq ~ acquisition, data, xlab = "acquisition", ylab = "acq_exp_sq")
boxplot(ret_exp_sq ~ acquisition, data, xlab = "acquisition", ylab = "ret_exp_sq")
boxplot(freq ~ acquisition, data, xlab = "acquisition", ylab = "freq")
boxplot(freq_sq ~ acquisition, data, xlab = "acquisition", ylab = "freq_sq")
boxplot(crossbuy ~ acquisition, data, xlab = "acquisition", ylab = "crossbuy")
boxplot(sow ~ acquisition, data, xlab = "acquisition", ylab = "sow")
data$acquisition = as.factor(data$acquisition)
#Check for Null Values
sum(is.na(data))
## [1] 0
#Check for Duplicates
sum(duplicated(data))
## [1] 0
acquisition targetacquisitionRetention data set to predict which customers will be acquired (acquisition) and for how long (duration) based on a feature set using a random forest.set.seed(123)
idx.train = sample(1:nrow(data), size = 0.8 * nrow(data))
train.df = data[idx.train,]
test.df = data[-idx.train,]
set.seed(123)
forest1 = rfsrc(acquisition ~ acq_exp + industry + revenue + employees, #Only include variables based on prior analysis.
data = train.df,
importance = TRUE,
ntree = 1000)
forest1
## Sample size: 400
## Frequency of class labels: 126, 274
## Number of trees: 1000
## Forest terminal node size: 1
## Average no. of terminal nodes: 62.454
## No. of variables tried at each split: 2
## Total no. of variables: 4
## Resampling used to grow trees: swor
## Resample size used to grow trees: 253
## Analysis: RF-C
## Family: class
## Splitting rule: gini *random*
## Number of random split points: 10
## (OOB) Normalized Brier score: 57.26487
## (OOB) AUC: 85.5347
## (OOB) Error rate: 0.205, 0.44444444, 0.09489051
##
## Confusion matrix:
##
## predicted
## observed 0 1 class.error
## 0 70 56 0.4444
## 1 26 248 0.0949
##
## Overall (OOB) error rate: 20.5%
acquisition#predicted class labels - used for classification problems
rf1_pred = predict(forest1, test.df)$class
#temporary df to bind prediction to original dataframe
temp1_df = cbind(data,rf1_pred)
#new dataframe to filter the new dataframe for a value of 1, signaling successful customer acquisition
data_new = temp1_df %>% filter(rf1_pred == 1)
summary(data_new)
## customer acquisition duration profit
## Min. : 1.0 0:115 Min. : 0.0 Min. :-1016.2
## 1st Qu.:124.5 1:240 1st Qu.: 0.0 1st Qu.: -273.7
## Median :252.0 Median : 957.0 Median : 3381.7
## Mean :250.1 Mean : 741.8 Mean : 2413.0
## 3rd Qu.:376.0 3rd Qu.:1149.0 3rd Qu.: 3938.5
## Max. :500.0 Max. :1673.0 Max. : 6134.3
## acq_exp ret_exp acq_exp_sq ret_exp_sq
## Min. : 1.21 Min. : 0.0 Min. : 1.5 Min. : 0
## 1st Qu.: 373.73 1st Qu.: 0.0 1st Qu.: 139674.5 1st Qu.: 0
## Median : 482.33 Median : 396.4 Median : 232642.2 Median : 157165
## Mean : 484.71 Mean : 335.9 Mean : 263725.9 Mean : 183315
## 3rd Qu.: 597.66 3rd Qu.: 522.8 3rd Qu.: 357203.5 3rd Qu.: 273279
## Max. :1016.18 Max. :1095.0 Max. :1032621.8 Max. :1198937
## freq freq_sq crossbuy sow
## Min. : 0.000 Min. : 0.00 Min. : 0.00 Min. : 0.00
## 1st Qu.: 0.000 1st Qu.: 0.00 1st Qu.: 0.00 1st Qu.: 0.00
## Median : 6.000 Median : 36.00 Median : 5.00 Median : 43.00
## Mean : 6.259 Mean : 70.23 Mean : 4.09 Mean : 38.42
## 3rd Qu.:11.000 3rd Qu.:121.00 3rd Qu.: 7.00 3rd Qu.: 65.00
## Max. :21.000 Max. :441.00 Max. :11.00 Max. :115.00
## industry revenue employees rf1_pred
## Min. :0.0000 Min. :14.49 Min. : 18.0 0: 0
## 1st Qu.:0.0000 1st Qu.:33.52 1st Qu.: 492.0 1:355
## Median :1.0000 Median :41.78 Median : 654.0
## Mean :0.5155 Mean :40.65 Mean : 668.5
## 3rd Qu.:1.0000 3rd Qu.:47.54 3rd Qu.: 837.5
## Max. :1.0000 Max. :65.10 Max. :1423.0
predicted_duration <- predict(forest1,data)$predicted
pred_data <- cbind(data,predicted_duration)
total_actual_duration <- sum(pred_data$duration)
total_predicted_duration <- sum(pred_data$predicted_duration)
cat("Total Actual Duration = ", total_actual_duration)
## Total Actual Duration = 371227
duration targetset.seed(123)
forest2 = rfsrc(duration~ acq_exp + industry + revenue + employees,
data = data_new,
importance = TRUE,
ntree = 1000)
forest2
## Sample size: 355
## Number of trees: 1000
## Forest terminal node size: 5
## Average no. of terminal nodes: 46.168
## No. of variables tried at each split: 2
## Total no. of variables: 4
## Resampling used to grow trees: swor
## Resample size used to grow trees: 224
## Analysis: RF-R
## Family: regr
## Splitting rule: mse *random*
## Number of random split points: 10
## (OOB) R squared: 0.3104681
## (OOB) Error rate: 203944.8
forest2$importance
## acq_exp industry revenue employees
## 25461.154 9928.581 7935.848 58257.926
data.frame(importance = forest2$importance) %>%
tibble::rownames_to_column(var = "variable") %>%
ggplot(aes(x = reorder(variable,importance), y = importance)) +
geom_bar(stat = "identity", fill = "steelblue", color = "black")+
coord_flip() +
labs(x = "Variables", y = "Variable importance")+
theme_nice
duration: employees, acq_exp, revenue, and industry. It is important to note that when analyzing variable importance, the variable with the highest value is the most important.mindepth = max.subtree(forest2,
sub.order = TRUE)
# first order depths
print(round(mindepth$order, 3)[,1])
## acq_exp industry revenue employees
## 1.398 2.251 1.093 0.671
# visualize minimal depth
data.frame(md = round(mindepth$order, 3)[,1]) %>%
tibble::rownames_to_column(var = "variable") %>%
ggplot(aes(x = reorder(variable,desc(md)), y = md)) +
geom_bar(stat = "identity", fill = "steelblue", color = "black", width = 0.2)+
coord_flip() +
labs(x = "Variables", y = "Minimal Depth")+
theme_nice
duration: employees, acq_exp, revenue, and industry. It is important to note that when analyzing minimal depth, the variable with the least value is the most important. The most important variable is going to split closest to the tree.# min depth
mindepth$sub.order
## acq_exp industry revenue employees
## acq_exp 0.1131823 0.4490810 0.17765787 0.16078760
## industry 0.2335417 0.1788152 0.24009811 0.23241239
## revenue 0.1495405 0.3939603 0.08634847 0.14618496
## employees 0.1344936 0.2746964 0.13511553 0.05449049
as.matrix(mindepth$sub.order) %>%
reshape2::melt() %>%
data.frame() %>%
ggplot(aes(x = Var1, y = Var2, fill = value)) +
scale_x_discrete(position = "top") +
geom_tile(color = "white") +
viridis::scale_fill_viridis("Relative min. depth") +
labs(x = "", y = "") +
theme_bw()
# cross-check with vimp
find.interaction(forest2,
method = "vimp",
importance = "permute")
## Pairing employees with acq_exp
## Pairing employees with industry
## Pairing employees with revenue
## Pairing acq_exp with industry
## Pairing acq_exp with revenue
## Pairing industry with revenue
##
## Method: vimp
## No. of variables: 4
## Variables sorted by VIMP?: TRUE
## No. of variables used for pairing: 4
## Total no. of paired interactions: 6
## Monte Carlo replications: 1
## Type of noising up used for VIMP: permute
##
## Var 1 Var 2 Paired Additive Difference
## employees:acq_exp 58257.926 25461.154 91683.48 83719.08 7964.4042
## employees:industry 58257.926 9928.581 68739.34 68186.51 552.8377
## employees:revenue 58257.926 7935.848 76350.82 66193.77 10157.0484
## acq_exp:industry 25461.154 9928.581 36705.08 35389.74 1315.3427
## acq_exp:revenue 25461.154 7935.848 41240.90 33397.00 7843.8941
## industry:revenue 9928.581 7935.848 21138.88 17864.43 3274.4543
set.seed(123)
idx.train_new = sample(1:nrow(data_new), size = 0.8 * nrow(data_new))
train.df_new = data_new[idx.train_new,]
test.df_new = data_new[-idx.train_new,]
forest.no_interaction.untuned = rfsrc(duration ~ acq_exp + industry + revenue + employees,
data = train.df_new,
importance = TRUE,
ntree = 1000)
# Establish a list of possible values for hyper-parameters
mtry.values <- seq(4,6,1)
nodesize.values <- seq(4,8,2)
ntree.values <- seq(4e3,6e3,1e3)
# Create a data frame containing all combinations
hyper_grid = expand.grid(mtry = mtry.values, nodesize = nodesize.values, ntree = ntree.values)
# Create an empty vector to store OOB error values
oob_err = c()
# Write a loop over the rows of hyper_grid to train the grid of models
for (i in 1:nrow(hyper_grid)) {
# Train a Random Forest model
model = rfsrc(duration ~ acq_exp + industry + revenue + employees,
data = train.df_new,
mtry = hyper_grid$mtry[i],
nodesize = hyper_grid$nodesize[i],
ntree = hyper_grid$ntree[i])
# Store OOB error for the model
oob_err[i] <- model$err.rate[length(model$err.rate)]
}
# Identify optimal set of hyperparmeters based on OOB error
opt_i <- which.min(oob_err)
print(hyper_grid[opt_i,])
## mtry nodesize ntree
## 27 6 8 6000
set.seed(123)
forest.hyper = rfsrc(duration ~ acq_exp + industry + revenue + employees,
data = train.df_new,
mtry = 6,
nodesize = 8,
ntree = 6000)
duration Prediction#Logistic Regression
regression.logistic = glm(duration ~ acq_exp + industry + revenue + employees, data = train.df_new)
#Decision Tree Model
dt.model = rpart(duration ~ acq_exp + industry + revenue + employees,
data = train.df_new)
error.df =
data.frame(pred1 = predict(forest.no_interaction.untuned,newdata = test.df_new)$predicted,
pred2 = predict(forest.hyper, newdata = test.df_new)$predicted,
pred3 = predict(regression.logistic, newdata = test.df_new),
pred4 = predict(dt.model, newdata = test.df_new),
actual = test.df_new$duration,
customer = test.df_new$customer) %>%
mutate_at(.funs = funs(abs.error = abs(actual - .),
abs.percent.error = abs(actual - .)/abs(actual)),
.vars = vars(pred1:pred4))
## Warning: `funs()` is deprecated as of dplyr 0.8.0.
## Please use a list of either functions or lambdas:
##
## # Simple named list:
## list(mean = mean, median = median)
##
## # Auto named with `tibble::lst()`:
## tibble::lst(mean, median)
##
## # Using lambdas
## list(~ mean(., trim = .2), ~ median(., na.rm = TRUE))
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
#mae
error.df %>%
summarise_at(.funs = funs(mae = mean(.)),
.vars = vars(pred1_abs.error:pred4_abs.error))
error.df2 =
error.df %>%
left_join(test.df_new, "customer") %>%
mutate(customer_portfolio = cut(x = rev <- revenue,
breaks = qu <- quantile(rev, probs = seq(0, 1, 0.25)),
labels = names(qu)[-1],
include.lowest = T))
portfolio.mae =
error.df2 %>%
group_by(customer_portfolio) %>%
summarise_at(.funs = funs(mae = mean(.)),
.vars = vars(pred1_abs.error:pred4_abs.error)) %>%
ungroup()
portfolio.errors =
portfolio.mae %>%
gather(key = error_type, value = error, -customer_portfolio) %>%
mutate(error_type2 = ifelse(grepl(pattern = "mae", error_type),"MAE","MAE"),
model_type = ifelse(grepl(pattern = "pred1", error_type),"Untuned Forest",
ifelse(grepl(pattern = "pred2", error_type),"Tuned Forest",
ifelse(grepl(pattern = "pred3", error_type),"Logistic Regression", "Decision Tree"))))
ggplot(portfolio.errors, aes(x = customer_portfolio,
y = error,
color = model_type,
group = model_type))+
geom_line(size = 1.02)+
geom_point(shape = 15) +
scale_color_brewer(palette = "Set1") +
labs(y = "Error", x = "Customer portfolios")+
theme_nice +
theme(legend.position = "top")+
guides(color = guide_legend(title = "Model Type", size = 4,nrow = 2,byrow = TRUE))
acquisition# Establish a list of possible values for hyper-parameters
mtry.values <- seq(4,6,1)
nodesize.values <- seq(4,8,2)
ntree.values <- seq(4e3,6e3,1e3)
# Create a data frame containing all combinations
hyper_grid = expand.grid(mtry = mtry.values, nodesize = nodesize.values, ntree = ntree.values)
# Create an empty vector to store OOB error values
oob_err = c()
# Write a loop over the rows of hyper_grid to train the grid of models
for (i in 1:nrow(hyper_grid)) {
# Train a Random Forest model
model = rfsrc(acquisition ~ acq_exp + industry + revenue + employees,
data = train.df,
mtry = hyper_grid$mtry[i],
nodesize = hyper_grid$nodesize[i],
ntree = hyper_grid$ntree[i])
# Store OOB error for the model
oob_err[i] <- model$err.rate[length(model$err.rate)]
}
# Identify optimal set of hyperparmeters based on OOB error
opt_i <- which.min(oob_err)
print(hyper_grid[opt_i,])
## mtry nodesize ntree
## 5 5 6 4000
set.seed(123)
forest_acquisition = rfsrc(acquisition ~ acq_exp + industry + revenue + employees,
data = train.df,
mtry = 4,
nodesize = 6,
ntree = 4000)
logistic.regression.acquisition = glm(acquisition ~ acq_exp + industry + revenue + employees, data = train.df, family = "binomial")
decision.tree.acquisition = rpart(acquisition ~ acq_exp + industry + revenue + employees, data = train.df)
pred1_acq = predict(forest_acquisition,newdata = test.df)$class
pred2_acq = predict(logistic.regression.acquisition, newdata = test.df)
pred2_acq = ifelse(pred2_acq > 0.50, 1, 0)
pred3_acq = predict(decision.tree.acquisition, newdata = test.df, type = "class")
confusionMatrix(as.factor(pred1_acq), test.df$acquisition, positive = '1')
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 22 6
## 1 14 58
##
## Accuracy : 0.8
## 95% CI : (0.7082, 0.8733)
## No Information Rate : 0.64
## P-Value [Acc > NIR] : 0.0003862
##
## Kappa : 0.5438
##
## Mcnemar's Test P-Value : 0.1175249
##
## Sensitivity : 0.9062
## Specificity : 0.6111
## Pos Pred Value : 0.8056
## Neg Pred Value : 0.7857
## Prevalence : 0.6400
## Detection Rate : 0.5800
## Detection Prevalence : 0.7200
## Balanced Accuracy : 0.7587
##
## 'Positive' Class : 1
##
confusionMatrix(as.factor(pred2_acq), test.df$acquisition, positive = '1')
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 24 10
## 1 12 54
##
## Accuracy : 0.78
## 95% CI : (0.6861, 0.8567)
## No Information Rate : 0.64
## P-Value [Acc > NIR] : 0.001834
##
## Kappa : 0.5167
##
## Mcnemar's Test P-Value : 0.831170
##
## Sensitivity : 0.8438
## Specificity : 0.6667
## Pos Pred Value : 0.8182
## Neg Pred Value : 0.7059
## Prevalence : 0.6400
## Detection Rate : 0.5400
## Detection Prevalence : 0.6600
## Balanced Accuracy : 0.7552
##
## 'Positive' Class : 1
##
confusionMatrix(as.factor(pred3_acq), test.df$acquisition, positive = '1')
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 24 14
## 1 12 50
##
## Accuracy : 0.74
## 95% CI : (0.6427, 0.8226)
## No Information Rate : 0.64
## P-Value [Acc > NIR] : 0.02196
##
## Kappa : 0.4425
##
## Mcnemar's Test P-Value : 0.84452
##
## Sensitivity : 0.7812
## Specificity : 0.6667
## Pos Pred Value : 0.8065
## Neg Pred Value : 0.6316
## Prevalence : 0.6400
## Detection Rate : 0.5000
## Detection Prevalence : 0.6200
## Balanced Accuracy : 0.7240
##
## 'Positive' Class : 1
##
acquisitionplot.variable(forest_acquisition, partial=TRUE)
durationplot.variable(forest.hyper, partial=TRUE)