This piece of code uses UC Irvine Bank Marketing Dataset (Moro,S., Rita,P., and Cortez,P.. (2012). Bank Marketing. UCI Machine Learning Repository. https://doi.org/10.24432/C5K306.)
library(readr)
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
## ✔ broom 1.0.5 ✔ recipes 1.0.8
## ✔ dials 1.2.0 ✔ rsample 1.2.0
## ✔ dplyr 1.1.2 ✔ tibble 3.2.1
## ✔ ggplot2 3.4.3 ✔ tidyr 1.3.0.9000
## ✔ infer 1.0.4 ✔ tune 1.1.2
## ✔ modeldata 1.2.0 ✔ workflows 1.1.3
## ✔ parsnip 1.1.1 ✔ workflowsets 1.0.1
## ✔ purrr 1.0.2 ✔ yardstick 1.2.0
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Search for functions across packages at https://www.tidymodels.org/find/
install.packages("tidyverse")
## Installing tidyverse [2.0.0] ...
## OK [linked cache in 0.21 milliseconds]
## * Installed 1 package in 1.9 seconds.
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ forcats 1.0.0 ✔ stringr 1.5.0
## ✔ lubridate 1.9.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ scales::col_factor() masks readr::col_factor()
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ stringr::fixed() masks recipes::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
# Read the dataset and convert the target variable to a factor
bank_df <- read_csv2("bank-full.csv")
## ℹ Using "','" as decimal and "'.'" as grouping mark. Use `read_delim()` for more control.
## Rows: 45211 Columns: 17── Column specification ────────────────────────────────────────────────────────
## Delimiter: ";"
## chr (10): job, marital, education, default, housing, loan, contact, month, p...
## dbl (7): age, balance, day, duration, campaign, pdays, previous
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
bank_df$y = as.factor(bank_df$y)
# Split data into train and test
set.seed(421)
split <- initial_split(bank_df, prop = 0.8, strata = y)
train <- split %>%
training()
test <- split %>%
testing()
# Train a logistic regression model
model <- logistic_reg(mixture = double(1), penalty = double(1)) %>%
set_engine("glmnet") %>%
set_mode("classification") %>%
fit(y ~ ., data = train)
# Model summary
tidy(model)
## Loading required package: Matrix
##
## Attaching package: 'Matrix'
##
## The following objects are masked from 'package:tidyr':
##
## expand, pack, unpack
##
## Loaded glmnet 4.1-8
## # A tibble: 43 × 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) -2.59 0
## 2 age -0.000477 0
## 3 jobblue-collar -0.183 0
## 4 jobentrepreneur -0.206 0
## 5 jobhousemaid -0.270 0
## 6 jobmanagement -0.0190 0
## 7 jobretired 0.360 0
## 8 jobself-employed -0.101 0
## 9 jobservices -0.105 0
## 10 jobstudent 0.415 0
## # ℹ 33 more rows
# Class Predictions
pred_class <- predict(model,
new_data = test,
type = "class")
# Class Probabilities
pred_proba <- predict(model,
new_data = test,
type = "prob")
results <- test %>%
select(y) %>%
bind_cols(pred_class, pred_proba)
print(results,n=20)
## # A tibble: 9,043 × 4
## y .pred_class .pred_no .pred_yes
## <fct> <fct> <dbl> <dbl>
## 1 no no 0.975 0.0253
## 2 no no 0.979 0.0209
## 3 no no 0.987 0.0135
## 4 no no 0.987 0.0132
## 5 no no 0.975 0.0252
## 6 no no 0.975 0.0255
## 7 no no 0.853 0.147
## 8 no no 0.986 0.0138
## 9 no no 0.982 0.0184
## 10 no no 0.948 0.0520
## 11 yes yes 0.266 0.734
## 12 yes yes 0.464 0.536
## 13 no no 0.982 0.0178
## 14 no no 0.982 0.0181
## 15 no no 0.973 0.0269
## 16 no no 0.982 0.0183
## 17 no no 0.992 0.00822
## 18 no no 0.988 0.0121
## 19 no no 0.975 0.0251
## 20 no no 0.990 0.0101
## # ℹ 9,023 more rows
# Create confusion matrix
conf_mat(results, truth = y,
estimate = .pred_class)
## Truth
## Prediction no yes
## no 7838 738
## yes 147 320
accuracy(results, truth = y, estimate = .pred_class)
## # A tibble: 1 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.902
The plot looks like this