An Introduction to Statistical Learning (2nd ed.)

Chapter 06

Regularization

We will be using the Hitters data set from the ISLR package. We wish to predict the baseball players Salary based on several different characteristics which are included in the data set. Since we wish to predict Salary, then we need to remove any missing data from that column. Otherwise, we won’t be able to run the models.

library(tidymodels)
library(ISLR)
library(ggcorrplot)

theme_set(theme_bw())
hitters <- tibble(Hitters)

glimpse(hitters)
Rows: 322
Columns: 20
$ AtBat     <int> 293, 315, 479, 496, 321, 594, 185, 298, 323, 401, 574, 202, ~
$ Hits      <int> 66, 81, 130, 141, 87, 169, 37, 73, 81, 92, 159, 53, 113, 60,~
$ HmRun     <int> 1, 7, 18, 20, 10, 4, 1, 0, 6, 17, 21, 4, 13, 0, 7, 3, 20, 2,~
$ Runs      <int> 30, 24, 66, 65, 39, 74, 23, 24, 26, 49, 107, 31, 48, 30, 29,~
$ RBI       <int> 29, 38, 72, 78, 42, 51, 8, 24, 32, 66, 75, 26, 61, 11, 27, 1~
$ Walks     <int> 14, 39, 76, 37, 30, 35, 21, 7, 8, 65, 59, 27, 47, 22, 30, 11~
$ Years     <int> 1, 14, 3, 11, 2, 11, 2, 3, 2, 13, 10, 9, 4, 6, 13, 3, 15, 5,~
$ CAtBat    <int> 293, 3449, 1624, 5628, 396, 4408, 214, 509, 341, 5206, 4631,~
$ CHits     <int> 66, 835, 457, 1575, 101, 1133, 42, 108, 86, 1332, 1300, 467,~
$ CHmRun    <int> 1, 69, 63, 225, 12, 19, 1, 0, 6, 253, 90, 15, 41, 4, 36, 3, ~
$ CRuns     <int> 30, 321, 224, 828, 48, 501, 30, 41, 32, 784, 702, 192, 205, ~
$ CRBI      <int> 29, 414, 266, 838, 46, 336, 9, 37, 34, 890, 504, 186, 204, 1~
$ CWalks    <int> 14, 375, 263, 354, 33, 194, 24, 12, 8, 866, 488, 161, 203, 2~
$ League    <fct> A, N, A, N, N, A, N, A, N, A, A, N, N, A, N, A, N, A, A, N, ~
$ Division  <fct> E, W, W, E, E, W, E, W, W, E, E, W, E, E, E, W, W, W, W, W, ~
$ PutOuts   <int> 446, 632, 880, 200, 805, 282, 76, 121, 143, 0, 238, 304, 211~
$ Assists   <int> 33, 43, 82, 11, 40, 421, 127, 283, 290, 0, 445, 45, 11, 151,~
$ Errors    <int> 20, 10, 14, 3, 4, 25, 7, 9, 19, 0, 22, 11, 7, 6, 8, 0, 10, 1~
$ Salary    <dbl> NA, 475.000, 480.000, 500.000, 91.500, 750.000, 70.000, 100.~
$ NewLeague <fct> A, N, A, N, N, A, A, A, N, A, A, N, N, A, N, A, N, A, A, N, ~
sum(is.na(hitters$Salary))
[1] 59
#Remove na
hitters <- hitters %>% 
  filter(!is.na(Salary))

sum(is.na(hitters$Salary))
[1] 0
hitters %>% 
  select(is.numeric) %>% 
  cor %>% 
  ggcorrplot(hc.order = T, type = "lower", lab = T)

ggplot(hitters, aes(Salary)) + 
  geom_histogram(fill = "lightgrey", color = 'black') 

ggplot(hitters, aes(Salary, color = NewLeague)) + 
  geom_density() 

Ridge Regression

We will use the glmnet package to perform ridge regression. parsnip does not have a dedicated function to create a ridge regression model specification. You need to use linear_reg() and set mixture = 0 to specify a ridge model. The mixture argument specifies the amount of different types of regularization, mixture = 0 specifies only ridge regularization and mixture = 1 specifies only lasso regularization. Setting mixture to a value between 0 and 1 lets us use both. When using the glmnet engine we also need to set a penalty to be able to fit the model. We will set this value to 0 for now, it is not the best value, but we will look at how to select the best value in a little bit.

ridge_model <- linear_reg(mixture = 0, penalty = 0) %>% 
  set_engine("glmnet") %>% 
  set_mode("regression")
ridge_fit <- 
  fit(
    ridge_model,
    Salary ~ ., 
    data = hitters
  )

tidy(ridge_fit)
# A tibble: 20 x 3
   term          estimate penalty
   <chr>            <dbl>   <dbl>
 1 (Intercept)   81.1           0
 2 AtBat         -0.682         0
 3 Hits           2.77          0
 4 HmRun         -1.37          0
 5 Runs           1.01          0
 6 RBI            0.713         0
 7 Walks          3.38          0
 8 Years         -9.07          0
 9 CAtBat        -0.00120       0
10 CHits          0.136         0
11 CHmRun         0.698         0
12 CRuns          0.296         0
13 CRBI           0.257         0
14 CWalks        -0.279         0
15 LeagueN       53.2           0
16 DivisionW   -123.            0
17 PutOuts        0.264         0
18 Assists        0.170         0
19 Errors        -3.69          0
20 NewLeagueN   -18.1           0
tidy(ridge_fit, penalty = 50) %>% slice(1:10)
# A tibble: 10 x 3
   term        estimate penalty
   <chr>          <dbl>   <dbl>
 1 (Intercept) 48.2          50
 2 AtBat       -0.354        50
 3 Hits         1.95         50
 4 HmRun       -1.29         50
 5 Runs         1.16         50
 6 RBI          0.809        50
 7 Walks        2.71         50
 8 Years       -6.20         50
 9 CAtBat       0.00609      50
10 CHits        0.107        50
tidy(ridge_fit, penalty = 705) %>% slice(1:10)
# A tibble: 10 x 3
   term        estimate penalty
   <chr>          <dbl>   <dbl>
 1 (Intercept)  54.4        705
 2 AtBat         0.112      705
 3 Hits          0.656      705
 4 HmRun         1.18       705
 5 Runs          0.937      705
 6 RBI           0.847      705
 7 Walks         1.32       705
 8 Years         2.58       705
 9 CAtBat        0.0108     705
10 CHits         0.0468     705
ridge_fit %>% 
  extract_fit_engine() %>% 
  plot(xvar = "lambda")

Tuning ridge with tidymodels

set.seed(2021)
hitters_split <- initial_split(hitters, prop = .75)
hitters_train <- training(hitters_split)
hitters_test <- testing(hitters_split)

set.seed(2021)
hitters_cv <- vfold_cv(hitters_train, v = 10)
hitters_rec <- 
  recipe(Salary ~ ., data = hitters) %>% 
  step_novel(all_nominal_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_numeric_predictors()) %>% 
  step_normalize(all_numeric_predictors()) 
tuned_ridge_model <- 
  linear_reg(mixture = 0, penalty = tune()) %>% 
  set_engine("glmnet") %>% 
  set_mode("regression")
ridge_wf <- 
  workflow() %>% 
  add_recipe(hitters_rec) %>% 
  add_model(tuned_ridge_model)
ridge_wf %>% parameters()
Collection of 1 parameters for tuning

 identifier    type    object
    penalty penalty nparam[+]
ridge_wf %>% pull_dials_object("penalty")
Amount of Regularization (quantitative)
Transformer:  log-10 
Range (transformed scale): [-10, 0]
penalty_grid <- grid_regular(penalty(range = c(-5, 5)), levels = 50)
fit_tuned_ridge <- 
  tune_grid(
    ridge_wf,
    resamples = hitters_cv,
    grid = penalty_grid,
    control = control_grid(save_pred = T)
  )
fit_tuned_ridge %>% collect_metrics()
# A tibble: 100 x 7
     penalty .metric .estimator    mean     n std_err .config              
       <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>                
 1 0.00001   rmse    standard   340.       10 31.3    Preprocessor1_Model01
 2 0.00001   rsq     standard     0.460    10  0.0621 Preprocessor1_Model01
 3 0.0000160 rmse    standard   340.       10 31.3    Preprocessor1_Model02
 4 0.0000160 rsq     standard     0.460    10  0.0621 Preprocessor1_Model02
 5 0.0000256 rmse    standard   340.       10 31.3    Preprocessor1_Model03
 6 0.0000256 rsq     standard     0.460    10  0.0621 Preprocessor1_Model03
 7 0.0000409 rmse    standard   340.       10 31.3    Preprocessor1_Model04
 8 0.0000409 rsq     standard     0.460    10  0.0621 Preprocessor1_Model04
 9 0.0000655 rmse    standard   340.       10 31.3    Preprocessor1_Model05
10 0.0000655 rsq     standard     0.460    10  0.0621 Preprocessor1_Model05
# ... with 90 more rows
autoplot(fit_tuned_ridge)

best_ridge <- fit_tuned_ridge %>% select_best(metric = "rmse")

Here we see that the amount of regularization affects the performance metrics differently. Note how there are areas where the amount of regularization doesn’t have any meaningful influence on the coefficient estimates.

final_ridge_wf <- 
  ridge_wf %>% 
  finalize_workflow(best_ridge)

ridge_lastfit <- last_fit(
  final_ridge_wf,
  split = hitters_split
)

ridge_lastfit %>% collect_predictions() %>% 
  ggplot(aes(Salary, .pred)) +
  geom_point() + 
  geom_smooth()

ridge_lastfit %>% collect_metrics()
# A tibble: 2 x 4
  .metric .estimator .estimate .config             
  <chr>   <chr>          <dbl> <chr>               
1 rmse    standard     333.    Preprocessor1_Model1
2 rsq     standard       0.453 Preprocessor1_Model1

Lasso

The mixture argument specifies the amount of different types of regularization, mixture = 0 specifies only ridge regularization and mixture = 1 specifies only lasso regularization. Setting mixture to a value between 0 and 1 lets us use both.

tuned_lasso_model <- 
  linear_reg(mixture = 1, penalty = tune()) %>% 
  set_engine("glmnet") %>% 
  set_mode("regression")

lasso_wf <- 
  workflow() %>% 
  add_recipe(hitters_rec) %>% 
  add_model(tuned_lasso_model)
fit_tuned_lasso <- 
  tune_grid(
    lasso_wf,
    resamples = hitters_cv,
    grid = penalty_grid,
    metrics = NULL,
    control = control_grid(save_pred = T)
  )
fit_tuned_lasso %>% collect_metrics()
# A tibble: 100 x 7
     penalty .metric .estimator    mean     n std_err .config              
       <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>                
 1 0.00001   rmse    standard   345.       10 33.5    Preprocessor1_Model01
 2 0.00001   rsq     standard     0.448    10  0.0640 Preprocessor1_Model01
 3 0.0000160 rmse    standard   345.       10 33.5    Preprocessor1_Model02
 4 0.0000160 rsq     standard     0.448    10  0.0640 Preprocessor1_Model02
 5 0.0000256 rmse    standard   345.       10 33.5    Preprocessor1_Model03
 6 0.0000256 rsq     standard     0.448    10  0.0640 Preprocessor1_Model03
 7 0.0000409 rmse    standard   345.       10 33.5    Preprocessor1_Model04
 8 0.0000409 rsq     standard     0.448    10  0.0640 Preprocessor1_Model04
 9 0.0000655 rmse    standard   345.       10 33.5    Preprocessor1_Model05
10 0.0000655 rsq     standard     0.448    10  0.0640 Preprocessor1_Model05
# ... with 90 more rows
autoplot(fit_tuned_lasso)

best_lasso <- fit_tuned_lasso %>% select_best(metric = "rmse")

Elastic net

Tuning mixture and penalty

tuned_elastic_model <- 
  linear_reg(mixture = tune(), penalty = tune()) %>% 
  set_engine("glmnet") %>% 
  set_mode("regression")

elastic_wf <- 
  workflow() %>% 
  add_recipe(hitters_rec) %>% 
  add_model(tuned_elastic_model)
elastic_wf %>% parameters()
Collection of 2 parameters for tuning

 identifier    type    object
    penalty penalty nparam[+]
    mixture mixture nparam[+]
elastic_wf %>% pull_dials_object("penalty")
Amount of Regularization (quantitative)
Transformer:  log-10 
Range (transformed scale): [-10, 0]
elastic_wf %>% pull_dials_object("mixture")
Proportion of lasso Penalty (quantitative)
Range: [0.05, 1]
elastic_grid <- crossing(mixture = seq(0,1, by = .2),
                         penalty = -5:400)
fit_tuned_elastic <- 
  tune_grid(
    elastic_wf,
    resamples = hitters_cv,
    grid = elastic_grid,
    control = control_grid(save_pred = T)
  )
fit_tuned_elastic %>% collect_metrics() %>%View
autoplot(fit_tuned_elastic)

fit_tuned_elastic %>% show_best(metric = "rmse")
# A tibble: 5 x 8
  penalty mixture .metric .estimator  mean     n std_err .config                
    <int>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                  
1     102     0.2 rmse    standard    336.    10    29.6 Preprocessor1_Model0514
2     103     0.2 rmse    standard    336.    10    29.6 Preprocessor1_Model0515
3     101     0.2 rmse    standard    336.    10    29.6 Preprocessor1_Model0513
4     104     0.2 rmse    standard    336.    10    29.6 Preprocessor1_Model0516
5     100     0.2 rmse    standard    336.    10    29.6 Preprocessor1_Model0512
fit_tuned_elastic %>% show_best(metric = "rsq")
# A tibble: 5 x 8
  penalty mixture .metric .estimator  mean     n std_err .config                
    <int>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                  
1       5     0.2 rsq     standard   0.468    10  0.0642 Preprocessor1_Model0417
2       6     0.2 rsq     standard   0.467    10  0.0641 Preprocessor1_Model0418
3     107     0.2 rsq     standard   0.467    10  0.0561 Preprocessor1_Model0519
4     106     0.2 rsq     standard   0.467    10  0.0561 Preprocessor1_Model0518
5       4     0.4 rsq     standard   0.467    10  0.0641 Preprocessor1_Model0822
best_elastic <- fit_tuned_elastic %>% select_best(metric = "rmse")

Principal Component Analysis

I will treat principal component regression as a linear model with PCA transformations in the preprocessing. But using the tidymodels framework then this is still mostly one model.

lm_model <- 
  linear_reg() %>% 
  set_mode("regression") %>% 
  set_engine("lm")
pca_recipe <- 
  recipe(formula = Salary ~ ., data = hitters_train) %>% 
  step_novel(all_nominal_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors()) %>%
  step_pca(all_predictors(), threshold = tune(), num_comp = tune())

pca_wf <- 
  workflow() %>% 
  add_recipe(pca_recipe) %>% 
  add_model(lm_model)
pca_wf %>% parameters()
Collection of 2 parameters for tuning

 identifier      type    object
   num_comp  num_comp nparam[+]
  threshold threshold nparam[+]
pca_wf %>% pull_dials_object("num_comp")
# Components (quantitative)
Range: [1, 4]
pca_wf %>% pull_dials_object("threshold")
Threshold (quantitative)
Range: [0, 1]
pca_grid <- crossing(num_comp = 0:10,
                     threshold = seq(0,1, by=.1)
                     )
tune_pca <- 
  tune_grid(
    pca_wf,
    resamples = hitters_cv,
    grid = pca_grid,
    control = control_grid(save_pred = T)
  )
tune_pca %>% collect_metrics()
# A tibble: 242 x 8
   num_comp threshold .metric .estimator    mean     n std_err .config          
      <int>     <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>            
 1        0       0   rmse    standard   346.       10 33.6    Preprocessor001_~
 2        0       0   rsq     standard     0.446    10  0.0639 Preprocessor001_~
 3        0       0.1 rmse    standard   346.       10 33.6    Preprocessor002_~
 4        0       0.1 rsq     standard     0.446    10  0.0639 Preprocessor002_~
 5        0       0.2 rmse    standard   346.       10 33.6    Preprocessor003_~
 6        0       0.2 rsq     standard     0.446    10  0.0639 Preprocessor003_~
 7        0       0.3 rmse    standard   346.       10 33.6    Preprocessor004_~
 8        0       0.3 rsq     standard     0.446    10  0.0639 Preprocessor004_~
 9        0       0.4 rmse    standard   346.       10 33.6    Preprocessor005_~
10        0       0.4 rsq     standard     0.446    10  0.0639 Preprocessor005_~
# ... with 232 more rows
autoplot(tune_pca) +  scale_x_continuous(breaks = 0:10)

tune_pca %>% show_best(metric = "rmse")
# A tibble: 5 x 8
  num_comp threshold .metric .estimator  mean     n std_err .config             
     <int>     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
1        1       0.8 rmse    standard    337.    10    29.8 Preprocessor020_Mod~
2        2       0.8 rmse    standard    337.    10    29.8 Preprocessor031_Mod~
3        3       0.8 rmse    standard    337.    10    29.8 Preprocessor042_Mod~
4        4       0.8 rmse    standard    337.    10    29.8 Preprocessor053_Mod~
5        5       0.8 rmse    standard    337.    10    29.8 Preprocessor064_Mod~
tune_pca %>% show_best(metric = "rsq")
# A tibble: 5 x 8
  num_comp threshold .metric .estimator  mean     n std_err .config             
     <int>     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
1        1       0.9 rsq     standard   0.460    10  0.0541 Preprocessor021_Mod~
2        2       0.9 rsq     standard   0.460    10  0.0541 Preprocessor032_Mod~
3        3       0.9 rsq     standard   0.460    10  0.0541 Preprocessor043_Mod~
4        4       0.9 rsq     standard   0.460    10  0.0541 Preprocessor054_Mod~
5        5       0.9 rsq     standard   0.460    10  0.0541 Preprocessor065_Mod~
(best_pca_rmse <- tune_pca %>% select_best(metric = "rmse"))
# A tibble: 1 x 3
  num_comp threshold .config               
     <int>     <dbl> <chr>                 
1        1       0.8 Preprocessor020_Model1
(best_pca_rsq <- tune_pca %>% select_best(metric = "rsq"))
# A tibble: 1 x 3
  num_comp threshold .config               
     <int>     <dbl> <chr>                 
1        1       0.9 Preprocessor021_Model1
final_pca_wf <- 
  pca_wf %>% 
  finalize_workflow(best_pca_rmse)

final_pca_fit <- 
  last_fit(
    final_pca_wf,
    split = hitters_split
  )
final_pca_fit %>% collect_metrics()
# A tibble: 2 x 4
  .metric .estimator .estimate .config             
  <chr>   <chr>          <dbl> <chr>               
1 rmse    standard     342.    Preprocessor1_Model1
2 rsq     standard       0.432 Preprocessor1_Model1

An Introduction to Statistcial Learning

ISLR tidymodels Labs

– END

sessionInfo()
R version 4.1.0 (2021-05-18)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19043)

Matrix products: default

locale:
[1] LC_COLLATE=Spanish_Mexico.1252  LC_CTYPE=Spanish_Mexico.1252   
[3] LC_MONETARY=Spanish_Mexico.1252 LC_NUMERIC=C                   
[5] LC_TIME=Spanish_Mexico.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] vctrs_0.3.8        rlang_0.4.11       glmnet_4.1-2       Matrix_1.3-3      
 [5] ggcorrplot_0.1.3   ISLR_1.2           yardstick_0.0.8    workflowsets_0.0.2
 [9] workflows_0.2.3    tune_0.1.6         tidyr_1.1.3        tibble_3.1.3      
[13] rsample_0.1.0      recipes_0.1.16     purrr_0.3.4        parsnip_0.1.7     
[17] modeldata_0.1.1    infer_0.5.4        ggplot2_3.3.5      dplyr_1.0.7       
[21] dials_0.0.9        scales_1.1.1       broom_0.7.8        tidymodels_0.1.3  

loaded via a namespace (and not attached):
 [1] nlme_3.1-152       lubridate_1.7.10   DiceDesign_1.9     tools_4.1.0       
 [5] backports_1.2.1    bslib_0.2.5.1      utf8_1.2.2         R6_2.5.0          
 [9] rpart_4.1-15       DBI_1.1.1          mgcv_1.8-35        colorspace_2.0-2  
[13] nnet_7.3-16        withr_2.4.2        tidyselect_1.1.1   compiler_4.1.0    
[17] cli_3.0.0          labeling_0.4.2     sass_0.4.0         stringr_1.4.0     
[21] digest_0.6.27      rmarkdown_2.9      pkgconfig_2.0.3    htmltools_0.5.1.1 
[25] parallelly_1.26.1  lhs_1.1.1          highr_0.9          rstudioapi_0.13   
[29] shape_1.4.6        jquerylib_0.1.4    generics_0.1.0     farver_2.1.0      
[33] jsonlite_1.7.2     magrittr_2.0.1     Rcpp_1.0.7         munsell_0.5.0     
[37] fansi_0.5.0        GPfit_1.0-8        lifecycle_1.0.0    furrr_0.2.3       
[41] stringi_1.6.2      pROC_1.17.0.1      yaml_2.2.1         MASS_7.3-54       
[45] plyr_1.8.6         grid_4.1.0         parallel_4.1.0     listenv_0.8.0     
[49] crayon_1.4.1       lattice_0.20-44    splines_4.1.0      knitr_1.33        
[53] pillar_1.6.2       reshape2_1.4.4     codetools_0.2-18   glue_1.4.2        
[57] evaluate_0.14      foreach_1.5.1      gtable_0.3.0       future_1.21.0     
[61] assertthat_0.2.1   xfun_0.24          gower_0.2.2        prodlim_2019.11.13
[65] class_7.3-19       survival_3.2-11    timeDate_3043.102  iterators_1.0.13  
[69] hardhat_0.1.6      lava_1.6.9         globals_0.14.0     ellipsis_0.3.2    
[73] ipred_0.9-11