library(tidyverse)
library(ISLR2) # For Dataset
library(leaps)

theme_set(theme_bw())

Explore Data

I will use Hitters dataset from {ISLR2} package.

Let’s see what we’ve got.

The goal here is to predict variable Salary.

names(Hitters)
##  [1] "AtBat"     "Hits"      "HmRun"     "Runs"      "RBI"       "Walks"    
##  [7] "Years"     "CAtBat"    "CHits"     "CHmRun"    "CRuns"     "CRBI"     
## [13] "CWalks"    "League"    "Division"  "PutOuts"   "Assists"   "Errors"   
## [19] "Salary"    "NewLeague"
dim(Hitters)
## [1] 322  20
skimr::skim(Hitters)
Data summary
Name Hitters
Number of rows 322
Number of columns 20
_______________________
Column type frequency:
factor 3
numeric 17
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
League 0 1 FALSE 2 A: 175, N: 147
Division 0 1 FALSE 2 W: 165, E: 157
NewLeague 0 1 FALSE 2 A: 176, N: 146

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
AtBat 0 1.00 380.93 153.40 16.0 255.25 379.5 512.00 687 ▁▇▇▇▅
Hits 0 1.00 101.02 46.45 1.0 64.00 96.0 137.00 238 ▃▇▆▃▁
HmRun 0 1.00 10.77 8.71 0.0 4.00 8.0 16.00 40 ▇▃▂▁▁
Runs 0 1.00 50.91 26.02 0.0 30.25 48.0 69.00 130 ▅▇▆▃▁
RBI 0 1.00 48.03 26.17 0.0 28.00 44.0 64.75 121 ▃▇▃▃▁
Walks 0 1.00 38.74 21.64 0.0 22.00 35.0 53.00 105 ▅▇▅▂▁
Years 0 1.00 7.44 4.93 1.0 4.00 6.0 11.00 24 ▇▆▃▂▁
CAtBat 0 1.00 2648.68 2324.21 19.0 816.75 1928.0 3924.25 14053 ▇▃▂▁▁
CHits 0 1.00 717.57 654.47 4.0 209.00 508.0 1059.25 4256 ▇▃▁▁▁
CHmRun 0 1.00 69.49 86.27 0.0 14.00 37.5 90.00 548 ▇▁▁▁▁
CRuns 0 1.00 358.80 334.11 1.0 100.25 247.0 526.25 2165 ▇▂▁▁▁
CRBI 0 1.00 330.12 333.22 0.0 88.75 220.5 426.25 1659 ▇▂▁▁▁
CWalks 0 1.00 260.24 267.06 0.0 67.25 170.5 339.25 1566 ▇▂▁▁▁
PutOuts 0 1.00 288.94 280.70 0.0 109.25 212.0 325.00 1378 ▇▃▁▁▁
Assists 0 1.00 106.91 136.85 0.0 7.00 39.5 166.00 492 ▇▂▁▁▁
Errors 0 1.00 8.04 6.37 0.0 3.00 6.0 11.00 32 ▇▅▂▁▁
Salary 59 0.82 535.93 451.12 67.5 190.00 425.0 750.00 2460 ▇▃▁▁▁

Missing Values

Let’s remove rows containg missing values.

visdat::vis_miss(Hitters)

Hitters <- Hitters %>% na.omit()

sum(is.na(Hitters$Salary))
## [1] 0

Subset Selection

Which combinations of variables gives the lowest test error rate at each model sizes?

Best Subset Selection Method

regfit.full <- leaps::regsubsets(Salary ~ ., Hitters)
summary(regfit.full)
## Subset selection object
## Call: regsubsets.formula(Salary ~ ., Hitters)
## 19 Variables  (and intercept)
##            Forced in Forced out
## AtBat          FALSE      FALSE
## Hits           FALSE      FALSE
## HmRun          FALSE      FALSE
## Runs           FALSE      FALSE
## RBI            FALSE      FALSE
## Walks          FALSE      FALSE
## Years          FALSE      FALSE
## CAtBat         FALSE      FALSE
## CHits          FALSE      FALSE
## CHmRun         FALSE      FALSE
## CRuns          FALSE      FALSE
## CRBI           FALSE      FALSE
## CWalks         FALSE      FALSE
## LeagueN        FALSE      FALSE
## DivisionW      FALSE      FALSE
## PutOuts        FALSE      FALSE
## Assists        FALSE      FALSE
## Errors         FALSE      FALSE
## NewLeagueN     FALSE      FALSE
## 1 subsets of each size up to 8
## Selection Algorithm: exhaustive
##          AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun CRuns CRBI
## 1  ( 1 ) " "   " "  " "   " "  " " " "   " "   " "    " "   " "    " "   "*" 
## 2  ( 1 ) " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "   "*" 
## 3  ( 1 ) " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "   "*" 
## 4  ( 1 ) " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "   "*" 
## 5  ( 1 ) "*"   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "   "*" 
## 6  ( 1 ) "*"   "*"  " "   " "  " " "*"   " "   " "    " "   " "    " "   "*" 
## 7  ( 1 ) " "   "*"  " "   " "  " " "*"   " "   "*"    "*"   "*"    " "   " " 
## 8  ( 1 ) "*"   "*"  " "   " "  " " "*"   " "   " "    " "   "*"    "*"   " " 
##          CWalks LeagueN DivisionW PutOuts Assists Errors NewLeagueN
## 1  ( 1 ) " "    " "     " "       " "     " "     " "    " "       
## 2  ( 1 ) " "    " "     " "       " "     " "     " "    " "       
## 3  ( 1 ) " "    " "     " "       "*"     " "     " "    " "       
## 4  ( 1 ) " "    " "     "*"       "*"     " "     " "    " "       
## 5  ( 1 ) " "    " "     "*"       "*"     " "     " "    " "       
## 6  ( 1 ) " "    " "     "*"       "*"     " "     " "    " "       
## 7  ( 1 ) " "    " "     "*"       "*"     " "     " "    " "       
## 8  ( 1 ) "*"    " "     "*"       "*"     " "     " "    " "

Try to include all predictors; I will increase nvmax.

regfit.full <- regsubsets(Salary ~ ., data = Hitters,
    nvmax = 19)
reg.summary <- summary(regfit.full)
reg.summary 
## Subset selection object
## Call: regsubsets.formula(Salary ~ ., data = Hitters, nvmax = 19)
## 19 Variables  (and intercept)
##            Forced in Forced out
## AtBat          FALSE      FALSE
## Hits           FALSE      FALSE
## HmRun          FALSE      FALSE
## Runs           FALSE      FALSE
## RBI            FALSE      FALSE
## Walks          FALSE      FALSE
## Years          FALSE      FALSE
## CAtBat         FALSE      FALSE
## CHits          FALSE      FALSE
## CHmRun         FALSE      FALSE
## CRuns          FALSE      FALSE
## CRBI           FALSE      FALSE
## CWalks         FALSE      FALSE
## LeagueN        FALSE      FALSE
## DivisionW      FALSE      FALSE
## PutOuts        FALSE      FALSE
## Assists        FALSE      FALSE
## Errors         FALSE      FALSE
## NewLeagueN     FALSE      FALSE
## 1 subsets of each size up to 19
## Selection Algorithm: exhaustive
##           AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun CRuns CRBI
## 1  ( 1 )  " "   " "  " "   " "  " " " "   " "   " "    " "   " "    " "   "*" 
## 2  ( 1 )  " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "   "*" 
## 3  ( 1 )  " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "   "*" 
## 4  ( 1 )  " "   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "   "*" 
## 5  ( 1 )  "*"   "*"  " "   " "  " " " "   " "   " "    " "   " "    " "   "*" 
## 6  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   " "    " "   " "    " "   "*" 
## 7  ( 1 )  " "   "*"  " "   " "  " " "*"   " "   "*"    "*"   "*"    " "   " " 
## 8  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   " "    " "   "*"    "*"   " " 
## 9  ( 1 )  "*"   "*"  " "   " "  " " "*"   " "   "*"    " "   " "    "*"   "*" 
## 10  ( 1 ) "*"   "*"  " "   " "  " " "*"   " "   "*"    " "   " "    "*"   "*" 
## 11  ( 1 ) "*"   "*"  " "   " "  " " "*"   " "   "*"    " "   " "    "*"   "*" 
## 12  ( 1 ) "*"   "*"  " "   "*"  " " "*"   " "   "*"    " "   " "    "*"   "*" 
## 13  ( 1 ) "*"   "*"  " "   "*"  " " "*"   " "   "*"    " "   " "    "*"   "*" 
## 14  ( 1 ) "*"   "*"  "*"   "*"  " " "*"   " "   "*"    " "   " "    "*"   "*" 
## 15  ( 1 ) "*"   "*"  "*"   "*"  " " "*"   " "   "*"    "*"   " "    "*"   "*" 
## 16  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   " "   "*"    "*"   " "    "*"   "*" 
## 17  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   " "   "*"    "*"   " "    "*"   "*" 
## 18  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   "*"   "*"    "*"   " "    "*"   "*" 
## 19  ( 1 ) "*"   "*"  "*"   "*"  "*" "*"   "*"   "*"    "*"   "*"    "*"   "*" 
##           CWalks LeagueN DivisionW PutOuts Assists Errors NewLeagueN
## 1  ( 1 )  " "    " "     " "       " "     " "     " "    " "       
## 2  ( 1 )  " "    " "     " "       " "     " "     " "    " "       
## 3  ( 1 )  " "    " "     " "       "*"     " "     " "    " "       
## 4  ( 1 )  " "    " "     "*"       "*"     " "     " "    " "       
## 5  ( 1 )  " "    " "     "*"       "*"     " "     " "    " "       
## 6  ( 1 )  " "    " "     "*"       "*"     " "     " "    " "       
## 7  ( 1 )  " "    " "     "*"       "*"     " "     " "    " "       
## 8  ( 1 )  "*"    " "     "*"       "*"     " "     " "    " "       
## 9  ( 1 )  "*"    " "     "*"       "*"     " "     " "    " "       
## 10  ( 1 ) "*"    " "     "*"       "*"     "*"     " "    " "       
## 11  ( 1 ) "*"    "*"     "*"       "*"     "*"     " "    " "       
## 12  ( 1 ) "*"    "*"     "*"       "*"     "*"     " "    " "       
## 13  ( 1 ) "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 14  ( 1 ) "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 15  ( 1 ) "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 16  ( 1 ) "*"    "*"     "*"       "*"     "*"     "*"    " "       
## 17  ( 1 ) "*"    "*"     "*"       "*"     "*"     "*"    "*"       
## 18  ( 1 ) "*"    "*"     "*"       "*"     "*"     "*"    "*"       
## 19  ( 1 ) "*"    "*"     "*"       "*"     "*"     "*"    "*"

We can use plot() to show predictors included at each Cp, BIC, or Adjusted R-squared.

plot(regfit.full, scale = "Cp")

Evaluate Model Performance

Using Cp, BIC, Adjusted R Squared

broom::tidy() can also be used in regsubsets class

regfit_tidy <- broom::tidy(regfit.full)
regfit_tidy 
regfit_tidy_prep <- regfit_tidy %>% 
  rowwise() %>% 
  mutate(pred_no = sum(c_across(AtBat:NewLeagueN)), .keep = "unused") %>% 
  ungroup() 
  
regfit_tidy_prep

Plotting \(C_p\), BIC, \(Adjusted R^2\)

Let’s create method for autoplot() of class “regsubsets”

autoplot.regsubsets <- function(x, 
                                res = c("mallows_cp", "BIC","r.squared","adj.r.squared")
                                ) {
  
  res <- match.arg(res)
  res_sym <- dplyr::ensym(res)
  
  df_params <- broom::tidy(x)
  df_reduced <- df_params %>% 
    dplyr::rowwise() %>% 
    dplyr::mutate(
      # Add Number of Predictors
      pred_no = sum(dplyr::c_across(
        !tidyselect::any_of(c("(Intercept)", "r.squared","adj.r.squared",
                              "BIC","mallows_cp")))), .keep = "unused")  %>% 
    dplyr::ungroup() 
  
  
  ### Find X & Y-Coordinate of the best model 
  #### i.e., lowest point for "mallows_cp", "BIC" and highest point for    "r.squared","adj.r.squared"
  fun <- if(res %in% c("mallows_cp", "BIC")) min else max
  
  best_y <- fun(df_reduced[[res]])
  best_x <- df_reduced %>% 
      dplyr::filter(!!res_sym == fun(!!res_sym)) %>% dplyr::pull(pred_no)
  
  direction_col <- ifelse(res %in% c("mallows_cp", "BIC"), -1, 1)
  
  # Plot
  df_reduced %>%
    ggplot2::ggplot(ggplot2::aes(pred_no, !!res_sym, color = !!res_sym)) +
    ggplot2::geom_point(show.legend = F) +
    ggplot2::geom_line(show.legend = F) +
    ggplot2::scale_color_viridis_c(option = "plasma", end = 0.8, 
                                   direction = direction_col) +
    ggplot2::annotate("point", x = best_x, y = best_y, shape = 4, size = 5, stroke = 0.5)+
    ggplot2::labs(x = "Number of Predictors") 
  
}
library(patchwork)
library(latex2exp)

p_cp <- autoplot(regfit.full, res = "mallows_cp") +
  labs(y = TeX("C_p"))

p_bic <- autoplot(regfit.full, res = "BIC")

p_adj_rsq <- autoplot(regfit.full, res = "adj.r.squared") +
  labs(y = TeX("Adjusted R^2"))

p_cp + p_bic + p_adj_rsq +
  plot_annotation(title = "Best Subset Selection at Each Model Sizes", 
                  subtitle = TeX("Estimate test error by C_p, BIC, and Adjusted R^2"),
                  caption = "Data from Hitters dataset in ISLR2 package")

lbr::ggsave_mac(here("plot/Hitters_BestSubset.png"))

Using Cross-Validation

I need to build a lot of helper functions.

Predict Method for “regsubsets”

“regsubsets” object has no predict() method, so we need to create ourselves.

Luckily, ISLR’s lab demonstration already included predict.regsubsets() in the R-Markdown material.

I will modify it a little bit so that it returns a vector instead of matrix.

predict.regsubsets <- function(object, newdata, id, ...) {
  form <- as.formula(object[["call"]][[2]])
  mat <- model.matrix(form, newdata)
  coefi <- coef(object, id = id)
  xvars <- names(coefi)
  # Matrix
  pred_mat <- mat[, xvars] %*% coefi
  # Vector
  pred <- as.numeric(pred_mat)
  names(pred) <- rownames(mat)
  pred
}

predict(regfit.full, newdata = Hitters, id = 3) %>% head()
##       -Alan Ashby      -Alvin Davis     -Andre Dawson -Andres Galarraga 
##         611.11976         715.34087         950.55323         424.10211 
##  -Alfredo Griffin        -Al Newman 
##         708.86493          59.21692

Model Summary for “regsubsets”

Next, I will implement broom::glance() method for “regsubsets” object. This will provide a models summary at each number of predictors.

glance.regsubsets <- function(x, newdata, ...) {
  
  n_predictors <- x$np - 1
  y_var <- as.formula(x$call[[2]])[[2]]
  y <- newdata[[y_var]]
  
  mse <- numeric(n_predictors)
  
  for (i in 1:n_predictors) {
    y_hat_i <- predict.regsubsets(x, newdata, id = i, ...)

    # Mean Squared Error
    mse[[i]] <- mean(c(y - y_hat_i)^2)
  }
  
  
  tibble::tibble(
    n_predictors = 1:n_predictors,
    MSE = mse,
    r.squared = summary(x)[["rsq"]],
    adj.r.squared = summary(x)[["adjr2"]],
    mallows_cp = summary(x)[["cp"]],
    BIC = summary(x)[["bic"]]
  )
  
}

broom::glance(regfit.full, newdata = Hitters)

Split Data into Folds

Now, let’s split data into 10-fold using vfold_cv() function from {rsample} package.

library(rsample)
set.seed(123)

Hitters_folds <- vfold_cv(Hitters, v = 10) 

class(Hitters_folds)
## [1] "vfold_cv"   "rset"       "tbl_df"     "tbl"        "data.frame"
Hitters_folds

Analysis (train) and Assessment (hold-out) data can be obtained by analysis() and assessment(), respectively.

Hitters_folds$splits[[1]] %>% analysis()
Hitters_folds$splits[[1]] %>% assessment()

Up next, regsubsets_cv_glance() will compute cross-validation’s summary statistics using method and functions that I’ve just defined above.

regsubsets_cv_glance <- function(x, vfold_cv, return_folds = FALSE, ...) {
  
 # Add analysis & assess DF
 vfold_cv <- vfold_cv %>% 
   dplyr::mutate(
     analysis_df = purrr::map(splits, rsample::analysis),
     assess_df = purrr::map(splits, rsample::assessment)
     ) 
  
  # Fit `regsubsets` to each folds
  fitted_folds <- vector("list", nrow(vfold_cv))
  n_predictors <- ncol(vfold_cv$analysis_df[[1]]) - 1 # Number of Predictors

  for (i in 1:nrow(vfold_cv)) {
    fitted_folds[[i]] <-
      leaps::regsubsets(x, data = vfold_cv$analysis_df[[i]], nvmax = n_predictors, ...)
    names(fitted_folds) <- 
      sprintf(paste0("Fold%0", nchar(nrow(vfold_cv)), "d"), 1:nrow(vfold_cv))
  }
  
  params_df <- purrr::map2_dfr(
    .x = fitted_folds,
    .y = vfold_cv$assess_df,
    # Compute multi-models summary at each folds
    ~broom::glance(.x, newdata = .y), 
    .id = "id"
    ) %>% 
    dplyr::select(id, n_predictors, MSE, r.squared)
  
  if(return_folds) return(params_df)
  
  params_df %>% 
    dplyr::group_by(n_predictors) %>% 
    dplyr::summarise(across(MSE:r.squared, mean, .names = "{.col}_cv"))

}

regsubsets_cv_glance(Salary ~ ., Hitters_folds, F) 

Compute Cross-Validation Estimates

MSE_cv is a mean squared error at each predictor subsets; each MSE_cv is averaged over 10 held-out sets.

Hitters_cv_fitted <- regsubsets_cv_glance(Salary ~ ., Hitters_folds) 

Hitters_cv_fitted

Plotting Cross-Validation Estimates

The Final helper functions, regsubsets_cv_plot(), is for plotting 10-fold Cross-Validation MSE at each number of predictors.

regsubsets_cv_plot <- function(df_glanced, 
                               res = c("MSE_cv","r.squared_cv")
                          ) {
  
  res <- match.arg(res)
  res_sym <- dplyr::ensym(res)
  
  ### Find X & Y-Coordinate of the best model 
  fun <- if(res %in% c("MSE_cv")) min else max
  
  best_y <- fun(df_glanced[[res]])
  best_x <- df_glanced %>%
      dplyr::filter(!!res_sym == fun(!!res_sym)) %>% dplyr::pull(n_predictors)

  direction_col <- ifelse(res %in% c("MSE_cv"), -1, 1)
  
  # Plot
  df_glanced %>%
    ggplot2::ggplot(ggplot2::aes(n_predictors, !!res_sym, color = !!res_sym)) +
    ggplot2::geom_point(show.legend = F) +
    ggplot2::geom_line(show.legend = F) +
    ggplot2::scale_color_viridis_c(option = "plasma", end = 0.8,
                                   direction = direction_col) +
    ggplot2::annotate("point", x = best_x, y = best_y, shape = 4, size = 5, stroke = 0.5)+
    ggplot2::labs(x = "Number of Predictors")
}
regsubsets_cv_plot(Hitters_cv_fitted, "MSE_cv") +
  labs(title = "Best Subset Selection by 10-Fold Cross-Validation",
       y = TeX("MSE_{10-fold-CV}")
       )

Final Comparisons

To see the overall picture, All methods that estimate the test error that we’ve explored so far will be plotted in this final plots. “X” symbols indicate the best model for each methods.

library(patchwork)
library(latex2exp)

p_cp <- autoplot(regfit.full, res = "mallows_cp") +
  labs(y = TeX("C_p"))

p_bic <- autoplot(regfit.full, res = "BIC")

p_adj_rsq <- autoplot(regfit.full, res = "adj.r.squared") +
  labs(y = TeX("Adjusted R^2"))

p_mse_cv <- regsubsets_cv_plot(Hitters_cv_fitted, "MSE_cv") +
  labs(y = TeX("MSE_{10-fold-CV}"))

p_cp + p_bic + p_adj_rsq + p_mse_cv +
  plot_annotation(title = "Best Subset Selection", 
                  subtitle = TeX("Estimate test error by C_p, BIC, Adjusted R^2, and 10-fold Cross-Validation"),
                  caption = "Data from Hitters dataset in ISLR2 package")

lbr::ggsave_mac(here("plot/Hitters_BestSubsetAll.png"))
devtools::session_info()
## ─ Session info ───────────────────────────────────────────────────────────────
##  setting  value                       
##  version  R version 4.1.0 (2021-05-18)
##  os       macOS Big Sur 10.16         
##  system   x86_64, darwin17.0          
##  ui       X11                         
##  language (EN)                        
##  collate  en_US.UTF-8                 
##  ctype    en_US.UTF-8                 
##  tz       Asia/Bangkok                
##  date     2021-10-14                  
## 
## ─ Packages ───────────────────────────────────────────────────────────────────
##  package       * version    date       lib source                             
##  assertthat      0.2.1      2019-03-21 [1] CRAN (R 4.1.0)                     
##  backports       1.2.1      2020-12-09 [1] CRAN (R 4.1.0)                     
##  base64enc       0.1-3      2015-07-28 [1] CRAN (R 4.1.0)                     
##  broom           0.7.9      2021-07-27 [1] CRAN (R 4.1.0)                     
##  bslib           0.2.5.1    2021-05-18 [1] CRAN (R 4.1.0)                     
##  cachem          1.0.5      2021-05-15 [1] CRAN (R 4.1.0)                     
##  callr           3.7.0      2021-04-20 [1] CRAN (R 4.1.0)                     
##  cellranger      1.1.0      2016-07-27 [1] CRAN (R 4.1.0)                     
##  class           7.3-19     2021-05-03 [1] CRAN (R 4.1.0)                     
##  cli             3.0.1      2021-07-17 [1] CRAN (R 4.1.0)                     
##  codetools       0.2-18     2020-11-04 [1] CRAN (R 4.1.0)                     
##  colorspace      2.0-2      2021-06-24 [1] CRAN (R 4.1.0)                     
##  corrplot        0.90       2021-06-30 [1] CRAN (R 4.1.0)                     
##  crayon          1.4.1      2021-02-08 [1] CRAN (R 4.1.0)                     
##  DBI             1.1.1      2021-01-15 [1] CRAN (R 4.1.0)                     
##  dbplyr          2.1.1      2021-04-06 [1] CRAN (R 4.1.0)                     
##  desc            1.3.0      2021-03-05 [1] CRAN (R 4.1.0)                     
##  devtools        2.4.2      2021-06-07 [1] CRAN (R 4.1.0)                     
##  digest          0.6.28     2021-09-23 [1] CRAN (R 4.1.0)                     
##  dplyr         * 1.0.7      2021-06-18 [1] CRAN (R 4.1.0)                     
##  ellipsis        0.3.2      2021-04-29 [1] CRAN (R 4.1.0)                     
##  evaluate        0.14       2019-05-28 [1] CRAN (R 4.1.0)                     
##  fansi           0.5.0      2021-05-25 [1] CRAN (R 4.1.0)                     
##  farver          2.1.0      2021-02-28 [1] CRAN (R 4.1.0)                     
##  fastmap         1.1.0      2021-01-25 [1] CRAN (R 4.1.0)                     
##  forcats       * 0.5.1      2021-01-27 [1] CRAN (R 4.1.0)                     
##  fs              1.5.0      2020-07-31 [1] CRAN (R 4.1.0)                     
##  furrr           0.2.3      2021-06-25 [1] CRAN (R 4.1.0)                     
##  future          1.22.1     2021-08-25 [1] CRAN (R 4.1.0)                     
##  gargle          1.2.0      2021-07-02 [1] CRAN (R 4.1.0)                     
##  generics        0.1.0      2020-10-31 [1] CRAN (R 4.1.0)                     
##  ggplot2       * 3.3.5      2021-06-25 [1] CRAN (R 4.1.0)                     
##  ggrepel         0.9.1      2021-01-15 [1] CRAN (R 4.1.0)                     
##  globals         0.14.0     2020-11-22 [1] CRAN (R 4.1.0)                     
##  glue            1.4.2      2020-08-27 [1] CRAN (R 4.1.0)                     
##  googledrive     2.0.0      2021-07-08 [1] CRAN (R 4.1.0)                     
##  googlesheets4   1.0.0      2021-07-21 [1] CRAN (R 4.1.0)                     
##  gower           0.2.2      2020-06-23 [1] CRAN (R 4.1.0)                     
##  gtable          0.3.0      2019-03-25 [1] CRAN (R 4.1.0)                     
##  haven           2.4.3      2021-08-04 [1] CRAN (R 4.1.0)                     
##  here          * 1.0.1      2020-12-13 [1] CRAN (R 4.1.0)                     
##  highr           0.9        2021-04-16 [1] CRAN (R 4.1.0)                     
##  hms             1.1.0      2021-05-17 [1] CRAN (R 4.1.0)                     
##  htmltools       0.5.2      2021-08-25 [1] CRAN (R 4.1.0)                     
##  httr            1.4.2      2020-07-20 [1] CRAN (R 4.1.0)                     
##  ipred           0.9-11     2021-03-12 [1] CRAN (R 4.1.0)                     
##  ISLR2         * 1.0        2021-07-22 [1] CRAN (R 4.1.0)                     
##  janeaustenr     0.1.5      2017-06-10 [1] CRAN (R 4.1.0)                     
##  jquerylib       0.1.4      2021-04-26 [1] CRAN (R 4.1.0)                     
##  jsonlite        1.7.2      2020-12-09 [1] CRAN (R 4.1.0)                     
##  knitr           1.34       2021-09-09 [1] CRAN (R 4.1.0)                     
##  labeling        0.4.2      2020-10-20 [1] CRAN (R 4.1.0)                     
##  latex2exp     * 0.5.0      2021-03-18 [1] CRAN (R 4.1.0)                     
##  lattice         0.20-44    2021-05-02 [1] CRAN (R 4.1.0)                     
##  lava            1.6.9      2021-03-11 [1] CRAN (R 4.1.0)                     
##  lbr             0.0.0.9000 2021-09-24 [1] Github (Lightbridge-KS/lbr@771323a)
##  leaps         * 3.1        2020-01-16 [1] CRAN (R 4.1.0)                     
##  lifecycle       1.0.1      2021-09-24 [1] CRAN (R 4.1.0)                     
##  listenv         0.8.0      2019-12-05 [1] CRAN (R 4.1.0)                     
##  lubridate       1.7.10     2021-02-26 [1] CRAN (R 4.1.0)                     
##  magrittr        2.0.1      2020-11-17 [1] CRAN (R 4.1.0)                     
##  MASS            7.3-54     2021-05-03 [1] CRAN (R 4.1.0)                     
##  Matrix          1.3-3      2021-05-04 [1] CRAN (R 4.1.0)                     
##  memoise         2.0.0      2021-01-26 [1] CRAN (R 4.1.0)                     
##  modelr          0.1.8      2020-05-19 [1] CRAN (R 4.1.0)                     
##  munsell         0.5.0      2018-06-12 [1] CRAN (R 4.1.0)                     
##  nnet            7.3-16     2021-05-03 [1] CRAN (R 4.1.0)                     
##  openxlsx        4.2.4      2021-06-16 [1] CRAN (R 4.1.0)                     
##  parallelly      1.28.1     2021-09-09 [1] CRAN (R 4.1.0)                     
##  patchwork     * 1.1.1      2020-12-17 [1] CRAN (R 4.1.0)                     
##  pillar          1.6.2      2021-07-29 [1] CRAN (R 4.1.0)                     
##  pkgbuild        1.2.0      2020-12-15 [1] CRAN (R 4.1.0)                     
##  pkgconfig       2.0.3      2019-09-22 [1] CRAN (R 4.1.0)                     
##  pkgload         1.2.1      2021-04-06 [1] CRAN (R 4.1.0)                     
##  plyr            1.8.6      2020-03-03 [1] CRAN (R 4.1.0)                     
##  prettyunits     1.1.1      2020-01-24 [1] CRAN (R 4.1.0)                     
##  pROC            1.17.0.1   2021-01-13 [1] CRAN (R 4.1.0)                     
##  processx        3.5.2      2021-04-30 [1] CRAN (R 4.1.0)                     
##  prodlim         2019.11.13 2019-11-17 [1] CRAN (R 4.1.0)                     
##  ps              1.6.0      2021-02-28 [1] CRAN (R 4.1.0)                     
##  purrr         * 0.3.4      2020-04-17 [1] CRAN (R 4.1.0)                     
##  R6              2.5.1      2021-08-19 [1] CRAN (R 4.1.0)                     
##  Rcpp            1.0.7      2021-07-07 [1] CRAN (R 4.1.0)                     
##  readr         * 2.0.1      2021-08-10 [1] CRAN (R 4.1.0)                     
##  readxl          1.3.1      2019-03-13 [1] CRAN (R 4.1.0)                     
##  recipes         0.1.16     2021-04-16 [1] CRAN (R 4.1.0)                     
##  remotes         2.4.0      2021-06-02 [1] CRAN (R 4.1.0)                     
##  repr            1.1.3      2021-01-21 [1] CRAN (R 4.1.0)                     
##  reprex          2.0.1      2021-08-05 [1] CRAN (R 4.1.0)                     
##  rlang           0.4.11     2021-04-30 [1] CRAN (R 4.1.0)                     
##  rmarkdown       2.11       2021-09-14 [1] CRAN (R 4.1.0)                     
##  rpart           4.1-15     2019-04-12 [1] CRAN (R 4.1.0)                     
##  rprojroot       2.0.2      2020-11-15 [1] CRAN (R 4.1.0)                     
##  rsample       * 0.1.0      2021-05-08 [1] CRAN (R 4.1.0)                     
##  rstudioapi      0.13       2020-11-12 [1] CRAN (R 4.1.0)                     
##  rvest           1.0.1      2021-07-26 [1] CRAN (R 4.1.0)                     
##  sass            0.4.0      2021-05-12 [1] CRAN (R 4.1.0)                     
##  scales          1.1.1      2020-05-11 [1] CRAN (R 4.1.0)                     
##  sessioninfo     1.1.1      2018-11-05 [1] CRAN (R 4.1.0)                     
##  skimr           2.1.3      2021-03-07 [1] CRAN (R 4.1.0)                     
##  SnowballC       0.7.0      2020-04-01 [1] CRAN (R 4.1.0)                     
##  stringi         1.7.4      2021-08-25 [1] CRAN (R 4.1.0)                     
##  stringr       * 1.4.0      2019-02-10 [1] CRAN (R 4.1.0)                     
##  survival        3.2-11     2021-04-26 [1] CRAN (R 4.1.0)                     
##  testthat        3.0.4      2021-07-01 [1] CRAN (R 4.1.0)                     
##  tibble        * 3.1.4      2021-08-25 [1] CRAN (R 4.1.0)                     
##  tidyr         * 1.1.3      2021-03-03 [1] CRAN (R 4.1.0)                     
##  tidyselect      1.1.1      2021-04-30 [1] CRAN (R 4.1.0)                     
##  tidytext        0.3.1      2021-04-10 [1] CRAN (R 4.1.0)                     
##  tidyverse     * 1.3.1      2021-04-15 [1] CRAN (R 4.1.0)                     
##  timeDate        3043.102   2018-02-21 [1] CRAN (R 4.1.0)                     
##  tokenizers      0.2.1      2018-03-29 [1] CRAN (R 4.1.0)                     
##  tzdb            0.1.2      2021-07-20 [1] CRAN (R 4.1.0)                     
##  units           0.7-2      2021-06-08 [1] CRAN (R 4.1.0)                     
##  usethis         2.0.1.9000 2021-09-23 [1] Github (r-lib/usethis@3385e14)     
##  utf8            1.2.2      2021-07-24 [1] CRAN (R 4.1.0)                     
##  vctrs           0.3.8      2021-04-29 [1] CRAN (R 4.1.0)                     
##  viridisLite     0.4.0      2021-04-13 [1] CRAN (R 4.1.0)                     
##  visdat          0.5.3      2019-02-15 [1] CRAN (R 4.1.0)                     
##  withr           2.4.2      2021-04-18 [1] CRAN (R 4.1.0)                     
##  xfun            0.26       2021-09-14 [1] CRAN (R 4.1.0)                     
##  xml2            1.3.2      2020-04-23 [1] CRAN (R 4.1.0)                     
##  yaml            2.2.1      2020-02-01 [1] CRAN (R 4.1.0)                     
##  yardstick       0.0.8      2021-03-28 [1] CRAN (R 4.1.0)                     
##  zip             2.2.0      2021-05-31 [1] CRAN (R 4.1.0)                     
## 
## [1] /Library/Frameworks/R.framework/Versions/4.1/Resources/library
---
title: "Linear Model Subset Selection"
subtitle: "A Tidy Approach"
author: "Kittipos Sirivongrungson"
date: "`r format(Sys.time(), '%d %B %Y')`"
output:
  html_document:
    theme: united
    df_print: paged
    code_folding: "show"
    toc: TRUE
    toc_float: TRUE
    code_download: TRUE
---

```{r setup, include=FALSE}
knitr::opts_knit$set(root.dir = rprojroot::find_rstudio_root_file()) # Set WD to Root
here::i_am("lab_mod/ch6-varselect.Rmd")
library(here)
```


```{r set_up, message=FALSE}
library(tidyverse)
library(ISLR2) # For Dataset
library(leaps)

theme_set(theme_bw())
```


# Explore Data

I will use `Hitters` dataset from `{ISLR2}` package.

Let's see what we've got. 

**The goal here is to predict variable `Salary`.**

```{r}
names(Hitters)
```


```{r}
dim(Hitters)
```


```{r}
skimr::skim(Hitters)
```

## Missing Values

Let's remove rows containg missing values.

```{r}
visdat::vis_miss(Hitters)
```

```{r}
Hitters <- Hitters %>% na.omit()

sum(is.na(Hitters$Salary))
```


# Subset Selection

Which combinations of variables gives the lowest test error rate at each model sizes?

## Best Subset Selection Method

```{r}
regfit.full <- leaps::regsubsets(Salary ~ ., Hitters)
summary(regfit.full)
```

Try to include all predictors; I will increase `nvmax`.

```{r}
regfit.full <- regsubsets(Salary ~ ., data = Hitters,
    nvmax = 19)
reg.summary <- summary(regfit.full)
reg.summary 
```

We can use `plot()` to show predictors included at each Cp, BIC, or Adjusted R-squared.

```{r}
plot(regfit.full, scale = "Cp")
```

# Evaluate Model Performance

## Using Cp, BIC, Adjusted R Squared

`broom::tidy()` can also be used in `regsubsets` class

```{r}
regfit_tidy <- broom::tidy(regfit.full)
regfit_tidy 
```

```{r}
regfit_tidy_prep <- regfit_tidy %>% 
  rowwise() %>% 
  mutate(pred_no = sum(c_across(AtBat:NewLeagueN)), .keep = "unused") %>% 
  ungroup() 
  
regfit_tidy_prep
```



```{r include=FALSE, eval=FALSE}
### How to Do it
regfit_tidy_prep %>% 
  
  ggplot(aes(pred_no, mallows_cp, color = mallows_cp)) +
  geom_point(show.legend = F) +
  geom_line(show.legend = F) +
  scale_color_viridis_c(option = "plasma", end = 0.8, direction = -1) +
  labs(y = TeX("C_p", italic = TRUE), x = "Number of Predictors") +
  annotate("point",  x = 10, y = 5.009317, shape = 4, size = 4, stroke = 0.5)
```


### Plotting $C_p$, BIC, $Adjusted R^2$

Let's create method for `autoplot()` of class "regsubsets"


```{r autoplot.regsubsets}
autoplot.regsubsets <- function(x, 
                                res = c("mallows_cp", "BIC","r.squared","adj.r.squared")
                                ) {
  
  res <- match.arg(res)
  res_sym <- dplyr::ensym(res)
  
  df_params <- broom::tidy(x)
  df_reduced <- df_params %>% 
    dplyr::rowwise() %>% 
    dplyr::mutate(
      # Add Number of Predictors
      pred_no = sum(dplyr::c_across(
        !tidyselect::any_of(c("(Intercept)", "r.squared","adj.r.squared",
                              "BIC","mallows_cp")))), .keep = "unused")  %>% 
    dplyr::ungroup() 
  
  
  ### Find X & Y-Coordinate of the best model 
  #### i.e., lowest point for "mallows_cp", "BIC" and highest point for    "r.squared","adj.r.squared"
  fun <- if(res %in% c("mallows_cp", "BIC")) min else max
  
  best_y <- fun(df_reduced[[res]])
  best_x <- df_reduced %>% 
      dplyr::filter(!!res_sym == fun(!!res_sym)) %>% dplyr::pull(pred_no)
  
  direction_col <- ifelse(res %in% c("mallows_cp", "BIC"), -1, 1)
  
  # Plot
  df_reduced %>%
    ggplot2::ggplot(ggplot2::aes(pred_no, !!res_sym, color = !!res_sym)) +
    ggplot2::geom_point(show.legend = F) +
    ggplot2::geom_line(show.legend = F) +
    ggplot2::scale_color_viridis_c(option = "plasma", end = 0.8, 
                                   direction = direction_col) +
    ggplot2::annotate("point", x = best_x, y = best_y, shape = 4, size = 5, stroke = 0.5)+
    ggplot2::labs(x = "Number of Predictors") 
  
}

```




```{r}
library(patchwork)
library(latex2exp)

p_cp <- autoplot(regfit.full, res = "mallows_cp") +
  labs(y = TeX("C_p"))

p_bic <- autoplot(regfit.full, res = "BIC")

p_adj_rsq <- autoplot(regfit.full, res = "adj.r.squared") +
  labs(y = TeX("Adjusted R^2"))

p_cp + p_bic + p_adj_rsq +
  plot_annotation(title = "Best Subset Selection at Each Model Sizes", 
                  subtitle = TeX("Estimate test error by C_p, BIC, and Adjusted R^2"),
                  caption = "Data from Hitters dataset in ISLR2 package")
  
  
lbr::ggsave_mac(here("plot/Hitters_BestSubset.png"))
```

## Using Cross-Validation

I need to build a lot of helper functions.

### Predict Method for "regsubsets"

"regsubsets" object has no `predict()` method, so we need to create ourselves.

Luckily, `ISLR`'s lab demonstration already included `predict.regsubsets()` in the R-Markdown material. 

I will modify it a little bit so that it returns a vector instead of matrix.

```{r predict.regsubsets}
predict.regsubsets <- function(object, newdata, id, ...) {
  form <- as.formula(object[["call"]][[2]])
  mat <- model.matrix(form, newdata)
  coefi <- coef(object, id = id)
  xvars <- names(coefi)
  # Matrix
  pred_mat <- mat[, xvars] %*% coefi
  # Vector
  pred <- as.numeric(pred_mat)
  names(pred) <- rownames(mat)
  pred
}

predict(regfit.full, newdata = Hitters, id = 3) %>% head()
```


### Model Summary for "regsubsets"

Next, I will implement `broom::glance()` method for "regsubsets" object.
This will provide a models summary at each number of predictors.

```{r glance.regsubsets}
glance.regsubsets <- function(x, newdata, ...) {
  
  n_predictors <- x$np - 1
  y_var <- as.formula(x$call[[2]])[[2]]
  y <- newdata[[y_var]]
  
  mse <- numeric(n_predictors)
  
  for (i in 1:n_predictors) {
    y_hat_i <- predict.regsubsets(x, newdata, id = i, ...)

    # Mean Squared Error
    mse[[i]] <- mean(c(y - y_hat_i)^2)
  }
  
  
  tibble::tibble(
    n_predictors = 1:n_predictors,
    MSE = mse,
    r.squared = summary(x)[["rsq"]],
    adj.r.squared = summary(x)[["adjr2"]],
    mallows_cp = summary(x)[["cp"]],
    BIC = summary(x)[["bic"]]
  )
  
}

broom::glance(regfit.full, newdata = Hitters)
```


### Split Data into Folds

Now, let's split data into 10-fold using `vfold_cv()` function from `{rsample}` package.

```{r}
library(rsample)
```

```{r vfold_cv}
set.seed(123)

Hitters_folds <- vfold_cv(Hitters, v = 10) 

class(Hitters_folds)
Hitters_folds
```

Analysis (train) and Assessment (hold-out) data can be obtained by `analysis()` and `assessment()`, respectively.

```{r}
Hitters_folds$splits[[1]] %>% analysis()
Hitters_folds$splits[[1]] %>% assessment()
```

Up next, `regsubsets_cv_glance()` will compute cross-validation's summary statistics using method and functions that I've just defined above.


```{r regsubsets_cv_glance}
regsubsets_cv_glance <- function(x, vfold_cv, return_folds = FALSE, ...) {
  
 # Add analysis & assess DF
 vfold_cv <- vfold_cv %>% 
   dplyr::mutate(
     analysis_df = purrr::map(splits, rsample::analysis),
     assess_df = purrr::map(splits, rsample::assessment)
     ) 
  
  # Fit `regsubsets` to each folds
  fitted_folds <- vector("list", nrow(vfold_cv))
  n_predictors <- ncol(vfold_cv$analysis_df[[1]]) - 1 # Number of Predictors

  for (i in 1:nrow(vfold_cv)) {
    fitted_folds[[i]] <-
      leaps::regsubsets(x, data = vfold_cv$analysis_df[[i]], nvmax = n_predictors, ...)
    names(fitted_folds) <- 
      sprintf(paste0("Fold%0", nchar(nrow(vfold_cv)), "d"), 1:nrow(vfold_cv))
  }
  
  params_df <- purrr::map2_dfr(
    .x = fitted_folds,
    .y = vfold_cv$assess_df,
    # Compute multi-models summary at each folds
    ~broom::glance(.x, newdata = .y), 
    .id = "id"
    ) %>% 
    dplyr::select(id, n_predictors, MSE, r.squared)
  
  if(return_folds) return(params_df)
  
  params_df %>% 
    dplyr::group_by(n_predictors) %>% 
    dplyr::summarise(across(MSE:r.squared, mean, .names = "{.col}_cv"))

}

regsubsets_cv_glance(Salary ~ ., Hitters_folds, F) 
```

## Compute Cross-Validation Estimates

`MSE_cv` is a mean squared error at each predictor subsets; each `MSE_cv` is averaged over 10 held-out sets.

```{r Hitters_cv_fitted}
Hitters_cv_fitted <- regsubsets_cv_glance(Salary ~ ., Hitters_folds) 

Hitters_cv_fitted
```

## Plotting Cross-Validation Estimates

The Final helper functions, `regsubsets_cv_plot()`, is for plotting 10-fold Cross-Validation MSE at each number of predictors.

```{r regsubsets_cv_plot}
regsubsets_cv_plot <- function(df_glanced, 
                               res = c("MSE_cv","r.squared_cv")
                          ) {
  
  res <- match.arg(res)
  res_sym <- dplyr::ensym(res)
  
  ### Find X & Y-Coordinate of the best model 
  fun <- if(res %in% c("MSE_cv")) min else max
  
  best_y <- fun(df_glanced[[res]])
  best_x <- df_glanced %>%
      dplyr::filter(!!res_sym == fun(!!res_sym)) %>% dplyr::pull(n_predictors)

  direction_col <- ifelse(res %in% c("MSE_cv"), -1, 1)
  
  # Plot
  df_glanced %>%
    ggplot2::ggplot(ggplot2::aes(n_predictors, !!res_sym, color = !!res_sym)) +
    ggplot2::geom_point(show.legend = F) +
    ggplot2::geom_line(show.legend = F) +
    ggplot2::scale_color_viridis_c(option = "plasma", end = 0.8,
                                   direction = direction_col) +
    ggplot2::annotate("point", x = best_x, y = best_y, shape = 4, size = 5, stroke = 0.5)+
    ggplot2::labs(x = "Number of Predictors")
}
```

```{r}
regsubsets_cv_plot(Hitters_cv_fitted, "MSE_cv") +
  labs(title = "Best Subset Selection by 10-Fold Cross-Validation",
       y = TeX("MSE_{10-fold-CV}")
       )
```

# Final Comparisons

To see the overall picture, All methods that estimate the test error that we've explored so far will be plotted in this final plots. "X" symbols indicate the best model for each methods.

```{r}
library(patchwork)
library(latex2exp)

p_cp <- autoplot(regfit.full, res = "mallows_cp") +
  labs(y = TeX("C_p"))

p_bic <- autoplot(regfit.full, res = "BIC")

p_adj_rsq <- autoplot(regfit.full, res = "adj.r.squared") +
  labs(y = TeX("Adjusted R^2"))

p_mse_cv <- regsubsets_cv_plot(Hitters_cv_fitted, "MSE_cv") +
  labs(y = TeX("MSE_{10-fold-CV}"))

p_cp + p_bic + p_adj_rsq + p_mse_cv +
  plot_annotation(title = "Best Subset Selection", 
                  subtitle = TeX("Estimate test error by C_p, BIC, Adjusted R^2, and 10-fold Cross-Validation"),
                  caption = "Data from Hitters dataset in ISLR2 package")
  
  
lbr::ggsave_mac(here("plot/Hitters_BestSubsetAll.png"))
```


```{r}
devtools::session_info()
```

