Install packages
#install.packages(c("SMCRM","dplyr","tidyr","ggplot2","survival","rpart","rattle","purrr"))
install.packages("randomForestSRC", repos="http://cran.rstudio.com/", dependencies=TRUE)
## package 'randomForestSRC' successfully unpacked and MD5 sums checked
##
## The downloaded binary packages are in
## C:\Users\musta\AppData\Local\Temp\RtmpaKAklc\downloaded_packages
Load Packages
library(SMCRM) # CRM data
library(dplyr) # data wrangling
library(tidyr) # data wrangling
library(ggplot2) # plotting
library(survival) # survival
library(rpart) # DT
library(randomForestSRC) # RF
library(caret) #confusion matrix
library(safeBinaryRegression) #use this to check for perfect separation
# theme for nice plotting
theme_nice <- theme_classic()+
theme(
axis.line.y.left = element_line(colour = "black"),
axis.line.y.right = element_line(colour = "black"),
axis.line.x.bottom = element_line(colour = "black"),
axis.line.x.top = element_line(colour = "black"),
axis.text.y = element_text(colour = "black", size = 12),
axis.text.x = element_text(color = "black", size = 12),
axis.ticks = element_line(color = "black")) +
theme(
axis.ticks.length = unit(-0.25, "cm"),
axis.text.x = element_text(margin=unit(c(0.5,0.5,0.5,0.5), "cm")),
axis.text.y = element_text(margin=unit(c(0.5,0.5,0.5,0.5), "cm")))
Create the dataset
data("acquisitionRetention")
cust_ret<-acquisitionRetention
Review the dataset
str(cust_ret)
## 'data.frame': 500 obs. of 15 variables:
## $ customer : num 1 2 3 4 5 6 7 8 9 10 ...
## $ acquisition: num 1 1 1 0 1 1 1 1 0 0 ...
## $ duration : num 1635 1039 1288 0 1631 ...
## $ profit : num 6134 3524 4081 -638 5446 ...
## $ acq_exp : num 694 460 249 638 589 ...
## $ ret_exp : num 972 450 805 0 920 ...
## $ acq_exp_sq : num 480998 211628 62016 407644 346897 ...
## $ ret_exp_sq : num 943929 202077 648089 0 846106 ...
## $ freq : num 6 11 21 0 2 7 15 13 0 0 ...
## $ freq_sq : num 36 121 441 0 4 49 225 169 0 0 ...
## $ crossbuy : num 5 6 6 0 9 4 5 5 0 0 ...
## $ sow : num 95 22 90 0 80 48 51 23 0 0 ...
## $ industry : num 1 0 0 0 0 1 0 1 0 1 ...
## $ revenue : num 47.2 45.1 29.1 40.6 48.7 ...
## $ employees : num 898 686 1423 181 631 ...
##perfect separation concerns with the dataset
Review for missing values
cbind(lapply(lapply(cust_ret, is.na), sum)) #: identifies total missing value by column
## [,1]
## customer 0
## acquisition 0
## duration 0
## profit 0
## acq_exp 0
## ret_exp 0
## acq_exp_sq 0
## ret_exp_sq 0
## freq 0
## freq_sq 0
## crossbuy 0
## sow 0
## industry 0
## revenue 0
## employees 0
cust_ret<-na.omit(cust_ret) #: remove missing values, no missing values, nothing removed
Change to factor variables
cust_ret$acquisition <- as.factor(cust_ret$acquisition)
str(cust_ret)
## 'data.frame': 500 obs. of 15 variables:
## $ customer : num 1 2 3 4 5 6 7 8 9 10 ...
## $ acquisition: Factor w/ 2 levels "0","1": 2 2 2 1 2 2 2 2 1 1 ...
## $ duration : num 1635 1039 1288 0 1631 ...
## $ profit : num 6134 3524 4081 -638 5446 ...
## $ acq_exp : num 694 460 249 638 589 ...
## $ ret_exp : num 972 450 805 0 920 ...
## $ acq_exp_sq : num 480998 211628 62016 407644 346897 ...
## $ ret_exp_sq : num 943929 202077 648089 0 846106 ...
## $ freq : num 6 11 21 0 2 7 15 13 0 0 ...
## $ freq_sq : num 36 121 441 0 4 49 225 169 0 0 ...
## $ crossbuy : num 5 6 6 0 9 4 5 5 0 0 ...
## $ sow : num 95 22 90 0 80 48 51 23 0 0 ...
## $ industry : num 1 0 0 0 0 1 0 1 0 1 ...
## $ revenue : num 47.2 45.1 29.1 40.6 48.7 ...
## $ employees : num 898 686 1423 181 631 ...
Classification: acquisition Train test split
set.seed(22)
index_train <- sample(1:nrow(cust_ret), size = 0.7 * nrow(cust_ret))
train_df <- cust_ret[index_train,]
test_df <- cust_ret[-index_train,]
Classification: acquisition Build forest on train
Forest_acq2 with all predictors, including the perfect separation predictors produced 0 missclassfications
set.seed(123)
forest_acq2 <- rfsrc(acquisition ~
duration +
profit +
acq_exp +
ret_exp +
freq +
crossbuy +
sow +
industry +
revenue +
employees,
data = train_df,
importance = TRUE,
ntree = 1000)
forest_acq2
## Sample size: 350
## Frequency of class labels: 101, 249
## Number of trees: 1000
## Forest terminal node size: 1
## Average no. of terminal nodes: 2.438
## No. of variables tried at each split: 4
## Total no. of variables: 10
## Resampling used to grow trees: swor
## Resample size used to grow trees: 221
## Analysis: RF-C
## Family: class
## Splitting rule: gini *random*
## Number of random split points: 10
## Imbalanced ratio: 2.4653
## (OOB) Brier score: 8.1e-07
## (OOB) Normalized Brier score: 3.23e-06
## (OOB) AUC: 1
## (OOB) PR-AUC: 1
## (OOB) G-mean: 1
## (OOB) Requested performance error: 0, 0, 0
##
## Confusion matrix:
##
## predicted
## observed 0 1 class.error
## 0 101 0 0
## 1 0 249 0
##
## (OOB) Misclassification rate: 0
# construct linear and non-linear parametric models on training set
#regression_acq <- lm(acquisition ~ duration + profit + acq_exp + ret_exp + freq + crossbuy + sow + industry + revenue + employees,
# data = train_df)
#regression_acq_sq <- lm(acquisition ~ duration + profit + acq_exp + ret_exp + freq + crossbuy + sow + industry + revenue + employees
# + acq_exp_sq + ret_exp_sq + freq_sq,
# data = train_df)
Classification: acquisition use SafeBinaryRegression package to check for perfect and quasi separation get an R perfect separation error
#glm_perfsep <- glm(acquisition ~ duration + profit + acq_exp + ret_exp + freq + crossbuy + sow + industry + revenue + employees,
# data = train_df, family = binomial)
#glm_perfsep <- glm(acquisition ~ duration + profit + acq_exp + ret_exp + freq + crossbuy + sow + industry + revenue + employees,
# data = train_df, family = binomial, separation = "test")
#stats::glm(acquisition ~ duration + profit + acq_exp + ret_exp + freq + crossbuy + sow + industry + revenue + employees,
# data = train_df, family = binomial)
str(train_df)
## 'data.frame': 350 obs. of 15 variables:
## $ customer : num 486 393 88 330 478 300 380 315 209 357 ...
## $ acquisition: Factor w/ 2 levels "0","1": 1 2 2 2 1 2 1 2 2 1 ...
## $ duration : num 0 938 1030 1147 0 ...
## $ profit : num -473 3645 3786 3512 -754 ...
## $ acq_exp : num 473 649 488 429 754 ...
## $ ret_exp : num 0 413 425 466 0 ...
## $ acq_exp_sq : num 224193 421811 238417 184118 568531 ...
## $ ret_exp_sq : num 0 170883 180651 217492 0 ...
## $ freq : num 0 15 9 8 0 13 0 6 7 0 ...
## $ freq_sq : num 0 225 81 64 0 169 0 36 49 0 ...
## $ crossbuy : num 0 6 9 7 0 8 0 6 4 0 ...
## $ sow : num 0 32 80 60 0 68 0 72 25 0 ...
## $ industry : num 0 0 1 0 0 1 0 1 1 0 ...
## $ revenue : num 24.1 52.2 48.3 55.8 35.5 ...
## $ employees : num 161 746 612 340 359 822 679 732 503 666 ...
Classification: acquisition Forest_acq3 with predictors that do not cause perfect separation Take out the following predictors and create a forest with 4 predictors duration - if 0 than acquisition = 0 freq - if 0 than acquisition = 0 crossbuy - if 0 than acquisition = 0 profit - if neg than acquisition = 0 sow - if 0 for the firm than acquisition = 0 ret_exp if 0 than acquisition = 0
set.seed(123)
forest_acq3 <- rfsrc(acquisition ~ acq_exp + industry + revenue + employees, data = train_df,
importance = TRUE, ntree = 1000)
forest_acq3
## Sample size: 350
## Frequency of class labels: 101, 249
## Number of trees: 1000
## Forest terminal node size: 1
## Average no. of terminal nodes: 54.415
## No. of variables tried at each split: 2
## Total no. of variables: 4
## Resampling used to grow trees: swor
## Resample size used to grow trees: 221
## Analysis: RF-C
## Family: class
## Splitting rule: gini *random*
## Number of random split points: 10
## Imbalanced ratio: 2.4653
## (OOB) Brier score: 0.14495895
## (OOB) Normalized Brier score: 0.57983582
## (OOB) AUC: 0.8347648
## (OOB) PR-AUC: 0.66698919
## (OOB) G-mean: 0.69352246
## (OOB) Requested performance error: 0.20857143, 0.47524752, 0.10040161
##
## Confusion matrix:
##
## predicted
## observed 0 1 class.error
## 0 54 47 0.4653
## 1 25 224 0.1004
##
## (OOB) Misclassification rate: 0.2057143
Error rate:
1 - 0.2057143
## [1] 0.7942857
Classification: acquisition Tuning a forest hyper-parameters for predictive accuracy
# Establish a list of possible values for hyper-parameters
mtry.values <- seq(2,4,1) ##number of IVs picked at random from total IVs, from 4 to 6 with increments of 1
nodesize.values <- seq(2,6,2) ##from 4 to 8 increments of 2
ntree.values <- seq(4e3,6e3,1e3) ##4000 to 6000 to increments of 1000
# Create a data frame containing all combinations in hyper_grid object
hyper_grid <- expand.grid(mtry = mtry.values, nodesize = nodesize.values, ntree = ntree.values)
# Create an empty vector to store OOB error values
oob_err <- c()
# Write a loop over the rows of hyper_grid to train the grid of models
for (i in 1:nrow(hyper_grid)) {
# Train a Random Forest model
model <- rfsrc(acquisition ~ acq_exp + industry + revenue + employees,
data = train_df,
mtry = hyper_grid$mtry[i],
nodesize = hyper_grid$nodesize[i],
ntree = hyper_grid$ntree[i])
# Store OOB error for the model
oob_err[i] <- model$err.rate[length(model$err.rate)]
}
# Identify optimal set of hyperparmeters based on OOB error
opt_i <- which.min(oob_err)
print(hyper_grid[opt_i,])
## mtry nodesize ntree
## 1 2 2 4000
##minimum error is 9 variables, 10 nodes and 4000 trees
Classification: acquisition Rebuild training forest with optimal hyper-params: mtry = 2, nodesize = 2, ntree = 4000
set.seed(123)
forest_acqhyper <- rfsrc(acquisition ~ acq_exp + industry + revenue + employees,
data = train_df,
mtry = 2,
nodesize = 2,
ntree = 4000)
forest_acqhyper
## Sample size: 350
## Frequency of class labels: 101, 249
## Number of trees: 4000
## Forest terminal node size: 2
## Average no. of terminal nodes: 45.375
## No. of variables tried at each split: 2
## Total no. of variables: 4
## Resampling used to grow trees: swor
## Resample size used to grow trees: 221
## Analysis: RF-C
## Family: class
## Splitting rule: gini *random*
## Number of random split points: 10
## Imbalanced ratio: 2.4653
## (OOB) Brier score: 0.14311093
## (OOB) Normalized Brier score: 0.57244374
## (OOB) AUC: 0.83903933
## (OOB) PR-AUC: 0.68099996
## (OOB) G-mean: 0.70939459
## (OOB) Requested performance error: 0.19428571, 0.44554455, 0.09236948
##
## Confusion matrix:
##
## predicted
## observed 0 1 class.error
## 0 56 45 0.4455
## 1 23 226 0.0924
##
## (OOB) Misclassification rate: 0.1942857
Error rate:
1-0.1942857
## [1] 0.8057143
Classification: acquisition Predict on the test set
Forest_acq3 (no optimal hyperparameters)
forest_acq3_predict <- predict.rfsrc(forest_acq3, newdata = test_df)
forest_acq3_predict
## Sample size of test (predict) data: 150
## Number of grow trees: 1000
## Average no. of grow terminal nodes: 54.415
## Total no. of grow variables: 4
## Resampling used to grow trees: swor
## Resample size used to grow trees: 95
## Analysis: RF-C
## Family: class
## Imbalanced ratio: 1.459
## Brier score: 0.15900063
## Normalized Brier score: 0.63600251
## AUC: 0.87953583
## PR-AUC: 0.81952264
## G-mean: 0.66902634
## Requested performance error: 0.26, 0.50819672, 0.08988764
##
## Confusion matrix:
##
## predicted
## observed 0 1 class.error
## 0 30 31 0.5082
## 1 8 81 0.0899
##
## Misclassification error: 0.26
Classification: acquisition Forest_acqhyper (optimal hyperparameters: mtry = 2, nodesize = 2, ntree = 4000) Forest_acqhyper is slightly better than forest_acq3 Forest_acq3 error: 79.41 Forest_acqhyper error: 80.57
forest_acqhyper_predict <- predict.rfsrc(forest_acqhyper, newdata = test_df)
forest_acqhyper_predict
## Sample size of test (predict) data: 150
## Number of grow trees: 4000
## Average no. of grow terminal nodes: 45.375
## Total no. of grow variables: 4
## Resampling used to grow trees: swor
## Resample size used to grow trees: 95
## Analysis: RF-C
## Family: class
## Imbalanced ratio: 1.459
## Brier score: 0.15419742
## Normalized Brier score: 0.61678969
## AUC: 0.88285135
## PR-AUC: 0.82715656
## G-mean: 0.67314346
## Requested performance error: 0.25333333, 0.50819672, 0.07865169
##
## Confusion matrix:
##
## predicted
## observed 0 1 class.error
## 0 30 31 0.5082
## 1 7 82 0.0787
##
## Misclassification error: 0.2533333
Error rate:
1-0.253333
## [1] 0.746667
Classification: acquisition Forest_acqhyper (optimal hyperparameters: mtry = 2, nodesize = 2, ntree = 4000) Forest_acqhyper is slightly better than forest_acq3 Forest_acq3 error (on training set): 79.41 Forest_acqhyper error (on training set): 80.57 Forest_acqhyper error (on test set): 74.6667 dt_fit (decision tree not pruned on test set): 75.33 dt_pfit (decision tree pruned on test set): 80.67
New_df <- cbind(test_df, forest_acqhyper_predict$class)
New_df <- select(New_df,customer, acquisition, `forest_acqhyper_predict$class`, acq_exp, industry, revenue,
employees, duration, profit, ret_exp, ret_exp_sq, acq_exp_sq, freq, freq_sq, crossbuy,
sow)
str(New_df)
## 'data.frame': 150 obs. of 16 variables:
## $ customer : num 3 5 7 9 10 11 13 15 18 22 ...
## $ acquisition : Factor w/ 2 levels "0","1": 2 2 2 1 1 2 2 1 2 1 ...
## $ forest_acqhyper_predict$class: Factor w/ 2 levels "0","1": 2 2 2 1 2 2 2 2 2 1 ...
## $ acq_exp : num 249 589 373 285 293 ...
## $ industry : num 0 0 0 0 1 0 0 1 1 0 ...
## $ revenue : num 29.1 48.7 51 29.3 42.1 ...
## $ employees : num 1423 631 911 530 511 ...
## $ duration : num 1288 1631 799 0 0 ...
## $ profit : num 4081 5446 2705 -285 -293 ...
## $ ret_exp : num 805 920 341 0 0 ...
## $ ret_exp_sq : num 648089 846106 116568 0 0 ...
## $ acq_exp_sq : num 62016 346897 139241 81202 85580 ...
## $ freq : num 21 2 15 0 0 10 4 0 11 0 ...
## $ freq_sq : num 441 4 225 0 0 100 16 0 121 0 ...
## $ crossbuy : num 6 9 5 0 0 4 6 0 7 0 ...
## $ sow : num 90 80 51 0 0 107 40 0 63 0 ...
Regression: Duration Identify the observations that the optimal forest classification model got wrong If acquisition == forest_acqhyper_predict$class, 1, 0
keep = ifelse(New_df$acquisition==New_df$`forest_acqhyper_predict$class`,1,0)
table(keep) ##38 missclassifications, 112 correct classifications, this checks with the confustion matrix
## keep
## 0 1
## 38 112
##for forest_acqhyper_predict
##
Add the ‘keep’ column to the new data set
New_df <- cbind(New_df,keep)
str(New_df)
## 'data.frame': 150 obs. of 17 variables:
## $ customer : num 3 5 7 9 10 11 13 15 18 22 ...
## $ acquisition : Factor w/ 2 levels "0","1": 2 2 2 1 1 2 2 1 2 1 ...
## $ forest_acqhyper_predict$class: Factor w/ 2 levels "0","1": 2 2 2 1 2 2 2 2 2 1 ...
## $ acq_exp : num 249 589 373 285 293 ...
## $ industry : num 0 0 0 0 1 0 0 1 1 0 ...
## $ revenue : num 29.1 48.7 51 29.3 42.1 ...
## $ employees : num 1423 631 911 530 511 ...
## $ duration : num 1288 1631 799 0 0 ...
## $ profit : num 4081 5446 2705 -285 -293 ...
## $ ret_exp : num 805 920 341 0 0 ...
## $ ret_exp_sq : num 648089 846106 116568 0 0 ...
## $ acq_exp_sq : num 62016 346897 139241 81202 85580 ...
## $ freq : num 21 2 15 0 0 10 4 0 11 0 ...
## $ freq_sq : num 441 4 225 0 0 100 16 0 121 0 ...
## $ crossbuy : num 6 9 5 0 0 4 6 0 7 0 ...
## $ sow : num 90 80 51 0 0 107 40 0 63 0 ...
## $ keep : num 1 1 1 1 0 1 1 0 1 1 ...
Reorder the columns in the New_df to make it easier to read
New_df <- select(New_df,customer, keep, acquisition, 'forest_acqhyper_predict$class', acq_exp, industry, revenue,
employees, duration, profit, ret_exp, ret_exp_sq, acq_exp_sq, freq, freq_sq, crossbuy,
sow)
str(New_df)
## 'data.frame': 150 obs. of 17 variables:
## $ customer : num 3 5 7 9 10 11 13 15 18 22 ...
## $ keep : num 1 1 1 1 0 1 1 0 1 1 ...
## $ acquisition : Factor w/ 2 levels "0","1": 2 2 2 1 1 2 2 1 2 1 ...
## $ forest_acqhyper_predict$class: Factor w/ 2 levels "0","1": 2 2 2 1 2 2 2 2 2 1 ...
## $ acq_exp : num 249 589 373 285 293 ...
## $ industry : num 0 0 0 0 1 0 0 1 1 0 ...
## $ revenue : num 29.1 48.7 51 29.3 42.1 ...
## $ employees : num 1423 631 911 530 511 ...
## $ duration : num 1288 1631 799 0 0 ...
## $ profit : num 4081 5446 2705 -285 -293 ...
## $ ret_exp : num 805 920 341 0 0 ...
## $ ret_exp_sq : num 648089 846106 116568 0 0 ...
## $ acq_exp_sq : num 62016 346897 139241 81202 85580 ...
## $ freq : num 21 2 15 0 0 10 4 0 11 0 ...
## $ freq_sq : num 441 4 225 0 0 100 16 0 121 0 ...
## $ crossbuy : num 6 9 5 0 0 4 6 0 7 0 ...
## $ sow : num 90 80 51 0 0 107 40 0 63 0 ...
Regression: Duration Delete the rows where ‘keep’ = 0. These are the rows where the optimal forest classification model predicted the wrong class. Now have 112 rows in the New_df data set. This checks with the 112 correct classifications made by the the optimal forest classification model (forest_acqhyper)
New_df <- subset(New_df, keep!=0)
str(New_df)
## 'data.frame': 112 obs. of 17 variables:
## $ customer : num 3 5 7 9 11 13 18 22 23 33 ...
## $ keep : num 1 1 1 1 1 1 1 1 1 1 ...
## $ acquisition : Factor w/ 2 levels "0","1": 2 2 2 1 2 2 2 1 2 2 ...
## $ forest_acqhyper_predict$class: Factor w/ 2 levels "0","1": 2 2 2 1 2 2 2 1 2 2 ...
## $ acq_exp : num 249 589 373 285 568 ...
## $ industry : num 0 0 0 0 0 0 1 0 1 1 ...
## $ revenue : num 29.1 48.7 51 29.3 43.3 ...
## $ employees : num 1423 631 911 530 926 ...
## $ duration : num 1288 1631 799 0 1009 ...
## $ profit : num 4081 5446 2705 -285 3837 ...
## $ ret_exp : num 805 920 341 0 422 ...
## $ ret_exp_sq : num 648089 846106 116568 0 178312 ...
## $ acq_exp_sq : num 62016 346897 139241 81202 322806 ...
## $ freq : num 21 2 15 0 10 4 11 0 11 14 ...
## $ freq_sq : num 441 4 225 0 100 16 121 0 121 196 ...
## $ crossbuy : num 6 9 5 0 4 6 7 0 4 5 ...
## $ sow : num 90 80 51 0 107 40 63 0 67 66 ...
Regression: Duration Add the New_df dataset to the original dataset cust_ret Why? If we just use the New_df we have only 112 observations to train/test on If we add New_df dataset to the original dataset, we have more observations to work with, However, there will be duplicates, we can eliminate duplicates My interpretation of the problem is to predict classification (acquisition) and then use accurate predictions to train a regression model for duration. Cust_ret: 500 obs 15 vars New_df: 112 obs 17 vars cust_ret2: 612 obs 17 vars (but there are duplicates)
cust_ret2 <- bind_rows(cust_ret, New_df)
str(cust_ret2)
## 'data.frame': 612 obs. of 17 variables:
## $ customer : num 1 2 3 4 5 6 7 8 9 10 ...
## $ acquisition : Factor w/ 2 levels "0","1": 2 2 2 1 2 2 2 2 1 1 ...
## $ duration : num 1635 1039 1288 0 1631 ...
## $ profit : num 6134 3524 4081 -638 5446 ...
## $ acq_exp : num 694 460 249 638 589 ...
## $ ret_exp : num 972 450 805 0 920 ...
## $ acq_exp_sq : num 480998 211628 62016 407644 346897 ...
## $ ret_exp_sq : num 943929 202077 648089 0 846106 ...
## $ freq : num 6 11 21 0 2 7 15 13 0 0 ...
## $ freq_sq : num 36 121 441 0 4 49 225 169 0 0 ...
## $ crossbuy : num 5 6 6 0 9 4 5 5 0 0 ...
## $ sow : num 95 22 90 0 80 48 51 23 0 0 ...
## $ industry : num 1 0 0 0 0 1 0 1 0 1 ...
## $ revenue : num 47.2 45.1 29.1 40.6 48.7 ...
## $ employees : num 898 686 1423 181 631 ...
## $ keep : num NA NA NA NA NA NA NA NA NA NA ...
## $ forest_acqhyper_predict$class: Factor w/ 2 levels "0","1": NA NA NA NA NA NA NA NA NA NA ...
Eliminate duplicates in the new dataset cust_ret2
cust_ret2 <- distinct(cust_ret2, customer, .keep_all = TRUE)
str(cust_ret2)
## 'data.frame': 500 obs. of 17 variables:
## $ customer : num 1 2 3 4 5 6 7 8 9 10 ...
## $ acquisition : Factor w/ 2 levels "0","1": 2 2 2 1 2 2 2 2 1 1 ...
## $ duration : num 1635 1039 1288 0 1631 ...
## $ profit : num 6134 3524 4081 -638 5446 ...
## $ acq_exp : num 694 460 249 638 589 ...
## $ ret_exp : num 972 450 805 0 920 ...
## $ acq_exp_sq : num 480998 211628 62016 407644 346897 ...
## $ ret_exp_sq : num 943929 202077 648089 0 846106 ...
## $ freq : num 6 11 21 0 2 7 15 13 0 0 ...
## $ freq_sq : num 36 121 441 0 4 49 225 169 0 0 ...
## $ crossbuy : num 5 6 6 0 9 4 5 5 0 0 ...
## $ sow : num 95 22 90 0 80 48 51 23 0 0 ...
## $ industry : num 1 0 0 0 0 1 0 1 0 1 ...
## $ revenue : num 47.2 45.1 29.1 40.6 48.7 ...
## $ employees : num 898 686 1423 181 631 ...
## $ keep : num NA NA NA NA NA NA NA NA NA NA ...
## $ forest_acqhyper_predict$class: Factor w/ 2 levels "0","1": NA NA NA NA NA NA NA NA NA NA ...
Remove the ‘keep’ and ‘forest_acqhyper_predict$class’ columns, no longer needed
cust_ret2 <- select(cust_ret2,-keep, -'forest_acqhyper_predict$class')
str(cust_ret2)
## 'data.frame': 500 obs. of 15 variables:
## $ customer : num 1 2 3 4 5 6 7 8 9 10 ...
## $ acquisition: Factor w/ 2 levels "0","1": 2 2 2 1 2 2 2 2 1 1 ...
## $ duration : num 1635 1039 1288 0 1631 ...
## $ profit : num 6134 3524 4081 -638 5446 ...
## $ acq_exp : num 694 460 249 638 589 ...
## $ ret_exp : num 972 450 805 0 920 ...
## $ acq_exp_sq : num 480998 211628 62016 407644 346897 ...
## $ ret_exp_sq : num 943929 202077 648089 0 846106 ...
## $ freq : num 6 11 21 0 2 7 15 13 0 0 ...
## $ freq_sq : num 36 121 441 0 4 49 225 169 0 0 ...
## $ crossbuy : num 5 6 6 0 9 4 5 5 0 0 ...
## $ sow : num 95 22 90 0 80 48 51 23 0 0 ...
## $ industry : num 1 0 0 0 0 1 0 1 0 1 ...
## $ revenue : num 47.2 45.1 29.1 40.6 48.7 ...
## $ employees : num 898 686 1423 181 631 ...
Regression: Duration We are now ready for the second part of the random forest problem. RF regression for duration using the cust_ret2 dataset. This dataset contains the customers (112) that the classification forest model correctly classified for the acquisition target/response variable. The 112 correct classifications for acquisition are from the test_df, the remaining 388 observations are from the train_df dataset.
Create a new train and test dataset to model duration
set.seed(22)
index_train2 <- sample(1:nrow(cust_ret2), size = 0.7 * nrow(cust_ret2))
train_df2 <- cust_ret2[index_train2,]
test_df2 <- cust_ret2[-index_train2,]
Regression: Duration Create random forest for duration target/outcome variable exclude: customer (just an id) acquisition (perfect separation issue) profit (perfect separation issue, if negative it means we never had them as a client, duration = 0) Model still “too good” R Squared: .9925 Ran on the full dataset
set.seed(22)
forest_dur1 <- rfsrc(duration ~ acq_exp + ret_exp + freq + crossbuy + sow + industry
+ revenue + employees,
data = cust_ret2,
importance = TRUE,
ntree = 1000)
forest_dur1
## Sample size: 500
## Number of trees: 1000
## Forest terminal node size: 5
## Average no. of terminal nodes: 43.334
## No. of variables tried at each split: 3
## Total no. of variables: 8
## Resampling used to grow trees: swor
## Resample size used to grow trees: 316
## Analysis: RF-R
## Family: regr
## Splitting rule: mse *random*
## Number of random split points: 10
## (OOB) R squared: 0.99250203
## (OOB) Requested performance error: 2222.77170247
Regression: Duration Forest inference Variable importance (five different tools on relationship between DV and IVs) (Code adopted from class lecture)
forest_dur1$importance # values vary a lot
## acq_exp ret_exp freq crossbuy sow industry
## 2537.4467 238395.5197 130500.7750 247432.5515 64404.7632 24.2627
## revenue employees
## 3478.6648 3836.9254
Height of the bar indicates importance of variable
data.frame(importance = forest_dur1$importance) %>%
tibble::rownames_to_column(var = "variable") %>%
ggplot(aes(x = reorder(variable,importance), y = importance)) +
geom_bar(stat = "identity", fill = "orange", color = "black")+
coord_flip() +
labs(x = "Variables", y = "Variable importance")+
theme_nice
Regression: Duration Difficult to see industry, acq_exp and revenue, do a log transform of importance Importance rank (best to worst): 1. crossbuy, 2. ret_exp, 3. freq, 4. sow, 5. employees, 6. revenue, 7. acq_exp, 8. industry
(Code adopted from class lecture)
forest_dur1$importance %>% log() # log transform
## acq_exp ret_exp freq crossbuy sow industry revenue employees
## 7.838914 12.381686 11.779134 12.418893 11.072943 3.188940 8.154404 8.252427
data.frame(importance = forest_dur1$importance + 100) %>% # add a large +ve constant
log() %>%
tibble::rownames_to_column(var = "variable") %>%
ggplot(aes(x = reorder(variable,importance), y = importance)) +
geom_bar(stat = "identity", fill = "orange", color = "black", width = 0.5)+
coord_flip() +
labs(x = "Variables", y = "Log-transformed variable importance") +
theme_nice
Regression: Duration Minimal depth Use max.subtree The variable that has the least value is the most important Minimal depth rank (best to worst): 1. ret_exp, 2. freq, 3. crossbuy, 4. sow, 5. employees, 6. revenue, 7. acq_exp, 8. industry Note: the order 1 - 3 changed compared to the importance metric, 4 - 8 stayed the same
Based on the differences displayed in the pairs table, it appears there is an interaction among pairs of variables (Code adopted from class lecture)
mindepth <- max.subtree(forest_dur1,
sub.order = TRUE)
# first order depths
print(round(mindepth$order, 3)[,1])
## acq_exp ret_exp freq crossbuy sow industry revenue employees
## 3.963 1.546 1.862 2.135 3.126 8.112 3.813 3.594
# visualize MD
data.frame(md = round(mindepth$order, 3)[,1]) %>%
tibble::rownames_to_column(var = "variable") %>%
ggplot(aes(x = reorder(variable,desc(md)), y = md)) +
geom_bar(stat = "identity", fill = "orange", color = "black", width = 0.2)+
coord_flip() +
labs(x = "Variables", y = "Minimal Depth")+
theme_nice
# interactions
mindepth$sub.order
## acq_exp ret_exp freq crossbuy sow industry revenue
## acq_exp 0.3477729 0.3903966 0.5571006 0.6977877 0.6151290 0.8953952 0.6171253
## ret_exp 0.3314307 0.1336721 0.2281840 0.3560787 0.3244922 0.7658361 0.3303055
## freq 0.3899769 0.2003319 0.1634163 0.4215394 0.3845516 0.7670012 0.3905414
## crossbuy 0.5140994 0.3098250 0.4153235 0.1890967 0.5074607 0.8153121 0.5035354
## sow 0.5663338 0.3508932 0.4984378 0.6266180 0.2760075 0.8683308 0.5772579
## industry 0.8854014 0.8207331 0.8810272 0.9103646 0.8786026 0.7040051 0.8877031
## revenue 0.6149121 0.3730532 0.5379177 0.6608750 0.6028117 0.8854155 0.3374083
## employees 0.5677346 0.3318404 0.4969750 0.6345589 0.5484385 0.8752124 0.5655392
## employees
## acq_exp 0.6054124
## ret_exp 0.3235390
## freq 0.3820103
## crossbuy 0.4933546
## sow 0.5777772
## industry 0.8872569
## revenue 0.5965840
## employees 0.3148966
as.matrix(mindepth$sub.order) %>%
reshape2::melt() %>%
data.frame() %>%
ggplot(aes(x = Var1, y = Var2, fill = value)) +
scale_x_discrete(position = "top") +
geom_tile(color = "white") +
viridis::scale_fill_viridis("Relative min. depth") +
labs(x = "", y = "") +
theme_bw()
# cross-check with vimp
find.interaction(forest_dur1,
method = "vimp",
importance = "permute")
## Pairing crossbuy with ret_exp
## Pairing crossbuy with freq
## Pairing crossbuy with sow
## Pairing crossbuy with employees
## Pairing crossbuy with revenue
## Pairing crossbuy with acq_exp
## Pairing crossbuy with industry
## Pairing ret_exp with freq
## Pairing ret_exp with sow
## Pairing ret_exp with employees
## Pairing ret_exp with revenue
## Pairing ret_exp with acq_exp
## Pairing ret_exp with industry
## Pairing freq with sow
## Pairing freq with employees
## Pairing freq with revenue
## Pairing freq with acq_exp
## Pairing freq with industry
## Pairing sow with employees
## Pairing sow with revenue
## Pairing sow with acq_exp
## Pairing sow with industry
## Pairing employees with revenue
## Pairing employees with acq_exp
## Pairing employees with industry
## Pairing revenue with acq_exp
## Pairing revenue with industry
## Pairing acq_exp with industry
##
## Method: vimp
## No. of variables: 8
## Variables sorted by VIMP?: TRUE
## No. of variables used for pairing: 8
## Total no. of paired interactions: 28
## Monte Carlo replications: 1
## Type of noising up used for VIMP: permute
##
## Var 1 Var 2 Paired Additive Difference
## crossbuy:ret_exp 82477.5224 81620.9510 210448.0661 164098.4734 46349.5928
## crossbuy:freq 82477.5224 46490.6628 154157.5864 128968.1853 25189.4012
## crossbuy:sow 82477.5224 28872.4080 130247.7395 111349.9304 18897.8091
## crossbuy:employees 82477.5224 214.5745 81150.4949 82692.0969 -1541.6020
## crossbuy:revenue 82477.5224 54.7363 82962.5599 82532.2587 430.3012
## crossbuy:acq_exp 82477.5224 86.7756 82942.6887 82564.2980 378.3907
## crossbuy:industry 82477.5224 37.2134 82456.1726 82514.7358 -58.5632
## ret_exp:freq 82019.3400 45891.5615 152079.5746 127910.9014 24168.6732
## ret_exp:sow 82019.3400 27966.7328 128433.1038 109986.0728 18447.0310
## ret_exp:employees 82019.3400 220.2717 81712.6764 82239.6117 -526.9353
## ret_exp:revenue 82019.3400 73.6300 81589.1584 82092.9699 -503.8115
## ret_exp:acq_exp 82019.3400 73.3167 81898.6757 82092.6566 -193.9809
## ret_exp:industry 82019.3400 31.8219 81898.0582 82051.1618 -153.1037
## freq:sow 45144.8145 28314.1148 82487.9165 73458.9293 9028.9871
## freq:employees 45144.8145 213.2832 45396.4305 45358.0977 38.3328
## freq:revenue 45144.8145 94.8675 45485.2214 45239.6820 245.5393
## freq:acq_exp 45144.8145 69.0119 45382.9202 45213.8264 169.0937
## freq:industry 45144.8145 29.2923 45375.1318 45174.1069 201.0249
## sow:employees 27369.1608 210.8855 27249.5684 27580.0463 -330.4779
## sow:revenue 27369.1608 86.0883 27219.9296 27455.2491 -235.3195
## sow:acq_exp 27369.1608 58.5137 27693.0374 27427.6745 265.3630
## sow:industry 27369.1608 29.5542 27512.4573 27398.7150 113.7423
## employees:revenue 206.7848 23.6434 228.7077 230.4282 -1.7205
## employees:acq_exp 206.7848 113.6859 269.3018 320.4708 -51.1690
## employees:industry 206.7848 22.5259 227.8345 229.3108 -1.4763
## revenue:acq_exp 80.1638 63.0323 178.1673 143.1960 34.9713
## revenue:industry 80.1638 18.6218 27.4150 98.7856 -71.3706
## acq_exp:industry 31.1908 -4.6724 80.4609 26.5184 53.9425
Regression: Duration Partial dependence create a lm regression model 4 significant variables
regression_dur <- lm(duration ~ acq_exp + ret_exp + freq + crossbuy + sow + industry
+ revenue + employees,
data = cust_ret2)
summary(regression_dur)
##
## Call:
## lm(formula = duration ~ acq_exp + ret_exp + freq + crossbuy +
## sow + industry + revenue + employees, data = cust_ret2)
##
## Residuals:
## Min 1Q Median 3Q Max
## -402.26 -35.42 -14.47 50.38 237.87
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -6.95216 21.50045 -0.323 0.7466
## acq_exp -0.01829 0.02171 -0.843 0.3998
## ret_exp 1.60185 0.02421 66.170 <2e-16 ***
## freq -1.83486 0.98960 -1.854 0.0643 .
## crossbuy 20.93405 2.00224 10.455 <2e-16 ***
## sow 2.32652 0.19139 12.156 <2e-16 ***
## industry 14.03595 7.47271 1.878 0.0609 .
## revenue 0.78699 0.38575 2.040 0.0419 *
## employees 0.02486 0.01704 1.459 0.1453
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 80.62 on 491 degrees of freedom
## Multiple R-squared: 0.9784, Adjusted R-squared: 0.9781
## F-statistic: 2784 on 8 and 491 DF, p-value: < 2.2e-16
Regression: Duration Adding the non-linear variables to a second regression model All non-linear terms are significant and have opposite signs (neg/positive) to linear term
# regression with non-linear specification similar to publication
regression_dur2 <- lm(duration ~ acq_exp + acq_exp_sq + ret_exp + ret_exp_sq + freq + freq_sq + crossbuy + sow + industry
+ revenue + employees,
data = cust_ret2)
summary(regression_dur2)
##
## Call:
## lm(formula = duration ~ acq_exp + acq_exp_sq + ret_exp + ret_exp_sq +
## freq + freq_sq + crossbuy + sow + industry + revenue + employees,
## data = cust_ret2)
##
## Residuals:
## Min 1Q Median 3Q Max
## -54.590 -9.202 -0.942 8.460 79.204
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 3.018e+01 6.630e+00 4.553 6.69e-06 ***
## acq_exp -5.253e-02 2.190e-02 -2.399 0.01681 *
## acq_exp_sq 6.886e-05 2.123e-05 3.244 0.00126 **
## ret_exp 2.793e+00 1.774e-02 157.505 < 2e-16 ***
## ret_exp_sq -1.224e-03 1.590e-05 -76.956 < 2e-16 ***
## freq 1.065e+01 7.492e-01 14.213 < 2e-16 ***
## freq_sq -1.016e+00 3.780e-02 -26.864 < 2e-16 ***
## crossbuy 4.572e+00 4.601e-01 9.937 < 2e-16 ***
## sow 4.897e-01 4.467e-02 10.963 < 2e-16 ***
## industry -8.777e+00 1.639e+00 -5.355 1.32e-07 ***
## revenue -2.214e-01 8.298e-02 -2.668 0.00789 **
## employees -2.164e-02 3.697e-03 -5.852 8.90e-09 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 17.2 on 488 degrees of freedom
## Multiple R-squared: 0.999, Adjusted R-squared: 0.999
## F-statistic: 4.541e+04 on 11 and 488 DF, p-value: < 2.2e-16
Regression: Duration Train forest on train_df2 dataset
(code adopted from class)
set.seed(22)
forest_dur2 <- rfsrc(duration ~ acq_exp + ret_exp + freq + crossbuy + sow + industry
+ revenue + employees,
data = train_df2,
importance = TRUE,
ntree = 1000)
forest_dur2
## Sample size: 350
## Number of trees: 1000
## Forest terminal node size: 5
## Average no. of terminal nodes: 32.441
## No. of variables tried at each split: 3
## Total no. of variables: 8
## Resampling used to grow trees: swor
## Resample size used to grow trees: 221
## Analysis: RF-R
## Family: regr
## Splitting rule: mse *random*
## Number of random split points: 10
## (OOB) R squared: 0.98911507
## (OOB) Requested performance error: 3115.69143878
# construct linear and non-linear parametric models on training set
regression_linear_dur <- lm(duration ~ acq_exp + ret_exp + freq + crossbuy + sow + industry
+ revenue + employees,
data = train_df2)
regression_nlinear_dur <- lm(duration ~ acq_exp + ret_exp + freq + crossbuy + sow + industry
+ revenue + employees + acq_exp_sq + ret_exp_sq + freq_sq,
data = train_df2)
regression_nlinearint_dur <- lm(duration ~ acq_exp + ret_exp + freq + crossbuy + sow + industry
+ revenue + employees + acq_exp_sq + ret_exp_sq + freq_sq + ret_exp * crossbuy,
data = train_df2)
Regression: Duration OOB error rates
(code adopted from class)
forest_dur2$err.rate[length(forest_dur2$err.rate)]
## [1] 3115.691
# plot the OOB error rate
data.frame(err.rate = forest_dur2$err.rate) %>%
na.omit() %>%
tibble::rownames_to_column(var = "trees") %>%
mutate(trees = as.numeric(trees)) %>%
ggplot(aes(x = trees, y = err.rate, group = 1))+
geom_line()+
scale_x_continuous(breaks = seq(0,1050,100))+
labs(x = "Number of trees", y = "OOB Error rate")+
theme_nice
Regression: Duration Tuning a forest hyper-parameters for predictive accuracy Optimal hyperparameters: mytry: 7 nodes: 4 trees: 6000
(code adopted from class)
# Establish a list of possible values for hyper-parameters
mtry.values2 <- seq(4,8,1)
nodesize.values2 <- seq(4,10,2)
ntree.values2 <- seq(4e3,6e3,1e3)
# Create a data frame containing all combinations
hyper_grid2 <- expand.grid(mtry = mtry.values2, nodesize = nodesize.values2, ntree = ntree.values2)
# Create an empty vector to store OOB error values
oob_err2 <- c()
# Write a loop over the rows of hyper_grid to train the grid of models
for (i in 1:nrow(hyper_grid2)) {
# Train a Random Forest model
model <- rfsrc(duration ~ acq_exp + ret_exp + freq + crossbuy + sow + industry
+ revenue + employees,
data = train_df2,
mtry = hyper_grid2$mtry[i],
nodesize = hyper_grid2$nodesize[i],
ntree = hyper_grid2$ntree[i])
# Store OOB error for the model
oob_err2[i] <- model$err.rate[length(model$err.rate)]
}
# Identify optimal set of hyperparmeters based on OOB error
opt_i2 <- which.min(oob_err2)
print(hyper_grid2[opt_i2,])
## mtry nodesize ntree
## 44 7 4 6000
Regression: Duration Rebuild training forest with optimal hyper-params mytry: 7 nodes: 4 trees: 6000
(code adopted from class)
set.seed(22)
forest_hyperdur <- rfsrc(duration ~ acq_exp + ret_exp + freq + crossbuy + sow + industry
+ revenue + employees,
data = train_df2,
mtry = 7,
nodesize = 4,
ntree = 6000)
Regression: Duration Predict on the test dataset test_df2
(code adopted from class)
error.df <-
data.frame(pred1 = predict.rfsrc(forest_dur2,newdata = test_df2)$predicted,
pred2 = predict.rfsrc(forest_hyperdur, newdata = test_df2)$predicted,
pred3 = predict(regression_linear_dur, newdata = test_df2),
pred4 = predict(regression_nlinear_dur, newdata = test_df2),
pred5 = predict(regression_nlinearint_dur, newdata = test_df2),
actual = test_df2$duration,
customer = test_df2$customer) %>%
mutate_at(.funs = funs(abs.error = abs(actual - .),
abs.percent.error = abs(actual - .)/abs(actual)),
.vars = vars(pred1:pred5))
## Warning: `funs()` was deprecated in dplyr 0.8.0.
## Please use a list of either functions or lambdas:
##
## # Simple named list:
## list(mean = mean, median = median)
##
## # Auto named with `tibble::lst()`:
## tibble::lst(mean, median)
##
## # Using lambdas
## list(~ mean(., trim = .2), ~ median(., na.rm = TRUE))
#mae
error.df %>%
summarise_at(.funs = funs(mae = mean(.)),
.vars = vars(pred1_abs.error:pred5_abs.error))
## pred1_abs.error_mae pred2_abs.error_mae pred3_abs.error_mae
## 1 19.91001 14.00363 55.64751
## pred4_abs.error_mae pred5_abs.error_mae
## 1 11.63739 11.5004
#mape
error.df %>%
summarise_at(.funs = funs(mape = mean(.*100)),
.vars = vars(pred1_abs.percent.error:pred5_abs.percent.error))
## pred1_abs.percent.error_mape pred2_abs.percent.error_mape
## 1 NaN NaN
## pred3_abs.percent.error_mape pred4_abs.percent.error_mape
## 1 Inf Inf
## pred5_abs.percent.error_mape
## 1 Inf
# errors from the top customer portfolios
error.df2 <-
error.df %>%
left_join(test_df2, "customer") %>%
mutate(customer_portfolio = cut(x = rev <- revenue,
breaks = qu <- quantile(rev, probs = seq(0, 1, 0.25)),
labels = names(qu)[-1],
include.lowest = T))
portfolio.mae <-
error.df2 %>%
group_by(customer_portfolio) %>%
summarise_at(.funs = funs(mae = mean(.)),
.vars = vars(pred1_abs.error:pred5_abs.error)) %>%
ungroup()
portfolio.mape <-
error.df2 %>%
group_by(customer_portfolio) %>%
summarise_at(.funs = funs(mape = mean(.*100)),
.vars = vars(pred1_abs.percent.error:pred5_abs.percent.error)) %>%
ungroup()
portfolio.errors <-
portfolio.mae %>%
left_join(portfolio.mape, "customer_portfolio") %>%
gather(key = error_type, value = error, -customer_portfolio) %>%
mutate(error_type2 = ifelse(grepl(pattern = "mae", error_type),"MAE","MAPE"),
model_type = ifelse(grepl(pattern = "pred1", error_type),"Untuned Forest",
ifelse(grepl(pattern = "pred2", error_type),"Tuned Forest",
ifelse(grepl(pattern = "pred3", error_type),"Linear Model",
ifelse(grepl(pattern = "pred4", error_type),"Non-linear Model","Non-linear w interaction")))),
model_type_reordered = factor(model_type, levels = c("Linear Model","Non-linear Model","Non-linear w interaction","Untuned Forest","Tuned Forest")))
ggplot(portfolio.errors, aes(x = customer_portfolio,
y = error,
color = model_type_reordered,
group = model_type_reordered))+
geom_line(size = 1.02)+
geom_point(shape = 15) +
facet_wrap(~error_type2, scales = "free_y")+
scale_color_brewer(palette = "Set2") +
labs(y = "Error", x = "Customer portfolios")+
theme_nice +
theme(legend.position = "top")+
guides(color = guide_legend(title = "Model Type", size = 4,nrow = 2,byrow = TRUE))
## Warning: Removed 4 row(s) containing missing values (geom_path).
## Warning: Removed 7 rows containing missing values (geom_point).
error.df2 %>%
group_by(customer_portfolio) %>%
summarise(mean_retention_expense = mean(ret_exp),
sum_retention_expense = sum(ret_exp))
## # A tibble: 4 x 3
## customer_portfolio mean_retention_expense sum_retention_expense
## <fct> <dbl> <dbl>
## 1 25% 195. 7395.
## 2 50% 266. 9830.
## 3 75% 317. 11711.
## 4 100% 375. 14242.
MAE errors for: untuned forest: 19.91 tuned forest: 14.003 linear: 55.64 non linear (squared variables acq_exp_sq, ret_exp_sq, freq_sq): 11.64 linear with interaction (ret_xep * crossbuy): 11.50
error.df %>%
summarise_at(.funs = funs(mae = mean(.)),
.vars = vars(pred1_abs.error:pred5_abs.error))
## pred1_abs.error_mae pred2_abs.error_mae pred3_abs.error_mae
## 1 19.91001 14.00363 55.64751
## pred4_abs.error_mae pred5_abs.error_mae
## 1 11.63739 11.5004
Regression: Duration Partial dependence PDP plot (code adopted from class)
# inspect relationship of ret_exp with predicted duration with PDP
min(forest_dur1$xvar$ret_exp)
## [1] 0
max(forest_dur1$xvar$ret_exp)
## [1] 1094.96
ret_exp_seq = seq(0,145,5)
Regression: Duration Partial dependence PDP plot (code adopted from class)
# extract marginal effect using partial dependence
marginal.effect <- partial(forest_dur1,
partial.xvar = "ret_exp",
partial.values = ret_exp_seq)
means.exp <- marginal.effect$regrOutput$duration %>% colMeans()
Regression: Duration Partial dependence PDP plot (code adopted from class)
marginal.effect.df <-
data.frame(pred.duration = means.exp, ret_exp_seq = ret_exp_seq)
Regression: Duration Partial dependence PDP plot (code adopted from class)
ggplot(marginal.effect.df, aes(x = ret_exp_seq, y = pred.duration)) +
geom_point(shape = 21, color = "purple", size = 2, stroke = 1.2)+
geom_smooth(method = "lm", formula = y ~ poly(x,3), se = FALSE, color = "black")+ # try with other values
labs(x = "Retention in $", y = "Predicted duration") +
scale_x_continuous(breaks = seq(0,150,25))+
theme_nice # positive effect of ret_exp not clear as suggested by reg coefs
Regression: Duration Partial dependence PDP plot (code adopted from class)
# first check distribution of actual duration and ret_exp
ggplot(cust_ret, aes(x = ret_exp, y = duration)) +
geom_point(shape = 21, col = "purple", size = 3) +
stat_smooth(method = "lm", se = FALSE, color = "black") +
scale_x_continuous(breaks = seq(0,2000,100)) +
scale_y_continuous(breaks = seq(0,2000,200)) +
geom_rug(sides = "b", col = "red", alpha = 0.2) +
labs(y = "Actual duration", x = "Retention in $") +
theme_nice
Regression: Duration Partial dependence PDP plot (code adopted from class)
# repeat with smaller values of ret_exp
ret_exp_seq2 = seq(0,2000,200)
marginal.effect.new <- partial(forest_dur1,
partial.xvar = "ret_exp",
partial.values = ret_exp_seq2)
means.exp.new <- marginal.effect.new$regrOutput$duration %>% colMeans()
marginal.effect.df.new <-
data.frame(pred.duration = means.exp.new, ret_exp_seq = ret_exp_seq2)
ggplot(marginal.effect.df.new, aes(x = ret_exp_seq, y = pred.duration)) +
geom_point(shape = 21, color = "purple", size = 2, stroke = 1.2)+
geom_path()+
labs(x = "Retention in $", y = "Predicted duration") +
scale_x_continuous(breaks = seq(0,2000,250))+
theme_nice
Generate PDP plots for all variables. Random forest model (RF) for duration (not tuned)
plot.variable(forest_dur2, partial=TRUE)
Generate PDP plots for all variables. Random forest model (RF) for duration (tuned)
plot.variable(forest_hyperdur, partial=TRUE)
# Fit Decision Tree
dt.fit<-rpart(acquisition ~ acq_exp + industry + revenue + employees, data = train_df)
# plot tree
rattle::fancyRpartPlot(dt.fit, sub = "Decision Tree to Predict Customer Acquisition")
# display the results
printcp(dt.fit)
##
## Classification tree:
## rpart(formula = acquisition ~ acq_exp + industry + revenue +
## employees, data = train_df)
##
## Variables actually used in tree construction:
## [1] acq_exp employees industry revenue
##
## Root node error: 101/350 = 0.28857
##
## n= 350
##
## CP nsplit rel error xerror xstd
## 1 0.133663 0 1.00000 1.00000 0.083928
## 2 0.049505 2 0.73267 0.90099 0.081248
## 3 0.026403 3 0.68317 0.87129 0.080359
## 4 0.019802 6 0.60396 0.92079 0.081819
## 5 0.010000 8 0.56436 0.95050 0.082641
# detailied summany of splits
summary(dt.fit)
## Call:
## rpart(formula = acquisition ~ acq_exp + industry + revenue +
## employees, data = train_df)
## n= 350
##
## CP nsplit rel error xerror xstd
## 1 0.13366337 0 1.0000000 1.0000000 0.08392763
## 2 0.04950495 2 0.7326733 0.9009901 0.08124847
## 3 0.02640264 3 0.6831683 0.8712871 0.08035938
## 4 0.01980198 6 0.6039604 0.9207921 0.08181871
## 5 0.01000000 8 0.5643564 0.9504950 0.08264129
##
## Variable importance
## employees acq_exp revenue industry
## 52 26 18 5
##
## Node number 1: 350 observations, complexity param=0.1336634
## predicted class=1 expected loss=0.2885714 P(node) =1
## class counts: 101 249
## probabilities: 0.289 0.711
## left son=2 (145 obs) right son=3 (205 obs)
## Primary splits:
## employees < 611 to the left, improve=25.889900, (0 missing)
## industry < 0.5 to the left, improve=10.780300, (0 missing)
## acq_exp < 750.61 to the right, improve= 9.080248, (0 missing)
## revenue < 38.21 to the left, improve= 7.637515, (0 missing)
## Surrogate splits:
## acq_exp < 658.535 to the right, agree=0.606, adj=0.048, (0 split)
## revenue < 34.12 to the left, agree=0.606, adj=0.048, (0 split)
##
## Node number 2: 145 observations, complexity param=0.1336634
## predicted class=0 expected loss=0.4827586 P(node) =0.4142857
## class counts: 75 70
## probabilities: 0.517 0.483
## left son=4 (73 obs) right son=5 (72 obs)
## Primary splits:
## revenue < 41.01 to the left, improve=8.268055, (0 missing)
## employees < 427.5 to the left, improve=7.668977, (0 missing)
## industry < 0.5 to the left, improve=6.434745, (0 missing)
## acq_exp < 750.61 to the right, improve=4.085588, (0 missing)
## Surrogate splits:
## acq_exp < 686.965 to the right, agree=0.552, adj=0.097, (0 split)
## employees < 427.5 to the left, agree=0.545, adj=0.083, (0 split)
## industry < 0.5 to the left, agree=0.538, adj=0.069, (0 split)
##
## Node number 3: 205 observations, complexity param=0.01980198
## predicted class=1 expected loss=0.1268293 P(node) =0.5857143
## class counts: 26 179
## probabilities: 0.127 0.873
## left son=6 (8 obs) right son=7 (197 obs)
## Primary splits:
## acq_exp < 183.57 to the left, improve=4.132035, (0 missing)
## industry < 0.5 to the left, improve=3.991039, (0 missing)
## employees < 872.5 to the left, improve=2.177643, (0 missing)
## revenue < 27.94 to the left, improve=1.495427, (0 missing)
##
## Node number 4: 73 observations, complexity param=0.02640264
## predicted class=0 expected loss=0.3150685 P(node) =0.2085714
## class counts: 50 23
## probabilities: 0.685 0.315
## left son=8 (38 obs) right son=9 (35 obs)
## Primary splits:
## industry < 0.5 to the left, improve=2.714368, (0 missing)
## acq_exp < 750.61 to the right, improve=2.571365, (0 missing)
## employees < 367.5 to the left, improve=2.261235, (0 missing)
## revenue < 22.655 to the left, improve=1.537152, (0 missing)
## Surrogate splits:
## acq_exp < 690.71 to the left, agree=0.603, adj=0.171, (0 split)
## revenue < 26.835 to the left, agree=0.589, adj=0.143, (0 split)
## employees < 393 to the right, agree=0.575, adj=0.114, (0 split)
##
## Node number 5: 72 observations, complexity param=0.04950495
## predicted class=1 expected loss=0.3472222 P(node) =0.2057143
## class counts: 25 47
## probabilities: 0.347 0.653
## left son=10 (25 obs) right son=11 (47 obs)
## Primary splits:
## employees < 425.5 to the left, improve=4.894208, (0 missing)
## industry < 0.5 to the left, improve=2.688889, (0 missing)
## revenue < 58.995 to the left, improve=1.869658, (0 missing)
## acq_exp < 660.695 to the right, improve=1.747263, (0 missing)
## Surrogate splits:
## acq_exp < 649.955 to the right, agree=0.667, adj=0.04, (0 split)
## revenue < 60.835 to the right, agree=0.667, adj=0.04, (0 split)
##
## Node number 6: 8 observations
## predicted class=0 expected loss=0.375 P(node) =0.02285714
## class counts: 5 3
## probabilities: 0.625 0.375
##
## Node number 7: 197 observations
## predicted class=1 expected loss=0.106599 P(node) =0.5628571
## class counts: 21 176
## probabilities: 0.107 0.893
##
## Node number 8: 38 observations
## predicted class=0 expected loss=0.1842105 P(node) =0.1085714
## class counts: 31 7
## probabilities: 0.816 0.184
##
## Node number 9: 35 observations, complexity param=0.02640264
## predicted class=0 expected loss=0.4571429 P(node) =0.1
## class counts: 19 16
## probabilities: 0.543 0.457
## left son=18 (7 obs) right son=19 (28 obs)
## Primary splits:
## acq_exp < 757.085 to the right, improve=3.657143, (0 missing)
## revenue < 34.255 to the right, improve=2.119680, (0 missing)
## employees < 367.5 to the left, improve=1.851429, (0 missing)
## Surrogate splits:
## revenue < 37.42 to the right, agree=0.829, adj=0.143, (0 split)
##
## Node number 10: 25 observations, complexity param=0.01980198
## predicted class=0 expected loss=0.4 P(node) =0.07142857
## class counts: 15 10
## probabilities: 0.600 0.400
## left son=20 (7 obs) right son=21 (18 obs)
## Primary splits:
## acq_exp < 660.695 to the right, improve=3.1111110, (0 missing)
## industry < 0.5 to the left, improve=1.0384620, (0 missing)
## revenue < 47.05 to the left, improve=1.0384620, (0 missing)
## employees < 323 to the left, improve=0.8888889, (0 missing)
## Surrogate splits:
## employees < 323 to the left, agree=0.76, adj=0.143, (0 split)
##
## Node number 11: 47 observations
## predicted class=1 expected loss=0.212766 P(node) =0.1342857
## class counts: 10 37
## probabilities: 0.213 0.787
##
## Node number 18: 7 observations
## predicted class=0 expected loss=0 P(node) =0.02
## class counts: 7 0
## probabilities: 1.000 0.000
##
## Node number 19: 28 observations, complexity param=0.02640264
## predicted class=1 expected loss=0.4285714 P(node) =0.08
## class counts: 12 16
## probabilities: 0.429 0.571
## left son=38 (8 obs) right son=39 (20 obs)
## Primary splits:
## acq_exp < 332.67 to the left, improve=2.3142860, (0 missing)
## employees < 427.5 to the left, improve=1.6937730, (0 missing)
## revenue < 34.255 to the right, improve=0.8642857, (0 missing)
## Surrogate splits:
## revenue < 24.03 to the left, agree=0.75, adj=0.125, (0 split)
##
## Node number 20: 7 observations
## predicted class=0 expected loss=0 P(node) =0.02
## class counts: 7 0
## probabilities: 1.000 0.000
##
## Node number 21: 18 observations
## predicted class=1 expected loss=0.4444444 P(node) =0.05142857
## class counts: 8 10
## probabilities: 0.444 0.556
##
## Node number 38: 8 observations
## predicted class=0 expected loss=0.25 P(node) =0.02285714
## class counts: 6 2
## probabilities: 0.750 0.250
##
## Node number 39: 20 observations
## predicted class=1 expected loss=0.3 P(node) =0.05714286
## class counts: 6 14
## probabilities: 0.300 0.700
# visualize cross validation results
plotcp(dt.fit)
set.seed(123)
# Get predictions on the test set
dt.pred = predict(dt.fit, newdata = test_df, type = "class")
# Get the confusion matrix
dt.pred.confusion <- table(dt.pred, test_df$acquisition)
dt.pred.confusion
##
## dt.pred 0 1
## 0 26 2
## 1 35 87
# Get the accuracy of the tree model
dt.pred.accuracy <- sum(diag(dt.pred.confusion))/sum(dt.pred.confusion)
dt.pred.accuracy
## [1] 0.7533333
# determine where to cut the tree
dt.fit$cptable[which.min(dt.fit$cptable[,"xerror"]),"CP"]
## [1] 0.02640264
# prune the tree to prevent overfitting
dt.pfit<- prune(dt.fit, cp = dt.fit$cptable[which.min(dt.fit$cptable[,"xerror"]),"CP"])
# show results of pruned tree
summary(dt.pfit)
## Call:
## rpart(formula = acquisition ~ acq_exp + industry + revenue +
## employees, data = train_df)
## n= 350
##
## CP nsplit rel error xerror xstd
## 1 0.13366337 0 1.0000000 1.0000000 0.08392763
## 2 0.04950495 2 0.7326733 0.9009901 0.08124847
## 3 0.02640264 3 0.6831683 0.8712871 0.08035938
##
## Variable importance
## employees revenue acq_exp industry
## 72 22 5 1
##
## Node number 1: 350 observations, complexity param=0.1336634
## predicted class=1 expected loss=0.2885714 P(node) =1
## class counts: 101 249
## probabilities: 0.289 0.711
## left son=2 (145 obs) right son=3 (205 obs)
## Primary splits:
## employees < 611 to the left, improve=25.889900, (0 missing)
## industry < 0.5 to the left, improve=10.780300, (0 missing)
## acq_exp < 750.61 to the right, improve= 9.080248, (0 missing)
## revenue < 38.21 to the left, improve= 7.637515, (0 missing)
## Surrogate splits:
## acq_exp < 658.535 to the right, agree=0.606, adj=0.048, (0 split)
## revenue < 34.12 to the left, agree=0.606, adj=0.048, (0 split)
##
## Node number 2: 145 observations, complexity param=0.1336634
## predicted class=0 expected loss=0.4827586 P(node) =0.4142857
## class counts: 75 70
## probabilities: 0.517 0.483
## left son=4 (73 obs) right son=5 (72 obs)
## Primary splits:
## revenue < 41.01 to the left, improve=8.268055, (0 missing)
## employees < 427.5 to the left, improve=7.668977, (0 missing)
## industry < 0.5 to the left, improve=6.434745, (0 missing)
## acq_exp < 750.61 to the right, improve=4.085588, (0 missing)
## Surrogate splits:
## acq_exp < 686.965 to the right, agree=0.552, adj=0.097, (0 split)
## employees < 427.5 to the left, agree=0.545, adj=0.083, (0 split)
## industry < 0.5 to the left, agree=0.538, adj=0.069, (0 split)
##
## Node number 3: 205 observations
## predicted class=1 expected loss=0.1268293 P(node) =0.5857143
## class counts: 26 179
## probabilities: 0.127 0.873
##
## Node number 4: 73 observations
## predicted class=0 expected loss=0.3150685 P(node) =0.2085714
## class counts: 50 23
## probabilities: 0.685 0.315
##
## Node number 5: 72 observations, complexity param=0.04950495
## predicted class=1 expected loss=0.3472222 P(node) =0.2057143
## class counts: 25 47
## probabilities: 0.347 0.653
## left son=10 (25 obs) right son=11 (47 obs)
## Primary splits:
## employees < 425.5 to the left, improve=4.894208, (0 missing)
## industry < 0.5 to the left, improve=2.688889, (0 missing)
## revenue < 58.995 to the left, improve=1.869658, (0 missing)
## acq_exp < 660.695 to the right, improve=1.747263, (0 missing)
## Surrogate splits:
## acq_exp < 649.955 to the right, agree=0.667, adj=0.04, (0 split)
## revenue < 60.835 to the right, agree=0.667, adj=0.04, (0 split)
##
## Node number 10: 25 observations
## predicted class=0 expected loss=0.4 P(node) =0.07142857
## class counts: 15 10
## probabilities: 0.600 0.400
##
## Node number 11: 47 observations
## predicted class=1 expected loss=0.212766 P(node) =0.1342857
## class counts: 10 37
## probabilities: 0.213 0.787
# plot pruned results
rattle::fancyRpartPlot(dt.pfit, sub = "Pruned Decision Tree to Predict Customer Acquisition")
# Get predictions on the test set
dt.ppred = predict(dt.pfit, newdata = test_df, type = "class")
# Get the confusion matrix on the pruned tree
dt.ppred.confusion <- table(dt.ppred, test_df$acquisition)
dt.ppred.confusion
##
## dt.ppred 0 1
## 0 38 6
## 1 23 83
# Get the accuracy of the prunned tree model
dt.ppred.accuracy <- sum(diag(dt.ppred.confusion))/sum(dt.ppred.confusion)
dt.ppred.accuracy
## [1] 0.8066667