library(dplyr)
## Warning: package 'dplyr' was built under R version 4.2.3
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(ggplot2)
## Warning: package 'ggplot2' was built under R version 4.2.3
library(rsample)
library(tidyverse)
## Warning: package 'tidyverse' was built under R version 4.2.3
## Warning: package 'tibble' was built under R version 4.2.3
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ forcats 1.0.0 ✔ stringr 1.5.0
## ✔ lubridate 1.9.2 ✔ tibble 3.2.1
## ✔ purrr 1.0.1 ✔ tidyr 1.3.0
## ✔ readr 2.1.4
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the ]8;;http://conflicted.r-lib.org/conflicted package]8;; to force all conflicts to become errors
library(caret)
## Warning: package 'caret' was built under R version 4.2.3
## Loading required package: lattice
##
## Attaching package: 'caret'
##
## The following object is masked from 'package:purrr':
##
## lift
#Model interpretability packages
library(vip)
## Warning: package 'vip' was built under R version 4.2.3
##
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
##
## vi
library(ISLR)
## Warning: package 'ISLR' was built under R version 4.2.3
data <- Default
head(data)
## default student balance income
## 1 No No 729.5265 44361.625
## 2 No Yes 817.1804 12106.135
## 3 No No 1073.5492 31767.139
## 4 No No 529.2506 35704.494
## 5 No No 785.6559 38463.496
## 6 No Yes 919.5885 7491.559
#Convert ordered factor columns into factor columns only!
df <- data %>% mutate_if(is.ordered, factor, ordered = FALSE)
head(df)
## default student balance income
## 1 No No 729.5265 44361.625
## 2 No Yes 817.1804 12106.135
## 3 No No 1073.5492 31767.139
## 4 No No 529.2506 35704.494
## 5 No No 785.6559 38463.496
## 6 No Yes 919.5885 7491.559
#Create training (70%) and test (30%) sets
set.seed(123)
churn_split <- initial_split(data, prop = .7, strata ="default")
churn_train <- training(churn_split)
churn_test <- testing(churn_split)
#Simple logistic regression —-
model1 <- glm(default ~ student,
family = "binomial", data = churn_train)
model2 <- glm(default ~ balance, family = "binomial",
data = churn_train)
model3 <- glm(default ~ ., family = "binomial",
data = churn_train)
exp(coef(model1))
## (Intercept) studentYes
## 0.02991275 1.46985060
exp(coef(model2))
## (Intercept) balance
## 2.357365e-05 1.005502e+00
exp(coef(model3))
## (Intercept) studentYes balance income
## 2.149248e-05 4.964063e-01 1.005730e+00 1.000000e+00
confint(model1)
## Waiting for profiling to be done...
## 2.5 % 97.5 %
## (Intercept) -3.6796696 -3.3479134
## studentYes 0.1096575 0.6548841
confint(model2)
## Waiting for profiling to be done...
## 2.5 % 97.5 %
## (Intercept) -11.534993033 -9.843361105
## balance 0.004990585 0.006018433
confint(model3)
## Waiting for profiling to be done...
## 2.5 % 97.5 %
## (Intercept) -1.192878e+01 -9.636720e+00
## studentYes -1.251043e+00 -1.467951e-01
## balance 5.192736e-03 6.272405e-03
## income -1.902737e-05 1.938777e-05
set.seed(123)
cv_model1 <- train(as.factor(default) ~ student,
data = churn_train,
method = "rf",
tuneGrid = data.frame(mtry = 2))
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
set.seed(123)
cv_model2 <- train(as.factor(default) ~ balance,
data = churn_train,
method = "rf",
tuneGrid = data.frame(mtry = 2))
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid mtry:
## reset to within valid range
set.seed(123)
cv_model3 <- train(as.factor(default) ~ .,
data = churn_train,
method = "rf",
tuneGrid = data.frame(mtry = 2))
#Extract out of sample performance measures
summary(
resamples(
list(
model1 = cv_model1,
model2 = cv_model2,
model3 = cv_model3
)
)
)$statistics$Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## model1 0.9590004 0.9644087 0.9680608 0.9670734 0.9694323 0.9731286 0
## model2 0.9500195 0.9549164 0.9581558 0.9572415 0.9594907 0.9638554 0
## model3 0.9632556 0.9659179 0.9678444 0.9680037 0.9700576 0.9720802 0
#predict class
pred_class <- predict(cv_model3, churn_train)
#create confusion matrix
confusionMatrix(
data = relevel(pred_class, ref="Yes"),
reference = relevel(churn_train$default, ref="Yes")
)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Yes No
## Yes 230 0
## No 0 6770
##
## Accuracy : 1
## 95% CI : (0.9995, 1)
## No Information Rate : 0.9671
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 1
##
## Mcnemar's Test P-Value : NA
##
## Sensitivity : 1.00000
## Specificity : 1.00000
## Pos Pred Value : 1.00000
## Neg Pred Value : 1.00000
## Prevalence : 0.03286
## Detection Rate : 0.03286
## Detection Prevalence : 0.03286
## Balanced Accuracy : 1.00000
##
## 'Positive' Class : Yes
##
library(ROCR)
## Warning: package 'ROCR' was built under R version 4.2.3
#Compute predicted probabilities
m1_prob <- predict(cv_model1, churn_train, type = "prob")$Yes
m3_prob <- predict(cv_model3, churn_train, type = "prob")$Yes
perf1 <- prediction(m1_prob, churn_train$default) %>%
performance(measure = "tpr", x.measure = "fpr")
perf2 <- prediction(m3_prob, churn_train$default) %>%
performance(measure = "tpr", x.measure = "fpr")
plot(perf1, col = "black", lty = 2)
plot(perf2, add = TRUE, col = "blue")
legend(0.8, 0.2, legend = c("cv_model1", "cv_model3"),
col = c("black", "blue"), lty = 2:1, cex = 0.6)