This week we will talk about shrinkage and hyperparameter tuning.
We will use the Hitters data set from the ISLR library. It can be loaded using the following code
The vast majority of variables are numerical, with the remainder being factors.
library(tidyverse)
library(tidymodels)
library(ISLR)
data("Hitters")
Hitters %>%
str()## 'data.frame': 322 obs. of 20 variables:
## $ AtBat : int 293 315 479 496 321 594 185 298 323 401 ...
## $ Hits : int 66 81 130 141 87 169 37 73 81 92 ...
## $ HmRun : int 1 7 18 20 10 4 1 0 6 17 ...
## $ Runs : int 30 24 66 65 39 74 23 24 26 49 ...
## $ RBI : int 29 38 72 78 42 51 8 24 32 66 ...
## $ Walks : int 14 39 76 37 30 35 21 7 8 65 ...
## $ Years : int 1 14 3 11 2 11 2 3 2 13 ...
## $ CAtBat : int 293 3449 1624 5628 396 4408 214 509 341 5206 ...
## $ CHits : int 66 835 457 1575 101 1133 42 108 86 1332 ...
## $ CHmRun : int 1 69 63 225 12 19 1 0 6 253 ...
## $ CRuns : int 30 321 224 828 48 501 30 41 32 784 ...
## $ CRBI : int 29 414 266 838 46 336 9 37 34 890 ...
## $ CWalks : int 14 375 263 354 33 194 24 12 8 866 ...
## $ League : Factor w/ 2 levels "A","N": 1 2 1 2 2 1 2 1 2 1 ...
## $ Division : Factor w/ 2 levels "E","W": 1 2 2 1 1 2 1 2 2 1 ...
## $ PutOuts : int 446 632 880 200 805 282 76 121 143 0 ...
## $ Assists : int 33 43 82 11 40 421 127 283 290 0 ...
## $ Errors : int 20 10 14 3 4 25 7 9 19 0 ...
## $ Salary : num NA 475 480 500 91.5 750 70 100 75 1100 ...
## $ NewLeague: Factor w/ 2 levels "A","N": 1 2 1 2 2 1 1 1 2 1 ...
Remove all rows where the salary is NA and split the data into testing and training data sets.
Hitters %>%
filter(!is.na(Salary)) -> Hitters_narm
set.seed(1234)
Hitters_split <- initial_split(Hitters_narm)
Hitters_train <- training(Hitters_split)
Hitters_test <- testing(Hitters_split)linear_reg() with mixture = 0 to specify a ridge regression model.mixture = 0 is L2 regularization. That is, ridge.mixture = 1 is L1 regularization. That is, lasso.Let’s set the penalty to 0 to see what happens.
ridge_spec0 <- linear_reg(mixture = 0, penalty = 0) %>%
set_mode("regression") %>%
set_engine("glmnet")
ridge_spec0## Linear Regression Model Specification (regression)
##
## Main Arguments:
## penalty = 0
## mixture = 0
##
## Computational engine: glmnet
Put in our ingredients and get a recipe. Because ridge regression is scale sensitive, we must ensure that the variables are on the same scale by using step_normalize(all_predictors()).
ridge_rec <- recipe(Salary ~ ., data = Hitters_train) %>%
step_novel(all_nominal_predictors()) %>% # Novel Factor Levels
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>% # remove variables that contain only a single value
step_normalize(all_predictors()) # center and scale each column
ridge_rec## Data Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 19
##
## Operations:
##
## Novel factor level assignment for all_nominal_predictors()
## Dummy variables from all_nominal_predictors()
## Zero variance filter on all_predictors()
## Centering and scaling for all_predictors()
ridge_wf0 <- workflow() %>%
add_model(ridge_spec0) %>%
add_recipe(ridge_rec)ridge_fit0 <- fit(ridge_wf0, data = Hitters_train)
ridge_fit0 %>% tidy()## Loading required package: Matrix
##
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
##
## expand, pack, unpack
## Loaded glmnet 4.1-1
## # A tibble: 20 x 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 569. 0
## 2 AtBat -122. 0
## 3 Hits 118. 0
## 4 HmRun -23.7 0
## 5 Runs 23.4 0
## 6 RBI 12.9 0
## 7 Walks 105. 0
## 8 Years -58.6 0
## 9 CAtBat -11.7 0
## 10 CHits 103. 0
## 11 CHmRun 63.7 0
## 12 CRuns 121. 0
## 13 CRBI 103. 0
## 14 CWalks -110. 0
## 15 PutOuts 69.9 0
## 16 Assists 31.5 0
## 17 Errors -16.5 0
## 18 League_N 25.9 0
## 19 Division_W -55.4 0
## 20 NewLeague_N -17.2 0
If \(\lambda = 0\) we don’t have any penalization so we still get the standard OLS estimates. A high root mean squared error indicates that the model may predict the incorrect answer.
augment(ridge_fit0, new_data = Hitters_test) %>%
rmse(truth = Salary, estimate = .pred)## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 228.
The plot shows we can try to optimize the penalty term. That is, tuning the hyperparameter in the ridge regression model.
augment(ridge_fit0, new_data = Hitters_test) %>%
ggplot(aes(Salary, .pred)) +
geom_abline(slope = 1, intercept = 0) +
geom_point() +
theme_bw()Now, we set penalty = tune() in order to find the best hyperparameter
ridge_spec <- linear_reg(mixture = 0, penalty = tune()) %>%
set_mode("regression") %>%
set_engine("glmnet")Look at the output below, the main arguments show penalty = tune(). Because the hyperparameter will be automatically turning.
ridge_wf <- workflow() %>%
add_model(ridge_spec) %>%
add_recipe(ridge_rec)
ridge_wf## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: linear_reg()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 4 Recipe Steps
##
## • step_novel()
## • step_dummy()
## • step_zv()
## • step_normalize()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Linear Regression Model Specification (regression)
##
## Main Arguments:
## penalty = tune()
## mixture = 0
##
## Computational engine: glmnet
Create the Cross-Validation term in order to use in the following tune_grid() session, the number of default folds is 10.
set.seed(123)
Hitters_fold <- vfold_cv(Hitters_train)Regularly predict the penalty 100 times using regular grids, with the penalty range limited to \(0\) to \(10^5\). Note, these are in transformed units, the default transformation is \(log10\).
# penalty_grid <- grid_regular(list(p1 = threshold(), p2 = threshold()), levels = 10, size = 100)
# grid_max_entropy()
#penalty_grid %>%
#ggplot(aes(p1, p2)) +
# geom_point()
# grid_max_entropy: try to not have overlapping parts
penalty_grid <- grid_regular(penalty(range = c(0, 5)), levels = 100)Show the penalty values with descending order
penalty_grid %>%
arrange(desc(penalty))## # A tibble: 100 x 1
## penalty
## <dbl>
## 1 100000
## 2 89022.
## 3 79248.
## 4 70548.
## 5 62803.
## 6 55908.
## 7 49770.
## 8 44306.
## 9 39442.
## 10 35112.
## # … with 90 more rows
tune_res <- tune_grid(
object = ridge_wf,
resamples = Hitters_fold,
grid = penalty_grid) # control = control_grid(verbose = TRUE): fitting model one by one
tune_res## # Tuning results
## # 10-fold cross-validation
## # A tibble: 10 x 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [178/20]> Fold01 <tibble [200 × 5]> <tibble [0 × 1]>
## 2 <split [178/20]> Fold02 <tibble [200 × 5]> <tibble [0 × 1]>
## 3 <split [178/20]> Fold03 <tibble [200 × 5]> <tibble [0 × 1]>
## 4 <split [178/20]> Fold04 <tibble [200 × 5]> <tibble [0 × 1]>
## 5 <split [178/20]> Fold05 <tibble [200 × 5]> <tibble [0 × 1]>
## 6 <split [178/20]> Fold06 <tibble [200 × 5]> <tibble [0 × 1]>
## 7 <split [178/20]> Fold07 <tibble [200 × 5]> <tibble [0 × 1]>
## 8 <split [178/20]> Fold08 <tibble [200 × 5]> <tibble [0 × 1]>
## 9 <split [179/19]> Fold09 <tibble [200 × 5]> <tibble [0 × 1]>
## 10 <split [179/19]> Fold10 <tibble [200 × 5]> <tibble [0 × 1]>
Display the each penalty on rmse and rsq, respectively.
tune_res %>%
collect_metrics()## # A tibble: 200 x 7
## penalty .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 rmse standard 360. 10 35.8 Preprocessor1_Model001
## 2 1 rsq standard 0.415 10 0.0533 Preprocessor1_Model001
## 3 1.12 rmse standard 360. 10 35.8 Preprocessor1_Model002
## 4 1.12 rsq standard 0.415 10 0.0533 Preprocessor1_Model002
## 5 1.26 rmse standard 360. 10 35.8 Preprocessor1_Model003
## 6 1.26 rsq standard 0.415 10 0.0533 Preprocessor1_Model003
## 7 1.42 rmse standard 360. 10 35.8 Preprocessor1_Model004
## 8 1.42 rsq standard 0.415 10 0.0533 Preprocessor1_Model004
## 9 1.59 rmse standard 360. 10 35.8 Preprocessor1_Model005
## 10 1.59 rsq standard 0.415 10 0.0533 Preprocessor1_Model005
## # … with 190 more rows
Display the best five rmse value of penalty
tune_res %>%
show_best(metric = "rmse")## # A tibble: 5 x 7
## penalty .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 335. rmse standard 358. 10 37.9 Preprocessor1_Model051
## 2 298. rmse standard 358. 10 37.9 Preprocessor1_Model050
## 3 266. rmse standard 358. 10 37.9 Preprocessor1_Model049
## 4 376. rmse standard 358. 10 38.0 Preprocessor1_Model052
## 5 236. rmse standard 358. 10 37.8 Preprocessor1_Model048
We can see that if the amount of regularization is close to 1000, the rmse is low and the rsq is high, on average. Thus, the best hyperparameter of this model should be here.
tune_res %>%
autoplot()The best rmse of the ridge regression model is 335.1603.
best_rmse <- select_best(tune_res, metric = "rmse")
best_rmse## # A tibble: 1 x 2
## penalty .config
## <dbl> <chr>
## 1 335. Preprocessor1_Model051
ridge_final <- finalize_workflow(ridge_wf, best_rmse)ridge_final_fit <- fit(ridge_final, data = Hitters_train)
ridge_fit0 %>% summary()## Length Class Mode
## pre 2 stage_pre list
## fit 2 stage_fit list
## post 1 stage_post list
## trained 1 -none- logical
ridge_final_fit %>% tidy()## # A tibble: 20 x 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 569. 335.
## 2 AtBat 7.19 335.
## 3 Hits 34.5 335.
## 4 HmRun -3.63 335.
## 5 Runs 25.1 335.
## 6 RBI 19.2 335.
## 7 Walks 47.0 335.
## 8 Years 3.73 335.
## 9 CAtBat 25.9 335.
## 10 CHits 39.4 335.
## 11 CHmRun 35.2 335.
## 12 CRuns 40.3 335.
## 13 CRBI 42.7 335.
## 14 CWalks 11.6 335.
## 15 PutOuts 43.0 335.
## 16 Assists 7.03 335.
## 17 Errors -4.81 335.
## 18 League_N 8.49 335.
## 19 Division_W -34.7 335.
## 20 NewLeague_N 1.02 335.
When we tuned the hyperparameter in the model, the rmse does not show a obvious reduction when compared to the penalty = 0’s rmse: 228. This means that tuning the hyperparameter will not work in this model.
augment(ridge_final_fit, new_data = Hitters_test) %>%
rmse(truth = Salary, estimate = .pred)## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 239.
Because the data points are still separated, this method does not provide much benefit for weight loss.
augment(ridge_final_fit, new_data = Hitters_test) %>%
ggplot(aes(Salary, .pred)) +
geom_abline(slope = 1, intercept = 0) +
geom_point() +
theme_bw()Compare \(\lambda\) = 0 to \(\lambda\) = 422.9243’s residual square.
augment(ridge_final_fit, new_data = Hitters_test) %>%
mutate(RS = (Salary-.pred)^2) %>%
select(RS) ## # A tibble: 65 x 1
## RS
## <dbl>
## 1 26254.
## 2 116410.
## 3 13400.
## 4 36214.
## 5 2656.
## 6 131.
## 7 21090.
## 8 258741.
## 9 6745.
## 10 5451.
## # … with 55 more rows
augment(ridge_fit0, new_data = Hitters_test) %>%
mutate(RS = (Salary-.pred)^2) %>%
select(RS) ## # A tibble: 65 x 1
## RS
## <dbl>
## 1 50222.
## 2 170727.
## 3 2234.
## 4 21880.
## 5 595.
## 6 4028.
## 7 16033.
## 8 273343.
## 9 10518.
## 10 53933.
## # … with 55 more rows