library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom        1.0.6      ✔ recipes      1.0.10
## ✔ dials        1.2.1      ✔ rsample      1.2.1 
## ✔ dplyr        1.1.4      ✔ tibble       3.2.1 
## ✔ ggplot2      3.5.1      ✔ tidyr        1.3.1 
## ✔ infer        1.0.7      ✔ tune         1.2.1 
## ✔ modeldata    1.3.0      ✔ workflows    1.1.4 
## ✔ parsnip      1.2.1      ✔ workflowsets 1.1.0 
## ✔ purrr        1.0.2      ✔ yardstick    1.3.1
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter()  masks stats::filter()
## ✖ dplyr::lag()     masks stats::lag()
## ✖ recipes::step()  masks stats::step()
## • Learn how to get started at https://www.tidymodels.org/start/
library(ISLR)

Hitters <- as_tibble(Hitters) %>%
  filter(!is.na(Salary))
ridge_spec <- linear_reg(mixture = 0, penalty = 0) %>%
  set_mode("regression") %>%
  set_engine("glmnet")
ridge_fit <- fit(ridge_spec, Salary ~ ., data = Hitters)
tidy(ridge_fit)
## 
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
## 
##     expand, pack, unpack
## Loaded glmnet 4.1-8
## # A tibble: 20 × 3
##    term          estimate penalty
##    <chr>            <dbl>   <dbl>
##  1 (Intercept)   81.1           0
##  2 AtBat         -0.682         0
##  3 Hits           2.77          0
##  4 HmRun         -1.37          0
##  5 Runs           1.01          0
##  6 RBI            0.713         0
##  7 Walks          3.38          0
##  8 Years         -9.07          0
##  9 CAtBat        -0.00120       0
## 10 CHits          0.136         0
## 11 CHmRun         0.698         0
## 12 CRuns          0.296         0
## 13 CRBI           0.257         0
## 14 CWalks        -0.279         0
## 15 LeagueN       53.2           0
## 16 DivisionW   -123.            0
## 17 PutOuts        0.264         0
## 18 Assists        0.170         0
## 19 Errors        -3.69          0
## 20 NewLeagueN   -18.1           0
tidy(ridge_fit, penalty = 11498)
## # A tibble: 20 × 3
##    term         estimate penalty
##    <chr>           <dbl>   <dbl>
##  1 (Intercept) 407.        11498
##  2 AtBat         0.0370    11498
##  3 Hits          0.138     11498
##  4 HmRun         0.525     11498
##  5 Runs          0.231     11498
##  6 RBI           0.240     11498
##  7 Walks         0.290     11498
##  8 Years         1.11      11498
##  9 CAtBat        0.00314   11498
## 10 CHits         0.0117    11498
## 11 CHmRun        0.0876    11498
## 12 CRuns         0.0234    11498
## 13 CRBI          0.0242    11498
## 14 CWalks        0.0250    11498
## 15 LeagueN       0.0866    11498
## 16 DivisionW    -6.23      11498
## 17 PutOuts       0.0165    11498
## 18 Assists       0.00262   11498
## 19 Errors       -0.0206    11498
## 20 NewLeagueN    0.303     11498
tidy(ridge_fit, penalty = 705)
## # A tibble: 20 × 3
##    term        estimate penalty
##    <chr>          <dbl>   <dbl>
##  1 (Intercept)  54.4        705
##  2 AtBat         0.112      705
##  3 Hits          0.656      705
##  4 HmRun         1.18       705
##  5 Runs          0.937      705
##  6 RBI           0.847      705
##  7 Walks         1.32       705
##  8 Years         2.58       705
##  9 CAtBat        0.0108     705
## 10 CHits         0.0468     705
## 11 CHmRun        0.338      705
## 12 CRuns         0.0937     705
## 13 CRBI          0.0979     705
## 14 CWalks        0.0718     705
## 15 LeagueN      13.7        705
## 16 DivisionW   -54.7        705
## 17 PutOuts       0.119      705
## 18 Assists       0.0161     705
## 19 Errors       -0.704      705
## 20 NewLeagueN    8.61       705
tidy(ridge_fit, penalty = 50)
## # A tibble: 20 × 3
##    term          estimate penalty
##    <chr>            <dbl>   <dbl>
##  1 (Intercept)   48.2          50
##  2 AtBat         -0.354        50
##  3 Hits           1.95         50
##  4 HmRun         -1.29         50
##  5 Runs           1.16         50
##  6 RBI            0.809        50
##  7 Walks          2.71         50
##  8 Years         -6.20         50
##  9 CAtBat         0.00609      50
## 10 CHits          0.107        50
## 11 CHmRun         0.629        50
## 12 CRuns          0.217        50
## 13 CRBI           0.215        50
## 14 CWalks        -0.149        50
## 15 LeagueN       45.9          50
## 16 DivisionW   -118.           50
## 17 PutOuts        0.250        50
## 18 Assists        0.121        50
## 19 Errors        -3.28         50
## 20 NewLeagueN    -9.42         50
ridge_fit %>%
  autoplot()

predict(ridge_fit, new_data = Hitters)
## # A tibble: 263 × 1
##     .pred
##     <dbl>
##  1  442. 
##  2  676. 
##  3 1059. 
##  4  521. 
##  5  543. 
##  6  218. 
##  7   74.7
##  8   96.1
##  9  809. 
## 10  865. 
## # ℹ 253 more rows
predict(ridge_fit, new_data = Hitters, penalty = 500)
## # A tibble: 263 × 1
##    .pred
##    <dbl>
##  1  525.
##  2  620.
##  3  895.
##  4  425.
##  5  589.
##  6  179.
##  7  147.
##  8  187.
##  9  841.
## 10  840.
## # ℹ 253 more rows
Hitters_split <- initial_split(Hitters, strata = "Salary")

Hitters_train <- training(Hitters_split)
Hitters_test <- testing(Hitters_split)

Hitters_fold <- vfold_cv(Hitters_train, v = 10)
ridge_recipe <- 
  recipe(formula = Salary ~ ., data = Hitters_train) %>% 
  step_novel(all_nominal_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())
ridge_spec <- 
  linear_reg(penalty = tune(), mixture = 0) %>% 
  set_mode("regression") %>% 
  set_engine("glmnet")
ridge_workflow <- workflow() %>% 
  add_recipe(ridge_recipe) %>% 
  add_model(ridge_spec)
penalty_grid <- grid_regular(penalty(range = c(-5, 5)), levels = 50)
penalty_grid
## # A tibble: 50 × 1
##      penalty
##        <dbl>
##  1 0.00001  
##  2 0.0000160
##  3 0.0000256
##  4 0.0000409
##  5 0.0000655
##  6 0.000105 
##  7 0.000168 
##  8 0.000268 
##  9 0.000429 
## 10 0.000687 
## # ℹ 40 more rows
tune_res <- tune_grid(
  ridge_workflow,
  resamples = Hitters_fold, 
  grid = penalty_grid
)

tune_res
## # Tuning results
## # 10-fold cross-validation 
## # A tibble: 10 × 4
##    splits           id     .metrics           .notes          
##    <list>           <chr>  <list>             <list>          
##  1 <split [176/20]> Fold01 <tibble [100 × 5]> <tibble [0 × 3]>
##  2 <split [176/20]> Fold02 <tibble [100 × 5]> <tibble [0 × 3]>
##  3 <split [176/20]> Fold03 <tibble [100 × 5]> <tibble [0 × 3]>
##  4 <split [176/20]> Fold04 <tibble [100 × 5]> <tibble [0 × 3]>
##  5 <split [176/20]> Fold05 <tibble [100 × 5]> <tibble [0 × 3]>
##  6 <split [176/20]> Fold06 <tibble [100 × 5]> <tibble [0 × 3]>
##  7 <split [177/19]> Fold07 <tibble [100 × 5]> <tibble [0 × 3]>
##  8 <split [177/19]> Fold08 <tibble [100 × 5]> <tibble [0 × 3]>
##  9 <split [177/19]> Fold09 <tibble [100 × 5]> <tibble [0 × 3]>
## 10 <split [177/19]> Fold10 <tibble [100 × 5]> <tibble [0 × 3]>
autoplot(tune_res)

collect_metrics(tune_res)
## # A tibble: 100 × 7
##      penalty .metric .estimator    mean     n std_err .config              
##        <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>                
##  1 0.00001   rmse    standard   317.       10 40.7    Preprocessor1_Model01
##  2 0.00001   rsq     standard     0.498    10  0.0735 Preprocessor1_Model01
##  3 0.0000160 rmse    standard   317.       10 40.7    Preprocessor1_Model02
##  4 0.0000160 rsq     standard     0.498    10  0.0735 Preprocessor1_Model02
##  5 0.0000256 rmse    standard   317.       10 40.7    Preprocessor1_Model03
##  6 0.0000256 rsq     standard     0.498    10  0.0735 Preprocessor1_Model03
##  7 0.0000409 rmse    standard   317.       10 40.7    Preprocessor1_Model04
##  8 0.0000409 rsq     standard     0.498    10  0.0735 Preprocessor1_Model04
##  9 0.0000655 rmse    standard   317.       10 40.7    Preprocessor1_Model05
## 10 0.0000655 rsq     standard     0.498    10  0.0735 Preprocessor1_Model05
## # ℹ 90 more rows
best_penalty <- select_best(tune_res, metric = "rsq")
best_penalty
## # A tibble: 1 × 2
##   penalty .config              
##     <dbl> <chr>                
## 1  15264. Preprocessor1_Model46
ridge_final <- finalize_workflow(ridge_workflow, best_penalty)

ridge_final_fit <- fit(ridge_final, data = Hitters_train)
augment(ridge_final_fit, new_data = Hitters_test) %>%
  rsq(truth = Salary, estimate = .pred)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rsq     standard       0.420
lasso_recipe <- 
  recipe(formula = Salary ~ ., data = Hitters_train) %>% 
  step_novel(all_nominal_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())
lasso_spec <- 
  linear_reg(penalty = tune(), mixture = 1) %>% 
  set_mode("regression") %>% 
  set_engine("glmnet") 

lasso_workflow <- workflow() %>% 
  add_recipe(lasso_recipe) %>% 
  add_model(lasso_spec)
penalty_grid <- grid_regular(penalty(range = c(-2, 2)), levels = 50)
tune_res <- tune_grid(
  lasso_workflow,
  resamples = Hitters_fold, 
  grid = penalty_grid
)

autoplot(tune_res)

best_penalty <- select_best(tune_res, metric = "rsq")
lasso_final <- finalize_workflow(lasso_workflow, best_penalty)

lasso_final_fit <- fit(lasso_final, data = Hitters_train)
augment(lasso_final_fit, new_data = Hitters_test) %>%
  rsq(truth = Salary, estimate = .pred)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rsq     standard       0.467