We will be using the Hitters data set from the ISLR package. We wish to predict the baseball players Salary based on several different characteristics which are included in the data set. Since we wish to predict Salary, then we need to remove any missing data from that column. Otherwise, we won’t be able to run the models.
library(tidymodels)
library(ISLR)
library(ggcorrplot)
theme_set(theme_bw())
hitters <- tibble(Hitters)
glimpse(hitters)
Rows: 322
Columns: 20
$ AtBat <int> 293, 315, 479, 496, 321, 594, 185, 298, 323, 401, 574, 202, ~
$ Hits <int> 66, 81, 130, 141, 87, 169, 37, 73, 81, 92, 159, 53, 113, 60,~
$ HmRun <int> 1, 7, 18, 20, 10, 4, 1, 0, 6, 17, 21, 4, 13, 0, 7, 3, 20, 2,~
$ Runs <int> 30, 24, 66, 65, 39, 74, 23, 24, 26, 49, 107, 31, 48, 30, 29,~
$ RBI <int> 29, 38, 72, 78, 42, 51, 8, 24, 32, 66, 75, 26, 61, 11, 27, 1~
$ Walks <int> 14, 39, 76, 37, 30, 35, 21, 7, 8, 65, 59, 27, 47, 22, 30, 11~
$ Years <int> 1, 14, 3, 11, 2, 11, 2, 3, 2, 13, 10, 9, 4, 6, 13, 3, 15, 5,~
$ CAtBat <int> 293, 3449, 1624, 5628, 396, 4408, 214, 509, 341, 5206, 4631,~
$ CHits <int> 66, 835, 457, 1575, 101, 1133, 42, 108, 86, 1332, 1300, 467,~
$ CHmRun <int> 1, 69, 63, 225, 12, 19, 1, 0, 6, 253, 90, 15, 41, 4, 36, 3, ~
$ CRuns <int> 30, 321, 224, 828, 48, 501, 30, 41, 32, 784, 702, 192, 205, ~
$ CRBI <int> 29, 414, 266, 838, 46, 336, 9, 37, 34, 890, 504, 186, 204, 1~
$ CWalks <int> 14, 375, 263, 354, 33, 194, 24, 12, 8, 866, 488, 161, 203, 2~
$ League <fct> A, N, A, N, N, A, N, A, N, A, A, N, N, A, N, A, N, A, A, N, ~
$ Division <fct> E, W, W, E, E, W, E, W, W, E, E, W, E, E, E, W, W, W, W, W, ~
$ PutOuts <int> 446, 632, 880, 200, 805, 282, 76, 121, 143, 0, 238, 304, 211~
$ Assists <int> 33, 43, 82, 11, 40, 421, 127, 283, 290, 0, 445, 45, 11, 151,~
$ Errors <int> 20, 10, 14, 3, 4, 25, 7, 9, 19, 0, 22, 11, 7, 6, 8, 0, 10, 1~
$ Salary <dbl> NA, 475.000, 480.000, 500.000, 91.500, 750.000, 70.000, 100.~
$ NewLeague <fct> A, N, A, N, N, A, A, A, N, A, A, N, N, A, N, A, N, A, A, N, ~
sum(is.na(hitters$Salary))
[1] 59
#Remove na
hitters <- hitters %>%
filter(!is.na(Salary))
sum(is.na(hitters$Salary))
[1] 0
hitters %>%
select(is.numeric) %>%
cor %>%
ggcorrplot(hc.order = T, type = "lower", lab = T)
ggplot(hitters, aes(Salary)) +
geom_histogram(fill = "lightgrey", color = 'black')
ggplot(hitters, aes(Salary, color = NewLeague)) +
geom_density()
We will use the glmnet package to perform ridge regression. parsnip does not have a dedicated function to create a ridge regression model specification. You need to use linear_reg() and set mixture = 0 to specify a ridge model. The mixture argument specifies the amount of different types of regularization, mixture = 0 specifies only ridge regularization and mixture = 1 specifies only lasso regularization. Setting mixture to a value between 0 and 1 lets us use both. When using the glmnet engine we also need to set a penalty to be able to fit the model. We will set this value to 0 for now, it is not the best value, but we will look at how to select the best value in a little bit.
ridge_model <- linear_reg(mixture = 0, penalty = 0) %>%
set_engine("glmnet") %>%
set_mode("regression")
ridge_fit <-
fit(
ridge_model,
Salary ~ .,
data = hitters
)
tidy(ridge_fit)
# A tibble: 20 x 3
term estimate penalty
<chr> <dbl> <dbl>
1 (Intercept) 81.1 0
2 AtBat -0.682 0
3 Hits 2.77 0
4 HmRun -1.37 0
5 Runs 1.01 0
6 RBI 0.713 0
7 Walks 3.38 0
8 Years -9.07 0
9 CAtBat -0.00120 0
10 CHits 0.136 0
11 CHmRun 0.698 0
12 CRuns 0.296 0
13 CRBI 0.257 0
14 CWalks -0.279 0
15 LeagueN 53.2 0
16 DivisionW -123. 0
17 PutOuts 0.264 0
18 Assists 0.170 0
19 Errors -3.69 0
20 NewLeagueN -18.1 0
tidy(ridge_fit, penalty = 50) %>% slice(1:10)
# A tibble: 10 x 3
term estimate penalty
<chr> <dbl> <dbl>
1 (Intercept) 48.2 50
2 AtBat -0.354 50
3 Hits 1.95 50
4 HmRun -1.29 50
5 Runs 1.16 50
6 RBI 0.809 50
7 Walks 2.71 50
8 Years -6.20 50
9 CAtBat 0.00609 50
10 CHits 0.107 50
tidy(ridge_fit, penalty = 705) %>% slice(1:10)
# A tibble: 10 x 3
term estimate penalty
<chr> <dbl> <dbl>
1 (Intercept) 54.4 705
2 AtBat 0.112 705
3 Hits 0.656 705
4 HmRun 1.18 705
5 Runs 0.937 705
6 RBI 0.847 705
7 Walks 1.32 705
8 Years 2.58 705
9 CAtBat 0.0108 705
10 CHits 0.0468 705
ridge_fit %>%
extract_fit_engine() %>%
plot(xvar = "lambda")
set.seed(2021)
hitters_split <- initial_split(hitters, prop = .75)
hitters_train <- training(hitters_split)
hitters_test <- testing(hitters_split)
set.seed(2021)
hitters_cv <- vfold_cv(hitters_train, v = 10)
hitters_rec <-
recipe(Salary ~ ., data = hitters) %>%
step_novel(all_nominal_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_numeric_predictors()) %>%
step_normalize(all_numeric_predictors())
tuned_ridge_model <-
linear_reg(mixture = 0, penalty = tune()) %>%
set_engine("glmnet") %>%
set_mode("regression")
ridge_wf <-
workflow() %>%
add_recipe(hitters_rec) %>%
add_model(tuned_ridge_model)
ridge_wf %>% parameters()
Collection of 1 parameters for tuning
identifier type object
penalty penalty nparam[+]
ridge_wf %>% pull_dials_object("penalty")
Amount of Regularization (quantitative)
Transformer: log-10
Range (transformed scale): [-10, 0]
penalty_grid <- grid_regular(penalty(range = c(-5, 5)), levels = 50)
fit_tuned_ridge <-
tune_grid(
ridge_wf,
resamples = hitters_cv,
grid = penalty_grid,
control = control_grid(save_pred = T)
)
fit_tuned_ridge %>% collect_metrics()
# A tibble: 100 x 7
penalty .metric .estimator mean n std_err .config
<dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 0.00001 rmse standard 340. 10 31.3 Preprocessor1_Model01
2 0.00001 rsq standard 0.460 10 0.0621 Preprocessor1_Model01
3 0.0000160 rmse standard 340. 10 31.3 Preprocessor1_Model02
4 0.0000160 rsq standard 0.460 10 0.0621 Preprocessor1_Model02
5 0.0000256 rmse standard 340. 10 31.3 Preprocessor1_Model03
6 0.0000256 rsq standard 0.460 10 0.0621 Preprocessor1_Model03
7 0.0000409 rmse standard 340. 10 31.3 Preprocessor1_Model04
8 0.0000409 rsq standard 0.460 10 0.0621 Preprocessor1_Model04
9 0.0000655 rmse standard 340. 10 31.3 Preprocessor1_Model05
10 0.0000655 rsq standard 0.460 10 0.0621 Preprocessor1_Model05
# ... with 90 more rows
autoplot(fit_tuned_ridge)
best_ridge <- fit_tuned_ridge %>% select_best(metric = "rmse")
Here we see that the amount of regularization affects the performance metrics differently. Note how there are areas where the amount of regularization doesn’t have any meaningful influence on the coefficient estimates.
final_ridge_wf <-
ridge_wf %>%
finalize_workflow(best_ridge)
ridge_lastfit <- last_fit(
final_ridge_wf,
split = hitters_split
)
ridge_lastfit %>% collect_predictions() %>%
ggplot(aes(Salary, .pred)) +
geom_point() +
geom_smooth()
ridge_lastfit %>% collect_metrics()
# A tibble: 2 x 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 rmse standard 333. Preprocessor1_Model1
2 rsq standard 0.453 Preprocessor1_Model1
The mixture argument specifies the amount of different types of regularization, mixture = 0 specifies only ridge regularization and mixture = 1 specifies only lasso regularization. Setting mixture to a value between 0 and 1 lets us use both.
tuned_lasso_model <-
linear_reg(mixture = 1, penalty = tune()) %>%
set_engine("glmnet") %>%
set_mode("regression")
lasso_wf <-
workflow() %>%
add_recipe(hitters_rec) %>%
add_model(tuned_lasso_model)
fit_tuned_lasso <-
tune_grid(
lasso_wf,
resamples = hitters_cv,
grid = penalty_grid,
metrics = NULL,
control = control_grid(save_pred = T)
)
fit_tuned_lasso %>% collect_metrics()
# A tibble: 100 x 7
penalty .metric .estimator mean n std_err .config
<dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 0.00001 rmse standard 345. 10 33.5 Preprocessor1_Model01
2 0.00001 rsq standard 0.448 10 0.0640 Preprocessor1_Model01
3 0.0000160 rmse standard 345. 10 33.5 Preprocessor1_Model02
4 0.0000160 rsq standard 0.448 10 0.0640 Preprocessor1_Model02
5 0.0000256 rmse standard 345. 10 33.5 Preprocessor1_Model03
6 0.0000256 rsq standard 0.448 10 0.0640 Preprocessor1_Model03
7 0.0000409 rmse standard 345. 10 33.5 Preprocessor1_Model04
8 0.0000409 rsq standard 0.448 10 0.0640 Preprocessor1_Model04
9 0.0000655 rmse standard 345. 10 33.5 Preprocessor1_Model05
10 0.0000655 rsq standard 0.448 10 0.0640 Preprocessor1_Model05
# ... with 90 more rows
autoplot(fit_tuned_lasso)
best_lasso <- fit_tuned_lasso %>% select_best(metric = "rmse")
tuned_elastic_model <-
linear_reg(mixture = tune(), penalty = tune()) %>%
set_engine("glmnet") %>%
set_mode("regression")
elastic_wf <-
workflow() %>%
add_recipe(hitters_rec) %>%
add_model(tuned_elastic_model)
elastic_wf %>% parameters()
Collection of 2 parameters for tuning
identifier type object
penalty penalty nparam[+]
mixture mixture nparam[+]
elastic_wf %>% pull_dials_object("penalty")
Amount of Regularization (quantitative)
Transformer: log-10
Range (transformed scale): [-10, 0]
elastic_wf %>% pull_dials_object("mixture")
Proportion of lasso Penalty (quantitative)
Range: [0.05, 1]
elastic_grid <- crossing(mixture = seq(0,1, by = .2),
penalty = -5:400)
fit_tuned_elastic <-
tune_grid(
elastic_wf,
resamples = hitters_cv,
grid = elastic_grid,
control = control_grid(save_pred = T)
)
fit_tuned_elastic %>% collect_metrics() %>%View
autoplot(fit_tuned_elastic)
fit_tuned_elastic %>% show_best(metric = "rmse")
# A tibble: 5 x 8
penalty mixture .metric .estimator mean n std_err .config
<int> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 102 0.2 rmse standard 336. 10 29.6 Preprocessor1_Model0514
2 103 0.2 rmse standard 336. 10 29.6 Preprocessor1_Model0515
3 101 0.2 rmse standard 336. 10 29.6 Preprocessor1_Model0513
4 104 0.2 rmse standard 336. 10 29.6 Preprocessor1_Model0516
5 100 0.2 rmse standard 336. 10 29.6 Preprocessor1_Model0512
fit_tuned_elastic %>% show_best(metric = "rsq")
# A tibble: 5 x 8
penalty mixture .metric .estimator mean n std_err .config
<int> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 5 0.2 rsq standard 0.468 10 0.0642 Preprocessor1_Model0417
2 6 0.2 rsq standard 0.467 10 0.0641 Preprocessor1_Model0418
3 107 0.2 rsq standard 0.467 10 0.0561 Preprocessor1_Model0519
4 106 0.2 rsq standard 0.467 10 0.0561 Preprocessor1_Model0518
5 4 0.4 rsq standard 0.467 10 0.0641 Preprocessor1_Model0822
best_elastic <- fit_tuned_elastic %>% select_best(metric = "rmse")
I will treat principal component regression as a linear model with PCA transformations in the preprocessing. But using the tidymodels framework then this is still mostly one model.
lm_model <-
linear_reg() %>%
set_mode("regression") %>%
set_engine("lm")
pca_recipe <-
recipe(formula = Salary ~ ., data = hitters_train) %>%
step_novel(all_nominal_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors()) %>%
step_pca(all_predictors(), threshold = tune(), num_comp = tune())
pca_wf <-
workflow() %>%
add_recipe(pca_recipe) %>%
add_model(lm_model)
pca_wf %>% parameters()
Collection of 2 parameters for tuning
identifier type object
num_comp num_comp nparam[+]
threshold threshold nparam[+]
pca_wf %>% pull_dials_object("num_comp")
# Components (quantitative)
Range: [1, 4]
pca_wf %>% pull_dials_object("threshold")
Threshold (quantitative)
Range: [0, 1]
pca_grid <- crossing(num_comp = 0:10,
threshold = seq(0,1, by=.1)
)
tune_pca <-
tune_grid(
pca_wf,
resamples = hitters_cv,
grid = pca_grid,
control = control_grid(save_pred = T)
)
tune_pca %>% collect_metrics()
# A tibble: 242 x 8
num_comp threshold .metric .estimator mean n std_err .config
<int> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 0 0 rmse standard 346. 10 33.6 Preprocessor001_~
2 0 0 rsq standard 0.446 10 0.0639 Preprocessor001_~
3 0 0.1 rmse standard 346. 10 33.6 Preprocessor002_~
4 0 0.1 rsq standard 0.446 10 0.0639 Preprocessor002_~
5 0 0.2 rmse standard 346. 10 33.6 Preprocessor003_~
6 0 0.2 rsq standard 0.446 10 0.0639 Preprocessor003_~
7 0 0.3 rmse standard 346. 10 33.6 Preprocessor004_~
8 0 0.3 rsq standard 0.446 10 0.0639 Preprocessor004_~
9 0 0.4 rmse standard 346. 10 33.6 Preprocessor005_~
10 0 0.4 rsq standard 0.446 10 0.0639 Preprocessor005_~
# ... with 232 more rows
autoplot(tune_pca) + scale_x_continuous(breaks = 0:10)
tune_pca %>% show_best(metric = "rmse")
# A tibble: 5 x 8
num_comp threshold .metric .estimator mean n std_err .config
<int> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 1 0.8 rmse standard 337. 10 29.8 Preprocessor020_Mod~
2 2 0.8 rmse standard 337. 10 29.8 Preprocessor031_Mod~
3 3 0.8 rmse standard 337. 10 29.8 Preprocessor042_Mod~
4 4 0.8 rmse standard 337. 10 29.8 Preprocessor053_Mod~
5 5 0.8 rmse standard 337. 10 29.8 Preprocessor064_Mod~
tune_pca %>% show_best(metric = "rsq")
# A tibble: 5 x 8
num_comp threshold .metric .estimator mean n std_err .config
<int> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 1 0.9 rsq standard 0.460 10 0.0541 Preprocessor021_Mod~
2 2 0.9 rsq standard 0.460 10 0.0541 Preprocessor032_Mod~
3 3 0.9 rsq standard 0.460 10 0.0541 Preprocessor043_Mod~
4 4 0.9 rsq standard 0.460 10 0.0541 Preprocessor054_Mod~
5 5 0.9 rsq standard 0.460 10 0.0541 Preprocessor065_Mod~
(best_pca_rmse <- tune_pca %>% select_best(metric = "rmse"))
# A tibble: 1 x 3
num_comp threshold .config
<int> <dbl> <chr>
1 1 0.8 Preprocessor020_Model1
(best_pca_rsq <- tune_pca %>% select_best(metric = "rsq"))
# A tibble: 1 x 3
num_comp threshold .config
<int> <dbl> <chr>
1 1 0.9 Preprocessor021_Model1
final_pca_wf <-
pca_wf %>%
finalize_workflow(best_pca_rmse)
final_pca_fit <-
last_fit(
final_pca_wf,
split = hitters_split
)
final_pca_fit %>% collect_metrics()
# A tibble: 2 x 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 rmse standard 342. Preprocessor1_Model1
2 rsq standard 0.432 Preprocessor1_Model1
An Introduction to Statistcial Learning
– END
sessionInfo()
R version 4.1.0 (2021-05-18)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19043)
Matrix products: default
locale:
[1] LC_COLLATE=Spanish_Mexico.1252 LC_CTYPE=Spanish_Mexico.1252
[3] LC_MONETARY=Spanish_Mexico.1252 LC_NUMERIC=C
[5] LC_TIME=Spanish_Mexico.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] vctrs_0.3.8 rlang_0.4.11 glmnet_4.1-2 Matrix_1.3-3
[5] ggcorrplot_0.1.3 ISLR_1.2 yardstick_0.0.8 workflowsets_0.0.2
[9] workflows_0.2.3 tune_0.1.6 tidyr_1.1.3 tibble_3.1.3
[13] rsample_0.1.0 recipes_0.1.16 purrr_0.3.4 parsnip_0.1.7
[17] modeldata_0.1.1 infer_0.5.4 ggplot2_3.3.5 dplyr_1.0.7
[21] dials_0.0.9 scales_1.1.1 broom_0.7.8 tidymodels_0.1.3
loaded via a namespace (and not attached):
[1] nlme_3.1-152 lubridate_1.7.10 DiceDesign_1.9 tools_4.1.0
[5] backports_1.2.1 bslib_0.2.5.1 utf8_1.2.2 R6_2.5.0
[9] rpart_4.1-15 DBI_1.1.1 mgcv_1.8-35 colorspace_2.0-2
[13] nnet_7.3-16 withr_2.4.2 tidyselect_1.1.1 compiler_4.1.0
[17] cli_3.0.0 labeling_0.4.2 sass_0.4.0 stringr_1.4.0
[21] digest_0.6.27 rmarkdown_2.9 pkgconfig_2.0.3 htmltools_0.5.1.1
[25] parallelly_1.26.1 lhs_1.1.1 highr_0.9 rstudioapi_0.13
[29] shape_1.4.6 jquerylib_0.1.4 generics_0.1.0 farver_2.1.0
[33] jsonlite_1.7.2 magrittr_2.0.1 Rcpp_1.0.7 munsell_0.5.0
[37] fansi_0.5.0 GPfit_1.0-8 lifecycle_1.0.0 furrr_0.2.3
[41] stringi_1.6.2 pROC_1.17.0.1 yaml_2.2.1 MASS_7.3-54
[45] plyr_1.8.6 grid_4.1.0 parallel_4.1.0 listenv_0.8.0
[49] crayon_1.4.1 lattice_0.20-44 splines_4.1.0 knitr_1.33
[53] pillar_1.6.2 reshape2_1.4.4 codetools_0.2-18 glue_1.4.2
[57] evaluate_0.14 foreach_1.5.1 gtable_0.3.0 future_1.21.0
[61] assertthat_0.2.1 xfun_0.24 gower_0.2.2 prodlim_2019.11.13
[65] class_7.3-19 survival_3.2-11 timeDate_3043.102 iterators_1.0.13
[69] hardhat_0.1.6 lava_1.6.9 globals_0.14.0 ellipsis_0.3.2
[73] ipred_0.9-11