Tidymodels Classification

Harold Nelson

2024-04-08

Setup

Load tidyverse and tidymodels Load the dataframe cdc2.

Solution

library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.0     ✔ tibble    3.2.1
## ✔ lubridate 1.9.3     ✔ tidyr     1.3.1
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom        1.0.5      ✔ rsample      1.2.1 
## ✔ dials        1.2.1      ✔ tune         1.2.0 
## ✔ infer        1.0.7      ✔ workflows    1.1.4 
## ✔ modeldata    1.3.0      ✔ workflowsets 1.1.0 
## ✔ parsnip      1.2.1      ✔ yardstick    1.3.1 
## ✔ recipes      1.0.10     
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
## • Use tidymodels_prefer() to resolve common conflicts.
load("cdc2.Rdata")

Split the data

Put 70% of the data in training. Set the strata to gender. Create cdc2_training and cdc2_test.

Solution

set.seed(123)
cdc2_split <- initial_split(cdc2, 
                  prop = .7, 
                  strata = gender)

cdc2_training <- cdc2_split %>%
 training()

cdc2_test <- cdc2_split %>%
 testing()

Verify the Split

Make sure that there are about 7/3 times as many observations in train as in test.

Examine the distributions of gender in the training and test dataframes.

Solution

nrow(cdc2_training)/nrow(cdc2_test)
## [1] 2.332833
7/3
## [1] 2.333333
table(cdc2_training$gender)/nrow(cdc2_training)
## 
##         m         f 
## 0.4783882 0.5216118
table(cdc2_test$gender)/nrow(cdc2_test)
## 
##         m         f 
## 0.4783333 0.5216667

Create A Logistic Regression Model

Copy the code from the datacamp course.

Solution

logistic_model <- logistic_reg() %>% 
  # Set the engine
  set_engine('glm') %>% 
  # Set the mode
  set_mode('classification')

Fit to training data

Use the logistic regression model to predict gender based on weight and smoke100. Print the model.

Solution

logistic_fit <- logistic_model %>% 
  fit(gender ~ weight + smoke100,
      data = cdc2_training)

# Print model fit object
logistic_fit
## parsnip model object
## 
## 
## Call:  stats::glm(formula = gender ~ weight + smoke100, family = stats::binomial, 
##     data = data)
## 
## Coefficients:
## (Intercept)       weight     smoke100  
##     5.74932     -0.03252     -0.41336  
## 
## Degrees of Freedom: 13996 Total (i.e. Null);  13994 Residual
## Null Deviance:       19380 
## Residual Deviance: 15630     AIC: 15640

Create Predictions - Class

Make class_preds for the test data.

Solution

class_preds <- predict(logistic_fit, new_data = cdc2_test,
                       type = "class")

Create Predictions - Probability

Make class_probs for the test data.

Solution

prob_preds <- predict(logistic_fit, new_data = cdc2_test,
                       type = "prob")

Combine

Create a dataframe with the actual values and the predicted class and probability values. Display the first few rows of the dataframe.

Solution

# Combine test set results
cdc2_results <- cdc2_test %>% 
  select(gender) %>% 
  bind_cols(class_preds, prob_preds)

head(cdc2_results)
##    gender .pred_class   .pred_m   .pred_f
## 4       f           f 0.1889616 0.8110384
## 5       f           f 0.2949580 0.7050420
## 9       f           f 0.3874440 0.6125560
## 14      m           m 0.5479315 0.4520685
## 18      m           m 0.6637623 0.3362377
## 25      f           f 0.1277014 0.8722986

Confusion?

Create the confusion matrix for our model.

Solution

conf_mat(cdc2_results, truth = gender,
         estimate = .pred_class)
##           Truth
## Prediction    m    f
##          m 1981  742
##          f  889 2388

Accuracy

Calculate the accuracy.

accuracy(cdc2_results, truth = gender,
         estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.728

Calculate the sensitivity

Solution

sens(cdc2_results, truth = gender,
     estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 sens    binary         0.690

Calculate the specificity

Solution

spec(cdc2_results, truth = gender,
     estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 spec    binary         0.763

Create a custom metric function

Include accuracy, sensitivity, and specifity. Display the results.

Solution

cdc2_metrics <- metric_set(accuracy, sens, spec)

cdc2_metrics(cdc2_results, truth = gender,
                estimate = .pred_class)
## # A tibble: 3 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.728
## 2 sens     binary         0.690
## 3 spec     binary         0.763

Confusion Matrix

Create the confusion matrix and use summary() to display the results.

Solution

conf_mat(cdc2_results,
         truth = gender,
         estimate = .pred_class) %>% 
  # Pass to the summary() function
  summary()
## # A tibble: 13 × 3
##    .metric              .estimator .estimate
##    <chr>                <chr>          <dbl>
##  1 accuracy             binary         0.728
##  2 kap                  binary         0.454
##  3 sens                 binary         0.690
##  4 spec                 binary         0.763
##  5 ppv                  binary         0.728
##  6 npv                  binary         0.729
##  7 mcc                  binary         0.455
##  8 j_index              binary         0.453
##  9 bal_accuracy         binary         0.727
## 10 detection_prevalence binary         0.454
## 11 precision            binary         0.728
## 12 recall               binary         0.690
## 13 f_meas               binary         0.708

Visualization

Create a mosaic plot to visualize the confusion matrix.

# Create a confusion matrix
conf_mat(cdc2_results,
         truth = gender,
         estimate = .pred_class) %>% 
  # Create a mosaic plot
  autoplot(type = 'mosaic')

ROC

Create an ROC plot for our model.

Note that in our case, the “positive” value of gender is “m” for the simple reason that it is higher in alphabetical order than “f”. This makes it the highest factor level.

# Calculate metrics across thresholds
threshold_df <- cdc2_results %>% 
  roc_curve(truth = gender, .pred_m)

# View results
head(threshold_df)
## # A tibble: 6 × 3
##   .threshold specificity sensitivity
##        <dbl>       <dbl>       <dbl>
## 1  -Inf         0               1   
## 2     0.0466    0               1   
## 3     0.0528    0.000319        1   
## 4     0.0574    0.000639        1   
## 5     0.0591    0.000639        1.00
## 6     0.0597    0.000958        1.00
# Plot ROC curve
threshold_df %>% 
  autoplot()

Streamlining

Train the model with last_fit() and then look at the metrics.

Solution

# Train model with last_fit()
cdc2_last_fit <- logistic_model %>% 
  last_fit(gender ~ weight + smoke100,
           split = cdc2_split)

# View test set metrics
cdc2_last_fit %>% 
  collect_metrics()
## # A tibble: 3 × 4
##   .metric     .estimator .estimate .config             
##   <chr>       <chr>          <dbl> <chr>               
## 1 accuracy    binary         0.728 Preprocessor1_Model1
## 2 roc_auc     binary         0.801 Preprocessor1_Model1
## 3 brier_class binary         0.186 Preprocessor1_Model1

Results

Get the results from last_fit() and examine them with head()

Solution

# Collect predictions
last_fit_results <- cdc2_last_fit %>% 
  collect_predictions()

# View results
head(last_fit_results)
## # A tibble: 6 × 7
##   .pred_class .pred_m .pred_f id                .row gender .config             
##   <fct>         <dbl>   <dbl> <chr>            <int> <fct>  <chr>               
## 1 f             0.189   0.811 train/test split     4 f      Preprocessor1_Model1
## 2 f             0.295   0.705 train/test split     5 f      Preprocessor1_Model1
## 3 f             0.387   0.613 train/test split     9 f      Preprocessor1_Model1
## 4 m             0.548   0.452 train/test split    14 m      Preprocessor1_Model1
## 5 m             0.664   0.336 train/test split    18 m      Preprocessor1_Model1
## 6 f             0.128   0.872 train/test split    25 f      Preprocessor1_Model1

Metrics

Create a metrics set and use it to examine the last_fit results.

Remember that “m” is the positive value.

# Custom metrics function
last_fit_metrics <- metric_set(accuracy, sens,
                               spec, roc_auc)

# Calculate metrics
last_fit_metrics(last_fit_results,
                 truth = gender,
                 estimate = .pred_class,
                 .pred_m)
## # A tibble: 4 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.728
## 2 sens     binary         0.690
## 3 spec     binary         0.763
## 4 roc_auc  binary         0.801

A Function

Define a function to accept the formula as a parameter and return the basic metrics. Run it with the formula gender ~ weight + height.

Solution

# Define a function to perform the logistic regression
perform_logistic_regression <- function(formula)
{
  model <- logistic_model %>% 
    last_fit(formula, split = cdc2_split)
  
  metrics <- model %>% 
    collect_metrics()
  
  return(metrics)
}

perform_logistic_regression(gender ~ weight + height)
## # A tibble: 3 × 4
##   .metric     .estimator .estimate .config             
##   <chr>       <chr>          <dbl> <chr>               
## 1 accuracy    binary         0.845 Preprocessor1_Model1
## 2 roc_auc     binary         0.927 Preprocessor1_Model1
## 3 brier_class binary         0.107 Preprocessor1_Model1