Intro In Regork, there has been a concern for the amount of customers that are leaving the company. If the company cannot retain customers, there will be massive losses in revenue per month.
To determine what keeps customers with Regork, we decided to investigate the most important determinants for current customers. We used many different machine learnings to find the most significant predictor variables to the status of customers.
This analysis can be very beneficial for the company because it can save the company upwards of $36,000 in loss of revenue per month. Our solution is to encourage new customers to sign longer contracts, and making sure customers are happy with the company’s services.
Data Preparation & Exploratory Data Analysis
library(tidymodels)
library(tidyverse)
library(baguette)
library(vip)
library(pdp)
library(readr)
library(ggplot2)
library(rpart.plot)
library(ranger)
library(gmodels)
library(ROCR)
library(VIM)
library(corrplot)
library(randomForest)
library(caTools)
cr1 <- read.csv("customer_retention.csv")
cr1 <- read_csv("customer_retention.csv")
cr <- mutate(cr1, Status = factor(Status)) %>% na.omit()
cr %>%
count(Status) %>%
mutate(prop = n / sum(n))
## # A tibble: 2 × 3
## Status n prop
## <fct> <int> <dbl>
## 1 Current 5132 0.734
## 2 Left 1856 0.266
cr %>%
group_by(Status) %>%
summarise(Number = n()) %>%
mutate(Percent = prop.table(Number)*100) %>%
ggplot(aes(Status, Percent)) +
geom_col(aes(fill = Status)) +
labs(title = "Percentage by Status") +
theme(plot.title = element_text(hjust = 0.5)) +
geom_text(aes(label = sprintf("%.1f%%", Percent)))
p1<- ggplot(data = cr, aes(x = PhoneService, y = MonthlyCharges, fill = PhoneService))+geom_boxplot()
p2<- ggplot(data = cr, aes(x = PhoneService, y = TotalCharges, fill = PhoneService))+geom_boxplot()
grid.arrange(p1, p2, ncol=2)
p3 <- ggplot(data = cr, aes(x = Contract, y = MonthlyCharges, fill = Contract))+geom_boxplot() + stat_summary(fun=mean, geom="point", shape=20, size=8, color="red", fill="red")
p4 <- ggplot(data = cr, aes(x = Contract, y = TotalCharges, fill = Contract))+geom_boxplot() + stat_summary(fun=mean, geom="point", shape=20, size=8, color="red", fill="red")
p5 <- ggplot(data = cr, aes(x = Contract, y = Tenure, fill = Contract))+geom_boxplot() + stat_summary(fun=mean, geom="point", shape=20, size=8, color="red", fill="red")
grid.arrange(p3, p4, p5, ncol=2)
p6 <- ggplot(data = cr, aes(x = factor(Status), y = Tenure, fill = Status)) +geom_boxplot()
p7 <- ggplot(data = cr, aes(x = factor(Status), y = MonthlyCharges, fill = Status))+geom_boxplot()
p8 <- ggplot(data = cr, aes(x = factor(Status), y = TotalCharges, fill = Status))+geom_boxplot()
grid.arrange(p6, p7, p8, ncol=2)
ggplot(data = cr, aes(x=Tenure, fill = factor(Status))) + geom_bar(position = "fill") + scale_fill_manual(values = c("blue", "red")) + theme(
plot.background = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
plot.title = element_text(hjust = 0.5),
text=element_text(size=14, family="Helvetica")) + labs(x = " ", title = "Status by Tenure")
ggplot(data = cr, aes(x=Contract, fill = factor(Status)))+ geom_bar(position = "fill") + scale_fill_manual(values = c("blue", "red")) + theme(
plot.background = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
plot.title = element_text(hjust = 0.5),
text=element_text(size=14, family="Helvetica")) + labs(x = " ", title = "Status by Contract type") + coord_flip()
ggplot(data = cr, aes(x=factor(SeniorCitizen), fill = factor(Status)))+ geom_bar(position = "fill") + scale_fill_manual(values = c("blue", "red")) + theme(
plot.background = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
plot.title = element_text(hjust = 0.5),
text=element_text(size=14, family="Helvetica")) + labs(x = " ", title = "Non-Senior vs Senior citizen")
ggplot(data = cr, aes(x=PaperlessBilling, fill = factor(Status))) + geom_bar(position = "fill") + scale_fill_manual(values = c("blue", "red")) + theme(
plot.background = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
plot.title = element_text(hjust = 0.5),
text=element_text(size=14, family="Helvetica")) + labs(x = " ", title = "Status by having paperless billing")
ggplot(data = cr, aes(x=PaymentMethod, fill = factor(Status))) + geom_bar(position = "fill") + scale_fill_manual(values = c("blue", "red")) + theme(
plot.background = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
plot.title = element_text(hjust = 0.5),
text=element_text(size=12, family="Helvetica")) + labs(x = "", title = "Status by Payment Method") + coord_flip()
ggplot(data = cr, aes(x=Partner, fill = factor(Status))) + geom_bar(position = "fill") + scale_fill_manual(values = c("blue", "red")) + theme(
plot.background = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
plot.title = element_text(hjust = 0.5),
text=element_text(size=12, family="Helvetica")) + labs(x = " ", title = "Status by having a partner")
ggplot(data = cr, aes(x=Dependents, fill = factor(Status))) + geom_bar(position = "fill") + scale_fill_manual(values = c("blue", "red")) + theme(
plot.background = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
plot.title = element_text(hjust = 0.5),
text=element_text(size=12, family="Helvetica")) + labs(x = " ", title = "Status by having dependents") + coord_flip()
ggplot(data = cr, aes(x=InternetService, fill = factor(Status))) + geom_bar(position = "fill") + scale_fill_manual(values = c("blue", "red")) + theme(
plot.background = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
plot.title = element_text(hjust = 0.5),
text=element_text(size=12, family="Helvetica")) + labs(x = " ", title = "Status by Internet service") + coord_flip()
p9 <- ggplot(data = cr, aes(x=OnlineSecurity, fill = factor(Status))) + geom_bar(position = "fill") + coord_flip() + labs(title = "Online security")
p10 <- ggplot(data = cr, aes(x=OnlineBackup, fill = factor(Status))) + geom_bar(position = "fill") + coord_flip() + labs(title = "Online backup")
p11 <- ggplot(data = cr, aes(x=DeviceProtection, fill = factor(Status))) + geom_bar(position = "fill") + coord_flip() + labs(title = "Device protection")
p12 <- ggplot(data = cr, aes(x=TechSupport, fill = factor(Status))) + geom_bar(position = "fill") + coord_flip() + labs(title = "Tech support")
grid.arrange(p9, p10, p11, p12,nrow=2 )
ggplot(data = cr, aes(x=MonthlyCharges, y=Tenure, color=factor(Status))) + geom_point(alpha = 0.5) + geom_smooth(method=lm) + labs(title = " ")
ggplot(data = cr, aes(x=TotalCharges, y=Tenure, color=factor(Status))) + geom_point(alpha = 0.5) +geom_smooth(method=lm) + scale_x_log10() + labs(title = "Log transformation")
set.seed(123)
cr_split <- initial_split(cr, prop = .7, strata = "Status")
cr_train <- training(cr_split)
cr_test <- testing(cr_split)
cr_train %>%
count(Status) %>%
mutate(prop = n / sum(n))
## # A tibble: 2 × 3
## Status n prop
## <fct> <int> <dbl>
## 1 Current 3592 0.734
## 2 Left 1299 0.266
cr_test %>%
count(Status) %>%
mutate(prop = n / sum(n))
## # A tibble: 2 × 3
## Status n prop
## <fct> <int> <dbl>
## 1 Current 1540 0.734
## 2 Left 557 0.266
set.seed(123)
kfold <- vfold_cv(cr_train, v = 5)
lr_mod <- logistic_reg()
results <- lr_mod %>%
fit_resamples(Status ~ ., kfold)
collect_metrics(results) %>% filter(.metric == "accuracy")
## # A tibble: 1 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.797 5 0.00705 Preprocessor1_Model1
show_best(results, metric = "roc_auc")
## # A tibble: 1 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 roc_auc binary 0.844 5 0.00974 Preprocessor1_Model1
show_best(results, metric = "roc_auc")
## # A tibble: 1 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 roc_auc binary 0.844 5 0.00974 Preprocessor1_Model1
final_fit <- lr_mod %>%
fit(Status ~ ., data = cr_train)
tidy(final_fit)
## # A tibble: 31 × 5
## term estimate std.error statistic p.value
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 (Intercept) 1.02 0.969 1.05 2.95e- 1
## 2 GenderMale -0.0461 0.0779 -0.592 5.54e- 1
## 3 SeniorCitizen 0.258 0.101 2.55 1.08e- 2
## 4 PartnerYes -0.141 0.0925 -1.52 1.29e- 1
## 5 DependentsYes -0.0475 0.108 -0.441 6.59e- 1
## 6 Tenure -0.0646 0.00762 -8.47 2.35e-17
## 7 PhoneServiceYes -0.0417 0.774 -0.0539 9.57e- 1
## 8 MultipleLinesNo phone service NA NA NA NA
## 9 MultipleLinesYes 0.442 0.211 2.10 3.56e- 2
## 10 InternetServiceFiber optic 1.48 0.950 1.56 1.18e- 1
## # … with 21 more rows
exp(coef(final_fit$fit))
## (Intercept) GenderMale
## 2.7593894 0.9549166
## SeniorCitizen PartnerYes
## 1.2949103 0.8688686
## DependentsYes Tenure
## 0.9536487 0.9374419
## PhoneServiceYes MultipleLinesNo phone service
## 0.9591253 NA
## MultipleLinesYes InternetServiceFiber optic
## 1.5562922 4.4143178
## InternetServiceNo OnlineSecurityNo internet service
## 0.2214383 NA
## OnlineSecurityYes OnlineBackupNo internet service
## 0.7591602 NA
## OnlineBackupYes DeviceProtectionNo internet service
## 0.9028811 NA
## DeviceProtectionYes TechSupportNo internet service
## 1.1801326 NA
## TechSupportYes StreamingTVNo internet service
## 0.8583097 NA
## StreamingTVYes StreamingMoviesNo internet service
## 1.6073481 NA
## StreamingMoviesYes ContractOne year
## 1.7542279 0.4452648
## ContractTwo year PaperlessBillingYes
## 0.2497494 1.3698551
## PaymentMethodCredit card (automatic) PaymentMethodElectronic check
## 1.0223826 1.3813743
## PaymentMethodMailed check MonthlyCharges
## 0.9410159 0.9689548
## TotalCharges
## 1.0003906
final_fit %>%
predict(cr_test) %>%
bind_cols(cr_test %>% select(Status)) %>%
conf_mat(truth = Status, estimate = .pred_class)
## Truth
## Prediction Current Left
## Current 1362 225
## Left 178 332
vip::vip(final_fit)
lr_fit1 <- lr_mod %>%
fit(Status ~ MonthlyCharges, data = cr_train)
lr_fit1 %>% predict(cr_train, type = "prob")
## # A tibble: 4,891 × 2
## .pred_Current .pred_Left
## <dbl> <dbl>
## 1 0.836 0.164
## 2 0.768 0.232
## 3 0.807 0.193
## 4 0.836 0.164
## 5 0.787 0.213
## 6 0.604 0.396
## 7 0.574 0.426
## 8 0.855 0.145
## 9 0.599 0.401
## 10 0.661 0.339
## # … with 4,881 more rows
Machine Learning
Logistic Regression
p1 <- lr_fit1 %>%
predict(cr_train, type = "prob") %>%
mutate(MonthlyCharges = cr_train$MonthlyCharges) %>%
ggplot(aes(MonthlyCharges, .pred_Left)) +
geom_point(alpha = .2, color = "orange") +
scale_y_continuous("Probability of Leaving", limits = c(0, 1)) +
ggtitle("Predicted Probabilities Regarding Monthly Charges") +
theme_classic()
p1
exp(coef(lr_fit1$fit))
## (Intercept) MonthlyCharges
## 0.1216563 1.0160933
lr_fit2 <- lr_mod %>%
fit(Status ~ MonthlyCharges + SeniorCitizen, data = cr_train)
lr_fit2 %>% predict(cr_train, type = "prob")
## # A tibble: 4,891 × 2
## .pred_Current .pred_Left
## <dbl> <dbl>
## 1 0.843 0.157
## 2 0.785 0.215
## 3 0.818 0.182
## 4 0.843 0.157
## 5 0.801 0.199
## 6 0.647 0.353
## 7 0.622 0.378
## 8 0.859 0.141
## 9 0.643 0.357
## 10 0.696 0.304
## # … with 4,881 more rows
p2 <- lr_fit2 %>%
predict(cr_train, type = "prob") %>%
mutate(MonthlyCharges = cr_train$MonthlyCharges,
SeniorCitizen = cr_train$SeniorCitizen) %>%
ggplot(aes(MonthlyCharges, .pred_Left, color = SeniorCitizen)) +
geom_point(alpha = .2, size = 1.5) +
scale_y_continuous("Probability of Leaving", limits = c(0, 1)) +
ggtitle("Predicted Probabilities Regarding Monthly Charges") +
theme_classic()
p2
lr_fit3 <- lr_mod %>%
fit(Status ~ TotalCharges + SeniorCitizen, data = cr_train)
p3 <- lr_fit3 %>%
predict(cr_train, type = "prob") %>%
mutate(TotalCharges = cr_train$TotalCharges,
SeniorCitizen = cr_train$SeniorCitizen) %>%
ggplot(aes(TotalCharges, .pred_Left, color = SeniorCitizen)) +
geom_point(alpha = .2, size = 1.5) +
scale_y_continuous("Probability of Leaving", limits = c(0, 1)) +
ggtitle("Predicted Probabilities Regarding Total Charges") +
theme_classic()
p3
lr_fit4 <- lr_mod %>%
fit(Status ~ TotalCharges, data = cr_train)
p4 <- lr_fit4 %>%
predict(cr_train, type = "prob") %>%
mutate(TotalCharges = cr_train$TotalCharges) %>%
ggplot(aes(TotalCharges, .pred_Left)) +
geom_point(alpha = .2, color = "orange") +
scale_y_continuous("Probability of Leaving", limits = c(0, 1)) +
ggtitle("Predicted Probabilities Regarding Total Charges") +
theme_classic()
p4
lr_fit5 <- lr_mod %>%
fit(Status ~ MonthlyCharges + Tenure, data = cr_train)
p5 <- lr_fit5 %>%
predict(cr_train, type = "prob") %>%
mutate(MonthlyCharges = cr_train$MonthlyCharges,
Tenure = cr_train$Tenure) %>%
ggplot(aes(MonthlyCharges, .pred_Left, color = Tenure)) +
geom_point(alpha = .2, size = 1.5) +
scale_y_continuous("Probability of Leaving", limits = c(0, 1)) +
ggtitle("Predicted Probabilities Regarding Monthly Charges") +
theme_classic()
p5
lr_fit6 <- lr_mod %>%
fit(Status ~ MonthlyCharges + Gender, data = cr_train)
p6 <- lr_fit6 %>%
predict(cr_train, type = "prob") %>%
mutate(MonthlyCharges = cr_train$MonthlyCharges,
Gender = cr_train$Gender) %>%
ggplot(aes(MonthlyCharges, .pred_Left, color = Gender)) +
geom_point(alpha = .2, size = .8) +
scale_y_continuous("Probability of Leaving", limits = c(0, 1)) +
ggtitle("Predicted Probabilities Regarding Monthly Charges") +
theme_classic()
p6
lr_fit7 <- lr_mod %>%
fit(Status ~ MonthlyCharges + Partner, data = cr_train)
p7 <- lr_fit7 %>%
predict(cr_train, type = "prob") %>%
mutate(MonthlyCharges = cr_train$MonthlyCharges,
Partner = cr_train$Partner) %>%
ggplot(aes(MonthlyCharges, .pred_Left, color = Partner)) +
geom_point(alpha = .2, size = .8) +
scale_y_continuous("Probability of Leaving", limits = c(0, 1)) +
ggtitle("Predicted Probabilities Regarding Monthly Charges") +
theme_classic()
p7
lr_fit8 <- lr_mod %>%
fit(Status ~ MonthlyCharges + Dependents, data = cr_train)
p8 <- lr_fit8 %>%
predict(cr_train, type = "prob") %>%
mutate(MonthlyCharges = cr_train$MonthlyCharges,
Dependents = cr_train$Dependents) %>%
ggplot(aes(MonthlyCharges, .pred_Left, color = Dependents)) +
geom_point(alpha = .2, size = .8) +
scale_y_continuous("Probability of Leaving", limits = c(0, 1)) +
ggtitle("Predicted Probabilities Regarding Monthly Charges") +
theme_classic()
p8
lr_fit9 <- lr_mod %>%
fit(Status ~ MonthlyCharges + Dependents + Partner +
SeniorCitizen + Gender, data = cr_train)
lr_fit9 %>%
predict(cr_test) %>%
bind_cols(cr_test %>% select(Status)) %>%
accuracy(truth = Status, estimate = .pred_class)
## # A tibble: 1 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.727
lr_fit10 <- lr_mod %>%
fit(Status ~ ., data = cr_train)
lr_fit10 %>%
predict(cr_test) %>%
bind_cols(cr_test %>% select(Status)) %>%
accuracy(truth = Status, estimate = .pred_class)
## # A tibble: 1 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.808
exp(coef(lr_fit10$fit))
## (Intercept) GenderMale
## 2.7593894 0.9549166
## SeniorCitizen PartnerYes
## 1.2949103 0.8688686
## DependentsYes Tenure
## 0.9536487 0.9374419
## PhoneServiceYes MultipleLinesNo phone service
## 0.9591253 NA
## MultipleLinesYes InternetServiceFiber optic
## 1.5562922 4.4143178
## InternetServiceNo OnlineSecurityNo internet service
## 0.2214383 NA
## OnlineSecurityYes OnlineBackupNo internet service
## 0.7591602 NA
## OnlineBackupYes DeviceProtectionNo internet service
## 0.9028811 NA
## DeviceProtectionYes TechSupportNo internet service
## 1.1801326 NA
## TechSupportYes StreamingTVNo internet service
## 0.8583097 NA
## StreamingTVYes StreamingMoviesNo internet service
## 1.6073481 NA
## StreamingMoviesYes ContractOne year
## 1.7542279 0.4452648
## ContractTwo year PaperlessBillingYes
## 0.2497494 1.3698551
## PaymentMethodCredit card (automatic) PaymentMethodElectronic check
## 1.0223826 1.3813743
## PaymentMethodMailed check MonthlyCharges
## 0.9410159 0.9689548
## TotalCharges
## 1.0003906
cr_recipe <- recipe(Status ~ ., data = cr_train)
cr_std <- cr_recipe %>%
step_normalize(all_numeric_predictors()) %>%
step_YeoJohnson(all_numeric_predictors())
Decision Tree
set.seed(123)
churn_split <- initial_split(cr,prop = 0.7, strata = "Status")
churn_train <- training(churn_split)
churn_test <- testing(churn_split)
dt_mod <- decision_tree(mode = 'classification') %>%
set_engine("rpart")
model_recipe <- recipe(Status ~ ., data = cr)
dt_fit <- workflow() %>%
add_recipe(model_recipe) %>%
add_model(dt_mod) %>%
fit(data = cr)
rpart.plot::rpart.plot(dt_fit$fit$fit$fit)
set.seed(123)
kfold <- vfold_cv(cr, v = 5)
dt_results <- fit_resamples(dt_mod, model_recipe, kfold)
collect_metrics(dt_results)
## # A tibble: 2 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.787 5 0.00447 Preprocessor1_Model1
## 2 roc_auc binary 0.802 5 0.00532 Preprocessor1_Model1
dt_mod <- decision_tree(
mode = "classification",
cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()
) %>%
set_engine("rpart")
dt_hyper_grid <- grid_regular(
cost_complexity(),
tree_depth(),
min_n(),
levels = 5
)
set.seed(123)
dt_results <- tune_grid(dt_mod, model_recipe, resamples = kfold, grid = dt_hyper_grid)
show_best(dt_results, metric = "roc_auc", n = 5)
## # A tibble: 5 × 9
## cost_complexity tree_depth min_n .metric .estima…¹ mean n std_err .config
## <dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.0000000001 8 21 roc_auc binary 0.820 5 0.00698 Prepro…
## 2 0.0000000178 8 21 roc_auc binary 0.820 5 0.00698 Prepro…
## 3 0.00000316 8 21 roc_auc binary 0.820 5 0.00698 Prepro…
## 4 0.0000000001 8 30 roc_auc binary 0.819 5 0.00741 Prepro…
## 5 0.0000000178 8 30 roc_auc binary 0.819 5 0.00741 Prepro…
## # … with abbreviated variable name ¹.estimator
dt_best_model <- select_best(dt_results, metric = 'roc_auc')
dt_final_wf <- workflow() %>%
add_recipe(model_recipe) %>%
add_model(dt_mod) %>%
finalize_workflow(dt_best_model)
dt_final_fit <- dt_final_wf %>%
fit(data = cr)
dt_final_fit %>%
extract_fit_parsnip() %>%
vip(20)
bag_mod <- bag_tree() %>%
set_engine("rpart", times = 5) %>%
set_mode("classification")
bag_results <- fit_resamples(bag_mod, model_recipe, kfold)
collect_metrics(bag_results)
## # A tibble: 2 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.759 5 0.00551 Preprocessor1_Model1
## 2 roc_auc binary 0.759 5 0.00597 Preprocessor1_Model1
kfold <- vfold_cv(churn_train, v = 5)
collect_metrics(dt_results)
## # A tibble: 250 × 9
## cost_complexity tree_depth min_n .metric .esti…¹ mean n std_err .config
## <dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.0000000001 1 2 accuracy binary 0.734 5 0.00564 Prepro…
## 2 0.0000000001 1 2 roc_auc binary 0.5 5 0 Prepro…
## 3 0.0000000178 1 2 accuracy binary 0.734 5 0.00564 Prepro…
## 4 0.0000000178 1 2 roc_auc binary 0.5 5 0 Prepro…
## 5 0.00000316 1 2 accuracy binary 0.734 5 0.00564 Prepro…
## 6 0.00000316 1 2 roc_auc binary 0.5 5 0 Prepro…
## 7 0.000562 1 2 accuracy binary 0.734 5 0.00564 Prepro…
## 8 0.000562 1 2 roc_auc binary 0.5 5 0 Prepro…
## 9 0.1 1 2 accuracy binary 0.734 5 0.00564 Prepro…
## 10 0.1 1 2 roc_auc binary 0.5 5 0 Prepro…
## # … with 240 more rows, and abbreviated variable name ¹.estimator
dt_mod <- decision_tree(
mode = "classification",
cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()
) %>%
set_engine("rpart")
dt_hyper_grid <- grid_regular(
cost_complexity(),
tree_depth(),
min_n(),
levels = 5
)
set.seed(123)
dt_results <- tune_grid(dt_mod, model_recipe, resamples = kfold, grid = dt_hyper_grid)
show_best(dt_results, metric = "roc_auc", n = 5)
## # A tibble: 5 × 9
## cost_complexity tree_depth min_n .metric .estima…¹ mean n std_err .config
## <dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.0000000001 8 30 roc_auc binary 0.823 5 0.00705 Prepro…
## 2 0.0000000178 8 30 roc_auc binary 0.823 5 0.00705 Prepro…
## 3 0.00000316 8 30 roc_auc binary 0.823 5 0.00705 Prepro…
## 4 0.0000000001 8 40 roc_auc binary 0.821 5 0.00764 Prepro…
## 5 0.0000000178 8 40 roc_auc binary 0.821 5 0.00764 Prepro…
## # … with abbreviated variable name ¹.estimator
best_model <- select_best(dt_results, metric = 'roc_auc')
final_wf <- workflow() %>%
add_recipe(model_recipe) %>%
add_model(dt_mod) %>%
finalize_workflow(best_model)
final_fit <- final_wf %>%
fit(data = churn_train)
final_fit %>%
extract_fit_parsnip() %>%
vip(20)
Random Forest
set.seed(123)
split <- initial_split(cr, prop = 0.7, strata = "Status")
train <- training(split)
test <- testing(split)
model_recipe <- recipe(Status ~ ., data = train)
rf_mod <- rand_forest(
mode = "classification",
trees = tune(),
mtry = tune(),
min_n = tune()
) %>%
set_engine("ranger", importance = "impurity")
rf_hyper_grid <- grid_regular(
trees(range = c(50,200)),
mtry(range = c(4,50)),
min_n(range = c(4,20)),
levels = 5)
set.seed(123)
kfold <- vfold_cv(train, v = 5, strata = Status)
rf_results <- tune_grid(rf_mod, model_recipe, resamples = kfold, grid = rf_hyper_grid)
show_best(rf_results, metric = "roc_auc")
## # A tibble: 5 × 9
## mtry trees min_n .metric .estimator mean n std_err .config
## <int> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 4 200 20 roc_auc binary 0.842 5 0.00474 Preprocessor1_Model1…
## 2 4 125 20 roc_auc binary 0.841 5 0.00456 Preprocessor1_Model1…
## 3 4 125 16 roc_auc binary 0.841 5 0.00449 Preprocessor1_Model0…
## 4 4 162 20 roc_auc binary 0.841 5 0.00479 Preprocessor1_Model1…
## 5 4 87 20 roc_auc binary 0.841 5 0.00482 Preprocessor1_Model1…
rf_best_hyperparameters <- select_best(rf_results, metric = "roc_auc")
final_rf_wf <- workflow() %>%
add_recipe(model_recipe) %>%
add_model(rf_mod) %>%
finalize_workflow(rf_best_hyperparameters)
rf_final_fit <- final_rf_wf %>%
fit(data = train)
rf_final_fit %>%
predict(test) %>%
bind_cols(test %>% select(Status)) %>%
conf_mat(truth = "Status", estimate = .pred_class)
## Truth
## Prediction Current Left
## Current 1387 266
## Left 153 291
rf_final_fit %>%
extract_fit_parsnip() %>%
vip(num_features = 10)