library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom 1.0.6 ✔ recipes 1.0.10
## ✔ dials 1.2.1 ✔ rsample 1.2.1
## ✔ dplyr 1.1.4 ✔ tibble 3.2.1
## ✔ ggplot2 3.5.1 ✔ tidyr 1.3.1
## ✔ infer 1.0.7 ✔ tune 1.2.1
## ✔ modeldata 1.3.0 ✔ workflows 1.1.4
## ✔ parsnip 1.2.1 ✔ workflowsets 1.1.0
## ✔ purrr 1.0.2 ✔ yardstick 1.3.1
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ✖ recipes::step() masks stats::step()
## • Learn how to get started at https://www.tidymodels.org/start/
library(ISLR)
Hitters <- as_tibble(Hitters) %>%
filter(!is.na(Salary))
ridge_spec <- linear_reg(mixture = 0, penalty = 0) %>%
set_mode("regression") %>%
set_engine("glmnet")
ridge_fit <- fit(ridge_spec, Salary ~ ., data = Hitters)
tidy(ridge_fit)
##
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
##
## expand, pack, unpack
## Loaded glmnet 4.1-8
## # A tibble: 20 × 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 81.1 0
## 2 AtBat -0.682 0
## 3 Hits 2.77 0
## 4 HmRun -1.37 0
## 5 Runs 1.01 0
## 6 RBI 0.713 0
## 7 Walks 3.38 0
## 8 Years -9.07 0
## 9 CAtBat -0.00120 0
## 10 CHits 0.136 0
## 11 CHmRun 0.698 0
## 12 CRuns 0.296 0
## 13 CRBI 0.257 0
## 14 CWalks -0.279 0
## 15 LeagueN 53.2 0
## 16 DivisionW -123. 0
## 17 PutOuts 0.264 0
## 18 Assists 0.170 0
## 19 Errors -3.69 0
## 20 NewLeagueN -18.1 0
tidy(ridge_fit, penalty = 11498)
## # A tibble: 20 × 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 407. 11498
## 2 AtBat 0.0370 11498
## 3 Hits 0.138 11498
## 4 HmRun 0.525 11498
## 5 Runs 0.231 11498
## 6 RBI 0.240 11498
## 7 Walks 0.290 11498
## 8 Years 1.11 11498
## 9 CAtBat 0.00314 11498
## 10 CHits 0.0117 11498
## 11 CHmRun 0.0876 11498
## 12 CRuns 0.0234 11498
## 13 CRBI 0.0242 11498
## 14 CWalks 0.0250 11498
## 15 LeagueN 0.0866 11498
## 16 DivisionW -6.23 11498
## 17 PutOuts 0.0165 11498
## 18 Assists 0.00262 11498
## 19 Errors -0.0206 11498
## 20 NewLeagueN 0.303 11498
tidy(ridge_fit, penalty = 705)
## # A tibble: 20 × 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 54.4 705
## 2 AtBat 0.112 705
## 3 Hits 0.656 705
## 4 HmRun 1.18 705
## 5 Runs 0.937 705
## 6 RBI 0.847 705
## 7 Walks 1.32 705
## 8 Years 2.58 705
## 9 CAtBat 0.0108 705
## 10 CHits 0.0468 705
## 11 CHmRun 0.338 705
## 12 CRuns 0.0937 705
## 13 CRBI 0.0979 705
## 14 CWalks 0.0718 705
## 15 LeagueN 13.7 705
## 16 DivisionW -54.7 705
## 17 PutOuts 0.119 705
## 18 Assists 0.0161 705
## 19 Errors -0.704 705
## 20 NewLeagueN 8.61 705
tidy(ridge_fit, penalty = 50)
## # A tibble: 20 × 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 48.2 50
## 2 AtBat -0.354 50
## 3 Hits 1.95 50
## 4 HmRun -1.29 50
## 5 Runs 1.16 50
## 6 RBI 0.809 50
## 7 Walks 2.71 50
## 8 Years -6.20 50
## 9 CAtBat 0.00609 50
## 10 CHits 0.107 50
## 11 CHmRun 0.629 50
## 12 CRuns 0.217 50
## 13 CRBI 0.215 50
## 14 CWalks -0.149 50
## 15 LeagueN 45.9 50
## 16 DivisionW -118. 50
## 17 PutOuts 0.250 50
## 18 Assists 0.121 50
## 19 Errors -3.28 50
## 20 NewLeagueN -9.42 50
ridge_fit %>%
autoplot()

predict(ridge_fit, new_data = Hitters)
## # A tibble: 263 × 1
## .pred
## <dbl>
## 1 442.
## 2 676.
## 3 1059.
## 4 521.
## 5 543.
## 6 218.
## 7 74.7
## 8 96.1
## 9 809.
## 10 865.
## # ℹ 253 more rows
predict(ridge_fit, new_data = Hitters, penalty = 500)
## # A tibble: 263 × 1
## .pred
## <dbl>
## 1 525.
## 2 620.
## 3 895.
## 4 425.
## 5 589.
## 6 179.
## 7 147.
## 8 187.
## 9 841.
## 10 840.
## # ℹ 253 more rows
Hitters_split <- initial_split(Hitters, strata = "Salary")
Hitters_train <- training(Hitters_split)
Hitters_test <- testing(Hitters_split)
Hitters_fold <- vfold_cv(Hitters_train, v = 10)
ridge_recipe <-
recipe(formula = Salary ~ ., data = Hitters_train) %>%
step_novel(all_nominal_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
ridge_spec <-
linear_reg(penalty = tune(), mixture = 0) %>%
set_mode("regression") %>%
set_engine("glmnet")
ridge_workflow <- workflow() %>%
add_recipe(ridge_recipe) %>%
add_model(ridge_spec)
penalty_grid <- grid_regular(penalty(range = c(-5, 5)), levels = 50)
penalty_grid
## # A tibble: 50 × 1
## penalty
## <dbl>
## 1 0.00001
## 2 0.0000160
## 3 0.0000256
## 4 0.0000409
## 5 0.0000655
## 6 0.000105
## 7 0.000168
## 8 0.000268
## 9 0.000429
## 10 0.000687
## # ℹ 40 more rows
tune_res <- tune_grid(
ridge_workflow,
resamples = Hitters_fold,
grid = penalty_grid
)
tune_res
## # Tuning results
## # 10-fold cross-validation
## # A tibble: 10 × 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [176/20]> Fold01 <tibble [100 × 5]> <tibble [0 × 3]>
## 2 <split [176/20]> Fold02 <tibble [100 × 5]> <tibble [0 × 3]>
## 3 <split [176/20]> Fold03 <tibble [100 × 5]> <tibble [0 × 3]>
## 4 <split [176/20]> Fold04 <tibble [100 × 5]> <tibble [0 × 3]>
## 5 <split [176/20]> Fold05 <tibble [100 × 5]> <tibble [0 × 3]>
## 6 <split [176/20]> Fold06 <tibble [100 × 5]> <tibble [0 × 3]>
## 7 <split [177/19]> Fold07 <tibble [100 × 5]> <tibble [0 × 3]>
## 8 <split [177/19]> Fold08 <tibble [100 × 5]> <tibble [0 × 3]>
## 9 <split [177/19]> Fold09 <tibble [100 × 5]> <tibble [0 × 3]>
## 10 <split [177/19]> Fold10 <tibble [100 × 5]> <tibble [0 × 3]>
autoplot(tune_res)

collect_metrics(tune_res)
## # A tibble: 100 × 7
## penalty .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.00001 rmse standard 317. 10 40.7 Preprocessor1_Model01
## 2 0.00001 rsq standard 0.498 10 0.0735 Preprocessor1_Model01
## 3 0.0000160 rmse standard 317. 10 40.7 Preprocessor1_Model02
## 4 0.0000160 rsq standard 0.498 10 0.0735 Preprocessor1_Model02
## 5 0.0000256 rmse standard 317. 10 40.7 Preprocessor1_Model03
## 6 0.0000256 rsq standard 0.498 10 0.0735 Preprocessor1_Model03
## 7 0.0000409 rmse standard 317. 10 40.7 Preprocessor1_Model04
## 8 0.0000409 rsq standard 0.498 10 0.0735 Preprocessor1_Model04
## 9 0.0000655 rmse standard 317. 10 40.7 Preprocessor1_Model05
## 10 0.0000655 rsq standard 0.498 10 0.0735 Preprocessor1_Model05
## # ℹ 90 more rows
best_penalty <- select_best(tune_res, metric = "rsq")
best_penalty
## # A tibble: 1 × 2
## penalty .config
## <dbl> <chr>
## 1 15264. Preprocessor1_Model46
ridge_final <- finalize_workflow(ridge_workflow, best_penalty)
ridge_final_fit <- fit(ridge_final, data = Hitters_train)
augment(ridge_final_fit, new_data = Hitters_test) %>%
rsq(truth = Salary, estimate = .pred)
## # A tibble: 1 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rsq standard 0.420
lasso_recipe <-
recipe(formula = Salary ~ ., data = Hitters_train) %>%
step_novel(all_nominal_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
lasso_spec <-
linear_reg(penalty = tune(), mixture = 1) %>%
set_mode("regression") %>%
set_engine("glmnet")
lasso_workflow <- workflow() %>%
add_recipe(lasso_recipe) %>%
add_model(lasso_spec)
penalty_grid <- grid_regular(penalty(range = c(-2, 2)), levels = 50)
tune_res <- tune_grid(
lasso_workflow,
resamples = Hitters_fold,
grid = penalty_grid
)
autoplot(tune_res)

best_penalty <- select_best(tune_res, metric = "rsq")
lasso_final <- finalize_workflow(lasso_workflow, best_penalty)
lasso_final_fit <- fit(lasso_final, data = Hitters_train)
augment(lasso_final_fit, new_data = Hitters_test) %>%
rsq(truth = Salary, estimate = .pred)
## # A tibble: 1 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rsq standard 0.467